/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.train;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.neo4j.gds.BaseProc;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.linkmodels.pipeline.LinkPredictionSplitConfig;
import org.neo4j.gds.ml.pipeline.proc.ProcedureReflection;
import org.neo4j.gds.ml.splitting.SplitRelationshipsBaseConfig;
import org.neo4j.gds.utils.StringFormatting;

public class RelationshipSplitter {
    private static final String SPLIT_ERROR_TEMPLATE = "%s graph contains no relationships. Consider increasing the `%s` or provide a larger graph";
    private final String graphName;
    private final LinkPredictionSplitConfig splitConfig;
    private final BaseProc caller;
    private final ProgressTracker progressTracker;

    RelationshipSplitter(String graphName, LinkPredictionSplitConfig splitConfig, BaseProc caller, ProgressTracker progressTracker) {
        this.graphName = graphName;
        this.splitConfig = splitConfig;
        this.caller = caller;
        this.progressTracker = progressTracker;
    }

    public void splitRelationships(GraphStore graphStore, List<String> relationshipTypes, List<String> nodeLabels, Optional<Long> randomSeed, Optional<String> relationshipWeightProperty) {
        this.progressTracker.beginSubTask();
        this.splitConfig.validateAgainstGraphStore(graphStore);
        String testComplementRelationshipType = this.splitConfig.testComplementRelationshipType();
        this.relationshipSplit(this.splitConfig.testSplit(), nodeLabels, relationshipTypes, randomSeed, relationshipWeightProperty);
        this.validateTestSplit(graphStore);
        this.relationshipSplit(this.splitConfig.trainSplit(), nodeLabels, List.of(testComplementRelationshipType), randomSeed, relationshipWeightProperty);
        this.validateTrainSplit(graphStore);
        graphStore.deleteRelationships(RelationshipType.of((String)testComplementRelationshipType));
        this.progressTracker.endSubTask();
    }

    private void validateTestSplit(GraphStore graphStore) {
        String testRelationshipType = this.splitConfig.testRelationshipType();
        if (graphStore.getGraph(new RelationshipType[]{RelationshipType.of((String)testRelationshipType)}).relationshipCount() <= 0L) {
            throw new IllegalStateException(StringFormatting.formatWithLocale((String)SPLIT_ERROR_TEMPLATE, (Object[])new Object[]{"Test", "testFraction"}));
        }
    }

    private void validateTrainSplit(GraphStore graphStore) {
        if (graphStore.getGraph(new RelationshipType[]{RelationshipType.of((String)this.splitConfig.trainRelationshipType())}).relationshipCount() <= 0L) {
            throw new IllegalStateException(StringFormatting.formatWithLocale((String)SPLIT_ERROR_TEMPLATE, (Object[])new Object[]{"Train", "trainFraction"}));
        }
        if (graphStore.getGraph(new RelationshipType[]{RelationshipType.of((String)this.splitConfig.featureInputRelationshipType())}).relationshipCount() <= 0L) {
            throw new IllegalStateException(StringFormatting.formatWithLocale((String)"Feature graph contains no relationships. Consider decreasing %s or %s or provide a larger graph.", (Object[])new Object[]{"testFraction", "trainFraction"}));
        }
    }

    private void relationshipSplit(SplitRelationshipsBaseConfig splitConfig, final List<String> nodeLabels, final List<String> relationshipTypes, final Optional<Long> randomSeed, final Optional<String> relationshipWeightProperty) {
        HashMap<String, Object> splitRelationshipProcConfig = new HashMap<String, Object>(splitConfig.toSplitMap()){
            {
                super(m);
                this.put("nodeLabels", nodeLabels);
                this.put("relationshipTypes", relationshipTypes);
                relationshipWeightProperty.ifPresent(s -> this.put("relationshipWeightProperty", s));
                randomSeed.ifPresent(seed -> this.put("randomSeed", seed));
            }
        };
        ProcedureReflection procReflection = ProcedureReflection.INSTANCE;
        procReflection.invokeProc(this.caller, this.graphName, procReflection.findProcedureMethod("splitRelationships"), (Map)splitRelationshipProcConfig);
    }
}

