/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.linkmodels;

import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import org.neo4j.gds.GraphAlgorithmFactory;
import org.neo4j.gds.RelationshipType;
import org.neo4j.gds.TrainProc;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.executor.ComputationResult;
import org.neo4j.gds.executor.ExecutionMode;
import org.neo4j.gds.executor.GdsCallable;
import org.neo4j.gds.executor.validation.BeforeLoadValidation;
import org.neo4j.gds.executor.validation.GraphProjectConfigValidations;
import org.neo4j.gds.executor.validation.ValidationConfiguration;
import org.neo4j.gds.ml.MLTrainResult;
import org.neo4j.gds.ml.linkmodels.ImmutableLinkPredictionTrainConfig;
import org.neo4j.gds.ml.linkmodels.LinkPredictionModelInfo;
import org.neo4j.gds.ml.linkmodels.LinkPredictionTrain;
import org.neo4j.gds.ml.linkmodels.LinkPredictionTrainConfig;
import org.neo4j.gds.ml.linkmodels.LinkPredictionTrainFactory;
import org.neo4j.gds.ml.linkmodels.logisticregression.LinkLogisticRegressionData;
import org.neo4j.gds.results.MemoryEstimateResult;
import org.neo4j.gds.utils.StringFormatting;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

@GdsCallable(name="gds.alpha.ml.linkPrediction.train", description="Trains a link prediction model", executionMode=ExecutionMode.TRAIN)
public class LinkPredictionTrainProc
extends TrainProc<LinkPredictionTrain, Model<LinkLogisticRegressionData, LinkPredictionTrainConfig, LinkPredictionModelInfo>, LinkPredictionTrainConfig, MLTrainResult> {
    @Procedure(name="gds.alpha.ml.linkPrediction.train", mode=Mode.READ)
    @Description(value="Trains a link prediction model")
    public Stream<MLTrainResult> train(@Name(value="graphName") String graphName, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        return this.trainAndStoreModelWithResult(this.compute(graphName, configuration));
    }

    @Procedure(name="gds.alpha.ml.linkPrediction.train.estimate", mode=Mode.READ)
    @Description(value="Estimates memory for training a link prediction model")
    public Stream<MemoryEstimateResult> estimate(@Name(value="graphNameOrConfiguration") Object graphNameOrConfiguration, @Name(value="algoConfiguration") Map<String, Object> algoConfiguration) {
        return this.computeEstimate(graphNameOrConfiguration, algoConfiguration);
    }

    protected LinkPredictionTrainConfig newConfig(String username, CypherMapWrapper config) {
        LinkPredictionTrainConfig lpConfig = LinkPredictionTrainConfig.of((String)username, (CypherMapWrapper)config);
        RelationshipType trainType = lpConfig.trainRelationshipType();
        RelationshipType testType = lpConfig.testRelationshipType();
        return ImmutableLinkPredictionTrainConfig.builder().from(lpConfig).relationshipTypes(List.of(trainType.name, testType.name)).relationshipWeightProperty("label").build();
    }

    protected String modelType() {
        return "Link Prediction";
    }

    protected MLTrainResult constructProcResult(ComputationResult<LinkPredictionTrain, Model<LinkLogisticRegressionData, LinkPredictionTrainConfig, LinkPredictionModelInfo>, LinkPredictionTrainConfig> computationResult) {
        return new MLTrainResult((Model)computationResult.result(), computationResult.computeMillis());
    }

    protected Model<?, ?, ?> extractModel(Model<LinkLogisticRegressionData, LinkPredictionTrainConfig, LinkPredictionModelInfo> algoResult) {
        return algoResult;
    }

    public ValidationConfiguration<LinkPredictionTrainConfig> validationConfig() {
        return new ValidationConfiguration<LinkPredictionTrainConfig>(){

            public List<BeforeLoadValidation<LinkPredictionTrainConfig>> beforeLoadValidations() {
                return List.of(new GraphProjectConfigValidations.UndirectedGraphValidation(), (graphProjectConfig, config) -> {
                    if (config.params().isEmpty()) {
                        throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"No model candidates (params) specified, we require at least one", (Object[])new Object[0]));
                    }
                });
            }
        };
    }

    public GraphAlgorithmFactory<LinkPredictionTrain, LinkPredictionTrainConfig> algorithmFactory() {
        return new LinkPredictionTrainFactory();
    }
}

