/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.splitting;

import java.util.Collection;
import java.util.Optional;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.api.IdMap;
import org.neo4j.gds.config.ElementTypeValidator;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.negativeSampling.RandomNegativeSampler;
import org.neo4j.gds.ml.splitting.DirectedEdgeSplitter;
import org.neo4j.gds.ml.splitting.EdgeSplitter;
import org.neo4j.gds.ml.splitting.SplitRelationshipsBaseConfig;
import org.neo4j.gds.ml.splitting.UndirectedEdgeSplitter;

public final class SplitRelationships
extends Algorithm<EdgeSplitter.SplitResult> {
    private final Graph graph;
    private final Graph masterGraph;
    private final SplitRelationshipsBaseConfig config;
    private final IdMap rootNodes;
    private final IdMap sourceNodes;
    private final IdMap targetNodes;

    private SplitRelationships(Graph graph, Graph masterGraph, IdMap rootNodes, IdMap sourceNodes, IdMap targetNodes, SplitRelationshipsBaseConfig config) {
        super(ProgressTracker.NULL_TRACKER);
        this.graph = graph;
        this.masterGraph = masterGraph;
        this.rootNodes = rootNodes;
        this.config = config;
        this.sourceNodes = sourceNodes;
        this.targetNodes = targetNodes;
    }

    public static SplitRelationships of(GraphStore graphStore, SplitRelationshipsBaseConfig config) {
        Collection nodeLabels = config.nodeLabelIdentifiers(graphStore);
        Collection sourceLabels = ElementTypeValidator.resolve((GraphStore)graphStore, config.sourceNodeLabels());
        Collection targetLabels = ElementTypeValidator.resolve((GraphStore)graphStore, config.targetNodeLabels());
        Collection relationshipTypes = config.internalRelationshipTypes(graphStore);
        Collection superRelationshipTypes = ElementTypeValidator.resolveTypes((GraphStore)graphStore, config.superRelationshipTypes());
        Graph graph = graphStore.getGraph(nodeLabels, relationshipTypes, config.relationshipWeightProperty());
        Graph masterGraph = graphStore.getGraph(nodeLabels, superRelationshipTypes, Optional.empty());
        Graph sourceNodes = graphStore.getGraph(sourceLabels);
        Graph targetNodes = graphStore.getGraph(targetLabels);
        return new SplitRelationships(graph, masterGraph, graphStore.nodes(), (IdMap)sourceNodes, (IdMap)targetNodes, config);
    }

    public static MemoryEstimation estimate(SplitRelationshipsBaseConfig configuration) {
        int pessimisticSizePerRel = configuration.hasRelationshipWeightProperty() ? 24 : 16;
        return MemoryEstimations.builder((String)"Relationship splitter").perGraphDimension("Selected relationships", (graphDimensions, threads) -> {
            double positiveRelCount = (double)graphDimensions.estimatedRelCount(configuration.relationshipTypes()) * configuration.holdoutFraction();
            double negativeRelCount = positiveRelCount * configuration.negativeSamplingRatio();
            long selectedRelCount = (long)(positiveRelCount + negativeRelCount);
            return MemoryRange.of((long)(selectedRelCount / 2L), (long)selectedRelCount).times((long)pessimisticSizePerRel);
        }).perGraphDimension("Remaining relationships", (graphDimensions, threads) -> {
            long remainingRelCount = (long)((double)graphDimensions.estimatedRelCount(configuration.relationshipTypes()) * (1.0 - configuration.holdoutFraction()));
            return MemoryRange.of((long)(remainingRelCount * (long)pessimisticSizePerRel));
        }).build();
    }

    public EdgeSplitter.SplitResult compute() {
        boolean isUndirected = this.graph.schema().isUndirected();
        EdgeSplitter splitter = isUndirected ? new UndirectedEdgeSplitter(this.config.randomSeed(), this.rootNodes, this.sourceNodes, this.targetNodes, this.config.holdoutRelationshipType(), this.config.remainingRelationshipType(), this.config.concurrency()) : new DirectedEdgeSplitter(this.config.randomSeed(), this.rootNodes, this.sourceNodes, this.targetNodes, this.config.holdoutRelationshipType(), this.config.remainingRelationshipType(), this.config.concurrency());
        EdgeSplitter.SplitResult splitResult = splitter.splitPositiveExamples(this.graph, this.config.holdoutFraction(), this.config.relationshipWeightProperty());
        RandomNegativeSampler negativeSampler = new RandomNegativeSampler(this.masterGraph, (long)((double)splitResult.selectedRelCount() * this.config.negativeSamplingRatio()), 0L, this.sourceNodes, this.targetNodes, this.config.randomSeed());
        negativeSampler.produceNegativeSamples(splitResult.selectedRels(), null);
        return splitResult;
    }
}

