package org.campagnelab.dl.framework.models;

import java.util.function.Consumer;
import org.campagnelab.dl.framework.mappers.FeatureMapper;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/campagnelab/dl/framework/models/ModelFeatureHelper.class */
public class ModelFeatureHelper<RecordType> {
    private INDArray[] featuresArray;

    public void prepareForNextRecord(Model model, RecordType recordtype, Consumer<INDArray[]> consumer, FeatureMapper... featureMapperArr) {
        if (this.featuresArray == null) {
            this.featuresArray = new INDArray[featureMapperArr.length];
        }
        if (model instanceof MultiLayerNetwork) {
            INDArray zeros = Nd4j.zeros(1, featureMapperArr[0].numberOfFeatures());
            featureMapperArr[0].prepareToNormalize(recordtype, 0);
            featureMapperArr[0].mapFeatures(recordtype, zeros, 0);
            this.featuresArray[0] = zeros;
            consumer.accept(this.featuresArray);
            return;
        }
        if (!(model instanceof ComputationGraph)) {
            throw new IllegalArgumentException("model is not of supported type: " + model.getClass().getCanonicalName());
        }
        for (int i = 0; i < featureMapperArr.length; i++) {
            this.featuresArray[i] = Nd4j.zeros(1, featureMapperArr[i].numberOfFeatures());
            featureMapperArr[i].prepareToNormalize(recordtype, 0);
            featureMapperArr[i].mapFeatures(recordtype, this.featuresArray[i], 0);
        }
        consumer.accept(this.featuresArray);
    }
}
