/*
 * Decompiled with CFR 0.152.
 */
package edu.iu.dsc.tws.rsched.schedulers.k8s.mpi;

import edu.iu.dsc.tws.api.config.Config;
import edu.iu.dsc.tws.api.config.MPIContext;
import edu.iu.dsc.tws.api.config.SchedulerContext;
import edu.iu.dsc.tws.api.faulttolerance.FaultToleranceContext;
import edu.iu.dsc.tws.common.logging.LoggingHelper;
import edu.iu.dsc.tws.master.JobMasterContext;
import edu.iu.dsc.tws.proto.system.job.JobAPI;
import edu.iu.dsc.tws.rsched.schedulers.k8s.K8sEnvVariables;
import edu.iu.dsc.tws.rsched.schedulers.k8s.KubernetesContext;
import edu.iu.dsc.tws.rsched.schedulers.k8s.KubernetesUtils;
import edu.iu.dsc.tws.rsched.schedulers.k8s.PodWatchUtils;
import edu.iu.dsc.tws.rsched.schedulers.k8s.worker.K8sWorkerUtils;
import edu.iu.dsc.tws.rsched.utils.JobUtils;
import edu.iu.dsc.tws.rsched.utils.ProcessUtils;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStreamWriter;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.logging.Level;
import java.util.logging.Logger;

public final class MPIMasterStarter {
    private static final Logger LOG = Logger.getLogger(MPIMasterStarter.class.getName());
    private static final String HOSTFILE_NAME = "hostfile";
    private static Config config = null;
    private static String jobID = null;

    private MPIMasterStarter() {
    }

    public static void main(String[] args) {
        LoggingHelper.setLoggingFormat((String)"[%1$tF %1$tT] [%4$s] [%7$s] %3$s: %5$s %6$s %n");
        String jobMasterIP = System.getenv(K8sEnvVariables.JOB_MASTER_IP.name());
        String podName = System.getenv(K8sEnvVariables.POD_NAME.name());
        String jvmMemory = System.getenv(K8sEnvVariables.JVM_MEMORY_MB.name());
        jobID = System.getenv(K8sEnvVariables.JOB_ID.name());
        if (jobID == null) {
            throw new RuntimeException("JobID is null");
        }
        String configDir = "/twister2-memory-dir/twister2-job";
        String logPropsFile = configDir + "/" + "common/logger.properties";
        config = K8sWorkerUtils.loadConfig(configDir);
        K8sWorkerUtils.initLogger(config, "mpiMaster");
        String jobDescFileName = SchedulerContext.createJobDescriptionFileName((String)jobID);
        jobDescFileName = "/twister2-memory-dir/twister2-job/" + jobDescFileName;
        JobAPI.Job job = JobUtils.readJobFile(jobDescFileName);
        LOG.info("Job description file is loaded: " + jobDescFileName);
        config = JobUtils.overrideConfigs(job, config);
        config = JobUtils.updateConfigs(job, config);
        String namespace = KubernetesContext.namespace(config);
        int workersPerPod = job.getComputeResource(0).getWorkersPerPod();
        int numberOfPods = KubernetesUtils.numberOfWorkerPods(job);
        InetAddress localHost = null;
        try {
            localHost = InetAddress.getLocalHost();
        }
        catch (UnknownHostException e) {
            LOG.log(Level.SEVERE, "Cannot get localHost.", e);
            throw new RuntimeException("Cannot get localHost.", e);
        }
        String podIP = localHost.getHostAddress();
        LOG.info("MPIMaster information summary: \npodName: " + podName + "\npodIP: " + podIP + "\njobID: " + jobID + "\nnamespace: " + namespace + "\nnumberOfWorkers: " + job.getNumberOfWorkers() + "\nnumberOfPods: " + numberOfPods);
        long start = System.currentTimeMillis();
        int timeoutSeconds = 100;
        if (!JobMasterContext.jobMasterRunsInClient((Config)config)) {
            jobMasterIP = K8sWorkerUtils.getJobMasterServiceIP(KubernetesContext.namespace(config), jobID);
            if (jobMasterIP == null) {
                jobMasterIP = PodWatchUtils.getJobMasterIpByWatchingPodToRunning(namespace, jobID, timeoutSeconds);
            }
            if (jobMasterIP == null) {
                LOG.severe("Could not get job master IP by wathing job master pod to running. Aborting. You need to terminate this job and resubmit it....");
                return;
            }
        }
        LOG.info("Job Master IP address: " + jobMasterIP);
        ArrayList<String> podIPs = PodWatchUtils.getWorkerIPsByWatchingPodsToRunning(namespace, jobID, numberOfPods, timeoutSeconds);
        PodWatchUtils.close();
        if (podIPs == null) {
            LOG.severe("Could not get IPs of all pods running. Aborting. You need to terminate this job and resubmit it....");
            return;
        }
        boolean written = MPIMasterStarter.createHostFile(podIPs, workersPerPod);
        if (!written) {
            LOG.severe("hostfile can not be generated. Aborting. You need to terminate this job and resubmit it....");
            return;
        }
        long duration = System.currentTimeMillis() - start;
        LOG.info("Getting all pods running took: " + duration + " ms.");
        String classToRun = "edu.iu.dsc.tws.rsched.schedulers.k8s.mpi.MPIWorkerStarter";
        String[] mpirunCommand = MPIMasterStarter.generateMPIrunCommand(classToRun, workersPerPod, jobMasterIP, logPropsFile, jvmMemory);
        if (KubernetesContext.checkPwdFreeSsh(config)) {
            start = System.currentTimeMillis();
            podIPs.remove(podIP);
            String[] scriptCommand = MPIMasterStarter.generateCheckSshCommand(podIPs);
            boolean pwdFreeSshOk = MPIMasterStarter.runScript(scriptCommand);
            duration = System.currentTimeMillis() - start;
            LOG.info("Checking password free access took: " + duration + " ms");
            if (!pwdFreeSshOk) {
                LOG.severe("Password free ssh can not be setup among pods. Not executing mpirun ...");
                return;
            }
        }
        MPIMasterStarter.executeMpirun(mpirunCommand);
    }

