/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.predict;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.linkmodels.LinkPredictionResult;
import org.neo4j.gds.ml.linkmodels.pipeline.LinkPredictionPipeline;
import org.neo4j.gds.ml.linkmodels.pipeline.linkFeatures.LinkFeatureExtractor;
import org.neo4j.gds.ml.linkmodels.pipeline.logisticRegression.LinkLogisticRegressionData;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.ApproximateLinkPrediction;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.ExhaustiveLinkPrediction;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPrediction;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPredictionPredictPipelineBaseConfig;
import org.neo4j.gds.ml.pipeline.ImmutableGraphFilter;
import org.neo4j.gds.ml.pipeline.Pipeline;
import org.neo4j.gds.ml.pipeline.PipelineExecutor;

public class LinkPredictionPredictPipelineExecutor
extends PipelineExecutor<LinkPredictionPredictPipelineBaseConfig, LinkPredictionPipeline, LinkPredictionResult> {
    private final LinkLogisticRegressionData linkLogisticRegressionData;

    LinkPredictionPredictPipelineExecutor(LinkPredictionPipeline pipeline, LinkLogisticRegressionData linkLogisticRegressionData, LinkPredictionPredictPipelineBaseConfig config, ExecutionContext executionContext, GraphStore graphStore, String graphName, ProgressTracker progressTracker) {
        super((Pipeline)pipeline, (AlgoBaseConfig)config, executionContext, graphStore, graphName, progressTracker);
        this.linkLogisticRegressionData = linkLogisticRegressionData;
    }

    public Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> splitDataset() {
        return Map.of(PipelineExecutor.DatasetSplits.TEST, ImmutableGraphFilter.of((Collection)((LinkPredictionPredictPipelineBaseConfig)this.config).nodeLabelIdentifiers(this.graphStore), (Collection)((LinkPredictionPredictPipelineBaseConfig)this.config).internalRelationshipTypes(this.graphStore)), PipelineExecutor.DatasetSplits.TRAIN, ImmutableGraphFilter.of((Collection)((LinkPredictionPredictPipelineBaseConfig)this.config).nodeLabelIdentifiers(this.graphStore), (Collection)((LinkPredictionPredictPipelineBaseConfig)this.config).internalRelationshipTypes(this.graphStore)), PipelineExecutor.DatasetSplits.FEATURE_INPUT, ImmutableGraphFilter.of((Collection)((LinkPredictionPredictPipelineBaseConfig)this.config).nodeLabelIdentifiers(this.graphStore), (Collection)((LinkPredictionPredictPipelineBaseConfig)this.config).internalRelationshipTypes(this.graphStore)), PipelineExecutor.DatasetSplits.TEST_COMPLEMENT, ImmutableGraphFilter.of((Collection)((LinkPredictionPredictPipelineBaseConfig)this.config).nodeLabelIdentifiers(this.graphStore), (Collection)((LinkPredictionPredictPipelineBaseConfig)this.config).internalRelationshipTypes(this.graphStore)));
    }

    protected LinkPredictionResult execute(Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> dataSplits) {
        Graph graph = this.graphStore.getGraph(((LinkPredictionPredictPipelineBaseConfig)this.config).nodeLabelIdentifiers(this.graphStore), ((LinkPredictionPredictPipelineBaseConfig)this.config).internalRelationshipTypes(this.graphStore), Optional.empty());
        LinkFeatureExtractor linkFeatureExtractor = LinkFeatureExtractor.of((Graph)graph, (List)((LinkPredictionPipeline)this.pipeline).featureSteps());
        LinkPrediction linkPrediction = this.getLinkPredictionStrategy(graph, ((LinkPredictionPredictPipelineBaseConfig)this.config).isApproximateStrategy(), linkFeatureExtractor);
        return linkPrediction.compute();
    }

    private LinkPrediction getLinkPredictionStrategy(Graph graph, boolean isApproximateStrategy, LinkFeatureExtractor linkFeatureExtractor) {
        if (isApproximateStrategy) {
            return new ApproximateLinkPrediction(this.linkLogisticRegressionData, linkFeatureExtractor, graph, ((LinkPredictionPredictPipelineBaseConfig)this.config).approximateConfig(), this.progressTracker);
        }
        return new ExhaustiveLinkPrediction(this.linkLogisticRegressionData, linkFeatureExtractor, graph, ((LinkPredictionPredictPipelineBaseConfig)this.config).concurrency(), ((LinkPredictionPredictPipelineBaseConfig)this.config).topN().orElseThrow(), ((LinkPredictionPredictPipelineBaseConfig)this.config).thresholdOrDefault(), this.progressTracker);
    }
}

