/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.nodemodels;

import java.util.List;
import org.neo4j.gds.AlgorithmFactory;
import org.neo4j.gds.BaseProc;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.core.loading.CatalogRequest;
import org.neo4j.gds.core.loading.GraphStoreCatalog;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.mem.AllocationTracker;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.exceptions.MemoryEstimationNotImplementedException;
import org.neo4j.gds.ml.nodemodels.NodeClassificationTrain;
import org.neo4j.gds.ml.nodemodels.pipeline.NodeClassificationPipeline;
import org.neo4j.gds.ml.nodemodels.pipeline.NodeClassificationPipelineCompanion;
import org.neo4j.gds.ml.nodemodels.pipeline.NodeClassificationPipelineTrainConfig;
import org.neo4j.gds.ml.nodemodels.pipeline.NodeClassificationTrainPipelineExecutor;
import org.neo4j.kernel.database.NamedDatabaseId;

public class NodeClassificationTrainPipelineAlgorithmFactory
extends AlgorithmFactory<NodeClassificationTrainPipelineExecutor, NodeClassificationPipelineTrainConfig> {
    private final BaseProc caller;
    private final NamedDatabaseId databaseId;
    private final ModelCatalog modelCatalog;

    public NodeClassificationTrainPipelineAlgorithmFactory(BaseProc caller, NamedDatabaseId databaseId, ModelCatalog modelCatalog) {
        this.caller = caller;
        this.databaseId = databaseId;
        this.modelCatalog = modelCatalog;
    }

    protected NodeClassificationTrainPipelineExecutor build(Graph graph, NodeClassificationPipelineTrainConfig configuration, AllocationTracker allocationTracker, ProgressTracker progressTracker) {
        String graphName = (String)configuration.graphName().orElseThrow(() -> new UnsupportedOperationException("Node Classification Prediction Pipeline cannot be used with anonymous graphs, please provide a named graph."));
        GraphStore graphStore = GraphStoreCatalog.get((CatalogRequest)CatalogRequest.of((String)configuration.username(), (NamedDatabaseId)this.databaseId), (String)graphName).graphStore();
        NodeClassificationPipeline pipeline = NodeClassificationPipelineCompanion.getNCPipeline(this.modelCatalog, configuration.pipeline(), configuration.username());
        pipeline.validateBeforeExecution(graphStore, (AlgoBaseConfig)configuration);
        return new NodeClassificationTrainPipelineExecutor(pipeline, configuration, this.caller, graphStore, graphName, progressTracker);
    }

    public MemoryEstimation memoryEstimation(NodeClassificationPipelineTrainConfig configuration) {
        throw new MemoryEstimationNotImplementedException();
    }

    protected String taskName() {
        return "Node Classification Train Pipeline";
    }

    public Task progressTask(Graph graph, NodeClassificationPipelineTrainConfig config) {
        NodeClassificationPipeline pipeline = NodeClassificationPipelineCompanion.getNCPipeline(this.modelCatalog, config.pipeline(), config.username());
        return Tasks.task((String)this.taskName(), (Task)Tasks.iterativeFixed((String)"execute node property steps", () -> List.of(Tasks.leaf((String)"step")), (int)pipeline.nodePropertySteps().size()), (Task[])new Task[]{NodeClassificationTrain.progressTask((int)pipeline.splitConfig().validationFolds(), (int)pipeline.trainingParameterSpace().size())});
    }
}