    public static boolean createHostFile(ArrayList<String> ipList, int workersPerPod) {
        try {
            StringBuffer bufferToLog = new StringBuffer();
            BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(HOSTFILE_NAME)));
            for (String ip : ipList) {
                writer.write(ip + " slots=" + workersPerPod + System.lineSeparator());
                bufferToLog.append(ip + System.lineSeparator());
            }
            writer.flush();
            writer.close();
            LOG.info("File: hostfile is written with the content:\n" + bufferToLog.toString());
            return true;
        }
        catch (Exception e) {
            LOG.log(Level.SEVERE, "Exception when writing the file: hostfile", e);
            return false;
        }
    }

    public static String[] generateMPIrunCommand(String className, int workersPerPod, String jobMasterIP, String logPropsFile, String jvmMemory) {
        String jst = System.getenv(K8sEnvVariables.JOB_SUBMISSION_TIME.name());
        String restore = System.getenv(K8sEnvVariables.RESTORE_JOB.name());
        ArrayList<String> cmdList = new ArrayList<String>();
        String[] mpirunCmd = new String[]{"mpirun", "--hostfile", HOSTFILE_NAME, "--allow-run-as-root", "-npernode", workersPerPod + "", "-tag-output", "-x", "KUBERNETES_SERVICE_HOST=" + System.getenv("KUBERNETES_SERVICE_HOST"), "-x", "KUBERNETES_SERVICE_PORT=" + System.getenv("KUBERNETES_SERVICE_PORT"), "-x", K8sEnvVariables.JOB_ID.name() + "=" + jobID, "-x", K8sEnvVariables.JOB_MASTER_IP.name() + "=" + jobMasterIP, "-x", K8sEnvVariables.JOB_SUBMISSION_TIME.name() + "=" + jst, "-x", K8sEnvVariables.RESTORE_JOB.name() + "=" + restore};
        cmdList.addAll(Arrays.asList(mpirunCmd));
        String mpiParams = MPIContext.mpiParams((Config)config);
        if (mpiParams != null && !mpiParams.trim().isEmpty()) {
            cmdList.addAll(Arrays.asList(mpiParams.split(" ")));
        }
        cmdList.add("java");
        cmdList.add("-Xms" + jvmMemory + "m");
        cmdList.add("-Xmx" + jvmMemory + "m");
        cmdList.add("-Djava.util.logging.config.file=" + logPropsFile);
        cmdList.add("-cp");
        cmdList.add(System.getenv("CLASSPATH"));
        cmdList.add(className);
        return cmdList.toArray(new String[0]);
    }

    public static void executeMpirun(String[] command) {
        StringBuilder stderr = new StringBuilder();
        boolean isVerbose = true;
        int tryCount = 0;
        while (tryCount++ < FaultToleranceContext.maxMpiJobRestarts((Config)config)) {
            LOG.info("mpirun will execute with the command: \n" + MPIMasterStarter.commandAsAString(command));
            int status = ProcessUtils.runSyncProcess(false, command, stderr, new File("."), isVerbose);
            if (status == 0) {
                LOG.info("mpirun completed with success...");
                if (stderr.length() != 0) {
                    LOG.info("The output:\n " + stderr.toString());
                }
                return;
            }
            if (tryCount < FaultToleranceContext.maxMpiJobRestarts((Config)config)) {
                LOG.severe(String.format("Failed to execute mpirun. Will try again. STDERR=%s", stderr));
            }
            stderr.setLength(0);
        }
        LOG.severe(String.format("Failed to execute mpirun. Tried %s times. STDERR=%s", tryCount, stderr));
    }

    public static String commandAsAString(String[] commandArray) {
        String command = "";
        for (String cmd : commandArray) {
            command = command + cmd + " ";
        }
        return command;
    }

    public static String[] generateCheckSshCommand(ArrayList<String> podIPs) {
        String[] command = new String[podIPs.size() + 1];
        command[0] = "./check_pwd_free_ssh.sh";
        int index = 1;
        Iterator<String> iterator = podIPs.iterator();
        while (iterator.hasNext()) {
            String ip;
            command[index] = ip = iterator.next();
            ++index;
        }
        return command;
    }

    public static boolean runScript(String[] command) {
        StringBuilder stderr = new StringBuilder();
        boolean isVerbose = true;
        String commandStr = MPIMasterStarter.commandAsAString(command);
        LOG.info("the script will be executed with the command: \n" + commandStr);
        int status = ProcessUtils.runSyncProcess(false, command, stderr, new File("."), isVerbose);
        if (status != 0) {
            LOG.severe(String.format("Failed to execute the script file command=%s, STDERR=%s", commandStr, stderr));
        } else {
            LOG.info("script: check_pwd_free_ssh.sh execution completed with success...");
            if (stderr.length() != 0) {
                LOG.info("The error output:\n " + stderr.toString());
            }
        }
        return status == 0;
    }
}

