/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.predict;

import org.neo4j.gds.api.Graph;
import org.neo4j.gds.ml.linkmodels.pipeline.logisticRegression.LinkLogisticRegressionPredictor;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureExtractor;
import org.neo4j.gds.similarity.knn.NeighborFilter;
import org.neo4j.gds.similarity.knn.NeighborFilterFactory;
import org.neo4j.gds.similarity.knn.SimilarityComputer;

class LinkPredictionSimilarityComputer
implements SimilarityComputer {
    private final LinkFeatureExtractor linkFeatureExtractor;
    private final LinkLogisticRegressionPredictor predictor;
    private final Graph graph;

    LinkPredictionSimilarityComputer(LinkFeatureExtractor linkFeatureExtractor, LinkLogisticRegressionPredictor predictor, Graph graph) {
        this.linkFeatureExtractor = linkFeatureExtractor;
        this.predictor = predictor;
        this.graph = graph;
    }

    public double similarity(long sourceId, long targetId) {
        double[] features = this.linkFeatureExtractor.extractFeatures(sourceId, targetId);
        return this.predictor.predictedProbability(features);
    }

    static class LinkFilterFactory
    implements NeighborFilterFactory {
        private final Graph graph;

        LinkFilterFactory(Graph graph) {
            this.graph = graph;
        }

        public NeighborFilter create() {
            return new LinkFilter(this.graph.concurrentCopy());
        }
    }

    static final class LinkFilter
    implements NeighborFilter {
        private final Graph graph;

        private LinkFilter(Graph graph) {
            this.graph = graph;
        }

        public boolean excludeNodePair(long firstNodeId, long secondNodeId) {
            if (firstNodeId == secondNodeId) {
                return true;
            }
            return this.graph.exists(firstNodeId, secondNodeId);
        }

        public long lowerBoundOfPotentialNeighbours(long node) {
            return this.graph.nodeCount() - 1L - (long)this.graph.degree(node);
        }
    }
}

