/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.predict;

import com.carrotsearch.hppc.LongHashSet;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.stream.LongStream;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.core.concurrency.Pools;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.partition.Partition;
import org.neo4j.gds.core.utils.partition.PartitionUtils;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.queue.BoundedLongLongPriorityQueue;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.linkmodels.ExhaustiveLinkPredictionResult;
import org.neo4j.gds.ml.linkmodels.pipeline.logisticRegression.LinkLogisticRegressionData;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPrediction;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPredictionPredictPipelineBaseConfig;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPredictionSimilarityComputer;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureExtractor;

public class ExhaustiveLinkPrediction
extends LinkPrediction {
    private final int topN;
    private final double threshold;

    public ExhaustiveLinkPrediction(LinkLogisticRegressionData modelData, LinkFeatureExtractor linkFeatureExtractor, Graph graph, int concurrency, int topN, double threshold, ProgressTracker progressTracker) {
        super(modelData, linkFeatureExtractor, graph, concurrency, progressTracker);
        this.topN = topN;
        this.threshold = threshold;
    }

    public static MemoryEstimation estimate(LinkPredictionPredictPipelineBaseConfig config, int linkFeatureDimension) {
        return MemoryEstimations.builder(ExhaustiveLinkPrediction.class).add("Priority queue", BoundedLongLongPriorityQueue.memoryEstimation((int)config.topN().orElseThrow())).perGraphDimension("Predict links operation", (dim, threads) -> MemoryRange.of((long)(MemoryUsage.sizeOfDoubleArray((long)linkFeatureDimension) + MemoryUsage.sizeOfLongHashSet((long)dim.averageDegree()))).times((long)threads.intValue())).build();
    }

    ExhaustiveLinkPredictionResult predictLinks(Graph graph, LinkPredictionSimilarityComputer linkPredictionSimilarityComputer) {
        BoundedLongLongPriorityQueue predictionQueue = BoundedLongLongPriorityQueue.max((int)this.topN);
        List tasks = PartitionUtils.rangePartition((int)this.concurrency, (long)graph.nodeCount(), partition -> new LinkPredictionScoreByIdsConsumer(graph.concurrentCopy(), linkPredictionSimilarityComputer, predictionQueue, (Partition)partition, this.progressTracker), Optional.empty());
        ParallelUtil.runWithConcurrency((int)this.concurrency, (Iterable)tasks, (ExecutorService)Pools.DEFAULT);
        long linksConsidered = tasks.stream().mapToLong(LinkPredictionScoreByIdsConsumer::linksConsidered).sum();
        return new ExhaustiveLinkPredictionResult(predictionQueue, linksConsidered);
    }

    final class LinkPredictionScoreByIdsConsumer
    implements Runnable {
        private final Graph graph;
        private final LinkPredictionSimilarityComputer linkPredictionSimilarityComputer;
        private final BoundedLongLongPriorityQueue predictionQueue;
        private final ProgressTracker progressTracker;
        private final Partition partition;
        private long linksConsidered;

        LinkPredictionScoreByIdsConsumer(Graph graph, LinkPredictionSimilarityComputer linkPredictionSimilarityComputer, BoundedLongLongPriorityQueue predictionQueue, Partition partition, ProgressTracker progressTracker) {
            this.graph = graph;
            this.linkPredictionSimilarityComputer = linkPredictionSimilarityComputer;
            this.predictionQueue = predictionQueue;
            this.progressTracker = progressTracker;
            this.partition = partition;
            this.linksConsidered = 0L;
        }

        @Override
        public void run() {
            this.partition.consume(sourceId -> {
                LongHashSet largerNeighbors = this.largerNeighbors(sourceId);
                long smallestTarget = sourceId + 1L;
                LongStream.range(smallestTarget, this.graph.nodeCount()).forEach(targetId -> {
                    if (largerNeighbors.contains(targetId)) {
                        return;
                    }
                    double probability = this.linkPredictionSimilarityComputer.similarity(sourceId, targetId);
                    ++this.linksConsidered;
                    if (probability < ExhaustiveLinkPrediction.this.threshold) {
                        return;
                    }
                    BoundedLongLongPriorityQueue boundedLongLongPriorityQueue = this.predictionQueue;
                    synchronized (boundedLongLongPriorityQueue) {
                        this.predictionQueue.offer(sourceId, targetId, probability);
                    }
                });
            });
            this.progressTracker.logProgress(this.partition.nodeCount());
        }

        private LongHashSet largerNeighbors(long sourceId) {
            LongHashSet neighbors = new LongHashSet();
            this.graph.forEachRelationship(sourceId, (src, trg) -> {
                if (src < trg) {
                    neighbors.add(trg);
                }
                return true;
            });
            return neighbors;
        }

        long linksConsidered() {
            return this.linksConsidered;
        }
    }
}

