/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.nodeClassification;

import java.util.function.LongUnaryOperator;
import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.metrics.classification.ClassificationMetric;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.ClassifierFactory;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.TrainerConfig;
import org.neo4j.gds.ml.nodeClassification.ParallelNodeClassifier;
import org.neo4j.gds.termination.TerminationFlag;

public final class ClassificationMetricComputer {
    private final HugeIntArray predictedClasses;
    private final HugeIntArray labels;

    private ClassificationMetricComputer(HugeIntArray predictedClasses, HugeIntArray labels) {
        this.labels = labels;
        this.predictedClasses = predictedClasses;
    }

    public double score(ClassificationMetric metric) {
        return metric.compute(this.labels, this.predictedClasses);
    }

    public static ClassificationMetricComputer forEvaluationSet(Features features, HugeIntArray labels, ReadOnlyHugeLongArray evaluationSet, Classifier classifier, int concurrency, TerminationFlag terminationFlag, ProgressTracker progressTracker) {
        ParallelNodeClassifier predictor = new ParallelNodeClassifier(classifier, features, 100, concurrency, terminationFlag, progressTracker);
        return new ClassificationMetricComputer(predictor.predict(evaluationSet), ClassificationMetricComputer.makeLocalTargets(evaluationSet, labels));
    }

    private static HugeIntArray makeLocalTargets(ReadOnlyHugeLongArray nodeIds, HugeIntArray targets) {
        HugeIntArray localTargets = HugeIntArray.newArray((long)nodeIds.size());
        localTargets.setAll(i -> targets.get(nodeIds.get(i)));
        return localTargets;
    }

    public static MemoryEstimation estimateEvaluation(TrainerConfig config, int batchSize, LongUnaryOperator trainSetSize, LongUnaryOperator testSetSize, int fudgedClassCount, int fudgedFeatureCount, boolean isReduced) {
        return MemoryEstimations.builder((String)"computing metrics").perNode("local targets", nodeCount -> {
            long sizeOfLargePartOfAFold = testSetSize.applyAsLong(nodeCount);
            return HugeLongArray.memoryEstimation((long)sizeOfLargePartOfAFold);
        }).perNode("predicted classes", nodeCount -> {
            long sizeOfLargePartOfAFold = testSetSize.applyAsLong(nodeCount);
            return HugeLongArray.memoryEstimation((long)sizeOfLargePartOfAFold);
        }).add("classifier model", ClassifierFactory.dataMemoryEstimation(config, trainSetSize, fudgedClassCount, fudgedFeatureCount, isReduced)).rangePerNode("classifier runtime", nodeCount -> ClassifierFactory.runtimeOverheadMemoryEstimation(config.method(), batchSize, fudgedClassCount, fudgedFeatureCount, isReduced)).build();
    }
}

