/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.predict;

import java.util.List;
import org.neo4j.gds.GraphStoreAlgorithmFactory;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.mem.AllocationTracker;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.exceptions.MemoryEstimationNotImplementedException;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.linkmodels.pipeline.LinkPredictionPipelineCompanion;
import org.neo4j.gds.ml.linkmodels.pipeline.logisticRegression.LinkLogisticRegressionData;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPredictionPredictPipelineBaseConfig;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPredictionPredictPipelineExecutor;
import org.neo4j.gds.ml.linkmodels.pipeline.train.LinkPredictionTrainConfig;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionModelInfo;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionPipeline;
import org.neo4j.gds.similarity.knn.KnnBaseConfig;
import org.neo4j.gds.similarity.knn.KnnFactory;

public class LinkPredictionPredictPipelineAlgorithmFactory<CONFIG extends LinkPredictionPredictPipelineBaseConfig>
extends GraphStoreAlgorithmFactory<LinkPredictionPredictPipelineExecutor, CONFIG> {
    private final ExecutionContext executionContext;
    private final ModelCatalog modelCatalog;

    LinkPredictionPredictPipelineAlgorithmFactory(ExecutionContext executionContext, ModelCatalog modelCatalog) {
        this.executionContext = executionContext;
        this.modelCatalog = modelCatalog;
    }

    public Task progressTask(GraphStore graphStore, CONFIG config) {
        LinkPredictionPipeline trainingPipeline = ((LinkPredictionModelInfo)LinkPredictionPipelineCompanion.getTrainedLPPipelineModel(this.modelCatalog, config.modelName(), config.username()).customInfo()).trainingPipeline();
        return Tasks.task((String)this.taskName(), (Task)Tasks.iterativeFixed((String)"execute node property steps", () -> List.of(Tasks.leaf((String)"step")), (int)trainingPipeline.nodePropertySteps().size()), (Task[])new Task[]{config.isApproximateStrategy() ? Tasks.task((String)"approximate link prediction", (Task)KnnFactory.knnTaskTree((Graph)graphStore.getUnion(), (KnnBaseConfig)config.approximateConfig()), (Task[])new Task[0]) : Tasks.leaf((String)"exhaustive link prediction", (long)graphStore.nodeCount())});
    }

    public String taskName() {
        return "Link Prediction Predict Pipeline";
    }

    public LinkPredictionPredictPipelineExecutor build(GraphStore graphStore, CONFIG configuration, AllocationTracker allocationTracker, ProgressTracker progressTracker) {
        Model<LinkLogisticRegressionData, LinkPredictionTrainConfig, LinkPredictionModelInfo> model = LinkPredictionPipelineCompanion.getTrainedLPPipelineModel(this.modelCatalog, configuration.modelName(), configuration.username());
        LinkPredictionPipeline linkPredictionPipeline = ((LinkPredictionModelInfo)model.customInfo()).trainingPipeline();
        return new LinkPredictionPredictPipelineExecutor(linkPredictionPipeline, (LinkLogisticRegressionData)model.data(), (LinkPredictionPredictPipelineBaseConfig)configuration, this.executionContext, graphStore, configuration.graphName(), progressTracker);
    }

    public MemoryEstimation memoryEstimation(CONFIG configuration) {
        throw new MemoryEstimationNotImplementedException();
    }
}

