/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.nodeClassification;

import java.util.function.Consumer;
import org.jetbrains.annotations.Nullable;
import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.core.batch.BatchTransformer;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.nodeClassification.NodeClassificationPredictConsumer;
import org.neo4j.gds.termination.TerminationFlag;

public class ParallelNodeClassifier {
    private final Classifier classifier;
    private final Features features;
    private final int batchSize;
    private final int concurrency;
    private final TerminationFlag terminationFlag;
    private final ProgressTracker progressTracker;

    ParallelNodeClassifier(Classifier classifier, Features features, int batchSize, int concurrency, TerminationFlag terminationFlag, ProgressTracker progressTracker) {
        this.classifier = classifier;
        this.features = features;
        this.batchSize = batchSize;
        this.concurrency = concurrency;
        this.terminationFlag = terminationFlag;
        this.progressTracker = progressTracker;
    }

    public HugeIntArray predict(ReadOnlyHugeLongArray evaluationSet) {
        return this.predict(evaluationSet.size(), arg_0 -> ((ReadOnlyHugeLongArray)evaluationSet).get(arg_0), null);
    }

    public HugeIntArray predict(@Nullable HugeObjectArray<double[]> predictedProbabilities) {
        return this.predict(this.features.size(), BatchTransformer.IDENTITY, predictedProbabilities);
    }

    private HugeIntArray predict(long evaluationSetSize, BatchTransformer batchTransformer, @Nullable HugeObjectArray<double[]> predictedProbabilities) {
        HugeIntArray predictedClasses = HugeIntArray.newArray((long)evaluationSetSize);
        NodeClassificationPredictConsumer consumer = new NodeClassificationPredictConsumer(this.features, batchTransformer, this.classifier, predictedProbabilities, predictedClasses, this.progressTracker);
        BatchQueue.consecutive((long)evaluationSetSize, (int)this.batchSize, (int)this.concurrency).parallelConsume((Consumer)consumer, this.concurrency, this.terminationFlag);
        return predictedClasses;
    }
}

