/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.linkPipeline;

import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.immutables.value.Value;
import org.neo4j.gds.ElementIdentifier;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.annotation.Configuration;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.config.ElementTypeValidator;
import org.neo4j.gds.config.ToMapConvertible;
import org.neo4j.gds.core.CypherMapAccess;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.GraphDimensions;
import org.neo4j.gds.ml.pipeline.NonEmptySetValidation;
import org.neo4j.gds.ml.pipeline.linkPipeline.ExpectedSetSizes;
import org.neo4j.gds.ml.pipeline.linkPipeline.ImmutableExpectedSetSizes;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionSplitConfigImpl;
import org.neo4j.gds.utils.StringFormatting;
import org.neo4j.gds.utils.StringJoining;

@Configuration
public interface LinkPredictionSplitConfig
extends ToMapConvertible {
    public static final String TEST_FRACTION_KEY = "testFraction";
    public static final String TRAIN_FRACTION_KEY = "trainFraction";
    public static final LinkPredictionSplitConfig DEFAULT_CONFIG = LinkPredictionSplitConfig.of(CypherMapWrapper.empty());

    @Value.Default
    @Configuration.IntegerRange(min=2)
    default public int validationFolds() {
        return 3;
    }

    @Value.Default
    @Configuration.Key(value="testFraction")
    @Configuration.DoubleRange(min=0.0, minInclusive=false)
    default public double testFraction() {
        return 0.1;
    }

    @Value.Default
    @Configuration.Key(value="trainFraction")
    @Configuration.DoubleRange(min=0.0, minInclusive=false)
    default public double trainFraction() {
        return 0.1;
    }

    @Value.Default
    @Configuration.DoubleRange(min=0.0, minInclusive=false)
    default public double negativeSamplingRatio() {
        return 1.0;
    }

    public Optional<String> negativeRelationshipType();

    @Value.Default
    @Configuration.Ignore
    default public RelationshipType testRelationshipType() {
        return RelationshipType.of((String)"_TEST_");
    }

    @Value.Default
    @Configuration.Ignore
    default public RelationshipType testComplementRelationshipType() {
        return RelationshipType.of((String)"_TEST_COMPLEMENT_");
    }

    @Value.Default
    @Configuration.Ignore
    default public RelationshipType trainRelationshipType() {
        return RelationshipType.of((String)"_TRAIN_");
    }

    @Value.Default
    @Configuration.Ignore
    default public RelationshipType featureInputRelationshipType() {
        return RelationshipType.of((String)"_FEATURE_INPUT_");
    }

    @Configuration.ToMap
    public Map<String, Object> toMap();

    @Configuration.CollectKeys
    default public Collection<String> configKeys() {
        return Collections.emptyList();
    }

    public static LinkPredictionSplitConfig of(CypherMapWrapper config) {
        return new LinkPredictionSplitConfigImpl((CypherMapAccess)config);
    }

    @Configuration.Ignore
    default public void validateAgainstGraphStore(GraphStore graphStore, RelationshipType targetRelationshipType) {
        Stream<RelationshipType> reservedTypes = Stream.of(this.testRelationshipType(), this.trainRelationshipType(), this.featureInputRelationshipType(), this.testComplementRelationshipType());
        List invalidTypes = reservedTypes.filter(arg_0 -> ((GraphStore)graphStore).hasRelationshipType(arg_0)).map(ElementIdentifier::name).collect(Collectors.toList());
        if (!invalidTypes.isEmpty()) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"The relationship types %s are in the input graph, but are reserved for splitting.", (Object[])new Object[]{StringJoining.join(invalidTypes)}));
        }
        if (this.negativeRelationshipType().isPresent()) {
            String negativeRelType = this.negativeRelationshipType().get();
            ElementTypeValidator.resolveAndValidateTypes((GraphStore)graphStore, List.of(negativeRelType), (String)"negativeRelationshipType");
            if (this.negativeSamplingRatio() != 1.0) {
                throw new IllegalArgumentException("Configuration parameter failure: `negativeSamplingRatio` and `negativeRelationshipType` cannot be used together.");
            }
        }
        ExpectedSetSizes expectedSetSizes = this.expectedSetSizes(graphStore.relationshipCount(targetRelationshipType));
        NonEmptySetValidation.validateRelSetSize(expectedSetSizes.testSize(), 1L, "test", "`testFraction` is too low");
        NonEmptySetValidation.validateRelSetSize(expectedSetSizes.testComplementSize(), 3L, "test-complement", "`testFraction` is too high");
        NonEmptySetValidation.validateRelSetSize(expectedSetSizes.trainSize(), 2L, "train", "`trainFraction` is too low");
        NonEmptySetValidation.validateRelSetSize(expectedSetSizes.featureInputSize(), 1L, "feature-input", "`trainFraction` is too high");
        NonEmptySetValidation.validateRelSetSize(expectedSetSizes.validationFoldSize(), 1L, "validation", "`validationFolds` is too high or the `trainFraction` too low");
    }

    @Value.Derived
    @Configuration.Ignore
    default public ExpectedSetSizes expectedSetSizes(long relationshipCount) {
        long positiveTestSetSize = (long)((double)relationshipCount * this.testFraction() / 2.0);
        long testSetSize = (long)((double)positiveTestSetSize * (1.0 + this.negativeSamplingRatio()));
        long testComplementSize = (long)((double)relationshipCount * (1.0 - this.testFraction()));
        long positiveTrainSetSize = (long)((double)testComplementSize * this.trainFraction() / 2.0);
        long trainSetSize = (long)((double)positiveTrainSetSize * (1.0 + this.negativeSamplingRatio()));
        long featureInputSize = (long)((double)testComplementSize * (1.0 - this.trainFraction()));
        long foldSize = trainSetSize / (long)this.validationFolds();
        return ImmutableExpectedSetSizes.builder().testSize(testSetSize).trainSize(trainSetSize).featureInputSize(featureInputSize).testComplementSize(testComplementSize).validationFoldSize(foldSize).build();
    }

    @Value.Derived
    @Configuration.Ignore
    default public GraphDimensions expectedGraphDimensions(GraphDimensions baseDim, String targetRelType) {
        ExpectedSetSizes expectedSetSizes = this.expectedSetSizes(baseDim.relationshipCounts().getOrDefault(RelationshipType.of((String)targetRelType), baseDim.relCountUpperBound()));
        return GraphDimensions.builder().nodeCount(baseDim.nodeCount()).relCountUpperBound(baseDim.relCountUpperBound()).putRelationshipCount(this.testRelationshipType(), expectedSetSizes.testSize()).putRelationshipCount(this.testComplementRelationshipType(), expectedSetSizes.testComplementSize()).putRelationshipCount(this.trainRelationshipType(), expectedSetSizes.trainSize()).putRelationshipCount(this.featureInputRelationshipType(), expectedSetSizes.featureInputSize()).putAllRelationshipCounts(baseDim.relationshipCounts()).build();
    }
}

