/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.nodemodels.pipeline.predict;

import java.util.Collection;
import java.util.Map;
import java.util.Optional;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.nodemodels.NodeClassificationPredict;
import org.neo4j.gds.ml.nodemodels.logisticregression.NodeClassificationResult;
import org.neo4j.gds.ml.nodemodels.logisticregression.NodeLogisticRegressionData;
import org.neo4j.gds.ml.nodemodels.logisticregression.NodeLogisticRegressionPredictor;
import org.neo4j.gds.ml.nodemodels.pipeline.NodeClassificationPipeline;
import org.neo4j.gds.ml.nodemodels.pipeline.predict.NodeClassificationPredictPipelineBaseConfig;
import org.neo4j.gds.ml.pipeline.ImmutableGraphFilter;
import org.neo4j.gds.ml.pipeline.Pipeline;
import org.neo4j.gds.ml.pipeline.PipelineExecutor;

public class NodeClassificationPredictPipelineExecutor
extends PipelineExecutor<NodeClassificationPredictPipelineBaseConfig, NodeClassificationPipeline, NodeClassificationResult> {
    private static final int BATCH_SIZE = 100;
    private final NodeLogisticRegressionData modelData;

    NodeClassificationPredictPipelineExecutor(NodeClassificationPipeline pipeline, NodeClassificationPredictPipelineBaseConfig config, ExecutionContext executionContext, GraphStore graphStore, String graphName, ProgressTracker progressTracker, NodeLogisticRegressionData modelData) {
        super((Pipeline)pipeline, (AlgoBaseConfig)config, executionContext, graphStore, graphName, progressTracker);
        this.modelData = modelData;
    }

    public Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> splitDataset() {
        return Map.of(PipelineExecutor.DatasetSplits.FEATURE_INPUT, ImmutableGraphFilter.of((Collection)((NodeClassificationPredictPipelineBaseConfig)this.config).nodeLabelIdentifiers(this.graphStore), (Collection)((NodeClassificationPredictPipelineBaseConfig)this.config).internalRelationshipTypes(this.graphStore)));
    }

    protected NodeClassificationResult execute(Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> dataSplits) {
        Graph graph = this.graphStore.getGraph(((NodeClassificationPredictPipelineBaseConfig)this.config).nodeLabelIdentifiers(this.graphStore), ((NodeClassificationPredictPipelineBaseConfig)this.config).internalRelationshipTypes(this.graphStore), Optional.empty());
        NodeClassificationPredict innerAlgo = new NodeClassificationPredict(new NodeLogisticRegressionPredictor(this.modelData, ((NodeClassificationPipeline)this.pipeline).featureProperties()), graph, 100, ((NodeClassificationPredictPipelineBaseConfig)this.config).concurrency(), ((NodeClassificationPredictPipelineBaseConfig)this.config).includePredictedProbabilities(), ((NodeClassificationPipeline)this.pipeline).featureProperties(), this.executionContext.allocationTracker(), this.progressTracker);
        return innerAlgo.compute();
    }
}

