/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.predict;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.linkmodels.LinkPredictionResult;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.ApproximateLinkPrediction;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.ExhaustiveLinkPrediction;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPrediction;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPredictionPredictPipelineBaseConfig;
import org.neo4j.gds.ml.pipeline.ImmutableGraphFilter;
import org.neo4j.gds.ml.pipeline.Pipeline;
import org.neo4j.gds.ml.pipeline.PipelineExecutor;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureExtractor;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionPipeline;
import org.neo4j.gds.models.Classifier;

public class LinkPredictionPredictPipelineExecutor
extends PipelineExecutor<LinkPredictionPredictPipelineBaseConfig, LinkPredictionPipeline, LinkPredictionResult> {
    private final Classifier classifier;

    public LinkPredictionPredictPipelineExecutor(LinkPredictionPipeline pipeline, Classifier classifier, LinkPredictionPredictPipelineBaseConfig config, ExecutionContext executionContext, GraphStore graphStore, String graphName, ProgressTracker progressTracker) {
        super((Pipeline)pipeline, (AlgoBaseConfig)config, executionContext, graphStore, graphName, progressTracker);
        this.classifier = classifier;
    }

    public Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> splitDataset() {
        return Map.of(PipelineExecutor.DatasetSplits.FEATURE_INPUT, ImmutableGraphFilter.of((Collection)((LinkPredictionPredictPipelineBaseConfig)this.config).nodeLabelIdentifiers(this.graphStore), (Collection)((LinkPredictionPredictPipelineBaseConfig)this.config).internalRelationshipTypes(this.graphStore)));
    }

    protected LinkPredictionResult execute(Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> dataSplits) {
        Graph graph = this.graphStore.getGraph(((LinkPredictionPredictPipelineBaseConfig)this.config).nodeLabelIdentifiers(this.graphStore), ((LinkPredictionPredictPipelineBaseConfig)this.config).internalRelationshipTypes(this.graphStore), Optional.empty());
        LinkFeatureExtractor linkFeatureExtractor = LinkFeatureExtractor.of((Graph)graph, (List)((LinkPredictionPipeline)this.pipeline).featureSteps());
        LinkPrediction linkPrediction = this.getLinkPredictionStrategy(graph, ((LinkPredictionPredictPipelineBaseConfig)this.config).isApproximateStrategy(), linkFeatureExtractor);
        return linkPrediction.compute();
    }

    public static MemoryEstimation estimate(ModelCatalog modelCatalog, LinkPredictionPipeline pipeline, LinkPredictionPredictPipelineBaseConfig configuration, int linkFeatureDimension) {
        MemoryEstimation maxOverNodePropertySteps = PipelineExecutor.estimateNodePropertySteps((ModelCatalog)modelCatalog, (List)pipeline.nodePropertySteps(), (List)configuration.nodeLabels(), (List)configuration.relationshipTypes());
        MemoryEstimation predictEstimation = configuration.isApproximateStrategy() ? ApproximateLinkPrediction.estimate(configuration) : ExhaustiveLinkPrediction.estimate(configuration, linkFeatureDimension);
        return MemoryEstimations.builder(LinkPredictionPredictPipelineExecutor.class).max("Pipeline execution", List.of(maxOverNodePropertySteps, predictEstimation)).build();
    }

    private LinkPrediction getLinkPredictionStrategy(Graph graph, boolean isApproximateStrategy, LinkFeatureExtractor linkFeatureExtractor) {
        if (isApproximateStrategy) {
            return new ApproximateLinkPrediction(this.classifier, linkFeatureExtractor, graph, ((LinkPredictionPredictPipelineBaseConfig)this.config).approximateConfig(), this.progressTracker);
        }
        return new ExhaustiveLinkPrediction(this.classifier, linkFeatureExtractor, graph, ((LinkPredictionPredictPipelineBaseConfig)this.config).concurrency(), ((LinkPredictionPredictPipelineBaseConfig)this.config).topN().orElseThrow(), ((LinkPredictionPredictPipelineBaseConfig)this.config).thresholdOrDefault(), this.progressTracker);
    }
}

