/*
 * Decompiled with CFR 0.152.
 */
package edu.iu.dsc.tws.rsched.schedulers.standalone;

import com.google.protobuf.InvalidProtocolBufferException;
import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.api.config.Context;
import edu.iu.dsc.tws.api.config.MPIContext;
import edu.iu.dsc.tws.api.config.SchedulerContext;
import edu.iu.dsc.tws.api.driver.IScalerPerCluster;
import edu.iu.dsc.tws.api.driver.NullScaler;
import edu.iu.dsc.tws.api.exceptions.JobFaultyException;
import edu.iu.dsc.tws.api.exceptions.Twister2Exception;
import edu.iu.dsc.tws.api.exceptions.Twister2RuntimeException;
import edu.iu.dsc.tws.api.faulttolerance.FaultToleranceContext;
import edu.iu.dsc.tws.api.resource.FSPersistentVolume;
import edu.iu.dsc.tws.api.resource.IPersistentVolume;
import edu.iu.dsc.tws.api.resource.IWorker;
import edu.iu.dsc.tws.api.resource.IWorkerController;
import edu.iu.dsc.tws.api.resource.IWorkerStatusUpdater;
import edu.iu.dsc.tws.common.config.ConfigLoader;
import edu.iu.dsc.tws.common.logging.LoggingHelper;
import edu.iu.dsc.tws.common.util.NetworkUtils;
import edu.iu.dsc.tws.master.IJobTerminator;
import edu.iu.dsc.tws.master.JobMasterContext;
import edu.iu.dsc.tws.master.server.JobMaster;
import edu.iu.dsc.tws.proto.jobmaster.JobMasterAPI;
import edu.iu.dsc.tws.proto.system.job.JobAPI;
import edu.iu.dsc.tws.proto.utils.NodeInfoUtils;
import edu.iu.dsc.tws.proto.utils.WorkerInfoUtils;
import edu.iu.dsc.tws.rsched.core.WorkerRuntime;
import edu.iu.dsc.tws.rsched.schedulers.NullTerminator;
import edu.iu.dsc.tws.rsched.schedulers.nomad.NomadContext;
import edu.iu.dsc.tws.rsched.schedulers.standalone.MPIWorkerController;
import edu.iu.dsc.tws.rsched.utils.JobUtils;
import edu.iu.dsc.tws.rsched.utils.ResourceSchedulerUtils;
import edu.iu.dsc.tws.rsched.worker.MPIWorkerManager;
import java.io.File;
import java.io.IOException;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.nio.IntBuffer;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Level;
import java.util.logging.Logger;
import mpi.Intracomm;
import mpi.MPI;
import mpi.MPIException;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;

public final class MPIWorkerStarter {
    private static final Logger LOG = Logger.getLogger(MPIWorkerStarter.class.getName());
    private Config config;
    private JobMasterAPI.WorkerInfo wInfo;
    private JobAPI.Job job;
    private int globalRank;
    private JobMaster jobMaster;
    private int restartCount;

    public static void main(String[] args) {
        new MPIWorkerStarter(args);
    }

    private MPIWorkerStarter(String[] args) {
        CommandLine cmd;
        try {
            MPI.InitThread((String[])args, (int)MPI.THREAD_MULTIPLE);
            this.globalRank = MPI.COMM_WORLD.getRank();
        }
        catch (MPIException e) {
            LOG.log(Level.SEVERE, "Failed to initialize MPI process", e);
            throw new Twister2RuntimeException("Failed to initialize MPI process", (Throwable)e);
        }
        this.setUncaughtExceptionHandler();
        Options cmdOptions = this.setupOptions();
        DefaultParser parser = new DefaultParser();
        try {
            cmd = parser.parse(cmdOptions, args);
        }
        catch (ParseException e) {
            HelpFormatter formatter = new HelpFormatter();
            formatter.printHelp("MPIWorkerStarter", cmdOptions);
            throw new Twister2RuntimeException("Error parsing command line options: ", (Throwable)e);
        }
        this.config = this.loadConfigurations(cmd);
        LOG.log(Level.FINE, "An MPI worker process is starting with the rank: " + this.globalRank);
        this.setThreadName();
        if (JobMasterContext.isJobMasterUsed((Config)this.config)) {
            this.startWorkerWithJM();
        } else {
            this.startWorkerWithoutJM(this.config, MPI.COMM_WORLD);
        }
        this.finalizeMPI();
    }

