/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.nodemodels.pipeline;

import java.util.List;
import java.util.Map;
import org.neo4j.gds.core.model.Model;
import org.neo4j.gds.core.model.ModelCatalog;
import org.neo4j.gds.ml.nodemodels.logisticregression.NodeLogisticRegressionData;
import org.neo4j.gds.ml.nodemodels.logisticregression.NodeLogisticRegressionTrainCoreConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeClassificationPipeline;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeClassificationPipelineModelInfo;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeClassificationPipelineTrainConfig;
import org.neo4j.gds.utils.StringFormatting;

public final class NodeClassificationPipelineCompanion {
    public static final String PREDICT_DESCRIPTION = "Predicts classes for all nodes based on a previously trained pipeline model";
    public static final String ESTIMATE_PREDICT_DESCRIPTION = "Estimates memory for predicting classes for all nodes based on a previously trained pipeline model";
    public static final String PIPELINE_MODEL_TYPE = "Node classification training pipeline";
    static final Map<String, Object> DEFAULT_SPLIT_CONFIG = Map.of("testFraction", 0.3, "validationFolds", 3);
    static final List<Map<String, Object>> DEFAULT_PARAM_CONFIG = List.of(NodeLogisticRegressionTrainCoreConfig.defaultConfig().toMap());

    private NodeClassificationPipelineCompanion() {
    }

    public static NodeClassificationPipeline getNCPipeline(ModelCatalog modelCatalog, String pipelineName, String username) {
        Model model = modelCatalog.getUntypedOrThrow(username, pipelineName);
        assert (model != null);
        if (!model.algoType().equals(PIPELINE_MODEL_TYPE)) {
            throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"Expected a model of type `%s`. But model `%s` is of type `%s`.", (Object[])new Object[]{PIPELINE_MODEL_TYPE, pipelineName, model.algoType()}));
        }
        assert (model.customInfo() instanceof NodeClassificationPipeline);
        return (NodeClassificationPipeline)model.customInfo();
    }

    public static Model<NodeLogisticRegressionData, NodeClassificationPipelineTrainConfig, NodeClassificationPipelineModelInfo> getTrainedNCPipelineModel(ModelCatalog modelCatalog, String modelName, String username) {
        return modelCatalog.get(username, modelName, NodeLogisticRegressionData.class, NodeClassificationPipelineTrainConfig.class, NodeClassificationPipelineModelInfo.class);
    }
}

