/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.train;

import java.util.Map;
import java.util.stream.Stream;
import org.neo4j.gds.GraphStoreAlgorithmFactory;
import org.neo4j.gds.TrainProc;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.executor.ComputationResult;
import org.neo4j.gds.executor.ExecutionMode;
import org.neo4j.gds.executor.GdsCallable;
import org.neo4j.gds.ml.MLTrainResult;
import org.neo4j.gds.ml.linkmodels.pipeline.train.LinkPredictionTrainConfig;
import org.neo4j.gds.ml.linkmodels.pipeline.train.LinkPredictionTrainPipelineAlgorithmFactory;
import org.neo4j.gds.ml.linkmodels.pipeline.train.LinkPredictionTrainPipelineExecutor;
import org.neo4j.gds.ml.linkmodels.pipeline.train.LinkPredictionTrainResult;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@GdsCallable(name="gds.alpha.ml.pipeline.linkPrediction.train", description="Trains a link prediction model based on a pipeline", executionMode=ExecutionMode.TRAIN)
public class LinkPredictionPipelineTrainProc
extends TrainProc<LinkPredictionTrainPipelineExecutor, LinkPredictionTrainResult, LinkPredictionTrainConfig, LPTrainResult> {
    @Procedure(name="gds.alpha.ml.pipeline.linkPrediction.train", mode=Mode.READ)
    @Description(value="Trains a link prediction model based on a pipeline")
    public Stream<LPTrainResult> train(@Name(value="graphName") String graphName, @Name(value="configuration", defaultValue="{}") Map<String, Object> config) {
        config.put("graphName", graphName);
        return this.trainAndStoreModelWithResult(this.compute(graphName, config));
    }

    protected LinkPredictionTrainConfig newConfig(String username, CypherMapWrapper config) {
        return LinkPredictionTrainConfig.of(this.username(), config);
    }

    public GraphStoreAlgorithmFactory<LinkPredictionTrainPipelineExecutor, LinkPredictionTrainConfig> algorithmFactory() {
        return new LinkPredictionTrainPipelineAlgorithmFactory(this.executionContext(), this.modelCatalog);
    }

    protected String modelType() {
        return "Link prediction pipeline";
    }

    protected Model<?, ?, ?> extractModel(LinkPredictionTrainResult algoResult) {
        return algoResult.model();
    }

    protected LPTrainResult constructProcResult(ComputationResult<LinkPredictionTrainPipelineExecutor, LinkPredictionTrainResult, LinkPredictionTrainConfig> computationResult) {
        return new LPTrainResult((LinkPredictionTrainResult)computationResult.result(), computationResult.computeMillis());
    }

    public static class LPTrainResult
    extends MLTrainResult {
        public final Map<String, Object> modelSelectionStats;

        public LPTrainResult(LinkPredictionTrainResult algoResult, long trainMillis) {
            super(algoResult.model(), trainMillis);
            this.modelSelectionStats = algoResult.modelSelectionStatistics().toMap();
        }
    }
}

