/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.negativeSampling;

import com.carrotsearch.hppc.predicates.LongPredicate;
import java.util.HashSet;
import java.util.Optional;
import java.util.SplittableRandom;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.core.loading.construction.RelationshipsBuilder;
import org.neo4j.gds.ml.negativeSampling.NegativeSampler;

public class RandomNegativeSampler
implements NegativeSampler {
    private static final int MAX_RETRIES = 20;
    private final SplittableRandom rng;
    private final Graph graph;
    private final long testSampleCount;
    private final long trainSampleCount;
    private final IdMap validSourceNodes;
    private final IdMap validTargetNodes;

    public RandomNegativeSampler(Graph graph, long testSampleCount, long trainSampleCount, IdMap validSourceNodes, IdMap validTargetNodes, Optional<Long> randomSeed) {
        this.graph = graph;
        this.testSampleCount = testSampleCount;
        this.trainSampleCount = trainSampleCount;
        this.validSourceNodes = validSourceNodes;
        this.validTargetNodes = validTargetNodes;
        this.rng = randomSeed.map(SplittableRandom::new).orElseGet(SplittableRandom::new);
    }

    @Override
    public void produceNegativeSamples(RelationshipsBuilder testSetBuilder, RelationshipsBuilder trainSetBuilder) {
        MutableLong remainingTestSamples = new MutableLong(this.testSampleCount);
        MutableLong remainingTrainSamples = new MutableLong(this.trainSampleCount);
        MutableLong remainingValidSourceNodes = new MutableLong(this.validSourceNodes.nodeCount());
        LongPredicate isValidSourceNodes = nodeId -> this.validSourceNodes.containsOriginalId(this.graph.toOriginalNodeId(nodeId));
        LongPredicate isValidTargetNodes = nodeId -> this.validTargetNodes.containsOriginalId(this.graph.toOriginalNodeId(nodeId));
        this.graph.forEachNode(nodeId -> {
            if (!isValidSourceNodes.apply(nodeId)) {
                return true;
            }
            int masterDegree = this.graph.degree(nodeId);
            long negativeEdgeCount = this.samplesPerNode(this.graph.nodeCount() - 1L - (long)masterDegree, remainingTestSamples.longValue() + remainingTrainSamples.longValue(), remainingValidSourceNodes.getAndDecrement());
            HashSet neighbours = new HashSet(masterDegree);
            this.graph.forEachRelationship(nodeId, (source, target) -> {
                neighbours.add(target);
                return true;
            });
            int retries = 20;
            int i = 0;
            while ((long)i < negativeEdgeCount) {
                long negativeTarget = this.randomNodeId(this.graph);
                if (isValidTargetNodes.apply(negativeTarget) && !neighbours.contains(negativeTarget) && negativeTarget != nodeId) {
                    if (this.sample(remainingTestSamples.doubleValue() / (remainingTestSamples.doubleValue() + remainingTrainSamples.doubleValue()))) {
                        remainingTestSamples.decrement();
                        testSetBuilder.addFromInternal(this.graph.toRootNodeId(nodeId), this.graph.toRootNodeId(negativeTarget), 0.0);
                    } else {
                        remainingTrainSamples.decrement();
                        trainSetBuilder.addFromInternal(this.graph.toRootNodeId(nodeId), this.graph.toRootNodeId(negativeTarget), 0.0);
                    }
                } else if (retries-- > 0) {
                    --i;
                }
                ++i;
            }
            return true;
        });
    }

    private long randomNodeId(Graph graph) {
        return Math.abs(this.rng.nextLong() % graph.nodeCount());
    }

    private long samplesPerNode(long maxSamples, double remainingSamples, long remainingNodes) {
        double numSamplesOnAverage = remainingSamples / (double)remainingNodes;
        long wholeSamples = (long)numSamplesOnAverage;
        int extraSample = this.sample(numSamplesOnAverage - (double)wholeSamples) ? 1 : 0;
        return Math.min(maxSamples, wholeSamples + (long)extraSample);
    }

    private boolean sample(double probability) {
        return this.rng.nextDouble() < probability;
    }
}

