package org.campagnelab.dl.framework.performance;

import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleListIterator;
import java.util.Collections;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/campagnelab/dl/framework/performance/AreaUnderTheROCCurve.class */
public class AreaUnderTheROCCurve {
    private static Logger LOG = LoggerFactory.getLogger(AreaUnderTheROCCurve.class);
    private int maxObservations;
    private boolean clipObservations;
    private DoubleArrayList positiveDecisions;
    private DoubleArrayList negativeDecisions;
    private double estimatedAUC;
    private int numPositive;
    private int numNegative;
    boolean foundNan;

    public AreaUnderTheROCCurve() {
        this.foundNan = false;
        this.positiveDecisions = new DoubleArrayList();
        this.negativeDecisions = new DoubleArrayList();
        this.maxObservations = Integer.MAX_VALUE;
        this.clipObservations = false;
    }

    public AreaUnderTheROCCurve(int i) {
        this();
        this.maxObservations = i;
        this.clipObservations = true;
    }

    public void reset() {
        this.positiveDecisions.clear();
        this.negativeDecisions.clear();
        this.foundNan = false;
    }

    public void observe(double d, double d2) {
        if (!this.foundNan && d != d) {
            LOG.warn("NaN found instead of a decision value. NaN are always interpreted as wrong predictions. ");
            this.foundNan = true;
        }
        if (d2 >= 0.0d) {
            this.positiveDecisions.add(d);
        } else {
            this.negativeDecisions.add(d);
        }
    }

    public double evaluateStatistic() {
        double d = 0.0d;
        this.numPositive = 0;
        this.numNegative = 0;
        if (this.clipObservations) {
            clipObservations();
        }
        DoubleListIterator it = this.positiveDecisions.iterator();
        while (it.hasNext()) {
            double doubleValue = ((Double) it.next()).doubleValue();
            DoubleListIterator it2 = this.negativeDecisions.iterator();
            while (it2.hasNext()) {
                double doubleValue2 = ((Double) it2.next()).doubleValue();
                d = d + (doubleValue > doubleValue2 ? 1.0d : 0.0d) + (doubleValue == doubleValue2 ? 0.5d : 0.0d);
            }
        }
        this.numPositive = this.positiveDecisions.size();
        this.numNegative = this.negativeDecisions.size();
        double d2 = (d / this.numPositive) / this.numNegative;
        this.estimatedAUC = d2;
        return d2;
    }

    public double[] confidenceInterval95() {
        double d = this.estimatedAUC;
        double d2 = d / (2.0d - d);
        double d3 = d * d;
        double d4 = (2.0d * d3) / (1.0d + d);
        double d5 = this.numPositive;
        double d6 = this.numNegative;
        double sqrt = Math.sqrt(((((d * (1.0d - d)) + ((d5 - 1.0d) * (d2 - d3))) + ((d6 - 1.0d) * (d4 - d3))) / d5) / d6);
        return new double[]{d - (1.9600000381469727d * sqrt), d + (1.9600000381469727d * sqrt)};
    }

    private void clipObservations() {
        if (this.positiveDecisions.size() > this.maxObservations || this.negativeDecisions.size() > this.maxObservations) {
            if (this.maxObservations < this.positiveDecisions.size()) {
                Collections.shuffle(this.positiveDecisions);
                this.positiveDecisions.size(this.maxObservations);
            }
            if (this.maxObservations < this.negativeDecisions.size()) {
                Collections.shuffle(this.negativeDecisions);
                this.negativeDecisions.size(this.maxObservations);
            }
        }
    }

    public static double evaluateStatistic(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        DoubleArrayList doubleArrayList = new DoubleArrayList();
        DoubleArrayList doubleArrayList2 = new DoubleArrayList();
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] != dArr[i]) {
                LOG.warn("NaN found instead of a decision value. NaN are always interpreted as wrong predictions. ");
            }
            if (dArr2[i] >= 0.0d) {
                doubleArrayList.add(dArr[i]);
            } else {
                doubleArrayList2.add(dArr[i]);
            }
        }
        DoubleListIterator it = doubleArrayList.iterator();
        while (it.hasNext()) {
            double doubleValue = ((Double) it.next()).doubleValue();
            DoubleListIterator it2 = doubleArrayList2.iterator();
            while (it2.hasNext()) {
                double doubleValue2 = ((Double) it2.next()).doubleValue();
                d = d + (doubleValue > doubleValue2 ? 1.0d : 0.0d) + (doubleValue == doubleValue2 ? 0.5d : 0.0d);
            }
        }
        return (d / doubleArrayList.size()) / doubleArrayList2.size();
    }
}
