/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels;

import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import org.neo4j.gds.AlgoBaseProc;
import org.neo4j.gds.AlgorithmFactory;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.config.AlgoBaseConfig;
import org.neo4j.gds.config.GraphCreateConfig;
import org.neo4j.gds.config.GraphCreateConfigValidations;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.ml.linkmodels.ExhaustiveLinkPredictionResult;
import org.neo4j.gds.ml.linkmodels.LinkPredictionPredict;
import org.neo4j.gds.ml.linkmodels.LinkPredictionPredictFactory;
import org.neo4j.gds.ml.linkmodels.LinkPredictionPredictStreamConfig;
import org.neo4j.gds.results.MemoryEstimateResult;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

public class LinkPredictionPredictStreamProc
extends AlgoBaseProc<LinkPredictionPredict, ExhaustiveLinkPredictionResult, LinkPredictionPredictStreamConfig> {
    @Context
    public ModelCatalog modelCatalog;

    @Procedure(name="gds.alpha.ml.linkPrediction.predict.stream", mode=Mode.READ)
    @Description(value="Predicts relationships for all node pairs based on a previously trained link prediction model.")
    public Stream<Result> stream(@Name(value="graphName") Object graphNameOrConfig, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        AlgoBaseProc.ComputationResult result = this.compute(graphNameOrConfig, configuration);
        Graph graph = result.graph();
        if (result.isGraphEmpty()) {
            graph.release();
            return Stream.empty();
        }
        return ((ExhaustiveLinkPredictionResult)result.result()).stream().map(predictedLink -> new Result(graph.toOriginalNodeId(predictedLink.sourceId()), graph.toOriginalNodeId(predictedLink.targetId()), predictedLink.probability()));
    }

    @Procedure(name="gds.alpha.ml.linkPrediction.predict.stream.estimate", mode=Mode.READ)
    @Description(value="Estimates memory for applying a linkPrediction model")
    public Stream<MemoryEstimateResult> estimate(@Name(value="graphName") Object graphNameOrConfig, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        return this.computeEstimate(graphNameOrConfig, configuration);
    }

    protected void validateConfigsBeforeLoad(GraphCreateConfig graphCreateConfig, LinkPredictionPredictStreamConfig config) {
        GraphCreateConfigValidations.validateIsUndirectedGraph((GraphCreateConfig)graphCreateConfig, (AlgoBaseConfig)config);
    }

    protected LinkPredictionPredictStreamConfig newConfig(String username, Optional<String> graphName, Optional<GraphCreateConfig> maybeImplicitCreate, CypherMapWrapper config) {
        return LinkPredictionPredictStreamConfig.of((String)username, graphName, maybeImplicitCreate, (CypherMapWrapper)config);
    }

    protected AlgorithmFactory<LinkPredictionPredict, LinkPredictionPredictStreamConfig> algorithmFactory() {
        return new LinkPredictionPredictFactory(this.modelCatalog);
    }

    public static final class Result {
        public final long node1;
        public final long node2;
        public final double probability;

        public Result(long node1, long node2, double probability) {
            this.node1 = node1;
            this.node2 = node2;
            this.probability = probability;
        }
    }
}

