/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.nodemodels.pipeline.predict;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.nodemodels.NodeClassificationPredict;
import org.neo4j.gds.ml.nodemodels.logisticregression.NodeClassificationResult;
import org.neo4j.gds.ml.nodemodels.logisticregression.NodeLogisticRegressionData;
import org.neo4j.gds.ml.nodemodels.logisticregression.NodeLogisticRegressionPredictor;
import org.neo4j.gds.ml.nodemodels.pipeline.predict.NodeClassificationPredictPipelineBaseConfig;
import org.neo4j.gds.ml.pipeline.ImmutableGraphFilter;
import org.neo4j.gds.ml.pipeline.Pipeline;
import org.neo4j.gds.ml.pipeline.PipelineExecutor;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeClassificationPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeClassificationPipelineModelInfo;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeClassificationPipelineTrainConfig;

public class NodeClassificationPredictPipelineExecutor
extends PipelineExecutor<NodeClassificationPredictPipelineBaseConfig, NodeClassificationPipeline, NodeClassificationResult> {
    private static final int MIN_BATCH_SIZE = 100;
    private final NodeLogisticRegressionData modelData;

    NodeClassificationPredictPipelineExecutor(NodeClassificationPipeline pipeline, NodeClassificationPredictPipelineBaseConfig config, ExecutionContext executionContext, GraphStore graphStore, String graphName, ProgressTracker progressTracker, NodeLogisticRegressionData modelData) {
        super((Pipeline)pipeline, (AlgoBaseConfig)config, executionContext, graphStore, graphName, progressTracker);
        this.modelData = modelData;
    }

    public static MemoryEstimation estimate(Model<NodeLogisticRegressionData, NodeClassificationPipelineTrainConfig, NodeClassificationPipelineModelInfo> model, NodeClassificationPredictPipelineBaseConfig configuration, ModelCatalog modelCatalog) {
        NodeClassificationPipeline pipeline = ((NodeClassificationPipelineModelInfo)model.customInfo()).trainingPipeline();
        int classCount = ((NodeClassificationPipelineModelInfo)model.customInfo()).classes().size();
        int featureCount = ((Matrix)((NodeLogisticRegressionData)model.data()).weights().data()).totalSize();
        MemoryEstimation nodePropertyStepEstimation = PipelineExecutor.estimateNodePropertySteps((ModelCatalog)modelCatalog, (List)pipeline.nodePropertySteps(), (List)configuration.nodeLabels(), (List)configuration.relationshipTypes());
        MemoryEstimation predictionEstimation = MemoryEstimations.builder().add("Pipeline Predict", NodeClassificationPredict.memoryEstimationWithDerivedBatchSize((boolean)configuration.includePredictedProbabilities(), (int)100, (int)featureCount, (int)classCount)).build();
        return MemoryEstimations.maxEstimation(List.of(nodePropertyStepEstimation, predictionEstimation));
    }

    public Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> splitDataset() {
        return Map.of(PipelineExecutor.DatasetSplits.FEATURE_INPUT, ImmutableGraphFilter.of((Collection)((NodeClassificationPredictPipelineBaseConfig)this.config).nodeLabelIdentifiers(this.graphStore), (Collection)((NodeClassificationPredictPipelineBaseConfig)this.config).internalRelationshipTypes(this.graphStore)));
    }

    protected NodeClassificationResult execute(Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> dataSplits) {
        Graph graph = this.graphStore.getGraph(((NodeClassificationPredictPipelineBaseConfig)this.config).nodeLabelIdentifiers(this.graphStore), ((NodeClassificationPredictPipelineBaseConfig)this.config).internalRelationshipTypes(this.graphStore), Optional.empty());
        NodeClassificationPredict innerAlgo = new NodeClassificationPredict(new NodeLogisticRegressionPredictor(this.modelData, ((NodeClassificationPipeline)this.pipeline).featureProperties()), graph, 100, ((NodeClassificationPredictPipelineBaseConfig)this.config).concurrency(), ((NodeClassificationPredictPipelineBaseConfig)this.config).includePredictedProbabilities(), ((NodeClassificationPipeline)this.pipeline).featureProperties(), this.progressTracker);
        return innerAlgo.compute();
    }
}

