package org.campagnelab.dl.framework.mappers;

import java.util.Arrays;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/campagnelab/dl/framework/mappers/TwoDimensionalConcatFeatureMapper.class */
public class TwoDimensionalConcatFeatureMapper<RecordType> implements FeatureMapper<RecordType> {
    private FeatureMapper<RecordType>[] delegates;
    private int featuresPerTimeStep;
    private int numFeatures;
    private int[] timeStepOffsets;
    private int[] numFeaturesOffsets;
    private MappedDimensions dim;
    private int[] mapperIndices;
    private int[] maskerIndices;
    private int zeroPaddingWidth;
    private int totalTimeSteps;

    @SafeVarargs
    public TwoDimensionalConcatFeatureMapper(int i, FeatureMapper<RecordType>... featureMapperArr) {
        this.mapperIndices = new int[]{0, 0, 0};
        this.maskerIndices = new int[]{0, 0};
        this.zeroPaddingWidth = i;
        MappedDimensions dimensions = featureMapperArr[0].dimensions();
        this.timeStepOffsets = new int[featureMapperArr.length + 1];
        this.numFeaturesOffsets = new int[featureMapperArr.length + 1];
        this.timeStepOffsets[0] = 0;
        this.numFeaturesOffsets[0] = 0;
        this.totalTimeSteps = 0;
        this.numFeatures = 0;
        int i2 = 1;
        for (FeatureMapper<RecordType> featureMapper : featureMapperArr) {
            MappedDimensions dimensions2 = featureMapper.dimensions();
            if (dimensions2.numDimensions() != 2) {
                throw new RuntimeException("Delegate mappers must be two dimensional mappers");
            }
            if (!dimensions.equalsDimension(dimensions2, 1)) {
                throw new RuntimeException("Delegate mappers must have same number of features");
            }
            int numElements = featureMapper.dimensions().numElements(2);
            this.numFeatures += featureMapper.numberOfFeatures();
            this.totalTimeSteps += numElements;
            this.timeStepOffsets[i2] = this.totalTimeSteps;
            this.numFeaturesOffsets[i2] = this.numFeatures;
            i2++;
        }
        this.featuresPerTimeStep = dimensions.numElements(1);
        this.delegates = featureMapperArr;
        this.dim = new MappedDimensions(this.featuresPerTimeStep + i, this.totalTimeSteps);
    }

    @SafeVarargs
    public TwoDimensionalConcatFeatureMapper(FeatureMapper<RecordType>... featureMapperArr) {
        this(0, featureMapperArr);
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public int numberOfFeatures() {
        return this.numFeatures + (this.totalTimeSteps * this.zeroPaddingWidth);
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public MappedDimensions dimensions() {
        return this.dim;
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public void prepareToNormalize(RecordType recordtype, int i) {
        for (FeatureMapper<RecordType> featureMapper : this.delegates) {
            featureMapper.prepareToNormalize(recordtype, i);
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public void mapFeatures(RecordType recordtype, INDArray iNDArray, int i) {
        this.mapperIndices[0] = i;
        int i2 = 0;
        for (FeatureMapper<RecordType> featureMapper : this.delegates) {
            int numElements = featureMapper.dimensions().numElements(2);
            for (int i3 = 0; i3 < numElements; i3++) {
                this.mapperIndices[2] = i2;
                int i4 = 0;
                while (i4 < this.featuresPerTimeStep + this.zeroPaddingWidth) {
                    this.mapperIndices[1] = i4;
                    iNDArray.putScalar(this.mapperIndices, i4 < this.featuresPerTimeStep ? featureMapper.produceFeature(recordtype, (i3 * this.featuresPerTimeStep) + i4) : 0.0f);
                    i4++;
                }
                i2++;
            }
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public boolean hasMask() {
        return true;
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public void maskFeatures(RecordType recordtype, INDArray iNDArray, int i) {
        this.maskerIndices[0] = i;
        int i2 = 0;
        for (FeatureMapper<RecordType> featureMapper : this.delegates) {
            int numElements = featureMapper.dimensions().numElements(2);
            for (int i3 = 0; i3 < numElements; i3++) {
                this.maskerIndices[1] = i2;
                iNDArray.putScalar(this.maskerIndices, featureMapper.isMasked(recordtype, i3 * this.featuresPerTimeStep) ? 1.0f : 0.0f);
                i2++;
            }
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public boolean isMasked(RecordType recordtype, int i) {
        int i2 = i / (this.featuresPerTimeStep + this.zeroPaddingWidth);
        int i3 = (i2 * this.featuresPerTimeStep) + (i % (this.featuresPerTimeStep + this.zeroPaddingWidth));
        int binarySearch = Arrays.binarySearch(this.timeStepOffsets, i2);
        if (binarySearch < 0) {
            binarySearch = (-(binarySearch + 1)) - 1;
        }
        return this.delegates[binarySearch].isMasked(recordtype, i3 - this.numFeaturesOffsets[binarySearch]);
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public float produceFeature(RecordType recordtype, int i) {
        int i2 = i / (this.featuresPerTimeStep + this.zeroPaddingWidth);
        int i3 = i % (this.featuresPerTimeStep + this.zeroPaddingWidth);
        int i4 = (i2 * this.featuresPerTimeStep) + i3;
        int binarySearch = Arrays.binarySearch(this.timeStepOffsets, i2);
        if (binarySearch < 0) {
            binarySearch = (-(binarySearch + 1)) - 1;
        }
        if (i3 < this.featuresPerTimeStep) {
            return this.delegates[binarySearch].produceFeature(recordtype, i4 - this.numFeaturesOffsets[binarySearch]);
        }
        return 0.0f;
    }
}