    private void setUncaughtExceptionHandler() {
        Thread.setDefaultUncaughtExceptionHandler((thread, throwable) -> {
            LOG.log(Level.SEVERE, "Uncaught exception in thread " + thread + ". Finalizing this worker...", throwable);
            if (!JobMasterContext.isJobMasterUsed((Config)this.config)) {
                System.exit(1);
            }
            if (this.wInfo != null && this.wInfo.getWorkerID() == -1 && this.jobMaster != null) {
                this.jobMaster.jmFailed();
                LOG.severe("!!!!!!!!!!!!!!!!!!!!!!!!!!!!! JM Exiting with failure.");
                System.exit(1);
            }
            if (JobMasterContext.jobMasterRunsInClient((Config)this.config)) {
                this.updateWorkerState(JobMasterAPI.WorkerState.FAILED);
                if (this.restartCount >= FaultToleranceContext.maxMpiJobRestarts((Config)this.config) - 1) {
                    this.sendWorkerFinalStateToJM(JobMasterAPI.WorkerState.FULLY_FAILED);
                }
                System.exit(1);
            }
            if (throwable instanceof JobFaultyException) {
                WorkerRuntime.close();
                LOG.severe("!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Worker Exiting with JobFaultyException failure.");
                System.exit(1);
            } else {
                this.updateWorkerState(JobMasterAPI.WorkerState.FAILED);
                this.sendWorkerFinalStateToJM(JobMasterAPI.WorkerState.FULLY_FAILED);
                LOG.severe("!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Worker Exiting with failure.");
                System.exit(1);
            }
        });
    }

    private void setThreadName() {
        if (JobMasterContext.isJobMasterUsed((Config)this.config) && !JobMasterContext.jobMasterRunsInClient((Config)this.config)) {
            if (this.globalRank == 0) {
                Thread.currentThread().setName("Twister2MPIWorker-JM");
            } else {
                Thread.currentThread().setName("Twister2MPIWorker-" + (this.globalRank - 1));
            }
        } else {
            Thread.currentThread().setName("Twister2MPIWorker-" + this.globalRank);
        }
    }

    private void startWorkerWithoutJM(Config cfg, Intracomm intracomm) {
        this.wInfo = this.createWorkerInfo(this.config, this.globalRank);
        Map<Integer, JobMasterAPI.WorkerInfo> infos = this.createWorkerInfoMap(intracomm);
        MPIWorkerController wc = new MPIWorkerController(this.globalRank, infos, this.restartCount);
        WorkerRuntime.init(cfg, wc);
        this.startWorker(intracomm);
    }

    private void startWorkerWithJM() {
        if (JobMasterContext.jobMasterRunsInClient((Config)this.config)) {
            this.wInfo = this.createWorkerInfo(this.config, this.globalRank);
            WorkerRuntime.init(this.config, this.job, this.wInfo, this.restartCount);
            this.startWorker(MPI.COMM_WORLD);
        } else {
            int splittedRank;
            Intracomm splittedComm;
            int color = this.globalRank == 0 ? 0 : 1;
            try {
                splittedComm = MPI.COMM_WORLD.split(color, this.globalRank);
                splittedRank = splittedComm.getRank();
            }
            catch (MPIException e) {
                throw new Twister2RuntimeException("Can not split MPI.COMM_WORLD", (Throwable)e);
            }
            this.wInfo = this.globalRank == 0 ? this.createWorkerInfo(this.config, -1) : this.createWorkerInfo(this.config, splittedRank);
            this.broadCastMasterInformation(this.globalRank);
            if (this.globalRank == 0) {
                this.startMaster();
            } else {
                WorkerRuntime.init(this.config, this.job, this.wInfo, this.restartCount);
                this.startWorker(splittedComm);
            }
        }
    }

    private void startWorker(Intracomm intracomm) {
        try {
            String twister2Home = Context.twister2Home((Config)this.config);
            this.initLogger(this.config, intracomm.getRank(), twister2Home);
            IWorkerController wc = WorkerRuntime.getWorkerController();
            IPersistentVolume persistentVolume = this.initPersistenceVolume(this.config, this.globalRank);
            MPIContext.addRuntimeObject((String)"comm", (Object)intracomm);
            IWorker worker = JobUtils.initializeIWorker(this.job);
            MPIWorkerManager workerManager = new MPIWorkerManager();
            workerManager.execute(this.config, this.job, wc, persistentVolume, null, worker);
        }
        catch (MPIException e) {
            LOG.log(Level.SEVERE, "Failed to synchronize the workers at the start");
            throw new RuntimeException(e);
        }
    }

