/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.train;

import java.util.List;
import org.neo4j.gds.AlgorithmFactory;
import org.neo4j.gds.BaseProc;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.loading.GraphStoreCatalog;
import org.neo4j.gds.core.utils.mem.AllocationTracker;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.exceptions.MemoryEstimationNotImplementedException;
import org.neo4j.gds.ml.Training;
import org.neo4j.gds.ml.linkmodels.pipeline.PipelineExecutor;
import org.neo4j.gds.ml.linkmodels.pipeline.PipelineUtils;
import org.neo4j.gds.ml.linkmodels.pipeline.TrainingPipeline;
import org.neo4j.gds.ml.linkmodels.pipeline.train.LinkPredictionTrain;
import org.neo4j.gds.ml.linkmodels.pipeline.train.LinkPredictionTrainConfig;
import org.neo4j.kernel.database.NamedDatabaseId;

public class LinkPredictionTrainFactory
extends AlgorithmFactory<LinkPredictionTrain, LinkPredictionTrainConfig> {
    private NamedDatabaseId databaseId;
    private BaseProc caller;

    public LinkPredictionTrainFactory(NamedDatabaseId databaseId, BaseProc caller) {
        this.databaseId = databaseId;
        this.caller = caller;
    }

    public LinkPredictionTrain build(Graph graph, LinkPredictionTrainConfig trainConfig, AllocationTracker allocationTracker, ProgressTracker progressTracker) {
        String graphName = (String)trainConfig.graphName().orElseThrow(() -> new UnsupportedOperationException("Link Prediction Pipeline cannot be used with anonymous graphs. Please load the graph before"));
        GraphStore graphStore = GraphStoreCatalog.get((String)trainConfig.username(), (NamedDatabaseId)this.databaseId, (String)graphName).graphStore();
        TrainingPipeline pipeline = PipelineUtils.getPipelineModelInfo(trainConfig.pipeline(), trainConfig.username());
        pipeline.validate();
        PipelineExecutor pipelineExecutor = new PipelineExecutor(pipeline, this.caller, this.databaseId, trainConfig.username(), graphName, progressTracker);
        return new LinkPredictionTrain(graphStore, trainConfig, pipeline, pipelineExecutor, progressTracker);
    }

    protected String taskName() {
        return "Link Prediction pipeline train";
    }

    public Task progressTask(Graph graph, LinkPredictionTrainConfig config) {
        TrainingPipeline pipeline = PipelineUtils.getPipelineModelInfo(config.pipeline(), config.username());
        return Tasks.task((String)this.taskName(), (Task)Tasks.leaf((String)"split relationships"), (Task[])new Task[]{Tasks.iterativeFixed((String)"execute node property steps", () -> List.of(Tasks.leaf((String)"step")), (int)pipeline.nodePropertySteps().size()), Tasks.leaf((String)"extract train features"), Tasks.leaf((String)"select model", (long)pipeline.parameterSpace().size()), Training.progressTask((String)"train best model"), Tasks.leaf((String)"compute train metrics"), Tasks.task((String)"evaluate on test data", (Task)Tasks.leaf((String)"extract test features"), (Task[])new Task[]{Tasks.leaf((String)"compute test metrics")})});
    }

    public MemoryEstimation memoryEstimation(LinkPredictionTrainConfig configuration) {
        throw new MemoryEstimationNotImplementedException();
    }
}

