/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.nodemodels.pipeline.predict;

import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import org.neo4j.gds.AlgorithmFactory;
import org.neo4j.gds.BaseProc;
import org.neo4j.gds.TrainProc;
import org.neo4j.gds.config.GraphCreateConfig;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.ml.MLTrainResult;
import org.neo4j.gds.ml.nodemodels.NodeClassificationTrainPipelineAlgorithmFactory;
import org.neo4j.gds.ml.nodemodels.logisticregression.NodeLogisticRegressionData;
import org.neo4j.gds.ml.nodemodels.pipeline.NodeClassificationPipelineModelInfo;
import org.neo4j.gds.ml.nodemodels.pipeline.NodeClassificationPipelineTrainConfig;
import org.neo4j.gds.ml.nodemodels.pipeline.NodeClassificationTrainPipelineExecutor;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

public class NodeClassificationPipelineTrainProc
extends TrainProc<NodeClassificationTrainPipelineExecutor, NodeLogisticRegressionData, NodeClassificationPipelineTrainConfig, NodeClassificationPipelineModelInfo> {
    @Context
    public ModelCatalog modelCatalog;

    @Procedure(name="gds.alpha.ml.pipeline.nodeClassification.train", mode=Mode.READ)
    @Description(value="Trains a node classification model based on a pipeline")
    public Stream<MLTrainResult> train(@Name(value="graphName") Object graphNameOrConfig, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        return this.trainAndStoreModelWithResult(graphNameOrConfig, configuration, (model, result) -> new MLTrainResult((Model<?, ?, ?>)model, result.computeMillis()));
    }

    protected NodeClassificationPipelineTrainConfig newConfig(String username, Optional<String> graphName, Optional<GraphCreateConfig> maybeImplicitCreate, CypherMapWrapper config) {
        return NodeClassificationPipelineTrainConfig.of(graphName, maybeImplicitCreate, (String)username, (CypherMapWrapper)config);
    }

    protected AlgorithmFactory<NodeClassificationTrainPipelineExecutor, NodeClassificationPipelineTrainConfig> algorithmFactory() {
        return new NodeClassificationTrainPipelineAlgorithmFactory((BaseProc)this, this.databaseId(), this.modelCatalog);
    }

    protected String modelType() {
        return "nodeLogisticRegression";
    }
}

