/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.metrics;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.Optional;
import java.util.PrimitiveIterator;
import java.util.stream.DoubleStream;
import org.neo4j.gds.collections.ha.HugeDoubleArray;
import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.termination.TerminationFlag;

public abstract class SignedProbabilities {
    static double ALMOST_ZERO = 1.0E-100;
    private static final Comparator<Double> ABSOLUTE_VALUE_COMPARATOR = Comparator.comparingDouble(Math::abs);
    private long positiveCount;
    private long negativeCount;

    public static long estimateMemory(long relationshipSetSize) {
        return MemoryUsage.sizeOfInstance(SignedProbabilities.class) + MemoryUsage.sizeOfInstance(Optional.class) + MemoryUsage.sizeOfInstance(ArrayList.class) + MemoryUsage.sizeOfInstance(Double.class) * relationshipSetSize;
    }

    static SignedProbabilities create(long capacity) {
        if (capacity > Integer.MAX_VALUE) {
            return new Huge(capacity);
        }
        return new ArrayBased((int)capacity);
    }

    public static SignedProbabilities computeFromLabeledData(Features features, HugeIntArray labels, Classifier classifier, BatchQueue evaluationQueue, int concurrency, TerminationFlag terminationFlag, ProgressTracker progressTracker) {
        progressTracker.setSteps(features.size());
        SignedProbabilities signedProbabilities = SignedProbabilities.create(evaluationQueue.totalSize());
        int positiveClassIndex = 1;
        evaluationQueue.parallelConsume(concurrency, __ -> batch -> {
            Matrix probabilityMatrix = classifier.predictProbabilities((Batch)batch, features);
            int offset = 0;
            PrimitiveIterator.OfLong batchIterator = batch.elementIds();
            while (batchIterator.hasNext()) {
                long relationshipIdx = batchIterator.nextLong();
                double probabilityOfPositiveEdge = probabilityMatrix.dataAt(offset++, positiveClassIndex);
                boolean isEdge = (double)labels.get(relationshipIdx) == 1.0;
                signedProbabilities.add(probabilityOfPositiveEdge, isEdge);
            }
            progressTracker.logSteps((long)batch.size());
        }, terminationFlag);
        return signedProbabilities;
    }

    public synchronized void add(double probability, boolean isPositive) {
        double signedProbability;
        double nonZeroProbability = probability == 0.0 ? ALMOST_ZERO : probability;
        double d = signedProbability = isPositive ? nonZeroProbability : -1.0 * nonZeroProbability;
        if (signedProbability > 0.0) {
            ++this.positiveCount;
        } else {
            ++this.negativeCount;
        }
        this.doAdd(signedProbability);
    }

    abstract void doAdd(double var1);

    public abstract DoubleStream stream();

    public long positiveCount() {
        return this.positiveCount;
    }

    public long negativeCount() {
        return this.negativeCount;
    }

    static final class Huge
    extends SignedProbabilities {
        private final HugeDoubleArray probabilities;
        private long index;

        Huge(long capacity) {
            this.probabilities = HugeDoubleArray.newArray((long)capacity);
            this.index = 0L;
        }

        @Override
        void doAdd(double signedProbability) {
            this.probabilities.set(this.index++, signedProbability);
        }

        @Override
        public DoubleStream stream() {
            return this.probabilities.stream().boxed().sorted(ABSOLUTE_VALUE_COMPARATOR).mapToDouble(d -> d);
        }
    }

    private static final class ArrayBased
    extends SignedProbabilities {
        private final ArrayList<Double> probabilities;

        private ArrayBased(int capacity) {
            this.probabilities = new ArrayList(capacity);
        }

        @Override
        void doAdd(double signedProbability) {
            this.probabilities.add(signedProbability);
        }

        @Override
        public DoubleStream stream() {
            this.probabilities.sort(ABSOLUTE_VALUE_COMPARATOR);
            return this.probabilities.stream().mapToDouble(d -> d);
        }
    }
}

