package org.campagnelab.dl.framework.tools;

import com.google.common.collect.Iterables;
import it.unimi.dsi.fastutil.objects.ObjectArrayList;
import it.unimi.dsi.logging.ProgressLogger;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.io.FilenameUtils;
import org.campagnelab.dl.framework.domains.DomainDescriptor;
import org.campagnelab.dl.framework.domains.DomainDescriptorLoader;
import org.campagnelab.dl.framework.domains.prediction.Prediction;
import org.campagnelab.dl.framework.gpu.InitializeGpu;
import org.campagnelab.dl.framework.iterators.MultiDataSetIteratorAdapter;
import org.campagnelab.dl.framework.iterators.cache.CacheHelper;
import org.campagnelab.dl.framework.models.ModelLoader;
import org.campagnelab.dl.framework.tools.arguments.ConditionRecordingTool;
import org.deeplearning4j.nn.api.Model;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/campagnelab/dl/framework/tools/Predict.class */
public abstract class Predict<RecordType> extends ConditionRecordingTool<PredictArguments> {
    private static Logger LOG = LoggerFactory.getLogger(Predict.class);
    protected String modelTime;
    protected String modelPrefix;
    protected String testSetBasename;
    private CacheHelper<RecordType> cacheHelper = new CacheHelper<>();
    protected DomainDescriptor<RecordType> domainDescriptor;

    @Override // org.campagnelab.dl.framework.tools.arguments.AbstractTool
    /* renamed from: args */
    public PredictArguments args2() {
        return (PredictArguments) this.arguments;
    }

    @Override // org.campagnelab.dl.framework.tools.arguments.AbstractTool
    public PredictArguments createArguments() {
        return new PredictArguments();
    }

    @Override // org.campagnelab.dl.framework.tools.arguments.AbstractTool
    public void execute() {
        PrintWriter printWriter;
        PrintWriter printWriter2;
        InitializeGpu.initialize();
        if (args2().deviceIndex != null) {
            Nd4j.getAffinityManager().attachThreadToDevice(Thread.currentThread(), args2().deviceIndex);
        }
        try {
            String name = new File(args2().modelPath).getName();
            File file = new File(args2().outputFile);
            boolean z = file.exists() && file.length() > 0;
            if (args2().toFile) {
                new File("predictions").mkdirs();
                String baseName = FilenameUtils.getBaseName(args2().testSet);
                String format = String.format("%s/%s-%s-%s-%s.tsv", "predictions", name, args2().modelName, args2().type, baseName);
                System.out.println("Writing predictions to " + format);
                printWriter = new PrintWriter(format, "UTF-8");
                printWriter2 = new PrintWriter(new FileWriter(args2().outputFile, true));
                this.modelTime = name;
                this.modelPrefix = args2().modelName;
                this.testSetBasename = baseName;
            } else {
                printWriter = new PrintWriter(System.out);
                printWriter2 = new PrintWriter(System.out);
                z = false;
            }
            try {
                printPredictions(args2().modelName, args2().modelPath, args2().testSet, printWriter, printWriter2, z);
            } catch (IOException e) {
                throw new RuntimeException("Unable to perform predictions", e);
            }
        } catch (IOException e2) {
            throw new RuntimeException("Unable to create result writer", e2);
        }
    }

