/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.metrics.classification;

import com.carrotsearch.hppc.BitSet;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.LongAdder;
import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.collections.haa.HugeAtomicLongArray;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.ml.decisiontree.DecisionTreePredictor;
import org.neo4j.gds.ml.metrics.Metric;
import org.neo4j.gds.ml.models.Features;

public final class OutOfBagError
implements Metric {
    public static final OutOfBagError OUT_OF_BAG_ERROR = new OutOfBagError();

    private OutOfBagError() {
    }

    @Override
    public boolean isModelSpecific() {
        return true;
    }

    public static void addPredictionsForTree(DecisionTreePredictor<Integer> decisionTree, int numberOfClasses, Features allFeatureVectors, ReadOnlyHugeLongArray trainSet, BitSet sampledTrainSet, HugeAtomicLongArray predictions) {
        for (long trainSetIdx = 0L; trainSetIdx < trainSet.size(); ++trainSetIdx) {
            if (sampledTrainSet.get(trainSetIdx)) continue;
            double[] featureVector = allFeatureVectors.get(trainSet.get(trainSetIdx));
            Integer prediction = decisionTree.predict(featureVector);
            predictions.getAndAdd(trainSetIdx * (long)numberOfClasses + (long)prediction.intValue(), 1L);
        }
    }

    public static double evaluate(ReadOnlyHugeLongArray trainSet, int numberOfClasses, HugeIntArray expectedLabels, int concurrency, HugeAtomicLongArray predictions) {
        LongAdder totalMistakes = new LongAdder();
        LongAdder totalOutOfAnyBagVectors = new LongAdder();
        List tasks = PartitionUtils.rangePartition((int)concurrency, (long)trainSet.size(), partition -> OutOfBagError.accumulationTask(partition, numberOfClasses, trainSet, predictions, expectedLabels, totalMistakes, totalOutOfAnyBagVectors), Optional.empty());
        RunWithConcurrency.builder().concurrency(concurrency).tasks((Iterable)tasks).run();
        if (totalOutOfAnyBagVectors.longValue() == 0L) {
            return 0.0;
        }
        return totalMistakes.doubleValue() / totalOutOfAnyBagVectors.doubleValue();
    }

    private static Runnable accumulationTask(Partition partition, int numberOfClasses, ReadOnlyHugeLongArray trainSet, HugeAtomicLongArray predictions, HugeIntArray expectedLabels, LongAdder totalMistakes, LongAdder totalOutOfAnyBagVectors) {
        return () -> {
            long numMistakes = 0L;
            long numOutOfAnyBagVectors = 0L;
            long startOffset = partition.startNode();
            long endOffset = startOffset + partition.nodeCount();
            for (long i = startOffset; i < endOffset; ++i) {
                long innerOffset = i * (long)numberOfClasses;
                long max = 0L;
                int maxClassIdx = 0;
                for (int j = 0; j < numberOfClasses; ++j) {
                    long numPredictions = predictions.get(innerOffset + (long)j);
                    if (numPredictions <= max) continue;
                    max = numPredictions;
                    maxClassIdx = j;
                }
                if (max == 0L) continue;
                ++numOutOfAnyBagVectors;
                if (maxClassIdx == expectedLabels.get(trainSet.get(i))) continue;
                ++numMistakes;
            }
            totalMistakes.add(numMistakes);
            totalOutOfAnyBagVectors.add(numOutOfAnyBagVectors);
        };
    }

    @Override
    public String name() {
        return "OUT_OF_BAG_ERROR";
    }

    public String toString() {
        return this.name();
    }

    @Override
    public Comparator<Double> comparator() {
        return Comparator.naturalOrder();
    }
}

