/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.splitting;

import com.carrotsearch.hppc.predicates.LongLongPredicate;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.LongAdder;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.api.RelationshipWithPropertyConsumer;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.loading.construction.RelationshipsBuilder;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.ml.splitting.EdgeSplitter;

public class UndirectedEdgeSplitter
extends EdgeSplitter {
    public UndirectedEdgeSplitter(Optional<Long> maybeSeed, IdMap rootNodes, IdMap sourceNodes, IdMap targetNodes, RelationshipType selectedRelationshipType, RelationshipType remainingRelationshipType, int concurrency) {
        super(maybeSeed, rootNodes, sourceNodes, targetNodes, selectedRelationshipType, remainingRelationshipType, concurrency);
    }

    @Override
    protected long validPositiveRelationshipCandidateCount(Graph graph, LongLongPredicate isValidNodePair) {
        LongAdder validRelationshipCountAdder = new LongAdder();
        List countValidRelationshipTasks = PartitionUtils.degreePartition((Graph)graph, (int)this.concurrency, partition -> () -> {
            Graph concurrentGraph = graph.concurrentCopy();
            partition.consume(nodeId -> concurrentGraph.forEachRelationship(nodeId, (s, t) -> {
                if (s < t && (isValidNodePair.apply(s, t) || isValidNodePair.apply(t, s))) {
                    validRelationshipCountAdder.add(2L);
                }
                return true;
            }));
        }, Optional.empty());
        RunWithConcurrency.builder().concurrency(this.concurrency).tasks((Iterable)countValidRelationshipTasks).run();
        return validRelationshipCountAdder.longValue();
    }

    @Override
    protected void positiveSampling(Graph graph, RelationshipsBuilder selectedRelsBuilder, RelationshipWithPropertyConsumer remainingRelsConsumer, MutableLong selectedRelCount, MutableLong remainingRelCount, long nodeId, LongLongPredicate isValidNodePair, MutableLong positiveSamplesRemaining, MutableLong candidateEdgesRemaining) {
        graph.forEachRelationship(nodeId, Double.NaN, (source, target, weight) -> {
            if (source < target && (isValidNodePair.apply(source, target) || isValidNodePair.apply(target, source))) {
                if (this.sample(positiveSamplesRemaining.doubleValue() / candidateEdgesRemaining.doubleValue())) {
                    positiveSamplesRemaining.addAndGet(-2L);
                    selectedRelCount.increment();
                    if (isValidNodePair.apply(source, target)) {
                        selectedRelsBuilder.addFromInternal(graph.toRootNodeId(source), graph.toRootNodeId(target), 1.0);
                    } else {
                        selectedRelsBuilder.addFromInternal(graph.toRootNodeId(target), graph.toRootNodeId(source), 1.0);
                    }
                } else {
                    remainingRelCount.increment();
                    remainingRelsConsumer.accept(source, target, weight);
                }
                candidateEdgesRemaining.addAndGet(-2L);
            }
            return true;
        });
    }
}