    private void printPredictions(String str, String str2, String str3, PrintWriter printWriter, PrintWriter printWriter2, boolean z) throws IOException {
        ModelLoader modelLoader = new ModelLoader(str2);
        String property = modelLoader.getModelProperties().getProperty("tag");
        if (!z) {
            printWriter2.append("tag\tprefix");
            for (String str4 : createOutputHeader()) {
                printWriter2.append((CharSequence) String.format("\t%s", str4));
            }
            printWriter2.append("\targuments\n");
        }
        if (modelLoader.loadFeatureMapper(modelLoader.getModelProperties()).getClass().getCanonicalName().contains("Trio")) {
            System.out.println("setting output to trio mode");
        }
        Model loadModel = modelLoader.loadModel(str);
        if (loadModel == null) {
            System.err.println("Cannot load model with prefix: " + str);
            System.exit(1);
        }
        this.domainDescriptor = DomainDescriptorLoader.load(str2);
        PredictWithModel predictWithModel = new PredictWithModel(this.domainDescriptor);
        Iterable limit = Iterables.limit(this.domainDescriptor.getRecordIterable().apply(str3), args2().scoreN);
        Iterable limit2 = Iterables.limit(this.domainDescriptor.getRecordIterable().apply(str3), args2().scoreN);
        initializeStats(str);
        writeHeader(printWriter);
        int i = args2().miniBatchSize;
        MultiDataSetIteratorAdapter<RecordType> multiDataSetIteratorAdapter = new MultiDataSetIteratorAdapter<RecordType>(limit, i, this.domainDescriptor, false, null) { // from class: org.campagnelab.dl.framework.tools.Predict.1
            @Override // org.campagnelab.dl.framework.iterators.MultiDataSetIteratorAdapter
            public String getBasename() {
                return FilenameUtils.getBaseName(Predict.this.args2().testSet);
            }
        };
        MultiDataSetIterator cache = args2().noCache ? multiDataSetIteratorAdapter : this.cacheHelper.cache(this.domainDescriptor, multiDataSetIteratorAdapter, multiDataSetIteratorAdapter.getBasename(), args2().scoreN, args2().miniBatchSize);
        ObjectArrayList objectArrayList = new ObjectArrayList(i);
        Iterator it = limit2.iterator();
        int i2 = 0;
        int i3 = 0;
        ProgressLogger progressLogger = new ProgressLogger(LOG);
        progressLogger.itemsName = "sites";
        long numRecords = this.domainDescriptor.getNumRecords(new String[]{args2().testSet});
        progressLogger.expectedUpdates = Math.min(args2().scoreN, numRecords);
        progressLogger.displayFreeMemory = false;
        progressLogger.displayLocalSpeed = true;
        progressLogger.start();
        while (true) {
            if (!cache.hasNext() || !it.hasNext()) {
                break;
            }
            MultiDataSet multiDataSet = (MultiDataSet) cache.next();
            int size = multiDataSet.getFeatures(0).size(0);
            i3++;
            objectArrayList.clear();
            for (int i4 = 0; i4 < size && it.hasNext(); i4++) {
                objectArrayList.add(it.next());
            }
            if (objectArrayList.size() != size) {
                System.out.printf("dataset #examples %d and # records (%d) must match. Unable to obtain records for some examples in minibatch. Aborting. ", Integer.valueOf(size), Integer.valueOf(objectArrayList.size()));
                break;
            } else {
                i2 = predictWithModel.makePredictions(multiDataSet, objectArrayList, loadModel, recordPredictions -> {
                    processPredictions(printWriter, recordPredictions.record, recordPredictions.predictions);
                }, num -> {
                    return num.intValue() > args2().scoreN;
                }, i2);
                progressLogger.update(objectArrayList.size());
            }
        }
        printWriter.close();
        printWriter2.append((CharSequence) String.format("%s\t%s", property, str));
        for (double d : createOutputStatistics()) {
            printWriter2.append((CharSequence) String.format("\t%f", Double.valueOf(d)));
        }
        printWriter2.append((CharSequence) ("\t" + getAllCommandLineArguments()));
        printWriter2.append("\n");
        printWriter2.close();
        progressLogger.stop();
        reportStatistics(str);
        System.out.println("Model: " + str2 + " tag:" + property);
        modelLoader.writeTestCount(numRecords);
    }

    protected abstract double[] createOutputStatistics();

    protected abstract String[] createOutputHeader();

    protected abstract void reportStatistics(String str);

    protected abstract void processPredictions(PrintWriter printWriter, RecordType recordtype, List<Prediction> list);

    protected abstract void writeHeader(PrintWriter printWriter);

    protected abstract void initializeStats(String str);
}
