/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.nodemodels.pipeline.predict;

import java.util.Map;
import java.util.stream.Stream;
import org.neo4j.gds.GraphStoreAlgorithmFactory;
import org.neo4j.gds.TrainProc;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.executor.ComputationResult;
import org.neo4j.gds.executor.ExecutionMode;
import org.neo4j.gds.executor.GdsCallable;
import org.neo4j.gds.ml.MLTrainResult;
import org.neo4j.gds.ml.PipelineCompanion;
import org.neo4j.gds.ml.nodemodels.NodeClassificationTrainPipelineAlgorithmFactory;
import org.neo4j.gds.ml.nodemodels.pipeline.NodeClassificationPipelineTrainResult;
import org.neo4j.gds.ml.nodemodels.pipeline.NodeClassificationTrainPipelineExecutor;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeClassificationPipelineTrainConfig;
import org.neo4j.gds.results.MemoryEstimateResult;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@GdsCallable(name="gds.alpha.ml.pipeline.nodeClassification.train", description="Trains a node classification model based on a pipeline", executionMode=ExecutionMode.TRAIN)
public class NodeClassificationPipelineTrainProc
extends TrainProc<NodeClassificationTrainPipelineExecutor, NodeClassificationPipelineTrainResult, NodeClassificationPipelineTrainConfig, NCTrainResult> {
    @Procedure(name="gds.alpha.ml.pipeline.nodeClassification.train", mode=Mode.READ)
    @Description(value="Trains a node classification model based on a pipeline")
    public Stream<NCTrainResult> train(@Name(value="graphName") String graphName, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        PipelineCompanion.prepareTrainConfig(graphName, configuration);
        return this.trainAndStoreModelWithResult(this.compute(graphName, configuration));
    }

    @Procedure(name="gds.alpha.ml.pipeline.nodeClassification.train.estimate", mode=Mode.READ)
    @Description(value="Estimates memory for training a node classification model based on a pipeline")
    public Stream<MemoryEstimateResult> estimate(@Name(value="graphNameOrConfiguration") Object graphNameOrConfiguration, @Name(value="algoConfiguration") Map<String, Object> algoConfiguration) {
        PipelineCompanion.prepareTrainConfig(graphNameOrConfiguration, algoConfiguration);
        return this.computeEstimate(graphNameOrConfiguration, algoConfiguration);
    }

    protected NodeClassificationPipelineTrainConfig newConfig(String username, CypherMapWrapper config) {
        return NodeClassificationPipelineTrainConfig.of((String)username, (CypherMapWrapper)config);
    }

    public GraphStoreAlgorithmFactory<NodeClassificationTrainPipelineExecutor, NodeClassificationPipelineTrainConfig> algorithmFactory() {
        return new NodeClassificationTrainPipelineAlgorithmFactory(this.executionContext(), this.modelCatalog());
    }

    protected String modelType() {
        return "nodeLogisticRegression";
    }

    protected NCTrainResult constructProcResult(ComputationResult<NodeClassificationTrainPipelineExecutor, NodeClassificationPipelineTrainResult, NodeClassificationPipelineTrainConfig> computationResult) {
        return new NCTrainResult((NodeClassificationPipelineTrainResult)computationResult.result(), computationResult.computeMillis());
    }

    protected Model<?, ?, ?> extractModel(NodeClassificationPipelineTrainResult algoResult) {
        return algoResult.model();
    }

    public static class NCTrainResult
    extends MLTrainResult {
        public final Map<String, Object> modelSelectionStats;

        public NCTrainResult(NodeClassificationPipelineTrainResult algoResult, long trainMillis) {
            super(algoResult.model(), trainMillis);
            this.modelSelectionStats = algoResult.modelSelectionStatistics().toMap();
        }
    }
}

