/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.predict;

import org.neo4j.gds.api.Graph;
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkFeatureExtractor;
import org.neo4j.gds.models.Classifier;
import org.neo4j.gds.models.FeaturesFactory;
import org.neo4j.gds.similarity.knn.NeighborFilter;
import org.neo4j.gds.similarity.knn.NeighborFilterFactory;
import org.neo4j.gds.similarity.knn.metrics.SimilarityComputer;

class LinkPredictionSimilarityComputer
implements SimilarityComputer {
    private final LinkFeatureExtractor linkFeatureExtractor;
    private final Classifier classifier;
    private final int positiveClassLocalId;

    LinkPredictionSimilarityComputer(LinkFeatureExtractor linkFeatureExtractor, Classifier classifier) {
        this.linkFeatureExtractor = linkFeatureExtractor;
        this.classifier = classifier;
        this.positiveClassLocalId = classifier.classIdMap().toMapped(1L);
    }

    public double similarity(long sourceId, long targetId) {
        double[] features = this.linkFeatureExtractor.extractFeatures(sourceId, targetId);
        return this.classifier.predictProbabilities(0L, FeaturesFactory.wrap((double[])features))[this.positiveClassLocalId];
    }

    static class LinkFilterFactory
    implements NeighborFilterFactory {
        private final Graph graph;

        LinkFilterFactory(Graph graph) {
            this.graph = graph;
        }

        public NeighborFilter create() {
            return new LinkFilter(this.graph.concurrentCopy());
        }
    }

    static final class LinkFilter
    implements NeighborFilter {
        private final Graph graph;

        private LinkFilter(Graph graph) {
            this.graph = graph;
        }

        public boolean excludeNodePair(long firstNodeId, long secondNodeId) {
            if (firstNodeId == secondNodeId) {
                return true;
            }
            return this.graph.exists(firstNodeId, secondNodeId);
        }

        public long lowerBoundOfPotentialNeighbours(long node) {
            return this.graph.nodeCount() - 1L - (long)this.graph.degree(node);
        }
    }
}

