/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels;

import java.util.Map;
import java.util.stream.Stream;
import org.neo4j.gds.AlgoBaseProc;
import org.neo4j.gds.AlgorithmFactory;
import org.neo4j.gds.GraphAlgorithmFactory;
import org.neo4j.gds.api.Graph;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.executor.AlgorithmSpec;
import org.neo4j.gds.executor.ComputationResult;
import org.neo4j.gds.executor.ComputationResultConsumer;
import org.neo4j.gds.executor.ExecutionMode;
import org.neo4j.gds.executor.GdsCallable;
import org.neo4j.gds.executor.validation.ValidationConfiguration;
import org.neo4j.gds.ml.linkmodels.ExhaustiveLinkPredictionResult;
import org.neo4j.gds.ml.linkmodels.LinkPredictionPredict;
import org.neo4j.gds.ml.linkmodels.LinkPredictionPredictCompanion;
import org.neo4j.gds.ml.linkmodels.LinkPredictionPredictFactory;
import org.neo4j.gds.ml.linkmodels.LinkPredictionPredictStreamConfig;
import org.neo4j.gds.results.MemoryEstimateResult;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@GdsCallable(name="gds.alpha.ml.linkPrediction.predict.stream", description="Predicts relationships for all node pairs based on a previously trained link prediction model.", executionMode=ExecutionMode.STREAM)
public class LinkPredictionPredictStreamProc
extends AlgoBaseProc<LinkPredictionPredict, ExhaustiveLinkPredictionResult, LinkPredictionPredictStreamConfig, Result> {
    @Procedure(name="gds.alpha.ml.linkPrediction.predict.stream", mode=Mode.READ)
    @Description(value="Predicts relationships for all node pairs based on a previously trained link prediction model.")
    public Stream<Result> stream(@Name(value="graphName") String graphName, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        ComputationResult result = this.compute(graphName, configuration);
        return (Stream)this.computationResultConsumer().consume(result, this.executionContext());
    }

    @Procedure(name="gds.alpha.ml.linkPrediction.predict.stream.estimate", mode=Mode.READ)
    @Description(value="Estimates memory for applying a linkPrediction model")
    public Stream<MemoryEstimateResult> estimate(@Name(value="graphNameOrConfiguration") Object graphNameOrConfiguration, @Name(value="algoConfiguration") Map<String, Object> algoConfiguration) {
        return this.computeEstimate(graphNameOrConfiguration, algoConfiguration);
    }

    public ValidationConfiguration<LinkPredictionPredictStreamConfig> validationConfig() {
        return LinkPredictionPredictCompanion.getValidationConfig();
    }

    public AlgorithmSpec<LinkPredictionPredict, ExhaustiveLinkPredictionResult, LinkPredictionPredictStreamConfig, Stream<Result>, AlgorithmFactory<?, LinkPredictionPredict, LinkPredictionPredictStreamConfig>> withModelCatalog(ModelCatalog modelCatalog) {
        this.setModelCatalog(modelCatalog);
        return this;
    }

    protected LinkPredictionPredictStreamConfig newConfig(String username, CypherMapWrapper config) {
        return LinkPredictionPredictStreamConfig.of((String)username, (CypherMapWrapper)config);
    }

    public GraphAlgorithmFactory<LinkPredictionPredict, LinkPredictionPredictStreamConfig> algorithmFactory() {
        return new LinkPredictionPredictFactory(this.modelCatalog());
    }

    public ComputationResultConsumer<LinkPredictionPredict, ExhaustiveLinkPredictionResult, LinkPredictionPredictStreamConfig, Stream<Result>> computationResultConsumer() {
        return (computationResult, executionContext) -> {
            Graph graph = computationResult.graph();
            if (computationResult.isGraphEmpty()) {
                graph.release();
                return Stream.empty();
            }
            return ((ExhaustiveLinkPredictionResult)computationResult.result()).stream().map(predictedLink -> new Result(graph.toOriginalNodeId(predictedLink.sourceId()), graph.toOriginalNodeId(predictedLink.targetId()), predictedLink.probability()));
        };
    }

    public static final class Result {
        public final long node1;
        public final long node2;
        public final double probability;

        public Result(long node1, long node2, double probability) {
            this.node1 = node1;
            this.node2 = node2;
            this.probability = probability;
        }
    }
}

