/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.train;

import java.util.Collection;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.ml.linkmodels.pipeline.train.LinkPredictionTrain;
import org.neo4j.gds.ml.linkmodels.pipeline.train.LinkPredictionTrainConfig;
import org.neo4j.gds.ml.linkmodels.pipeline.train.LinkPredictionTrainResult;
import org.neo4j.gds.ml.linkmodels.pipeline.train.RelationshipSplitter;
import org.neo4j.gds.ml.pipeline.ImmutableGraphFilter;
import org.neo4j.gds.ml.pipeline.Pipeline;
import org.neo4j.gds.ml.pipeline.PipelineExecutor;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionPipeline;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfig;
import org.neo4j.gds.ml.util.TrainingSetWarnings;

public class LinkPredictionTrainPipelineExecutor
extends PipelineExecutor<LinkPredictionTrainConfig, LinkPredictionPipeline, LinkPredictionTrainResult> {
    private static final int RECOMMENDED_MIN_RELS_PER_SET = 5;
    private final RelationshipSplitter relationshipSplitter;

    public LinkPredictionTrainPipelineExecutor(LinkPredictionPipeline pipeline, LinkPredictionTrainConfig config, ExecutionContext executionContext, GraphStore graphStore, String graphName, ProgressTracker progressTracker) {
        super((Pipeline)pipeline, (AlgoBaseConfig)config, executionContext, graphStore, graphName, progressTracker);
        this.relationshipSplitter = new RelationshipSplitter(graphName, pipeline.splitConfig(), executionContext, progressTracker);
    }

    public Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> splitDataset() {
        this.relationshipSplitter.splitRelationships(this.graphStore, ((LinkPredictionTrainConfig)this.config).relationshipTypes(), ((LinkPredictionTrainConfig)this.config).nodeLabels(), ((LinkPredictionTrainConfig)this.config).randomSeed(), ((LinkPredictionPipeline)this.pipeline).relationshipWeightProperty());
        LinkPredictionSplitConfig splitConfig = ((LinkPredictionPipeline)this.pipeline).splitConfig();
        Collection nodeLabels = ((LinkPredictionTrainConfig)this.config).nodeLabelIdentifiers(this.graphStore);
        Collection trainRelationshipTypes = RelationshipType.listOf((String[])new String[]{splitConfig.trainRelationshipType()});
        Collection testRelationshipTypes = RelationshipType.listOf((String[])new String[]{splitConfig.testRelationshipType()});
        Collection featureInputRelationshipType = RelationshipType.listOf((String[])new String[]{splitConfig.featureInputRelationshipType()});
        return Map.of(PipelineExecutor.DatasetSplits.TRAIN, ImmutableGraphFilter.of((Collection)nodeLabels, (Collection)trainRelationshipTypes), PipelineExecutor.DatasetSplits.TEST, ImmutableGraphFilter.of((Collection)nodeLabels, (Collection)testRelationshipTypes), PipelineExecutor.DatasetSplits.FEATURE_INPUT, ImmutableGraphFilter.of((Collection)nodeLabels, (Collection)featureInputRelationshipType));
    }

    protected LinkPredictionTrainResult execute(Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> dataSplits) {
        PipelineExecutor.GraphFilter trainDataSplit = dataSplits.get(PipelineExecutor.DatasetSplits.TRAIN);
        PipelineExecutor.GraphFilter testDataSplit = dataSplits.get(PipelineExecutor.DatasetSplits.TEST);
        Graph trainGraph = this.graphStore.getGraph(trainDataSplit.nodeLabels(), trainDataSplit.relationshipTypes(), Optional.of("label"));
        Graph testGraph = this.graphStore.getGraph(testDataSplit.nodeLabels(), testDataSplit.relationshipTypes(), Optional.of("label"));
        TrainingSetWarnings.warnForSmallRelationshipSets((long)trainGraph.relationshipCount(), (long)testGraph.relationshipCount(), (long)((LinkPredictionPipeline)this.pipeline).splitConfig().validationFolds(), (ProgressTracker)this.progressTracker);
        return new LinkPredictionTrain(trainGraph, testGraph, (LinkPredictionPipeline)this.pipeline, (LinkPredictionTrainConfig)this.config, this.progressTracker).compute();
    }

    private void removeDataSplitRelationships(Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> datasets) {
        datasets.values().stream().flatMap(graphFilter -> graphFilter.relationshipTypes().stream()).distinct().collect(Collectors.toList()).forEach(arg_0 -> ((GraphStore)this.graphStore).deleteRelationships(arg_0));
    }

    protected void cleanUpGraphStore(Map<PipelineExecutor.DatasetSplits, PipelineExecutor.GraphFilter> datasets) {
        this.removeDataSplitRelationships(datasets);
        super.cleanUpGraphStore(datasets);
    }
}

