package org.campagnelab.dl.framework.mappers;

import it.unimi.dsi.fastutil.floats.FloatArrayList;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/campagnelab/dl/framework/mappers/MeanNormalizationMapper.class */
public class MeanNormalizationMapper<RecordType> extends AbstractFeatureMapper1D<RecordType> {
    private final boolean dividebyStdev;
    FeatureNameMapper<RecordType> delegate;
    float mean;
    boolean normalizedCalled;
    private double stdev;
    private FloatArrayList values;
    static final /* synthetic */ boolean $assertionsDisabled;

    public MeanNormalizationMapper(FeatureNameMapper<RecordType> featureNameMapper) {
        this(featureNameMapper, false);
    }

    public MeanNormalizationMapper(FeatureNameMapper<RecordType> featureNameMapper, boolean z) {
        this.mean = 0.0f;
        this.normalizedCalled = false;
        this.values = new FloatArrayList();
        this.delegate = featureNameMapper;
        this.dividebyStdev = z;
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public int numberOfFeatures() {
        return this.delegate.numberOfFeatures();
    }

    @Override // org.campagnelab.dl.framework.mappers.AbstractFeatureMapper1D, org.campagnelab.dl.framework.mappers.FeatureMapper
    public void mapFeatures(RecordType recordtype, INDArray iNDArray, int i) {
        super.mapFeatures(recordtype, iNDArray, i);
        this.normalizedCalled = false;
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public void prepareToNormalize(RecordType recordtype, int i) {
        this.mean = 0.0f;
        this.stdev = 0.0d;
        int i2 = 0;
        this.delegate.prepareToNormalize(recordtype, i);
        this.values.clear();
        for (int i3 = 0; i3 < numberOfFeatures(); i3++) {
            float produceFeature = this.delegate.produceFeature(recordtype, i3);
            this.values.add(produceFeature);
            this.mean += produceFeature;
            i2++;
        }
        this.mean /= i2;
        double d = 0.0d;
        while (this.values.iterator().hasNext()) {
            double floatValue = ((Float) r0.next()).floatValue() - this.mean;
            d += floatValue * floatValue;
        }
        this.stdev = Math.sqrt(d);
        this.delegate.prepareToNormalize(recordtype, i);
        this.normalizedCalled = true;
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureMapper
    public float produceFeature(RecordType recordtype, int i) {
        if ($assertionsDisabled || this.normalizedCalled) {
            return normalize(produceFeatureInternal(recordtype, i), this.mean);
        }
        throw new AssertionError("normalized must be called before produceFeature");
    }

    @Override // org.campagnelab.dl.framework.mappers.FeatureNameMapper
    public String getFeatureName(int i) {
        return this.delegate.getFeatureName(i);
    }

    private float normalize(float f, float f2) {
        if (f2 == 0.0f) {
            return 0.0f;
        }
        float f3 = f - this.mean;
        if (this.dividebyStdev) {
            f3 = (float) (f3 / this.stdev);
        }
        return f3;
    }

    private float produceFeatureInternal(RecordType recordtype, int i) {
        return this.delegate.produceFeature(recordtype, i);
    }

    static {
        $assertionsDisabled = !MeanNormalizationMapper.class.desiredAssertionStatus();
    }
}
