package org.campagnelab.dl.framework.mappers;

import java.util.Arrays;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/campagnelab/dl/framework/mappers/TwoDimensionalConcatLabelMapper.class */
public class TwoDimensionalConcatLabelMapper<RecordType> implements LabelMapper<RecordType> {
    private LabelMapper<RecordType>[] delegates;
    private int labelsPerTimeStep;
    private int numLabels;
    private int[] timeStepOffsets;
    private int[] numLabelsOffsets;
    private MappedDimensions dim;
    private int[] mapperIndices;
    private int[] maskerIndices;
    private int zeroPaddingWidth;
    private int totalTimeSteps;

    @SafeVarargs
    public TwoDimensionalConcatLabelMapper(int i, LabelMapper<RecordType>... labelMapperArr) {
        this.mapperIndices = new int[]{0, 0, 0};
        this.maskerIndices = new int[]{0, 0};
        this.zeroPaddingWidth = i;
        MappedDimensions dimensions = labelMapperArr[0].dimensions();
        this.timeStepOffsets = new int[labelMapperArr.length + 1];
        this.numLabelsOffsets = new int[labelMapperArr.length + 1];
        this.timeStepOffsets[0] = 0;
        this.numLabelsOffsets[0] = 0;
        this.totalTimeSteps = 0;
        this.numLabels = 0;
        int i2 = 1;
        for (LabelMapper<RecordType> labelMapper : labelMapperArr) {
            MappedDimensions dimensions2 = labelMapper.dimensions();
            if (dimensions2.numDimensions() != 2) {
                throw new RuntimeException("Delegate mappers must be two dimensional mappers");
            }
            if (!dimensions.equalsDimension(dimensions2, 1)) {
                throw new RuntimeException("Delegate mappers must have same number of labels");
            }
            int numElements = labelMapper.dimensions().numElements(2);
            this.numLabels += labelMapper.numberOfLabels();
            this.totalTimeSteps += numElements;
            this.timeStepOffsets[i2] = this.totalTimeSteps;
            this.numLabelsOffsets[i2] = this.numLabels;
            i2++;
        }
        this.labelsPerTimeStep = dimensions.numElements(1);
        this.delegates = labelMapperArr;
        this.dim = new MappedDimensions(this.labelsPerTimeStep + i, this.totalTimeSteps);
    }

    @SafeVarargs
    public TwoDimensionalConcatLabelMapper(LabelMapper<RecordType>... labelMapperArr) {
        this(0, labelMapperArr);
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public int numberOfLabels() {
        return this.numLabels + (this.totalTimeSteps * this.zeroPaddingWidth);
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public MappedDimensions dimensions() {
        return this.dim;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public void mapLabels(RecordType recordtype, INDArray iNDArray, int i) {
        this.mapperIndices[0] = i;
        int i2 = 0;
        for (LabelMapper<RecordType> labelMapper : this.delegates) {
            int numElements = labelMapper.dimensions().numElements(2);
            for (int i3 = 0; i3 < numElements; i3++) {
                this.mapperIndices[2] = i2;
                int i4 = 0;
                while (i4 < this.labelsPerTimeStep + this.zeroPaddingWidth) {
                    this.mapperIndices[1] = i4;
                    iNDArray.putScalar(this.mapperIndices, i4 < this.labelsPerTimeStep ? labelMapper.produceLabel(recordtype, (i3 * this.labelsPerTimeStep) + i4) : 0.0f);
                    i4++;
                }
                i2++;
            }
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public float produceLabel(RecordType recordtype, int i) {
        int i2 = i / (this.labelsPerTimeStep + this.zeroPaddingWidth);
        int i3 = i % (this.labelsPerTimeStep + this.zeroPaddingWidth);
        int i4 = (i2 * this.labelsPerTimeStep) + i3;
        int binarySearch = Arrays.binarySearch(this.timeStepOffsets, i2);
        if (binarySearch < 0) {
            binarySearch = (-(binarySearch + 1)) - 1;
        }
        if (i3 < this.labelsPerTimeStep) {
            return this.delegates[binarySearch].produceLabel(recordtype, i4 - this.numLabelsOffsets[binarySearch]);
        }
        return 0.0f;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public boolean hasMask() {
        return true;
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public void maskLabels(RecordType recordtype, INDArray iNDArray, int i) {
        this.maskerIndices[0] = i;
        int i2 = 0;
        for (LabelMapper<RecordType> labelMapper : this.delegates) {
            int numElements = labelMapper.dimensions().numElements(2);
            for (int i3 = 0; i3 < numElements; i3++) {
                this.maskerIndices[1] = i2;
                iNDArray.putScalar(this.maskerIndices, labelMapper.isMasked(recordtype, i3 * this.labelsPerTimeStep) ? 1.0f : 0.0f);
                i2++;
            }
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public boolean isMasked(RecordType recordtype, int i) {
        int i2 = i / (this.labelsPerTimeStep + this.zeroPaddingWidth);
        int i3 = (i2 * this.labelsPerTimeStep) + (i % (this.labelsPerTimeStep + this.zeroPaddingWidth));
        int binarySearch = Arrays.binarySearch(this.timeStepOffsets, i2);
        if (binarySearch < 0) {
            binarySearch = (-(binarySearch + 1)) - 1;
        }
        return this.delegates[binarySearch].isMasked(recordtype, i3 - this.numLabelsOffsets[binarySearch]);
    }

    @Override // org.campagnelab.dl.framework.mappers.LabelMapper
    public void prepareToNormalize(RecordType recordtype, int i) {
        for (LabelMapper<RecordType> labelMapper : this.delegates) {
            labelMapper.prepareToNormalize(recordtype, i);
        }
    }
}
