/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.predict;

import java.util.Collection;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.stream.Stream;
import org.neo4j.gds.NodeLabel;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.api.GraphStore;
import org.neo4j.gds.core.concurrency.Pools;
import org.neo4j.gds.core.utils.mem.AllocationTracker;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.write.ImmutableRelationship;
import org.neo4j.gds.core.write.Relationship;
import org.neo4j.gds.ml.linkmodels.LinkPredictionResult;
import org.neo4j.gds.ml.linkmodels.PredictedLink;
import org.neo4j.gds.ml.linkmodels.pipeline.PipelineExecutor;
import org.neo4j.gds.ml.linkmodels.pipeline.logisticRegression.LinkLogisticRegressionData;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPrediction;
import org.neo4j.gds.ml.linkmodels.pipeline.predict.LinkPredictionSimilarityComputer;
import org.neo4j.gds.similarity.SimilarityResult;
import org.neo4j.gds.similarity.knn.ImmutableKnnContext;
import org.neo4j.gds.similarity.knn.Knn;
import org.neo4j.gds.similarity.knn.KnnBaseConfig;
import org.neo4j.gds.similarity.knn.SimilarityComputer;
import org.neo4j.values.storable.Value;
import org.neo4j.values.storable.Values;

public class ApproximateLinkPrediction
extends LinkPrediction {
    private final KnnBaseConfig knnConfig;

    public ApproximateLinkPrediction(LinkLogisticRegressionData modelData, PipelineExecutor pipelineExecutor, Collection<NodeLabel> nodeLabels, Collection<RelationshipType> relationshipTypes, GraphStore graphStore, KnnBaseConfig knnConfig, ProgressTracker progressTracker) {
        super(modelData, pipelineExecutor, nodeLabels, relationshipTypes, graphStore, knnConfig.concurrency(), progressTracker);
        this.knnConfig = knnConfig;
    }

    @Override
    LinkPredictionResult predictLinks(Graph graph, LinkPredictionSimilarityComputer linkPredictionSimilarityComputer) {
        Knn.Result knnResult = new Knn(graph.nodeCount(), this.knnConfig, (SimilarityComputer)linkPredictionSimilarityComputer, ImmutableKnnContext.of((ExecutorService)Pools.DEFAULT, (AllocationTracker)AllocationTracker.empty(), (ProgressTracker)this.progressTracker)).compute();
        Stream<SimilarityResult> predictions = knnResult.streamSimilarityResult().filter(i -> !graph.exists(i.sourceNodeId(), i.targetNodeId()));
        return new Result(predictions, knnResult.nodePairsConsidered(), knnResult.ranIterations(), knnResult.didConverge());
    }

    static class Result
    implements LinkPredictionResult {
        private final Stream<SimilarityResult> predictions;
        private final Map<String, Object> samplingStats;

        Result(Stream<SimilarityResult> predictions, long linksConsidered, long ranIterations, boolean didConverge) {
            this.predictions = predictions;
            this.samplingStats = Map.of("strategy", "approximate", "linksConsidered", linksConsidered, "ranIterations", ranIterations, "didConverge", didConverge);
        }

        public Stream<PredictedLink> stream() {
            return this.predictions.map(i -> PredictedLink.of((long)i.sourceNodeId(), (long)i.targetNodeId(), (double)i.similarity));
        }

        public Stream<Relationship> relationshipStream() {
            return this.predictions.map(i -> ImmutableRelationship.of((long)i.node1, (long)i.node2, (Value[])new Value[]{Values.doubleValue((double)i.similarity)}));
        }

        public Map<String, Object> samplingStats() {
            return this.samplingStats;
        }
    }
}

