/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.train;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.executor.AlgorithmSpec;
import org.neo4j.gds.executor.ExecutionContext;
import org.neo4j.gds.executor.ExecutorSpec;
import org.neo4j.gds.executor.ProcedureExecutor;
import org.neo4j.gds.executor.ProcedureExecutorSpec;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfig;
import org.neo4j.gds.ml.splitting.SplitRelationshipsBaseConfig;
import org.neo4j.gds.ml.splitting.SplitRelationshipsMutateProc;
import org.neo4j.gds.utils.StringFormatting;

public class RelationshipSplitter {
    private static final String SPLIT_ERROR_TEMPLATE = "%s graph contains no relationships. Consider increasing the `%s` or provide a larger graph";
    private final String graphName;
    private final LinkPredictionSplitConfig splitConfig;
    private final ExecutionContext executionContext;
    private final ProgressTracker progressTracker;

    RelationshipSplitter(String graphName, LinkPredictionSplitConfig splitConfig, ExecutionContext executionContext, ProgressTracker progressTracker) {
        this.graphName = graphName;
        this.splitConfig = splitConfig;
        this.executionContext = executionContext;
        this.progressTracker = progressTracker;
    }

    public void splitRelationships(GraphStore graphStore, List<String> relationshipTypes, List<String> nodeLabels, Optional<Long> randomSeed, Optional<String> relationshipWeightProperty) {
        this.progressTracker.beginSubTask();
        this.splitConfig.validateAgainstGraphStore(graphStore);
        String testComplementRelationshipType = this.splitConfig.testComplementRelationshipType();
        this.relationshipSplit(this.splitConfig.testSplit(), nodeLabels, relationshipTypes, randomSeed, relationshipWeightProperty);
        this.validateTestSplit(graphStore);
        this.relationshipSplit(this.splitConfig.trainSplit(), nodeLabels, List.of(testComplementRelationshipType), randomSeed, relationshipWeightProperty);
        graphStore.deleteRelationships(RelationshipType.of((String)testComplementRelationshipType));
        this.progressTracker.endSubTask();
    }

    private void validateTestSplit(GraphStore graphStore) {
        String testRelationshipType = this.splitConfig.testRelationshipType();
        if (graphStore.getGraph(new RelationshipType[]{RelationshipType.of((String)testRelationshipType)}).relationshipCount() <= 0L) {
            throw new IllegalStateException(StringFormatting.formatWithLocale((String)SPLIT_ERROR_TEMPLATE, (Object[])new Object[]{"Test", "testFraction"}));
        }
    }

    private void relationshipSplit(SplitRelationshipsBaseConfig splitConfig, final List<String> nodeLabels, final List<String> relationshipTypes, final Optional<Long> randomSeed, final Optional<String> relationshipWeightProperty) {
        HashMap<String, Object> splitRelationshipProcConfig = new HashMap<String, Object>(splitConfig.toSplitMap()){
            {
                super(m);
                this.put("nodeLabels", nodeLabels);
                this.put("relationshipTypes", relationshipTypes);
                relationshipWeightProperty.ifPresent(s -> this.put("relationshipWeightProperty", s));
                randomSeed.ifPresent(seed -> this.put("randomSeed", seed));
            }
        };
        SplitRelationshipsMutateProc splitRelationshipsMutateProc = new SplitRelationshipsMutateProc();
        new ProcedureExecutor((AlgorithmSpec)splitRelationshipsMutateProc, (ExecutorSpec)new ProcedureExecutorSpec(), this.executionContext).compute(this.graphName, (Map)splitRelationshipProcConfig, false, false);
    }
}

