/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.negativeSampling;

import java.util.Collection;
import java.util.Optional;
import java.util.SplittableRandom;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.loading.construction.RelationshipsBuilder;
import org.neo4j.gds.ml.negativeSampling.NegativeSampler;
import org.neo4j.gds.utils.StringFormatting;

public class UserInputNegativeSampler
implements NegativeSampler {
    private final Graph negativeExampleGraph;
    private final double testTrainFraction;
    private final SplittableRandom rng;

    public UserInputNegativeSampler(Graph negativeExampleGraph, double testTrainFraction, Optional<Long> randomSeed, Collection<NodeLabel> sourceLabels, Collection<NodeLabel> targetLabels) {
        if (!negativeExampleGraph.schema().isUndirected()) {
            throw new IllegalArgumentException("UserInputNegativeSampler requires graph to be UNDIRECTED.");
        }
        this.negativeExampleGraph = negativeExampleGraph;
        this.testTrainFraction = testTrainFraction;
        this.rng = randomSeed.map(SplittableRandom::new).orElseGet(SplittableRandom::new);
        this.validateNegativeRelationships(sourceLabels, targetLabels);
    }

    @Override
    public void produceNegativeSamples(RelationshipsBuilder testSetBuilder, RelationshipsBuilder trainSetBuilder) {
        long totalRelationshipCount = this.negativeExampleGraph.relationshipCount() / 2L;
        long testRelationshipCount = (long)((double)totalRelationshipCount * this.testTrainFraction);
        MutableLong testRelationshipsToAdd = new MutableLong(testRelationshipCount);
        MutableLong trainRelationshipsToAdd = new MutableLong(totalRelationshipCount - testRelationshipCount);
        this.negativeExampleGraph.forEachNode(nodeId -> {
            this.negativeExampleGraph.forEachRelationship(nodeId, (s, t) -> {
                long rootS = this.negativeExampleGraph.toRootNodeId(s);
                long rootT = this.negativeExampleGraph.toRootNodeId(t);
                if (s < t) {
                    if (this.sample(testRelationshipsToAdd.doubleValue() / (testRelationshipsToAdd.doubleValue() + trainRelationshipsToAdd.doubleValue()))) {
                        testRelationshipsToAdd.decrement();
                        testSetBuilder.addFromInternal(rootS, rootT, 0.0);
                    } else {
                        trainRelationshipsToAdd.decrement();
                        trainSetBuilder.addFromInternal(rootS, rootT, 0.0);
                    }
                }
                return true;
            });
            return true;
        });
    }

    private boolean sample(double probability) {
        return this.rng.nextDouble() < probability;
    }

    private void validateNegativeRelationships(Collection<NodeLabel> validSourceLabel, Collection<NodeLabel> validTargetLabel) {
        this.negativeExampleGraph.forEachNode(nodeId -> {
            this.negativeExampleGraph.forEachRelationship(nodeId, (s, t) -> {
                boolean negativeRelHasCorrectType = this.nodePairsHaveValidLabels(this.negativeExampleGraph.nodeLabels(s), this.negativeExampleGraph.nodeLabels(t), validSourceLabel, validTargetLabel);
                if (!negativeRelHasCorrectType) {
                    throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"There is a relationship of negativeRelationshipType between nodes %s and %s. The nodes have types %s and %s. However, they need to be between %s and %s.", (Object[])new Object[]{this.negativeExampleGraph.toOriginalNodeId(s), this.negativeExampleGraph.toOriginalNodeId(t), this.negativeExampleGraph.nodeLabels(s), this.negativeExampleGraph.nodeLabels(t), validSourceLabel.toString(), validTargetLabel.toString()}));
                }
                return true;
            });
            return true;
        });
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private boolean nodePairsHaveValidLabels(Collection<NodeLabel> candidateSource, Collection<NodeLabel> candidateTarget, Collection<NodeLabel> validSourceLabels, Collection<NodeLabel> validTargetLabels) {
        if (candidateSource.stream().anyMatch(validSourceLabels::contains)) {
            if (candidateTarget.stream().anyMatch(validTargetLabels::contains)) return true;
        }
        if (!candidateSource.stream().anyMatch(validTargetLabels::contains)) return false;
        if (!candidateTarget.stream().anyMatch(validSourceLabels::contains)) return false;
        return true;
    }
}

