/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.nodeClassification;

import java.util.PrimitiveIterator;
import java.util.function.Consumer;
import org.jetbrains.annotations.Nullable;
import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.collections.ha.HugeObjectArray;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.core.batch.BatchTransformer;
import org.neo4j.gds.ml.core.batch.MappedBatch;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.models.Classifier;
import org.neo4j.gds.ml.models.Features;

public class NodeClassificationPredictConsumer
implements Consumer<Batch> {
    private final Features features;
    private final BatchTransformer nodeIds;
    private final Classifier classifier;
    @Nullable
    private final HugeObjectArray<double[]> predictedProbabilities;
    private final HugeIntArray predictedClasses;
    private final ProgressTracker progressTracker;

    NodeClassificationPredictConsumer(Features features, BatchTransformer nodeIds, Classifier classifier, @Nullable HugeObjectArray<double[]> predictedProbabilities, HugeIntArray predictedClasses, ProgressTracker progressTracker) {
        this.features = features;
        this.nodeIds = nodeIds;
        this.classifier = classifier;
        this.predictedProbabilities = predictedProbabilities;
        this.predictedClasses = predictedClasses;
        this.progressTracker = progressTracker;
    }

    @Override
    public void accept(Batch batch) {
        int numberOfClasses = this.classifier.numberOfClasses();
        Matrix probabilityMatrix = this.classifier.predictProbabilities((Batch)new MappedBatch(batch, this.nodeIds), this.features);
        int currentRow = 0;
        PrimitiveIterator.OfLong batchIterator = batch.elementIds();
        while (batchIterator.hasNext()) {
            long nodeId = batchIterator.nextLong();
            if (this.predictedProbabilities != null) {
                this.predictedProbabilities.set(nodeId, (Object)probabilityMatrix.getRow(currentRow));
            }
            int bestClassId = -1;
            double maxProbability = -1.0;
            for (int classId = 0; classId < numberOfClasses; ++classId) {
                double probability = probabilityMatrix.dataAt(currentRow, classId);
                if (!(probability > maxProbability)) continue;
                maxProbability = probability;
                bestClassId = classId;
            }
            this.predictedClasses.set(nodeId, bestClassId);
            ++currentRow;
        }
        this.progressTracker.logSteps((long)batch.size());
    }
}

