/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.nodePipeline.classification.train;

import org.neo4j.gds.GraphStoreAlgorithmFactory;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.pipeline.PipelineCatalog;
import org.neo4j.gds.ml.pipeline.PipelineCompanion;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeFeatureProducer;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.NodeClassificationTrainingPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationTrain;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationTrainAlgorithm;

public class NodeClassificationTrainPipelineAlgorithmFactory
extends GraphStoreAlgorithmFactory<NodeClassificationTrainAlgorithm, NodeClassificationPipelineTrainConfig> {
    private final ExecutionContext executionContext;

    public NodeClassificationTrainPipelineAlgorithmFactory(ExecutionContext executionContext) {
        this.executionContext = executionContext;
    }

    public NodeClassificationTrainAlgorithm build(GraphStore graphStore, NodeClassificationPipelineTrainConfig configuration, ProgressTracker progressTracker) {
        NodeClassificationTrainingPipeline pipeline = PipelineCatalog.getTyped(configuration.username(), configuration.pipeline(), NodeClassificationTrainingPipeline.class);
        return this.build(graphStore, configuration, pipeline, progressTracker);
    }

    public NodeClassificationTrainAlgorithm build(GraphStore graphStore, NodeClassificationPipelineTrainConfig configuration, NodeClassificationTrainingPipeline pipeline, ProgressTracker progressTracker) {
        PipelineCompanion.validateMainMetric(pipeline, configuration.metrics().get(0).toString());
        NodeFeatureProducer<NodeClassificationPipelineTrainConfig> nodeFeatureProducer = NodeFeatureProducer.create(graphStore, configuration, this.executionContext, progressTracker);
        nodeFeatureProducer.validateNodePropertyStepsContextConfigs(pipeline.nodePropertySteps());
        return new NodeClassificationTrainAlgorithm(NodeClassificationTrain.create(graphStore, pipeline, configuration, nodeFeatureProducer, progressTracker), pipeline, graphStore, configuration, progressTracker);
    }

    public MemoryEstimation memoryEstimation(NodeClassificationPipelineTrainConfig configuration) {
        NodeClassificationTrainingPipeline pipeline = PipelineCatalog.getTyped(configuration.username(), configuration.pipeline(), NodeClassificationTrainingPipeline.class);
        return MemoryEstimations.builder((String)NodeClassificationTrain.class.getSimpleName()).add(NodeClassificationTrain.estimate(pipeline, configuration, this.executionContext.modelCatalog())).build();
    }

    public String taskName() {
        return "Node Classification Train Pipeline";
    }

    public Task progressTask(GraphStore graphStore, NodeClassificationPipelineTrainConfig config) {
        NodeClassificationTrainingPipeline pipeline = PipelineCatalog.getTyped(config.username(), config.pipeline(), NodeClassificationTrainingPipeline.class);
        return NodeClassificationTrainPipelineAlgorithmFactory.progressTask(graphStore, pipeline);
    }

    public static Task progressTask(GraphStore graphStore, NodeClassificationTrainingPipeline pipeline) {
        return NodeClassificationTrain.progressTask(pipeline, graphStore.nodeCount());
    }
}

