/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.splitting;

import com.carrotsearch.hppc.predicates.LongLongPredicate;
import com.carrotsearch.hppc.predicates.LongPredicate;
import java.util.List;
import java.util.Optional;
import java.util.Random;
import org.apache.commons.lang3.mutable.MutableLong;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.api.DefaultValue;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.api.PartialIdMap;
import org.neo4j.gds.api.RelationshipWithPropertyConsumer;
import org.neo4j.gds.api.schema.Direction;
import org.neo4j.gds.core.Aggregation;
import org.neo4j.gds.core.concurrency.DefaultPool;
import org.neo4j.gds.core.loading.construction.GraphFactory;
import org.neo4j.gds.core.loading.construction.RelationshipsBuilder;
import org.neo4j.gds.ml.splitting.ImmutableSplitResult;

public abstract class EdgeSplitter {
    public static final double POSITIVE = 1.0;
    public static final String RELATIONSHIP_PROPERTY = "label";
    private final Random rng;
    private final RelationshipType selectedRelationshipType;
    private final RelationshipType remainingRelationshipType;
    private final IdMap sourceNodes;
    private final IdMap targetNodes;
    private final IdMap rootNodes;
    final int concurrency;

    EdgeSplitter(Optional<Long> maybeSeed, IdMap rootNodes, IdMap sourceNodes, IdMap targetNodes, RelationshipType selectedRelationshipType, RelationshipType remainingRelationshipType, int concurrency) {
        this.rootNodes = rootNodes;
        this.selectedRelationshipType = selectedRelationshipType;
        this.remainingRelationshipType = remainingRelationshipType;
        this.rng = new Random();
        maybeSeed.ifPresent(this.rng::setSeed);
        this.sourceNodes = sourceNodes;
        this.targetNodes = targetNodes;
        this.concurrency = concurrency;
    }

    public SplitResult splitPositiveExamples(Graph graph, double holdoutFraction, Optional<String> remainingRelPropertyKey) {
        LongPredicate isValidSourceNode = node -> this.sourceNodes.containsOriginalId(graph.toOriginalNodeId(node));
        LongPredicate isValidTargetNode = node -> this.targetNodes.containsOriginalId(graph.toOriginalNodeId(node));
        LongLongPredicate isValidNodePair = (s, t) -> isValidSourceNode.apply(s) && isValidTargetNode.apply(t);
        RelationshipsBuilder selectedRelsBuilder = EdgeSplitter.newRelationshipsBuilder(this.rootNodes, this.selectedRelationshipType, Direction.DIRECTED, Optional.of(RELATIONSHIP_PROPERTY));
        Direction remainingRelDirection = graph.schema().direction();
        RelationshipsBuilder remainingRelsBuilder = EdgeSplitter.newRelationshipsBuilder(this.rootNodes, this.remainingRelationshipType, remainingRelDirection, remainingRelPropertyKey);
        RelationshipWithPropertyConsumer remainingRelsConsumer = (s, t, w) -> {
            remainingRelsBuilder.addFromInternal(graph.toRootNodeId(s), graph.toRootNodeId(t), w);
            return true;
        };
        long validRelationshipCount = this.validPositiveRelationshipCandidateCount(graph, isValidNodePair);
        long positiveSamples = (long)((double)validRelationshipCount * holdoutFraction);
        MutableLong positiveSamplesRemaining = new MutableLong(positiveSamples);
        MutableLong candidateEdgesRemaining = new MutableLong(validRelationshipCount);
        MutableLong selectedRelCount = new MutableLong(0L);
        MutableLong remainingRelCount = new MutableLong(0L);
        graph.forEachNode(nodeId -> {
            this.positiveSampling(graph, selectedRelsBuilder, remainingRelsConsumer, selectedRelCount, remainingRelCount, nodeId, isValidNodePair, positiveSamplesRemaining, candidateEdgesRemaining);
            return true;
        });
        return SplitResult.of(remainingRelsBuilder, remainingRelCount.longValue(), selectedRelsBuilder, selectedRelCount.longValue());
    }

    protected abstract void positiveSampling(Graph var1, RelationshipsBuilder var2, RelationshipWithPropertyConsumer var3, MutableLong var4, MutableLong var5, long var6, LongLongPredicate var8, MutableLong var9, MutableLong var10);

    protected abstract long validPositiveRelationshipCandidateCount(Graph var1, LongLongPredicate var2);

    boolean sample(double probability) {
        return this.rng.nextDouble() < probability;
    }

    long samplesPerNode(long maxSamples, double remainingSamples, long remainingNodes) {
        double numSamplesOnAverage = remainingSamples / (double)remainingNodes;
        long wholeSamples = (long)numSamplesOnAverage;
        int extraSample = this.sample(numSamplesOnAverage - (double)wholeSamples) ? 1 : 0;
        return Math.min(maxSamples, wholeSamples + (long)extraSample);
    }

    private static RelationshipsBuilder newRelationshipsBuilder(IdMap rootNodes, RelationshipType relationshipType, Direction direction, Optional<String> propertyKey) {
        return GraphFactory.initRelationshipsBuilder().relationshipType(relationshipType).aggregation(Aggregation.SINGLE).nodes((PartialIdMap)rootNodes).orientation(direction.toOrientation()).addAllPropertyConfigs((Iterable)propertyKey.map(key -> List.of(GraphFactory.PropertyConfig.of((String)key, (Aggregation)Aggregation.SINGLE, (DefaultValue)DefaultValue.forDouble()))).orElse(List.of())).concurrency(1).executorService(DefaultPool.INSTANCE).build();
    }

    @ValueClass
    public static interface SplitResult {
        public RelationshipsBuilder remainingRels();

        public long remainingRelCount();

        public RelationshipsBuilder selectedRels();

        public long selectedRelCount();

        public static SplitResult of(RelationshipsBuilder remainingRels, long remainingRelCount, RelationshipsBuilder selectedRels, long selectedRelCount) {
            return ImmutableSplitResult.of(remainingRels, remainingRelCount, selectedRels, selectedRelCount);
        }
    }
}