    private void startMaster() {
        try {
            int port = JobMasterContext.jobMasterPort((Config)this.config);
            String hostAddress = ResourceSchedulerUtils.getHostIP(this.config);
            LOG.log(Level.INFO, String.format("Starting the job master: %s:%d", hostAddress, port));
            JobMasterAPI.NodeInfo jobMasterNodeInfo = null;
            NullScaler clusterScaler = new NullScaler();
            JobMasterAPI.JobMasterState initialState = JobMasterAPI.JobMasterState.JM_STARTED;
            NullTerminator nt = new NullTerminator();
            this.jobMaster = new JobMaster(this.config, "0.0.0.0", port, (IJobTerminator)nt, this.job, jobMasterNodeInfo, (IScalerPerCluster)clusterScaler, initialState);
            this.jobMaster.startJobMasterBlocking();
            LOG.log(Level.INFO, "JobMaster done... ");
        }
        catch (Twister2Exception e) {
            LOG.log(Level.SEVERE, "Exception when starting Job master: ", e);
            throw new RuntimeException(e);
        }
    }

    private void broadCastMasterInformation(int rank) {
        byte[] workerBytes = this.wInfo.toByteArray();
        int length = workerBytes.length;
        IntBuffer countSend = MPI.newIntBuffer((int)1);
        if (rank == 0) {
            countSend.put(length);
        }
        try {
            MPI.COMM_WORLD.bcast((Object)countSend, 1, MPI.INT, 0);
            length = countSend.get(0);
            ByteBuffer sendBuffer = MPI.newByteBuffer((int)length);
            if (rank == 0) {
                sendBuffer.put(workerBytes);
            }
            MPI.COMM_WORLD.bcast((Object)sendBuffer, length, MPI.BYTE, 0);
            byte[] jmInfoBytes = new byte[length];
            if (rank != 0) {
                sendBuffer.get(jmInfoBytes);
                JobMasterAPI.WorkerInfo masterInfo = ((JobMasterAPI.WorkerInfo.Builder)JobMasterAPI.WorkerInfo.newBuilder().mergeFrom(jmInfoBytes)).build();
                this.config = Config.newBuilder().putAll(this.config).put("twister2.job.master.port", (Object)masterInfo.getPort()).put("twister2.job.master.ip", (Object)masterInfo.getNodeInfo().getNodeIP()).build();
            } else {
                this.config = Config.newBuilder().putAll(this.config).put("twister2.job.master.port", (Object)this.wInfo.getPort()).put("twister2.job.master.ip", (Object)this.wInfo.getNodeInfo().getNodeIP()).build();
            }
        }
        catch (MPIException mpie) {
            throw new Twister2RuntimeException("Error when broadcasting Job Master information", (Throwable)mpie);
        }
        catch (InvalidProtocolBufferException ipbe) {
            throw new Twister2RuntimeException("Error when decoding Job Master information", (Throwable)ipbe);
        }
    }

    private Options setupOptions() {
        Options options = new Options();
        Option configDirectory = Option.builder((String)"d").desc("The class name of the container to launch").longOpt("config_dir").hasArgs().argName("configuration directory").required().build();
        Option twister2Home = Option.builder((String)"t").desc("The class name of the container to launch").longOpt("twister2_home").hasArgs().argName("twister2 home").required().build();
        Option clusterType = Option.builder((String)"n").desc("The clustr type").longOpt("cluster_type").hasArgs().argName("cluster type").required().build();
        Option jobId = Option.builder((String)"j").desc("Job Id").longOpt("job_id").hasArgs().argName("job id").required().build();
        Option jobMasterIP = Option.builder((String)"i").desc("Job master ip").longOpt("job_master_ip").hasArgs().argName("job master ip").required().build();
        Option jobMasterPort = Option.builder((String)"p").desc("Job master ip").longOpt("job_master_port").hasArgs().argName("job master port").required().build();
        Option restoreJob = Option.builder((String)"r").desc("Whether the job is being restored").longOpt("restore_job").hasArgs().argName("restore job").required().build();
        Option restartCountOption = Option.builder((String)"x").desc("number of time the job is restarted after failure").longOpt("restart_count").hasArgs().argName("restart count").required().build();
        options.addOption(twister2Home);
        options.addOption(configDirectory);
        options.addOption(clusterType);
        options.addOption(jobId);
        options.addOption(jobMasterIP);
        options.addOption(jobMasterPort);
        options.addOption(restoreJob);
        options.addOption(restartCountOption);
        return options;
    }

