/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels.pipeline.train;

import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import org.neo4j.gds.AlgorithmFactory;
import org.neo4j.gds.BaseProc;
import org.neo4j.gds.TrainProc;
import org.neo4j.gds.config.GraphCreateConfig;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.ml.MLTrainResult;
import org.neo4j.gds.ml.linkmodels.pipeline.LinkPredictionModelInfo;
import org.neo4j.gds.ml.linkmodels.pipeline.logisticRegression.LinkLogisticRegressionData;
import org.neo4j.gds.ml.linkmodels.pipeline.train.LinkPredictionTrain;
import org.neo4j.gds.ml.linkmodels.pipeline.train.LinkPredictionTrainConfig;
import org.neo4j.gds.ml.linkmodels.pipeline.train.LinkPredictionTrainFactory;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

public class LinkPredictionPipelineTrainProc
extends TrainProc<LinkPredictionTrain, LinkLogisticRegressionData, LinkPredictionTrainConfig, LinkPredictionModelInfo> {
    @Procedure(name="gds.alpha.ml.pipeline.linkPrediction.train", mode=Mode.READ)
    @Description(value="Trains a link prediction model based on a pipeline")
    public Stream<MLTrainResult> train(@Name(value="graphName") Object graphNameOrConfig, @Name(value="configuration", defaultValue="{}") Map<String, Object> config) {
        return this.trainAndStoreModelWithResult(graphNameOrConfig, config, (model, result) -> new MLTrainResult((Model<?, ?, ?>)model, result.computeMillis()));
    }

    protected LinkPredictionTrainConfig newConfig(String username, Optional<String> graphName, Optional<GraphCreateConfig> maybeImplicitCreate, CypherMapWrapper config) {
        return LinkPredictionTrainConfig.of(this.username(), graphName, maybeImplicitCreate, config);
    }

    protected AlgorithmFactory<LinkPredictionTrain, LinkPredictionTrainConfig> algorithmFactory() {
        return new LinkPredictionTrainFactory(this.databaseId(), (BaseProc)this);
    }

    protected String modelType() {
        return "Link prediction pipeline";
    }
}

