/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.nodemodels.logisticregression;

import java.util.Collection;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import org.neo4j.gds.ml.MLTrainResult;
import org.neo4j.gds.ml.nodemodels.NodeClassificationTrain;
import org.neo4j.gds.ml.nodemodels.logisticregression.NodeClassificationTrainConfig;
import org.neo4j.gds.ml.nodemodels.multiclasslogisticregression.MultiClassNLRData;
import org.neo4j.graphalgo.AlgoBaseProc;
import org.neo4j.graphalgo.AlgorithmFactory;
import org.neo4j.graphalgo.TrainProc;
import org.neo4j.graphalgo.api.Graph;
import org.neo4j.graphalgo.api.GraphStore;
import org.neo4j.graphalgo.api.GraphStoreValidation;
import org.neo4j.graphalgo.config.AlgoBaseConfig;
import org.neo4j.graphalgo.config.GraphCreateConfig;
import org.neo4j.graphalgo.core.CypherMapWrapper;
import org.neo4j.graphalgo.core.loading.GraphStoreWithConfig;
import org.neo4j.graphalgo.core.model.Model;
import org.neo4j.graphalgo.core.model.ModelCatalog;
import org.neo4j.graphalgo.core.utils.mem.AllocationTracker;
import org.neo4j.graphalgo.core.utils.mem.MemoryEstimation;
import org.neo4j.graphalgo.core.utils.progress.ProgressEventTracker;
import org.neo4j.graphalgo.exceptions.MemoryEstimationNotImplementedException;
import org.neo4j.graphalgo.utils.StringFormatting;
import org.neo4j.graphalgo.utils.StringJoining;
import org.neo4j.logging.Log;
import org.neo4j.procedure.Description;
import org.neo4j.procedure.Mode;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

public class NodeClassificationTrainProc
extends TrainProc<NodeClassificationTrain, MultiClassNLRData, NodeClassificationTrainConfig> {
    @Procedure(name="gds.alpha.ml.nodeClassification.train", mode=Mode.READ)
    @Description(value="Trains a node classification model")
    public Stream<MLTrainResult> train(@Name(value="graphName") Object graphNameOrConfig, @Name(value="configuration", defaultValue="{}") Map<String, Object> configuration) {
        AlgoBaseProc.ComputationResult result = this.compute(graphNameOrConfig, configuration);
        ModelCatalog.set((Model)((Model)result.result()));
        return Stream.of(new MLTrainResult((Model)result.result(), result.computeMillis()));
    }

    protected void validateConfigsAndGraphStore(GraphStoreWithConfig graphStoreWithConfig, NodeClassificationTrainConfig config) {
        Collection filterLabels;
        GraphStore graphStore = graphStoreWithConfig.graphStore();
        if (!graphStore.hasNodeProperty(filterLabels = config.nodeLabelIdentifiers(graphStore), config.targetProperty())) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"`%s`: `%s` not found in graph with node properties: %s", (Object[])new Object[]{"targetProperty", config.targetProperty(), StringJoining.join((Collection)graphStore.nodePropertyKeys(filterLabels))}));
        }
        GraphStoreValidation.validate((GraphStoreWithConfig)graphStoreWithConfig, (AlgoBaseConfig)config);
        if (config.params().isEmpty()) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"No model candidates (params) specified, we require at least one", (Object[])new Object[0]));
        }
        if (config.metrics().isEmpty()) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"No metrics specified, we require at least one", (Object[])new Object[0]));
        }
    }

    protected NodeClassificationTrainConfig newConfig(String username, Optional<String> graphName, Optional<GraphCreateConfig> maybeImplicitCreate, CypherMapWrapper config) {
        return NodeClassificationTrainConfig.of(graphName, maybeImplicitCreate, (String)username, (CypherMapWrapper)config);
    }

    protected AlgorithmFactory<NodeClassificationTrain, NodeClassificationTrainConfig> algorithmFactory() {
        return new AlgorithmFactory<NodeClassificationTrain, NodeClassificationTrainConfig>(){

            public NodeClassificationTrain build(Graph graph, NodeClassificationTrainConfig configuration, AllocationTracker tracker, Log log, ProgressEventTracker eventTracker) {
                return new NodeClassificationTrain(graph, configuration, log);
            }

            public MemoryEstimation memoryEstimation(NodeClassificationTrainConfig configuration) {
                throw new MemoryEstimationNotImplementedException();
            }
        };
    }
}

