/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.nodemodels;

import java.util.List;
import org.neo4j.gds.GraphStoreAlgorithmFactory;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.nodemodels.pipeline.NodeClassificationTrainPipelineExecutor;
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeClassificationPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.train.NodeClassificationPipelineTrainConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.train.NodeClassificationTrain;

public class NodeClassificationTrainPipelineAlgorithmFactory
extends GraphStoreAlgorithmFactory<NodeClassificationTrainPipelineExecutor, NodeClassificationPipelineTrainConfig> {
    private final ExecutionContext executionContext;

    public NodeClassificationTrainPipelineAlgorithmFactory(ExecutionContext executionContext) {
        this.executionContext = executionContext;
    }

    public NodeClassificationTrainPipelineExecutor build(GraphStore graphStore, NodeClassificationPipelineTrainConfig configuration, ProgressTracker progressTracker) {
        NodeClassificationPipeline pipeline = (NodeClassificationPipeline)PipelineCatalog.getTyped((String)configuration.username(), (String)configuration.pipeline(), NodeClassificationPipeline.class);
        pipeline.validateBeforeExecution(graphStore, (AlgoBaseConfig)configuration);
        return new NodeClassificationTrainPipelineExecutor(pipeline, configuration, this.executionContext, graphStore, configuration.graphName(), progressTracker);
    }

    public MemoryEstimation memoryEstimation(NodeClassificationPipelineTrainConfig configuration) {
        NodeClassificationPipeline pipeline = (NodeClassificationPipeline)PipelineCatalog.getTyped((String)configuration.username(), (String)configuration.pipeline(), NodeClassificationPipeline.class);
        return MemoryEstimations.builder(NodeClassificationTrainPipelineExecutor.class).add("Pipeline executor", NodeClassificationTrainPipelineExecutor.estimate(pipeline, configuration, this.executionContext.modelCatalog())).build();
    }

    public String taskName() {
        return "Node Classification Train Pipeline";
    }

    public Task progressTask(GraphStore graphStore, NodeClassificationPipelineTrainConfig config) {
        NodeClassificationPipeline pipeline = (NodeClassificationPipeline)PipelineCatalog.getTyped((String)config.username(), (String)config.pipeline(), NodeClassificationPipeline.class);
        return Tasks.task((String)this.taskName(), (Task)Tasks.iterativeFixed((String)"execute node property steps", () -> List.of(Tasks.leaf((String)"step")), (int)pipeline.nodePropertySteps().size()), (Task[])new Task[]{NodeClassificationTrain.progressTask((int)pipeline.splitConfig().validationFolds(), (int)pipeline.numberOfModelCandidates())});
    }
}

