/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.metrics;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.immutables.value.Value;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.config.ToMapConvertible;
import org.neo4j.gds.ml.metrics.EvaluationScores;
import org.neo4j.gds.ml.metrics.ImmutableModelCandidateStats;
import org.neo4j.gds.ml.metrics.Metric;
import org.neo4j.gds.ml.models.TrainerConfig;

@ValueClass
public interface ModelCandidateStats
extends ToMapConvertible {
    public TrainerConfig trainerConfig();

    public Map<Metric, EvaluationScores> trainingStats();

    public Map<Metric, EvaluationScores> validationStats();

    @Value.Auxiliary
    @Value.Derived
    default public Map<String, Object> toMap() {
        return Map.of("parameters", this.trainerConfig().toMapWithTrainerMethod(), "metrics", this.renderMetrics());
    }

    @Value.Derived
    default public Map<String, Map<String, Object>> renderMetrics(Map<Metric, Double> testMetrics, Map<Metric, Double> outerTrainMetrics) {
        return this.renderMetrics(Optional.of(testMetrics), Optional.of(outerTrainMetrics));
    }

    private Map<String, Map<String, Object>> renderMetrics() {
        return this.renderMetrics(Optional.empty(), Optional.empty());
    }

    private Map<String, Map<String, Object>> renderMetrics(Optional<Map<Metric, Double>> testMetrics, Optional<Map<Metric, Double>> outerTrainMetrics) {
        return this.metrics().stream().collect(Collectors.toMap(Object::toString, metric -> {
            HashMap<String, Map<String, Object>> result = new HashMap<String, Map<String, Object>>();
            if (this.trainingStats().containsKey(metric)) {
                result.put("train", this.trainingStats().get(metric).toMap());
            }
            if (this.validationStats().containsKey(metric)) {
                result.put("validation", this.validationStats().get(metric).toMap());
            }
            testMetrics.ifPresent(test -> {
                if (test.containsKey(metric)) {
                    result.put("test", (Map<String, Object>)test.get(metric));
                }
            });
            outerTrainMetrics.ifPresent(outerTrain -> {
                if (outerTrain.containsKey(metric)) {
                    result.put("outerTrain", (Map<String, Object>)outerTrain.get(metric));
                }
            });
            return result;
        }));
    }

    private List<Metric> metrics() {
        return Stream.concat(this.trainingStats().keySet().stream(), this.validationStats().keySet().stream()).distinct().collect(Collectors.toList());
    }

    public static ModelCandidateStats of(TrainerConfig trainerConfig, Map<Metric, EvaluationScores> trainStats, Map<Metric, EvaluationScores> validationStats) {
        return ImmutableModelCandidateStats.of(trainerConfig, trainStats, validationStats);
    }
}