    private Config loadConfigurations(CommandLine cmd) {
        String twister2Home = cmd.getOptionValue("twister2_home");
        String configDir = cmd.getOptionValue("config_dir");
        String clusterType = cmd.getOptionValue("cluster_type");
        String jobId = cmd.getOptionValue("job_id");
        String jIp = cmd.getOptionValue("job_master_ip");
        int jPort = Integer.parseInt(cmd.getOptionValue("job_master_port"));
        boolean restoreJob = Boolean.parseBoolean(cmd.getOptionValue("restore_job"));
        this.restartCount = Integer.parseInt(cmd.getOptionValue("restart_count"));
        Config cfg = ConfigLoader.loadConfig((String)twister2Home, (String)configDir, (String)clusterType);
        Config workerConfig = Config.newBuilder().putAll(cfg).put(MPIContext.TWISTER2_HOME.getKey(), (Object)twister2Home).put("twister2.container.id", (Object)this.globalRank).put("twister2.cluster.type", (Object)clusterType).build();
        String jobDescFile = JobUtils.getJobDescriptionFilePath(jobId, workerConfig);
        this.job = JobUtils.readJobFile(jobDescFile);
        Config updatedConfig = JobUtils.overrideConfigs(this.job, cfg);
        updatedConfig = Config.newBuilder().putAll(updatedConfig).put(MPIContext.TWISTER2_HOME.getKey(), (Object)twister2Home).put("twister2.resource.job.worker.class", (Object)this.job.getWorkerClassName()).put("twister2.container.id", (Object)this.globalRank).put("twister2.job.id", (Object)jobId).put("twister2.job.object", (Object)this.job).put("twister2.cluster.type", (Object)clusterType).put("twister2.job.master.ip", (Object)jIp).put("twister2.job.master.port", (Object)jPort).put("twister2.resource.zookeeper.server.addresses", null).put("twister2.checkpointing.restore.job", (Object)restoreJob).build();
        LOG.log(Level.FINE, String.format("Initializing process with twister_home: %s worker_class: %s config_dir: %s cluster_type: %s", twister2Home, this.job.getWorkerClassName(), configDir, clusterType));
        return updatedConfig;
    }

    public Map<Integer, JobMasterAPI.WorkerInfo> createWorkerInfoMap(Intracomm intracomm) {
        try {
            byte[] workerBytes = this.wInfo.toByteArray();
            int length = workerBytes.length;
            IntBuffer countSend = MPI.newIntBuffer((int)1);
            int worldSize = intracomm.getSize();
            IntBuffer countReceive = MPI.newIntBuffer((int)worldSize);
            countSend.put(length);
            intracomm.allGather((Object)countSend, 1, MPI.INT, (Object)countReceive, 1, MPI.INT);
            int[] receiveSizes = new int[worldSize];
            int[] displacements = new int[worldSize];
            int sum = 0;
            for (int i = 0; i < worldSize; ++i) {
                receiveSizes[i] = countReceive.get(i);
                displacements[i] = sum;
                sum += receiveSizes[i];
            }
            ByteBuffer sendBuffer = MPI.newByteBuffer((int)length);
            ByteBuffer receiveBuffer = MPI.newByteBuffer((int)sum);
            sendBuffer.put(workerBytes);
            intracomm.allGatherv((Object)sendBuffer, length, MPI.BYTE, (Object)receiveBuffer, receiveSizes, displacements, MPI.BYTE);
            HashMap<Integer, JobMasterAPI.WorkerInfo> workerInfoMap = new HashMap<Integer, JobMasterAPI.WorkerInfo>();
            for (int i = 0; i < receiveSizes.length; ++i) {
                byte[] c = new byte[receiveSizes[i]];
                receiveBuffer.get(c);
                JobMasterAPI.WorkerInfo info = ((JobMasterAPI.WorkerInfo.Builder)JobMasterAPI.WorkerInfo.newBuilder().mergeFrom(c)).build();
                workerInfoMap.put(i, info);
                LOG.log(Level.FINE, String.format("Worker %d info: %s", i, workerInfoMap.get(i)));
            }
            return workerInfoMap;
        }
        catch (MPIException e) {
            throw new RuntimeException("Failed to communicate", e);
        }
        catch (InvalidProtocolBufferException e) {
            throw new RuntimeException("Failed to create worker info", e);
        }
    }

