/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.metrics.classification;

import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.eclipse.collections.api.block.function.primitive.LongIntToObjectFunction;
import org.intellij.lang.annotations.RegExp;
import org.neo4j.gds.collections.LongMultiSet;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.core.subgraph.LocalIdMap;
import org.neo4j.gds.ml.metrics.Metric;
import org.neo4j.gds.ml.metrics.classification.Accuracy;
import org.neo4j.gds.ml.metrics.classification.ClassificationMetric;
import org.neo4j.gds.ml.metrics.classification.F1Macro;
import org.neo4j.gds.ml.metrics.classification.F1Score;
import org.neo4j.gds.ml.metrics.classification.F1Weighted;
import org.neo4j.gds.ml.metrics.classification.GlobalAccuracy;
import org.neo4j.gds.ml.metrics.classification.OutOfBagError;
import org.neo4j.gds.ml.metrics.classification.Precision;
import org.neo4j.gds.ml.metrics.classification.Recall;
import org.neo4j.gds.utils.StringFormatting;

public final class ClassificationMetricSpecification {
    private final String stringRepresentation;
    private final BiFunction<LocalIdMap, LongMultiSet, Stream<Metric>> metricFactory;

    private ClassificationMetricSpecification(String stringRepresentation, BiFunction<LocalIdMap, LongMultiSet, Stream<Metric>> metricFactory) {
        this.stringRepresentation = stringRepresentation;
        this.metricFactory = metricFactory;
    }

    private static ClassificationMetricSpecification createSpecification(BiFunction<LocalIdMap, LongMultiSet, Stream<Metric>> metricFactory, String stringRepresentation) {
        return new ClassificationMetricSpecification(stringRepresentation, metricFactory);
    }

    public Stream<Metric> createMetrics(LocalIdMap classIdMap, LongMultiSet classCounts) {
        return this.metricFactory.apply(classIdMap, classCounts);
    }

    public String toString() {
        return this.stringRepresentation;
    }

    public boolean equals(Object obj) {
        if (!(obj instanceof ClassificationMetricSpecification)) {
            return false;
        }
        return this.toString().equals(obj.toString());
    }

    public int hashCode() {
        return this.toString().hashCode();
    }

    public static MemoryEstimation memoryEstimation(int numberOfClasses) {
        return MemoryEstimations.builder().rangePerNode("metrics", __ -> {
            long sizeOfRepresentativeMetric = MemoryUsage.sizeOf((Object)new F1Score(1L, 1));
            return MemoryRange.of((long)sizeOfRepresentativeMetric, (long)((long)numberOfClasses * sizeOfRepresentativeMetric));
        }).build();
    }

    public static List<String> specificationsToString(List<ClassificationMetricSpecification> specifications) {
        return specifications.stream().map(ClassificationMetricSpecification::toString).collect(Collectors.toList());
    }

    public static final class Parser {
        private static final List<String> MODEL_SPECIFIC_METRICS = List.of(OutOfBagError.OUT_OF_BAG_ERROR.name());
        private static final Map<String, LongIntToObjectFunction<ClassificationMetric>> SINGLE_CLASS_METRIC_FACTORIES = Map.of("F1", F1Score::new, "PRECISION", Precision::new, "RECALL", Recall::new, "ACCURACY", Accuracy::new);
        private static final Map<String, BiFunction<LocalIdMap, LongMultiSet, ClassificationMetric>> ALL_CLASS_METRIC_FACTORIES = Map.of("F1_WEIGHTED", F1Weighted::new, "F1_MACRO", (classIdMap, ignore) -> new F1Macro((LocalIdMap)classIdMap), "ACCURACY", (ignored1, ignored2) -> new GlobalAccuracy());
        @RegExp
        private static final String NUMBER_OR_STAR = "(-?[\\d]+|\\*)";
        @RegExp
        private static final String CLASS_NAME_PATTERN = "(.+)";
        private static final Pattern SINGLE_CLASS_METRIC_PATTERN = Pattern.compile("(.+)\\(\\s*CLASS\\s*=\\s*(-?[\\d]+|\\*)\\s*\\)");

        private Parser() {
        }

        public static Iterable<String> singleClassMetrics() {
            return SINGLE_CLASS_METRIC_FACTORIES.keySet();
        }

        public static Iterable<String> allClassMetrics() {
            return ALL_CLASS_METRIC_FACTORIES.keySet();
        }

