/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.predict;

import java.util.Collection;
import java.util.List;
import org.neo4j.gds.AlgorithmFactory;
import org.neo4j.gds.BaseProc;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.loading.GraphStoreCatalog;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.utils.mem.AllocationTracker;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.exceptions.MemoryEstimationNotImplementedException;
import org.neo4j.gds.ml.linkmodels.pipeline.LinkPredictionModelInfo;
import org.neo4j.gds.ml.linkmodels.pipeline.PipelineExecutor;
import org.neo4j.gds.ml.linkmodels.pipeline.PipelineUtils;
import org.neo4j.gds.ml.linkmodels.pipeline.TrainingPipeline;
import org.neo4j.gds.ml.linkmodels.pipeline.logisticRegression.LinkLogisticRegressionData;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.ApproximateLinkPrediction;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.ExhaustiveLinkPrediction;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPrediction;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPredictionPipelineBaseConfig;
import org.neo4j.gds.ml.linkmodels.pipeline.train.LinkPredictionTrainConfig;
import org.neo4j.gds.similarity.knn.KnnBaseConfig;
import org.neo4j.gds.similarity.knn.KnnFactory;
import org.neo4j.kernel.database.NamedDatabaseId;

public class LinkPredictionPipelineAlgorithmFactory<CONFIG extends LinkPredictionPipelineBaseConfig>
extends AlgorithmFactory<LinkPrediction, CONFIG> {
    private final BaseProc caller;
    private final NamedDatabaseId databaseId;

    LinkPredictionPipelineAlgorithmFactory(BaseProc caller, NamedDatabaseId databaseId) {
        this.caller = caller;
        this.databaseId = databaseId;
    }

    protected Task progressTask(Graph graph, CONFIG config) {
        TrainingPipeline trainingPipeline = ((LinkPredictionModelInfo)PipelineUtils.getLinkPredictionPipeline(config.modelName(), config.username()).customInfo()).trainingPipeline();
        return Tasks.task((String)this.taskName(), (Task)Tasks.iterativeFixed((String)"execute node property steps", () -> List.of(Tasks.leaf((String)"step")), (int)trainingPipeline.nodePropertySteps().size()), (Task[])new Task[]{config.isApproximateStrategy() ? Tasks.task((String)"approximate link prediction", (Task)KnnFactory.knnTaskTree((Graph)graph, (KnnBaseConfig)config.approximateConfig()), (Task[])new Task[0]) : Tasks.leaf((String)"exhaustive link prediction", (long)graph.nodeCount())});
    }

    protected String taskName() {
        return "Link Prediction Pipeline";
    }

    protected LinkPrediction build(Graph graph, CONFIG configuration, AllocationTracker allocationTracker, ProgressTracker progressTracker) {
        String graphName = (String)configuration.graphName().orElseThrow(() -> new UnsupportedOperationException("Link Prediction Pipeline cannot be used with anonymous graphs. Please load the graph before"));
        Model<LinkLogisticRegressionData, LinkPredictionTrainConfig, LinkPredictionModelInfo> model = PipelineUtils.getLinkPredictionPipeline(configuration.modelName(), configuration.username());
        GraphStore graphStore = GraphStoreCatalog.get((String)configuration.username(), (NamedDatabaseId)this.databaseId, (String)graphName).graphStore();
        PipelineExecutor pipelineExecutor = new PipelineExecutor(((LinkPredictionModelInfo)model.customInfo()).trainingPipeline(), this.caller, this.databaseId, configuration.username(), graphName, progressTracker);
        Collection nodeLabels = configuration.nodeLabelIdentifiers(graphStore);
        Collection relationshipTypes = configuration.internalRelationshipTypes(graphStore);
        if (configuration.isApproximateStrategy()) {
            return new ApproximateLinkPrediction((LinkLogisticRegressionData)model.data(), pipelineExecutor, (Collection<NodeLabel>)nodeLabels, (Collection<RelationshipType>)relationshipTypes, graphStore, configuration.approximateConfig(), progressTracker);
        }
        return new ExhaustiveLinkPrediction((LinkLogisticRegressionData)model.data(), pipelineExecutor, nodeLabels, relationshipTypes, graphStore, configuration.concurrency(), configuration.topN().orElseThrow(), configuration.thresholdOrDefault(), progressTracker);
    }

    public MemoryEstimation memoryEstimation(CONFIG configuration) {
        throw new MemoryEstimationNotImplementedException();
    }
}

