/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.nodemodels.pipeline.predict;

import java.util.List;
import java.util.Map;
import org.neo4j.gds.GraphStoreAlgorithmFactory;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.mem.AllocationTracker;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.nodemodels.NodeClassificationPredictAlgorithmFactory;
import org.neo4j.gds.ml.nodemodels.NodeClassificationPredictConfig;
import org.neo4j.gds.ml.nodemodels.NodeClassificationPredictConfigImpl;
import org.neo4j.gds.ml.nodemodels.logisticregression.NodeLogisticRegressionData;
import org.neo4j.gds.ml.nodemodels.pipeline.NodeClassificationPipelineCompanion;
import org.neo4j.gds.ml.nodemodels.pipeline.predict.NodeClassificationPredictPipelineBaseConfig;
import org.neo4j.gds.ml.nodemodels.pipeline.predict.NodeClassificationPredictPipelineExecutor;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeClassificationPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeClassificationPipelineModelInfo;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeClassificationPipelineTrainConfig;

public class NodeClassificationPredictPipelineAlgorithmFactory<CONFIG extends NodeClassificationPredictPipelineBaseConfig>
extends GraphStoreAlgorithmFactory<NodeClassificationPredictPipelineExecutor, CONFIG> {
    private final ModelCatalog modelCatalog;
    private final ExecutionContext executionContext;
    private final NodeClassificationPredictAlgorithmFactory<NodeClassificationPredictConfig> innerFactory;

    NodeClassificationPredictPipelineAlgorithmFactory(ExecutionContext executionContext, ModelCatalog modelCatalog) {
        this.modelCatalog = modelCatalog;
        this.innerFactory = new NodeClassificationPredictAlgorithmFactory(modelCatalog);
        this.executionContext = executionContext;
    }

    public Task progressTask(GraphStore graphStore, CONFIG config) {
        NodeClassificationPipeline trainingPipeline = ((NodeClassificationPipelineModelInfo)NodeClassificationPipelineCompanion.getTrainedNCPipelineModel(this.modelCatalog, config.modelName(), config.username()).customInfo()).trainingPipeline();
        return Tasks.task((String)this.taskName(), (Task)Tasks.iterativeFixed((String)"execute node property steps", () -> List.of(Tasks.leaf((String)"step")), (int)trainingPipeline.nodePropertySteps().size()), (Task[])new Task[]{this.innerFactory.progressTask(graphStore.getUnion(), this.innerConfig(config))});
    }

    private NodeClassificationPredictConfig innerConfig(CONFIG configuration) {
        return new NodeClassificationPredictConfigImpl(configuration.username(), CypherMapWrapper.create((Map)configuration.toMap()).withEntry("includePredictedProbabilities", (Object)configuration.includePredictedProbabilities()).withoutEntry("predictedProbabilityProperty"));
    }

    public String taskName() {
        return "Node Classification Predict Pipeline";
    }

    public NodeClassificationPredictPipelineExecutor build(GraphStore graphStore, CONFIG configuration, AllocationTracker allocationTracker, ProgressTracker progressTracker) {
        Model<NodeLogisticRegressionData, NodeClassificationPipelineTrainConfig, NodeClassificationPipelineModelInfo> model = NodeClassificationPipelineCompanion.getTrainedNCPipelineModel(this.modelCatalog, configuration.modelName(), configuration.username());
        NodeClassificationPipeline nodeClassificationPipeline = ((NodeClassificationPipelineModelInfo)model.customInfo()).trainingPipeline();
        return new NodeClassificationPredictPipelineExecutor(nodeClassificationPipeline, (NodeClassificationPredictPipelineBaseConfig)configuration, this.executionContext, graphStore, configuration.graphName(), progressTracker, (NodeLogisticRegressionData)model.data());
    }

    public MemoryEstimation memoryEstimation(CONFIG configuration) {
        Model<NodeLogisticRegressionData, NodeClassificationPipelineTrainConfig, NodeClassificationPipelineModelInfo> model = NodeClassificationPipelineCompanion.getTrainedNCPipelineModel(this.modelCatalog, configuration.modelName(), configuration.username());
        return MemoryEstimations.builder(NodeClassificationPredictPipelineExecutor.class).add("Pipeline executor", NodeClassificationPredictPipelineExecutor.estimate(model, configuration, this.modelCatalog)).build();
    }
}

