package org.campagnelab.dl.framework.tools;

import com.google.common.collect.Iterables;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.floats.FloatArraySet;
import it.unimi.dsi.logging.ProgressLogger;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Properties;
import java.util.stream.Collectors;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.FilenameUtils;
import org.campagnelab.dl.framework.architecture.graphs.ComputationGraphAssembler;
import org.campagnelab.dl.framework.domains.DomainDescriptor;
import org.campagnelab.dl.framework.gpu.InitializeGpu;
import org.campagnelab.dl.framework.gpu.ParameterPrecision;
import org.campagnelab.dl.framework.iterators.MultiDataSetIteratorAdapter;
import org.campagnelab.dl.framework.iterators.cache.CacheHelper;
import org.campagnelab.dl.framework.iterators.cache.FullyInMemoryCache;
import org.campagnelab.dl.framework.mappers.FeatureMapper;
import org.campagnelab.dl.framework.mixup.MixupMultiDataSetPreProcessor;
import org.campagnelab.dl.framework.models.ComputationGraphSaver;
import org.campagnelab.dl.framework.models.ModelLoader;
import org.campagnelab.dl.framework.models.ModelPropertiesHelper;
import org.campagnelab.dl.framework.performance.Metric;
import org.campagnelab.dl.framework.performance.PerformanceLogger;
import org.campagnelab.dl.framework.performance.PerformanceMetricDescriptor;
import org.campagnelab.dl.framework.tools.TrainingArguments;
import org.campagnelab.dl.framework.tools.arguments.ConditionRecordingTool;
import org.campagnelab.dl.framework.training.ParallelTrainerOnGPU;
import org.campagnelab.dl.framework.training.SequentialTrainer;
import org.campagnelab.dl.framework.training.Trainer;
import org.campagnelab.dl.framework.training.WrapInAsyncAttach;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.GraphVertex;
import org.deeplearning4j.nn.graph.vertex.impl.LayerVertex;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.FileStatsStorage;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.util.DataTypeUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/campagnelab/dl/framework/tools/TrainModel.class */
public abstract class TrainModel<RecordType> extends ConditionRecordingTool<TrainingArguments> {
    private static Logger LOG;
    private String directory;
    private double bestScore;
    private long time;
    protected DomainDescriptor<RecordType> domainDescriptor;
    private String bestMetricName;
    protected PerformanceLogger performanceLogger;
    private ComputationGraph computationGraph;
    static final /* synthetic */ boolean $assertionsDisabled;
    protected FeatureMapper featureMapper = null;
    private CacheHelper<RecordType> cacheHelper = new CacheHelper<>();
    ParameterPrecision precision = ParameterPrecision.FP32;

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract DomainDescriptor<RecordType> domainDescriptor();

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.campagnelab.dl.framework.tools.arguments.AbstractTool
    public void execute() {
        InitializeGpu.initialize();
        if (((TrainingArguments) args2()).getTrainingSets().length == 0) {
            System.err.println("You must provide training datasets.");
        }
        this.domainDescriptor = domainDescriptor();
        if (((TrainingArguments) args2()).advancedModelConfiguration != null) {
            this.domainDescriptor.loadAdvancedModelProperties(((TrainingArguments) args2()).advancedModelConfiguration);
        }
        try {
            this.featureMapper = this.domainDescriptor.getFeatureMapper("input");
            execute(this.featureMapper, ((TrainingArguments) args2()).getTrainingSets(), ((TrainingArguments) args2()).miniBatchSize);
        } catch (IOException e) {
            System.err.println("An exception occured. Details may be provided below");
            e.printStackTrace();
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void execute(FeatureMapper featureMapper, String[] strArr, int i) throws IOException {
        if (((TrainingArguments) args2()).deviceIndex != null && !((TrainingArguments) args2()).parallel) {
            Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread(), ((TrainingArguments) args2()).deviceIndex);
        }
        if (((TrainingArguments) args2()).previousModelPath != null) {
            System.out.println(String.format("Resuming training with %s model parameters from %s %n", ((TrainingArguments) args2()).previousModelName, ((TrainingArguments) args2()).previousModelPath));
        }
        if ("FP16".equals(((TrainingArguments) args2()).precision)) {
            this.precision = ParameterPrecision.FP16;
            System.out.println("Parameter precision set to FP16.");
        }
        if ("FP16".equals(((TrainingArguments) args2()).precision)) {
            DataTypeUtil.setDTypeForContext(DataBuffer.Type.HALF);
        }
        this.time = new Date().getTime();
        System.out.println("epochs: " + ((TrainingArguments) args2()).maxEpochs);
        System.out.println("FeatureMapper:" + featureMapper.getClass().getTypeName());
        System.out.println("ComputationGraphAssembler:" + ((TrainingArguments) args2()).architectureClassname);
        this.directory = "models/" + Long.toString(this.time);
        FileUtils.forceMkdir(new File(this.directory));
        System.out.println("model directory: " + new File(this.directory).getAbsolutePath());
        this.bestMetricName = "best" + this.domainDescriptor.performanceDescritor().earlyStoppingMetric();
        ComputationGraphAssembler computationalGraph = this.domainDescriptor.getComputationalGraph();
        if (!$assertionsDisabled && computationalGraph == null) {
            throw new AssertionError("Computational Graph assembler must be defined.");
        }
        computationalGraph.setArguments((TrainingArguments) args2());
        HashMap hashMap = new HashMap();
        for (String str : computationalGraph.getInputNames()) {
            int[] iArr = (int[]) this.domainDescriptor.getNumInputs(str).clone();
            boolean z = ((TrainingArguments) args2()).previousModelPretraining && ((((TrainingArguments) args2()).eosIndex != null && ((TrainingArguments) args2()).eosIndex.intValue() == iArr[0]) || ((TrainingArguments) args2()).eosIndex == null);
            hashMap.put(str, Boolean.valueOf(z));
            if (z) {
                iArr[0] = iArr[0] + 1;
            }
            computationalGraph.setNumInputs(str, iArr);
        }
        this.domainDescriptor.setInputsPaddedEos(hashMap);
        for (String str2 : computationalGraph.getOutputNames()) {
            computationalGraph.setNumOutputs(str2, this.domainDescriptor.getNumOutputs(str2));
            computationalGraph.setLossFunction(str2, this.domainDescriptor.getOutputLoss(str2));
        }
        for (String str3 : computationalGraph.getComponentNames()) {
            computationalGraph.setNumHiddenNodes(str3, this.domainDescriptor.getNumHiddenNodes(str3));
        }
        this.computationGraph = computationalGraph.createComputationalGraph(this.domainDescriptor);
        this.computationGraph.init();
        if (((TrainingArguments) args2()).addUiListener) {
            UIServer uIServer = UIServer.getInstance();
            InMemoryStatsStorage inMemoryStatsStorage = ((TrainingArguments) args2()).uiStatsFile == null ? new InMemoryStatsStorage() : new FileStatsStorage(new File(((TrainingArguments) args2()).uiStatsFile));
            uIServer.attach(inMemoryStatsStorage);
            this.computationGraph.setListeners(new IterationListener[]{new StatsListener(inMemoryStatsStorage)});
        }
        if (((TrainingArguments) args2()).previousModelPath != null) {
            ComputationGraph loadModel = new ModelLoader(((TrainingArguments) args2()).previousModelPath).loadModel(((TrainingArguments) args2()).previousModelName);
            ComputationGraph computationGraph = loadModel instanceof ComputationGraph ? loadModel : null;
            if (loadModel == null || computationGraph == null || computationGraph.getUpdater() == null || computationGraph.params() == null) {
                System.err.println("Unable to load model or updater from " + ((TrainingArguments) args2()).previousModelPath);
            } else {
                this.computationGraph.setUpdater(computationGraph.getUpdater());
                this.computationGraph.setParams(loadModel.params());
            }
        }
        int i2 = 0;
        for (GraphVertex graphVertex : this.computationGraph.getVertices()) {
            if (graphVertex instanceof LayerVertex) {
                int numParams = graphVertex.getLayer().numParams();
                System.out.println("Number of parameters in layer " + graphVertex.getVertexName() + ": " + numParams);
                i2 += numParams;
            }
        }
        System.out.println("Total number of network parameters: " + i2);
        writeProperties();
        this.performanceLogger = new PerformanceLogger(this.directory);
        PerformanceMetricDescriptor<RecordType> performanceDescritor = this.domainDescriptor.performanceDescritor();
        String[] performanceMetrics = performanceDescritor.performanceMetrics();
        Metric[] metricArr = new Metric[performanceMetrics.length];
        for (int i3 = 0; i3 < metricArr.length; i3++) {
            String str4 = performanceMetrics[i3];
            metricArr[i3] = new Metric(str4, performanceDescritor.largerValueIsBetterPerformance(str4));
        }
        this.performanceLogger.definePerformances(metricArr);
        EarlyStoppingResult<ComputationGraph> train = train();
        System.out.println("Total epochs: " + train.getTotalEpochs());
        System.out.println("Best epoch number: " + train.getBestModelEpoch());
        System.out.println("FeatureMapper:" + featureMapper.getClass().getTypeName());
        System.out.println("ComputationGraphAssembler:" + ((TrainingArguments) args2()).architectureClassname);
        for (String str5 : performanceMetrics) {
            System.out.println(str5 + " at best epoch: " + this.performanceLogger.getBest(str5));
        }
        writeProperties();
        writeBestScoreFile();
        System.out.println("Model completed, saved at time: " + this.time);
        this.performanceLogger.write();
        for (String str6 : this.domainDescriptor.performanceDescritor().performanceMetrics()) {
            resultValues().put(str6, Double.valueOf(this.performanceLogger.getBest(str6)));
        }
        resultValues().put("bestModelEpoch", Integer.valueOf(this.performanceLogger.getBestEpoch(this.bestMetricName)));
        resultValues().put("model-time", Long.valueOf(this.time));
    }

    protected void writeBestScoreFile() throws IOException {
        FileWriter fileWriter = new FileWriter(this.directory + "/bestScore");
        fileWriter.append((CharSequence) Double.toString(this.performanceLogger.getBestScore()));
        fileWriter.close();
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected void writeProperties() throws IOException {
        ModelPropertiesHelper modelPropertiesHelper = new ModelPropertiesHelper();
        appendProperties(this.domainDescriptor.getComputationalGraph(), modelPropertiesHelper);
        modelPropertiesHelper.addProperties(getReaderProperties(((TrainingArguments) args2()).trainingSets.get(0)));
        modelPropertiesHelper.put("domainDescriptor", this.domainDescriptor.getClass().getCanonicalName());
        modelPropertiesHelper.put("tag", getTag());
        this.domainDescriptor.putProperties(modelPropertiesHelper.getProperties());
        modelPropertiesHelper.writeProperties(this.directory);
        this.domainDescriptor.writeProperties(this.directory);
    }

    protected static int numLabels(INDArray iNDArray) {
        FloatArraySet floatArraySet = new FloatArraySet();
        for (int i = 0; i < iNDArray.size(0); i++) {
            floatArraySet.add(iNDArray.getFloat(i));
        }
        return floatArraySet.size();
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void appendProperties(ComputationGraphAssembler computationGraphAssembler, ModelPropertiesHelper modelPropertiesHelper) {
        computationGraphAssembler.saveProperties(modelPropertiesHelper);
        modelPropertiesHelper.setFeatureCalculator(this.featureMapper);
        modelPropertiesHelper.setLearningRate(((TrainingArguments) args2()).learningRate);
        modelPropertiesHelper.setDropoutRate(((TrainingArguments) args2()).dropoutRate);
        modelPropertiesHelper.setMiniBatchSize(((TrainingArguments) args2()).miniBatchSize);
        modelPropertiesHelper.setNumEpochs(((TrainingArguments) args2()).maxEpochs);
        modelPropertiesHelper.setNumTrainingSets(((TrainingArguments) args2()).trainingSets.size());
        modelPropertiesHelper.setTime(this.time);
        modelPropertiesHelper.setSeed(((TrainingArguments) args2()).seed);
        modelPropertiesHelper.setEarlyStopCriterion(((TrainingArguments) args2()).stopWhenEpochsWithoutImprovement);
        modelPropertiesHelper.setRegularization(((TrainingArguments) args2()).regularizationRate);
        modelPropertiesHelper.setPrecision(this.precision);
        modelPropertiesHelper.put("allArguments", getAllCommandLineArguments());
    }

    public abstract Properties getReaderProperties(String str) throws IOException;

    /* JADX WARN: Multi-variable type inference failed */
    protected EarlyStoppingResult<ComputationGraph> train() throws IOException {
        String str = ((TrainingArguments) args2()).validationSet;
        if (!new File(str).exists()) {
            throw new IOException("Validation file not found! " + str);
        }
        this.bestScore = Double.MAX_VALUE;
        ComputationGraphSaver computationGraphSaver = new ComputationGraphSaver(this.directory);
        HashMap hashMap = new HashMap();
        System.out.println("errorEnrichment=" + ((TrainingArguments) args2()).errorEnrichment);
        this.performanceLogger.setCondition(((TrainingArguments) args2()).experimentalCondition);
        long j = 0;
        int i = 0;
        System.out.flush();
        PerformanceMetricDescriptor<RecordType> performanceDescritor = this.domainDescriptor.performanceDescritor();
        String earlyStoppingMetric = performanceDescritor.earlyStoppingMetric();
        double initializePerformance = initializePerformance(performanceDescritor, earlyStoppingMetric);
        MultiDataSetIteratorAdapter<RecordType> multiDataSetIteratorAdapter = new MultiDataSetIteratorAdapter<RecordType>(Iterables.limit(Iterables.concat((Iterable) ((TrainingArguments) args2()).trainingSets.stream().map(str2 -> {
            return this.domainDescriptor.getRecordIterable().apply(str2);
        }).collect(Collectors.toList())), ((TrainingArguments) args2()).numTraining), ((TrainingArguments) args2()).miniBatchSize, this.domainDescriptor, ((TrainingArguments) args2()).previousModelPretraining, ((TrainingArguments) args2()).eosIndex) { // from class: org.campagnelab.dl.framework.tools.TrainModel.1
            /* JADX WARN: Multi-variable type inference failed */
            @Override // org.campagnelab.dl.framework.iterators.MultiDataSetIteratorAdapter
            public String getBasename() {
                return TrainModel.this.buildBaseName(((TrainingArguments) TrainModel.this.args2()).trainingSets);
            }
        };
        MultiDataSetIterator cache = !((TrainingArguments) args2()).ignoreCache ? this.cacheHelper.cache(this.domainDescriptor, multiDataSetIteratorAdapter, multiDataSetIteratorAdapter.getBasename(), ((TrainingArguments) args2()).numTraining, ((TrainingArguments) args2()).miniBatchSize) : multiDataSetIteratorAdapter;
        if (((TrainingArguments) args2()).memoryCacheTraining()) {
            cache = new FullyInMemoryCache(cache);
            LOG.warn("Loading training set in memory.");
            cache.reset();
            LOG.warn("Done.");
        }
        if (((TrainingArguments) args2()).mixupAlpha != null) {
            cache.setPreProcessor(new MixupMultiDataSetPreProcessor(((TrainingArguments) args2()).seed, ((TrainingArguments) args2()).mixupAlpha.doubleValue()));
        }
        long min = Math.min(((TrainingArguments) args2()).numTraining, this.domainDescriptor.getNumRecords(((TrainingArguments) args2()).getTrainingSets()));
        int i2 = (int) (min / ((TrainingArguments) args2()).miniBatchSize);
        System.out.printf("Training with %d minibatches per epoch%n", Integer.valueOf(i2));
        MultiDataSetIterator readValidationSet = readValidationSet();
        System.out.println("Finished loading validation records.");
        if (((TrainingArguments) args2()).buildCacheAndStop) {
            System.out.println("Cache has been built. Exiting now since --build-cache-then-stop was used.");
            System.exit(0);
        }
        ProgressLogger progressLogger = new ProgressLogger(LOG);
        progressLogger.displayLocalSpeed = true;
        progressLogger.itemsName = "epoch";
        progressLogger.expectedUpdates = ((TrainingArguments) args2()).maxEpochs;
        switch (((TrainingArguments) args2()).trackingStyle) {
            case PERFS:
                System.out.println(this.performanceLogger.getMetricHeader());
                break;
            case SPEED:
                progressLogger.start();
                break;
            default:
                System.out.println("Unsupported tracking style: " + ((TrainingArguments) args2()).trackingStyle);
                break;
        }
        Trainer parallelTrainerOnGPU = ((TrainingArguments) args2()).parallel ? new ParallelTrainerOnGPU(this.computationGraph, ((TrainingArguments) args2()).miniBatchSize, (int) min) : new SequentialTrainer();
        parallelTrainerOnGPU.setLogSpeed(((TrainingArguments) args2()).trackingStyle == TrainingArguments.TrackStyle.SPEED);
        MultiDataSetIterator wrap = ((TrainingArguments) args2()).parallel ? cache : WrapInAsyncAttach.wrap(cache);
        for (int i3 = 0; i3 < ((TrainingArguments) args2()).maxEpochs; i3++) {
            ProgressLogger progressLogger2 = new ProgressLogger(LOG);
            progressLogger2.itemsName = "mini-batch";
            progressLogger2.expectedUpdates = i2;
            if (((TrainingArguments) args2()).trackingStyle == TrainingArguments.TrackStyle.SPEED) {
                progressLogger2.start();
            }
            j += parallelTrainerOnGPU.train(this.computationGraph, wrap, progressLogger2);
            double score = parallelTrainerOnGPU.getScore();
            computationGraphSaver.saveLatestModel(this.computationGraph, score);
            writeProperties();
            writeBestScoreFile();
            if ((i3 + 1) % ((TrainingArguments) args2()).validateEvery == 0) {
                initializePerformance(performanceDescritor, performanceDescritor.earlyStoppingMetric());
                new DoubleArrayList();
                readValidationSet.reset();
                if (!$assertionsDisabled && !readValidationSet.hasNext()) {
                    throw new AssertionError("validation iterator must have datasets. Make sure the latest release of Goby is installed in the maven repo.");
                }
                double[] estimateMetric = performanceDescritor.estimateMetric(this.computationGraph, readValidationSet, ((TrainingArguments) args2()).numValidation, performanceDescritor.performanceMetrics());
                DoubleArrayList wrap2 = DoubleArrayList.wrap(estimateMetric);
                double findMetricValue = findMetricValue(performanceDescritor.earlyStoppingMetric(), performanceDescritor.performanceMetrics(), estimateMetric);
                this.performanceLogger.logMetrics("epochs", j, i3, wrap2.toDoubleArray());
                this.performanceLogger.logTrainingScore("epochs", i3, score);
                if (((TrainingArguments) args2()).trackingStyle == TrainingArguments.TrackStyle.PERFS) {
                    this.performanceLogger.show("epochs");
                }
                if ((Double.isNaN(initializePerformance) || !performanceDescritor.largerValueIsBetterPerformance(earlyStoppingMetric) || findMetricValue <= initializePerformance) && (performanceDescritor.largerValueIsBetterPerformance(earlyStoppingMetric) || findMetricValue >= initializePerformance)) {
                    i++;
                } else {
                    computationGraphSaver.saveModel(this.computationGraph, "best" + earlyStoppingMetric);
                    initializePerformance = findMetricValue;
                    this.performanceLogger.logMetrics(this.bestMetricName, j, i3, wrap2.toDoubleArray());
                    i = 0;
                }
                if (!Double.isNaN(initializePerformance)) {
                    if (i > ((TrainingArguments) args2()).stopWhenEpochsWithoutImprovement) {
                    }
                }
                progressLogger.stop();
                return new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, "not early stopping", hashMap, this.performanceLogger.getBestEpoch(this.bestMetricName), this.bestScore, ((TrainingArguments) args2()).maxEpochs, this.computationGraph);
            }
            if (((TrainingArguments) args2()).trackingStyle == TrainingArguments.TrackStyle.SPEED) {
                progressLogger2.stop();
                progressLogger.updateAndDisplay();
            }
            if (!((TrainingArguments) args2()).parallel) {
                wrap.reset();
            }
            this.performanceLogger.write();
        }
        progressLogger.stop();
        return new EarlyStoppingResult<>(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, "not early stopping", hashMap, this.performanceLogger.getBestEpoch(this.bestMetricName), this.bestScore, ((TrainingArguments) args2()).maxEpochs, this.computationGraph);
    }

    private double findMetricValue(String str, String[] strArr, double[] dArr) {
        int i = 0;
        for (String str2 : strArr) {
            if (str.equals(str2)) {
                return dArr[i];
            }
            i++;
        }
        throw new RuntimeException("Metric name not found: " + str);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public String buildBaseName(List<String> list) {
        String str;
        if (list.size() == 1) {
            str = FilenameUtils.removeExtension(list.get(0));
        } else {
            long j = 8723872838723L;
            while (list.iterator().hasNext()) {
                j ^= FilenameUtils.getBaseName(r0.next()).hashCode();
            }
            str = "multiset-" + Long.toString(j);
        }
        return str;
    }

    private double initializePerformance(PerformanceMetricDescriptor performanceMetricDescriptor, String str) {
        return performanceMetricDescriptor.largerValueIsBetterPerformance(str) ? Double.NEGATIVE_INFINITY : Double.POSITIVE_INFINITY;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private MultiDataSetIterator readValidationSet() {
        try {
            MultiDataSetIteratorAdapter<RecordType> multiDataSetIteratorAdapter = new MultiDataSetIteratorAdapter<RecordType>(this.domainDescriptor.getRecordIterable().apply(((TrainingArguments) args2()).validationSet), ((TrainingArguments) args2()).miniBatchSize, this.domainDescriptor, ((TrainingArguments) args2()).previousModelPretraining, ((TrainingArguments) args2()).eosIndex) { // from class: org.campagnelab.dl.framework.tools.TrainModel.2
                /* JADX WARN: Multi-variable type inference failed */
                @Override // org.campagnelab.dl.framework.iterators.MultiDataSetIteratorAdapter
                public String getBasename() {
                    return ((TrainingArguments) TrainModel.this.args2()).validationSet;
                }
            };
            MultiDataSetIterator cache = ((TrainingArguments) args2()).ignoreCache ? multiDataSetIteratorAdapter : this.cacheHelper.cache(this.domainDescriptor, multiDataSetIteratorAdapter, multiDataSetIteratorAdapter.getBasename(), ((TrainingArguments) args2()).numValidation, ((TrainingArguments) args2()).miniBatchSize);
            return ((TrainingArguments) args2()).memoryCacheValidation() ? new FullyInMemoryCache(cache) : cache;
        } catch (IOException e) {
            throw new RuntimeException("Unable to load validation records from " + ((TrainingArguments) args2()).validationSet);
        }
    }

    static {
        $assertionsDisabled = !TrainModel.class.desiredAssertionStatus();
        LOG = LoggerFactory.getLogger(TrainModel.class);
    }
}
