/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.nodemodels.pipeline;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.pipeline.ImmutableGraphFilter;
import org.neo4j.gds.ml.pipeline.Pipeline;
import org.neo4j.gds.ml.pipeline.PipelineExecutor;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeClassificationPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.train.NodeClassificationPipelineTrainConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.train.NodeClassificationTrain;
import org.neo4j.gds.ml.pipeline.nodePipeline.train.NodeClassificationTrainResult;

public class NodeClassificationTrainPipelineExecutor
extends PipelineExecutor<NodeClassificationPipelineTrainConfig, NodeClassificationPipeline, NodeClassificationTrainResult> {
    public NodeClassificationTrainPipelineExecutor(NodeClassificationPipeline pipeline, NodeClassificationPipelineTrainConfig config, ExecutionContext executionContext, GraphStore graphStore, String graphName, ProgressTracker progressTracker) {
        super((Pipeline)pipeline, (AlgoBaseConfig)config, executionContext, graphStore, graphName, progressTracker);
    }

    public static MemoryEstimation estimate(NodeClassificationPipeline pipeline, NodeClassificationPipelineTrainConfig configuration, ModelCatalog modelCatalog) {
        PipelineExecutor.validateTrainingParameterSpace((Pipeline)pipeline);
        MemoryEstimation nodePropertyStepsEstimation = PipelineExecutor.estimateNodePropertySteps((ModelCatalog)modelCatalog, (List)pipeline.nodePropertySteps(), (List)configuration.nodeLabels(), (List)configuration.relationshipTypes());
        MemoryEstimation trainingEstimation = MemoryEstimations.builder().add("Pipeline Train", NodeClassificationTrain.estimate((NodeClassificationPipeline)pipeline, (NodeClassificationPipelineTrainConfig)configuration)).build();
        return MemoryEstimations.maxEstimation((String)"Pipeline executor", List.of(nodePropertyStepsEstimation, trainingEstimation));
    }

    public Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> splitDataset() {
        return Map.of(PipelineExecutor.DatasetSplits.FEATURE_INPUT, ImmutableGraphFilter.of((Collection)((NodeClassificationPipelineTrainConfig)this.config).nodeLabelIdentifiers(this.graphStore), (Collection)((NodeClassificationPipelineTrainConfig)this.config).internalRelationshipTypes(this.graphStore)));
    }

    protected NodeClassificationTrainResult execute(Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> dataSplits) {
        PipelineExecutor.validateTrainingParameterSpace((Pipeline)this.pipeline);
        Collection nodeLabels = ((NodeClassificationPipelineTrainConfig)this.config).nodeLabelIdentifiers(this.graphStore);
        Collection relationshipTypes = ((NodeClassificationPipelineTrainConfig)this.config).internalRelationshipTypes(this.graphStore);
        Graph graph = this.graphStore.getGraph(nodeLabels, relationshipTypes, Optional.empty());
        ((NodeClassificationPipeline)this.pipeline).splitConfig().validateMinNumNodesInSplitSets(graph);
        return NodeClassificationTrain.create((Graph)graph, (NodeClassificationPipeline)((NodeClassificationPipeline)this.pipeline), (NodeClassificationPipelineTrainConfig)((NodeClassificationPipelineTrainConfig)this.config), (ProgressTracker)this.progressTracker).compute();
    }
}

