package org.campagnelab.dl.framework.mappers;

import java.util.function.Function;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/campagnelab/dl/framework/mappers/OneHotIntegerMapper.class */
public class OneHotIntegerMapper<RecordType> extends NoMaskFeatureMapper<RecordType> {
    private static Logger LOG;
    private Function<RecordType, Integer> recordToInteger;
    private int vectorNumElements;
    private int reducedValue = -1;
    private static final int[] indices;
    static final /* synthetic */ boolean $assertionsDisabled;

    public OneHotIntegerMapper(int i, Function<RecordType, Integer> function) {
        this.vectorNumElements = i;
        this.recordToInteger = function;
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public int numberOfFeatures() {
        return this.vectorNumElements;
    }

    public int getIntegerOfBase(RecordType recordtype) {
        int intValue = this.recordToInteger.apply(recordtype).intValue();
        if ($assertionsDisabled || intValue < this.vectorNumElements) {
            return intValue;
        }
        throw new AssertionError(String.format("value %d cannot exceed vectorNumElements %d", Integer.valueOf(intValue), Integer.valueOf(this.vectorNumElements)));
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public void prepareToNormalize(RecordType recordtype, int i) {
        this.reducedValue = getIntegerOfBase(recordtype);
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public void mapFeatures(RecordType recordtype, INDArray iNDArray, int i) {
        indices[0] = i;
        for (int i2 = 0; i2 < numberOfFeatures(); i2++) {
            indices[1] = i2;
            iNDArray.putScalar(indices, produceFeature(recordtype, i2));
        }
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public float produceFeature(RecordType recordtype, int i) {
        if ($assertionsDisabled || this.reducedValue >= 0) {
            return this.reducedValue == i ? 1.0f : 0.0f;
        }
        throw new AssertionError("prepareToNormalize must be called before produceFeature.");
    }

    static {
        $assertionsDisabled = !OneHotIntegerMapper.class.desiredAssertionStatus();
        LOG = LoggerFactory.getLogger(OneHotIntegerMapper.class);
        indices = new int[]{0, 0};
    }
}
