package org.campagnelab.dl.framework.tools;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.function.Consumer;
import java.util.function.Predicate;
import org.campagnelab.dl.framework.domains.DomainDescriptor;
import org.campagnelab.dl.framework.domains.prediction.Prediction;
import org.campagnelab.dl.framework.domains.prediction.PredictionInterpreter;
import org.campagnelab.dl.framework.domains.prediction.RecordPredictions;
import org.campagnelab.dl.framework.models.ModelOutputHelper;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;

/* loaded from: input_file:org/campagnelab/dl/framework/tools/PredictWithModel.class */
public class PredictWithModel<RecordType> {
    protected DomainDescriptor<RecordType> domainDescriptor;
    ModelOutputHelper outputHelper;
    protected PredictionInterpreter[] interpretors;
    static final /* synthetic */ boolean $assertionsDisabled;

    public PredictWithModel(DomainDescriptor<RecordType> domainDescriptor) {
        this.domainDescriptor = domainDescriptor;
        this.outputHelper = new ModelOutputHelper(domainDescriptor);
        int i = 0;
        String[] outputNames = domainDescriptor.getComputationalGraph().getOutputNames();
        this.interpretors = new PredictionInterpreter[outputNames.length];
        for (String str : outputNames) {
            int i2 = i;
            i++;
            this.interpretors[i2] = domainDescriptor.getPredictionInterpreter(str);
        }
    }

    public void makePredictions(Iterator<RecordType> it, Model model, Consumer<RecordPredictions<RecordType>> consumer, Predicate<Integer> predicate) {
        makePredictions(it, model, obj -> {
        }, consumer, predicate);
    }

    public int makePredictions(MultiDataSet multiDataSet, List<RecordType> list, Model model, Consumer<RecordPredictions<RecordType>> consumer, Predicate<Integer> predicate, int i) {
        if (!$assertionsDisabled && !(model instanceof ComputationGraph)) {
            throw new AssertionError("MultiDataSet only work with ComputationGraph");
        }
        INDArray[] output = ((ComputationGraph) model).output(false, multiDataSet.getFeatures());
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < list.size(); i2++) {
            arrayList.clear();
            RecordType recordtype = list.get(i2);
            for (int i3 = 0; i3 < this.domainDescriptor.getNumModelOutputs(); i3++) {
                if (this.interpretors[i3] != null) {
                    Prediction interpret = this.interpretors[i3].interpret(recordtype, output[i3].slice(i2));
                    interpret.outputIndex = i3;
                    interpret.index = i;
                    arrayList.add(interpret);
                }
            }
            consumer.accept(new RecordPredictions<>(recordtype, arrayList));
            i++;
            if (predicate.test(Integer.valueOf(i))) {
                break;
            }
        }
        return i;
    }

    public void makePredictions(Iterator<RecordType> it, Model model, Consumer<RecordType> consumer, Consumer<RecordPredictions<RecordType>> consumer2, Predicate<Integer> predicate) {
        int i = 0;
        ArrayList arrayList = new ArrayList();
        while (it.hasNext()) {
            RecordType next = it.next();
            consumer.accept(next);
            this.outputHelper.predictForNextRecord(model, next, this.domainDescriptor.featureMappers(true));
            arrayList.clear();
            for (int i2 = 0; i2 < this.domainDescriptor.getNumModelOutputs(); i2++) {
                INDArray output = this.outputHelper.getOutput(i2);
                if (this.interpretors[i2] != null) {
                    Prediction interpret = this.interpretors[i2].interpret(next, output);
                    interpret.outputIndex = i2;
                    interpret.index = i;
                    arrayList.add(interpret);
                }
            }
            consumer2.accept(new RecordPredictions<>(next, arrayList));
            i++;
            if (predicate.test(Integer.valueOf(i))) {
                return;
            }
        }
    }

    static {
        $assertionsDisabled = !PredictWithModel.class.desiredAssertionStatus();
    }
}