    private JobMasterAPI.WorkerInfo createWorkerInfo(Config cfg, int workerId) {
        String workerIP;
        List networkInterfaces = SchedulerContext.networkInterfaces((Config)cfg);
        if (networkInterfaces == null) {
            try {
                workerIP = InetAddress.getLocalHost().getHostAddress();
            }
            catch (UnknownHostException e) {
                throw new RuntimeException("Failed to get ip address", e);
            }
        } else {
            workerIP = ResourceSchedulerUtils.getLocalIPFromNetworkInterfaces(networkInterfaces);
            if (workerIP == null) {
                throw new RuntimeException("Failed to get ip address from network interfaces: " + networkInterfaces);
            }
        }
        JobMasterAPI.NodeInfo nodeInfo = NodeInfoUtils.createNodeInfo((String)workerIP, (String)"default", (String)"default");
        JobAPI.ComputeResource computeResource = JobUtils.getComputeResource(this.job, workerId);
        ArrayList<String> portNames = SchedulerContext.additionalPorts((Config)cfg);
        HashMap freePorts = new HashMap();
        if (portNames == null) {
            portNames = new ArrayList<String>();
        }
        portNames.add("__worker__");
        Map socketMap = NetworkUtils.findFreePorts(portNames);
        try {
            MPI.COMM_WORLD.barrier();
        }
        catch (MPIException e) {
            throw new Twister2RuntimeException("MPI Barrier failed at initialization stage");
        }
        AtomicBoolean closedSuccessfully = new AtomicBoolean(true);
        socketMap.forEach((k, v) -> {
            freePorts.put(k, v.getLocalPort());
            try {
                v.close();
            }
            catch (IOException e) {
                LOG.log(Level.SEVERE, e, () -> "Couldn't close opened server socket : " + k);
                closedSuccessfully.set(false);
            }
        });
        if (!closedSuccessfully.get()) {
            throw new IllegalStateException("Could not release one or more free TCP/IP ports");
        }
        Integer workerPort = (Integer)freePorts.get("__worker__");
        freePorts.remove("__worker__");
        LOG.fine("Worker info host:" + workerIP + ":" + workerPort);
        return WorkerInfoUtils.createWorkerInfo((int)workerId, (String)workerIP, (int)workerPort, (JobMasterAPI.NodeInfo)nodeInfo, (JobAPI.ComputeResource)computeResource, freePorts);
    }

    private void initLogger(Config cfg, int workerID, String logDirectory) {
        String jobWorkingDirectory = NomadContext.workingDirectory(cfg);
        String jobId = NomadContext.jobId((Config)cfg);
        String persistentJobDir = NomadContext.getLoggingSandbox(cfg) ? Paths.get(jobWorkingDirectory, jobId).toString() : logDirectory;
        if (persistentJobDir == null) {
            return;
        }
        String logDir = persistentJobDir + "/logs/worker-" + workerID;
        File directory = new File(logDir);
        if (!directory.exists() && !directory.mkdirs()) {
            throw new RuntimeException("Failed to create log directory: " + logDir);
        }
        LoggingHelper.setupLogging((Config)cfg, (String)logDir, (String)("worker-" + workerID));
        LOG.fine(String.format("Logging is setup with file %s", logDir));
    }

    private IPersistentVolume initPersistenceVolume(Config cfg, int rank) {
        File baseDir = new File(MPIContext.fileSystemMount((Config)cfg));
        while (!baseDir.exists() && !baseDir.mkdirs()) {
            try {
                Thread.sleep(100L);
            }
            catch (InterruptedException e) {
                throw new RuntimeException("Thread interrupted", e);
            }
        }
        return new FSPersistentVolume(baseDir.getAbsolutePath(), rank);
    }

    public void finalizeMPI() {
        try {
            this.sendWorkerFinalStateToJM(JobMasterAPI.WorkerState.COMPLETED);
            MPI.Finalize();
        }
        catch (MPIException mPIException) {
            // empty catch block
        }
    }

    private void updateWorkerState(JobMasterAPI.WorkerState workerState) {
        if (!JobMasterContext.isJobMasterUsed((Config)this.config)) {
            return;
        }
        if (this.wInfo.getWorkerID() == -1) {
            return;
        }
        IWorkerStatusUpdater workerStatusUpdater = WorkerRuntime.getWorkerStatusUpdater();
        if (workerStatusUpdater != null) {
            workerStatusUpdater.updateWorkerStatus(workerState);
        }
    }

    private void sendWorkerFinalStateToJM(JobMasterAPI.WorkerState workerState) {
        LOG.info(String.format("Worker-%d finished executing with the final status: %s", this.wInfo.getWorkerID(), workerState.name()));
        this.updateWorkerState(workerState);
        WorkerRuntime.close();
    }
}

