/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.pipeline.nodePipeline.classification.train;

import java.util.List;
import java.util.stream.Collectors;
import org.neo4j.gds.annotation.Configuration;
import org.neo4j.gds.collections.LongMultiSet;
import org.neo4j.gds.core.CypherMapAccess;
import org.neo4j.gds.core.CypherMapWrapper;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.metrics.Metric;
import org.neo4j.gds.ml.metrics.classification.ClassificationMetric;
import org.neo4j.gds.ml.metrics.classification.ClassificationMetricSpecification;
import org.neo4j.gds.ml.pipeline.nodePipeline.NodePropertyPipelineBaseTrainConfig;
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfigImpl;

@Configuration
public interface NodeClassificationPipelineTrainConfig
extends NodePropertyPipelineBaseTrainConfig {
    public static final long serialVersionUID = 66L;

    @Configuration.ConvertWith(method="org.neo4j.gds.ml.metrics.classification.ClassificationMetricSpecification.Parser#parse")
    @Configuration.ToMapValue(value="org.neo4j.gds.ml.metrics.classification.ClassificationMetricSpecification#specificationsToString")
    public List<ClassificationMetricSpecification> metrics();

    @Configuration.Ignore
    default public List<Metric> metrics(LocalIdMap classIdMap, LongMultiSet classCounts) {
        return this.metrics().stream().flatMap(spec -> spec.createMetrics(classIdMap, classCounts)).collect(Collectors.toList());
    }

    public static List<ClassificationMetric> classificationMetrics(List<Metric> metrics) {
        return metrics.stream().filter(metric -> !metric.isModelSpecific()).map(metric -> (ClassificationMetric)metric).collect(Collectors.toList());
    }

    public static NodeClassificationPipelineTrainConfig of(String username, CypherMapWrapper config) {
        return new NodeClassificationPipelineTrainConfigImpl(username, (CypherMapAccess)config);
    }
}

