/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.nodeClassification;

import java.util.Optional;
import org.jetbrains.annotations.Nullable;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.api.TrainingMethod;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.ClassifierFactory;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.logisticregression.LogisticRegressionClassifier;
import org.neo4j.gds.ml.nodeClassification.ImmutableNodeClassificationResult;
import org.neo4j.gds.ml.nodeClassification.ParallelNodeClassifier;
import org.neo4j.gds.termination.TerminationFlag;

public class NodeClassificationPredict {
    private final Classifier classifier;
    private final Features features;
    private final boolean produceProbabilities;
    private final ProgressTracker progressTracker;
    private final ParallelNodeClassifier predictor;

    public NodeClassificationPredict(Classifier classifier, Features features, int batchSize, int concurrency, boolean produceProbabilities, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        this.classifier = classifier;
        this.features = features;
        this.produceProbabilities = produceProbabilities;
        this.progressTracker = progressTracker;
        this.predictor = new ParallelNodeClassifier(classifier, features, batchSize, concurrency, terminationFlag, progressTracker);
    }

    public static Task progressTask(long nodeCount) {
        return Tasks.leaf((String)"Node classification predict", (long)nodeCount);
    }

    public static MemoryEstimation memoryEstimation(boolean produceProbabilities, int batchSize, int featureCount, int classCount) {
        MemoryEstimations.Builder builder = MemoryEstimations.builder((String)NodeClassificationPredict.class.getSimpleName());
        if (produceProbabilities) {
            builder.perNode("predicted probabilities", nodeCount -> HugeObjectArray.memoryEstimation((long)nodeCount, (long)MemoryUsage.sizeOfDoubleArray((long)classCount)));
        }
        builder.perNode("predicted classes", HugeLongArray::memoryEstimation);
        builder.fixed("computation graph", LogisticRegressionClassifier.sizeOfPredictionsVariableInBytes(batchSize, featureCount, classCount, classCount));
        return builder.build();
    }

    public static MemoryEstimation memoryEstimationWithDerivedBatchSize(TrainingMethod method, boolean produceProbabilities, int minBatchSize, int featureCount, int classCount, boolean isReduced) {
        MemoryEstimations.Builder builder = MemoryEstimations.builder((String)NodeClassificationPredict.class.getSimpleName());
        if (produceProbabilities) {
            builder.perNode("predicted probabilities", nodeCount -> HugeObjectArray.memoryEstimation((long)nodeCount, (long)MemoryUsage.sizeOfDoubleArray((long)classCount)));
        }
        builder.perNode("predicted classes", HugeLongArray::memoryEstimation);
        builder.perGraphDimension("classifier runtime", (dim, threads) -> ClassifierFactory.runtimeOverheadMemoryEstimation(method, (int)Math.min(dim.nodeCount(), (long)BatchQueue.computeBatchSize((long)dim.nodeCount(), (int)minBatchSize, (int)threads)), classCount, featureCount, isReduced));
        return builder.build();
    }

    public NodeClassificationResult compute() {
        this.progressTracker.beginSubTask();
        this.progressTracker.setSteps(this.features.size());
        HugeObjectArray<double[]> predictedProbabilities = this.initProbabilities();
        HugeIntArray predictedClasses = this.predictor.predict(predictedProbabilities);
        this.progressTracker.endSubTask();
        return NodeClassificationResult.of(predictedClasses, predictedProbabilities);
    }

    @Nullable
    private HugeObjectArray<double[]> initProbabilities() {
        if (this.produceProbabilities) {
            int numberOfClasses = this.classifier.numberOfClasses();
            HugeObjectArray predictions = HugeObjectArray.newArray(double[].class, (long)this.features.size());
            predictions.setAll(i -> new double[numberOfClasses]);
            return predictions;
        }
        return null;
    }

    @ValueClass
    public static interface NodeClassificationResult {
        public HugeIntArray predictedClasses();

        public Optional<HugeObjectArray<double[]>> predictedProbabilities();

        public static NodeClassificationResult of(HugeIntArray classes, @Nullable HugeObjectArray<double[]> probabilities) {
            return ImmutableNodeClassificationResult.builder().predictedProbabilities(Optional.ofNullable(probabilities)).predictedClasses(classes).build();
        }
    }
}