        public static List<ClassificationMetricSpecification> parse(List<?> userSpecifications) {
            List<String> badSpecifications;
            if (userSpecifications.isEmpty()) {
                throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"No metrics specified, we require at least one", (Object[])new Object[0]));
            }
            if (userSpecifications.get(0) instanceof ClassificationMetricSpecification) {
                return userSpecifications;
            }
            List<ClassificationMetricSpecification> stringInput = userSpecifications;
            String mainMetric = ((String)((Object)stringInput.get(0))).toUpperCase(Locale.ENGLISH);
            ArrayList<String> errors = new ArrayList<String>();
            if (mainMetric.contains("*")) {
                errors.add(StringFormatting.formatWithLocale((String)"The primary (first) metric provided must be one of %s.", (Object[])new Object[]{String.join((CharSequence)", ", Parser.validPrimaryMetricExpressions())}));
            }
            if (!(badSpecifications = stringInput.stream().filter(Parser::invalidSpecification).collect(Collectors.toList())).isEmpty()) {
                errors.add(Parser.errorMessage(badSpecifications));
            }
            if (!errors.isEmpty()) {
                throw new IllegalArgumentException(String.join((CharSequence)" ", errors));
            }
            return userSpecifications.stream().map(Parser::parse).distinct().collect(Collectors.toList());
        }

        public static ClassificationMetricSpecification parse(Object userSpecification) {
            if (userSpecification instanceof ClassificationMetricSpecification) {
                return (ClassificationMetricSpecification)userSpecification;
            }
            if (userSpecification instanceof String) {
                String input = (String)userSpecification;
                String upperCaseSpecification = StringFormatting.toUpperCaseWithLocale((String)input);
                if (upperCaseSpecification.equals(OutOfBagError.OUT_OF_BAG_ERROR.name())) {
                    return ClassificationMetricSpecification.createSpecification((ignored, ignored2) -> Stream.of(OutOfBagError.OUT_OF_BAG_ERROR), upperCaseSpecification);
                }
                Matcher matcher = SINGLE_CLASS_METRIC_PATTERN.matcher(upperCaseSpecification);
                if (!matcher.matches()) {
                    BiFunction<LocalIdMap, LongMultiSet, ClassificationMetric> allClassMetricGenerator = ALL_CLASS_METRIC_FACTORIES.get(upperCaseSpecification);
                    if (allClassMetricGenerator == null) {
                        throw new IllegalArgumentException(Parser.errorMessage(List.of(input)));
                    }
                    return ClassificationMetricSpecification.createSpecification((classIdMap, classCounts) -> Stream.of((Metric)allClassMetricGenerator.apply((LocalIdMap)classIdMap, (LongMultiSet)classCounts)), upperCaseSpecification);
                }
                String metricType = matcher.group(1);
                String classId = matcher.group(2);
                LongIntToObjectFunction<ClassificationMetric> metricGenerator = SINGLE_CLASS_METRIC_FACTORIES.get(metricType);
                if (metricGenerator == null) {
                    throw new IllegalArgumentException(Parser.errorMessage(List.of(input)));
                }
                Function<LocalIdMap, Stream> metricsFactory = classId.equals("*") ? classIdMap -> classIdMap.getMappings().map(idMap -> (Metric)metricGenerator.value(idMap.key, idMap.value)) : classIdMap -> Stream.of((Metric)metricGenerator.value(Long.parseLong(classId), classIdMap.toMapped(Long.parseLong(classId))));
                return ClassificationMetricSpecification.createSpecification((classIdMap, ignored) -> (Stream)metricsFactory.apply((LocalIdMap)classIdMap), StringFormatting.formatWithLocale((String)"%s(class=%s)", (Object[])new Object[]{metricType, classId}));
            }
            throw new IllegalArgumentException(StringFormatting.formatWithLocale((String)"Expected MetricSpecification or String. Got %s.", (Object[])new Object[]{userSpecification.getClass().getSimpleName()}));
        }

        private static List<String> allValidMetricExpressions() {
            return Parser.validMetricExpressions(true);
        }

        private static List<String> validPrimaryMetricExpressions() {
            return Parser.validMetricExpressions(false);
        }

        private static List<String> validMetricExpressions(boolean includeSyntacticSugarMetrics) {
            LinkedList<String> validExpressions = new LinkedList<String>(MODEL_SPECIFIC_METRICS);
            Set<String> allClassExpressions = ALL_CLASS_METRIC_FACTORIES.keySet();
            validExpressions.addAll(allClassExpressions);
            for (String singleClassMetric : Parser.singleClassMetrics()) {
                if (includeSyntacticSugarMetrics) {
                    validExpressions.add(singleClassMetric + "(class=*)");
                }
                validExpressions.add(singleClassMetric + "(class=<class value>)");
            }
            return validExpressions;
        }

        private static String errorMessage(List<String> specifications) {
            return StringFormatting.formatWithLocale((String)"Invalid metric expression%s %s. Available metrics are %s (case insensitive and space allowed between brackets).", (Object[])new Object[]{specifications.size() == 1 ? "" : "s", specifications.stream().map(s -> "`" + s + "`").collect(Collectors.joining(", ")), String.join((CharSequence)", ", Parser.allValidMetricExpressions())});
        }

        private static boolean invalidSpecification(String userSpecification) {
            String upperCaseSpecification = userSpecification.toUpperCase(Locale.ENGLISH);
            if (MODEL_SPECIFIC_METRICS.contains(upperCaseSpecification)) {
                return false;
            }
            Matcher matcher = SINGLE_CLASS_METRIC_PATTERN.matcher(upperCaseSpecification);
            if (matcher.matches()) {
                return false;
            }
            return !ALL_CLASS_METRIC_FACTORIES.containsKey(upperCaseSpecification);
        }
    }
}

