/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.predict;

import com.carrotsearch.hppc.LongHashSet;
import java.util.Collection;
import java.util.Iterator;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.stream.LongStream;
import org.neo4j.gds.Algorithm;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.ml.core.batch.Batch;
import org.neo4j.gds.ml.core.batch.BatchQueue;
import org.neo4j.gds.ml.core.tensor.Matrix;
import org.neo4j.gds.ml.linkmodels.LinkPredictionResult;
import org.neo4j.gds.ml.linkmodels.pipeline.PipelineExecutor;
import org.neo4j.gds.ml.linkmodels.pipeline.linkFeatures.LinkFeatureExtractor;
import org.neo4j.gds.ml.linkmodels.pipeline.logisticRegression.LinkLogisticRegressionData;
import org.neo4j.gds.ml.linkmodels.pipeline.logisticRegression.LinkLogisticRegressionPredictor;

public class LinkPrediction
extends Algorithm<LinkPrediction, LinkPredictionResult> {
    private final LinkLogisticRegressionData modelData;
    private final PipelineExecutor pipelineExecutor;
    private final Collection<NodeLabel> nodeLabels;
    private final Collection<RelationshipType> relationshipTypes;
    private final GraphStore graphStore;
    private final int concurrency;
    private final int topN;
    private final double threshold;

    public LinkPrediction(LinkLogisticRegressionData modelData, PipelineExecutor pipelineExecutor, Collection<NodeLabel> nodeLabels, Collection<RelationshipType> relationshipTypes, GraphStore graphStore, int concurrency, int topN, double threshold, ProgressTracker progressTracker) {
        super(progressTracker);
        this.modelData = modelData;
        this.pipelineExecutor = pipelineExecutor;
        this.nodeLabels = nodeLabels;
        this.relationshipTypes = relationshipTypes;
        this.graphStore = graphStore;
        this.concurrency = concurrency;
        this.topN = topN;
        this.threshold = threshold;
    }

    public LinkPredictionResult compute() {
        this.progressTracker.beginSubTask();
        this.pipelineExecutor.executeNodePropertySteps(this.nodeLabels, this.relationshipTypes);
        this.assertRunning();
        LinkPredictionResult result = this.predictLinks();
        this.pipelineExecutor.removeNodeProperties(this.graphStore, this.nodeLabels);
        this.progressTracker.endSubTask();
        return result;
    }

    private LinkPredictionResult predictLinks() {
        this.progressTracker.beginSubTask();
        Graph graph = this.graphStore.getGraph(this.nodeLabels, this.relationshipTypes, Optional.empty());
        LinkFeatureExtractor featureExtractor = this.pipelineExecutor.linkFeatureExtractor(graph);
        assert (featureExtractor.featureDimension() == ((Matrix)this.modelData.weights().data()).totalSize()) : "Model must contain a weight for each feature.";
        LinkLogisticRegressionPredictor predictor = new LinkLogisticRegressionPredictor(this.modelData);
        LinkPredictionResult result = new LinkPredictionResult(this.topN);
        BatchQueue batchQueue = new BatchQueue(graph.nodeCount(), 100, this.concurrency);
        batchQueue.parallelConsume(this.concurrency, ignore -> new LinkPredictionScoreByIdsConsumer(graph.concurrentCopy(), featureExtractor, predictor, result, this.progressTracker), this.terminationFlag);
        this.progressTracker.endSubTask();
        return result;
    }

    public LinkPrediction me() {
        return this;
    }

    public void release() {
    }

    private final class LinkPredictionScoreByIdsConsumer
    implements Consumer<Batch> {
        private final Graph graph;
        private final LinkFeatureExtractor linkFeatureExtractor;
        private final LinkLogisticRegressionPredictor predictor;
        private final LinkPredictionResult predictedLinks;
        private final ProgressTracker progressTracker;

        private LinkPredictionScoreByIdsConsumer(Graph graph, LinkFeatureExtractor linkFeatureExtractor, LinkLogisticRegressionPredictor predictor, LinkPredictionResult predictedLinks, ProgressTracker progressTracker) {
            this.graph = graph;
            this.linkFeatureExtractor = linkFeatureExtractor;
            this.predictor = predictor;
            this.predictedLinks = predictedLinks;
            this.progressTracker = progressTracker;
        }

        @Override
        public void accept(Batch batch) {
            Iterator iterator = batch.nodeIds().iterator();
            while (iterator.hasNext()) {
                long sourceId = (Long)iterator.next();
                LongHashSet largerNeighbors = this.largerNeighbors(sourceId);
                long smallestTarget = sourceId + 1L;
                LongStream.range(smallestTarget, this.graph.nodeCount()).forEach(targetId -> {
                    if (largerNeighbors.contains(targetId)) {
                        return;
                    }
                    double[] features = this.linkFeatureExtractor.extractFeatures(sourceId, targetId);
                    double probability = this.predictor.predictedProbability(features);
                    if (probability < LinkPrediction.this.threshold) {
                        return;
                    }
                    this.predictedLinks.add(sourceId, targetId, probability);
                });
            }
            this.progressTracker.logProgress((long)batch.size());
        }

        private LongHashSet largerNeighbors(long sourceId) {
            LongHashSet neighbors = new LongHashSet();
            this.graph.forEachRelationship(sourceId, (src, trg) -> {
                if (src < trg) {
                    neighbors.add(trg);
                }
                return true;
            });
            return neighbors;
        }
    }
}

