/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.splitting;

import com.carrotsearch.hppc.predicates.LongLongPredicate;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.LongAdder;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.api.RelationshipWithPropertyConsumer;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.loading.construction.RelationshipsBuilder;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.ml.splitting.EdgeSplitter;

public class DirectedEdgeSplitter
extends EdgeSplitter {
    public DirectedEdgeSplitter(Optional<Long> maybeSeed, IdMap rootNodes, IdMap sourceLabels, IdMap targetLabels, RelationshipType selectedRelationshipType, RelationshipType remainingRelationshipType, int concurrency) {
        super(maybeSeed, rootNodes, sourceLabels, targetLabels, selectedRelationshipType, remainingRelationshipType, concurrency);
    }

    @Override
    protected long validPositiveRelationshipCandidateCount(Graph graph, LongLongPredicate isValidNodePair) {
        LongAdder validRelationshipCountAdder = new LongAdder();
        List countValidRelationshipTasks = PartitionUtils.degreePartition((Graph)graph, (int)this.concurrency, partition -> () -> {
            Graph concurrentGraph = graph.concurrentCopy();
            partition.consume(nodeId -> concurrentGraph.forEachRelationship(nodeId, (s, t) -> {
                if (isValidNodePair.apply(s, t)) {
                    validRelationshipCountAdder.add(1L);
                }
                return true;
            }));
        }, Optional.empty());
        RunWithConcurrency.builder().concurrency(this.concurrency).tasks((Iterable)countValidRelationshipTasks).run();
        return validRelationshipCountAdder.longValue();
    }

    @Override
    protected void positiveSampling(Graph graph, RelationshipsBuilder selectedRelsBuilder, RelationshipWithPropertyConsumer remainingRelsConsumer, MutableLong selectedRelCount, MutableLong remainingRelCount, long nodeId, LongLongPredicate isValidNodePair, MutableLong positiveSamplesRemaining, MutableLong candidateEdgesRemaining) {
        graph.forEachRelationship(nodeId, Double.NaN, (source, target, weight) -> {
            if (isValidNodePair.apply(source, target)) {
                if (this.sample(positiveSamplesRemaining.doubleValue() / candidateEdgesRemaining.doubleValue())) {
                    positiveSamplesRemaining.decrementAndGet();
                    selectedRelCount.increment();
                    selectedRelsBuilder.addFromInternal(graph.toRootNodeId(source), graph.toRootNodeId(target), 1.0);
                } else {
                    remainingRelCount.increment();
                    remainingRelsConsumer.accept(source, target, weight);
                }
                candidateEdgesRemaining.addAndGet(-1L);
            }
            return true;
        });
    }
}

