/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.nodePropertyPrediction.regression;

import org.neo4j.gds.collections.ha.HugeDoubleArray;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.Regressor;
import org.neo4j.gds.termination.TerminationFlag;

public class NodeRegressionPredict {
    private final Regressor regressor;
    private final Features features;
    private final int concurrency;
    private final ProgressTracker progressTracker;
    private final TerminationFlag terminationFlag;

    public NodeRegressionPredict(Regressor regressor, Features features, int concurrency, ProgressTracker progressTracker, TerminationFlag terminationFlag) {
        this.regressor = regressor;
        this.features = features;
        this.concurrency = concurrency;
        this.progressTracker = progressTracker;
        this.terminationFlag = terminationFlag;
    }

    public static Task progressTask(long nodeCount) {
        return Tasks.leaf((String)"Predict", (long)nodeCount);
    }

    public HugeDoubleArray compute() {
        this.progressTracker.beginSubTask("Predict");
        HugeDoubleArray predictedTargets = HugeDoubleArray.newArray((long)this.features.size());
        ParallelUtil.parallelForEachNode((long)this.features.size(), (int)this.concurrency, (TerminationFlag)this.terminationFlag, id -> predictedTargets.set(id, this.regressor.predict(this.features.get(id))));
        this.progressTracker.endSubTask("Predict");
        return predictedTargets;
    }
}

