/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.metrics;

import java.util.Comparator;
import org.neo4j.gds.ml.metrics.Metric;
import org.neo4j.gds.ml.metrics.SignedProbabilities;
import org.neo4j.gds.ml.metrics.classification.OutOfBagError;

/*
 * Uses 'sealed' constructs - enablewith --sealed true
 */
public enum LinkMetric implements Metric
{
    AUCPR{

        @Override
        public Comparator<Double> comparator() {
            return Comparator.naturalOrder();
        }
    };


    public static Metric parseLinkMetric(Object nameOrMetric) {
        if (nameOrMetric instanceof Metric) {
            return (Metric)nameOrMetric;
        }
        if (nameOrMetric instanceof String) {
            String name = (String)nameOrMetric;
            if (name.equals(OutOfBagError.OUT_OF_BAG_ERROR.name())) {
                return OutOfBagError.OUT_OF_BAG_ERROR;
            }
            return LinkMetric.valueOf(name);
        }
        throw new IllegalArgumentException("Metrics must be strings");
    }

    public double compute(SignedProbabilities signedProbabilities, double negativeClassWeight) {
        long positiveCount = signedProbabilities.positiveCount();
        long negativeCount = signedProbabilities.negativeCount();
        if (positiveCount == 0L) {
            return 0.0;
        }
        CurveConsumer curveConsumer = new CurveConsumer();
        SignedProbabilitiesConsumer signedProbabilitiesConsumer = new SignedProbabilitiesConsumer(curveConsumer, positiveCount, negativeCount, negativeClassWeight);
        double recall = signedProbabilitiesConsumer.recall(positiveCount);
        double precision = signedProbabilitiesConsumer.precision(positiveCount);
        curveConsumer.acceptFirstPoint(recall, precision);
        signedProbabilities.stream().forEach(signedProbabilitiesConsumer::accept);
        curveConsumer.accept(0.0, 1.0);
        return curveConsumer.auc();
    }

    private static class CurveConsumer {
        private double auc;
        private double previousYcoordinate;
        private double previousXcoordinate;

        private CurveConsumer() {
        }

        void acceptFirstPoint(double x, double y) {
            this.previousXcoordinate = x;
            this.previousYcoordinate = y;
        }

        void accept(double x, double y) {
            this.auc += (this.previousYcoordinate + y) * (this.previousXcoordinate - x) / 2.0;
            this.previousXcoordinate = x;
            this.previousYcoordinate = y;
        }

        double auc() {
            return this.auc;
        }
    }

    private static final class SignedProbabilitiesConsumer {
        private final CurveConsumer innerConsumer;
        private final long positiveCount;
        private final long negativeCount;
        private final double negativeClassWeight;
        private double lastThreshold;
        private long positivesSeen;
        private long negativesSeen;

        private SignedProbabilitiesConsumer(CurveConsumer innerConsumer, long positiveCount, long negativeCount, double negativeClassWeight) {
            this.innerConsumer = innerConsumer;
            this.positiveCount = positiveCount;
            this.negativeCount = negativeCount;
            this.negativeClassWeight = negativeClassWeight;
            this.positivesSeen = 0L;
            this.negativesSeen = 0L;
        }

        void accept(double signedProbability) {
            boolean hasSeenAValue;
            boolean bl = hasSeenAValue = this.positivesSeen > 0L || this.negativesSeen > 0L;
            if (hasSeenAValue && Math.abs(signedProbability) != this.lastThreshold) {
                this.reportPointOnCurve();
            }
            this.lastThreshold = Math.abs(signedProbability);
            if (signedProbability > 0.0) {
                ++this.positivesSeen;
            } else {
                ++this.negativesSeen;
            }
        }

        private void reportPointOnCurve() {
            long truePositives = this.positiveCount - this.positivesSeen;
            if (truePositives == 0L) {
                this.innerConsumer.accept(0.0, 0.0);
            } else {
                this.innerConsumer.accept(this.recall(truePositives), this.precision(truePositives));
            }
        }

        private double precision(double truePositives) {
            long falsePositives = this.negativeCount - this.negativesSeen;
            return truePositives / (truePositives + this.negativeClassWeight * (double)falsePositives);
        }

        private double recall(double truePositives) {
            long falseNegatives = this.positivesSeen;
            return truePositives / (truePositives + (double)falseNegatives);
        }
    }
}

