/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.training;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.jetbrains.annotations.TestOnly;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.metrics.EvaluationScores;
import org.neo4j.gds.ml.metrics.ImmutableEvaluationScores;
import org.neo4j.gds.ml.metrics.Metric;
import org.neo4j.gds.ml.metrics.ModelCandidateStats;
import org.neo4j.gds.ml.models.TrainerConfig;

public final class TrainingStatistics {
    private final List<ModelCandidateStats> modelCandidateStats = new ArrayList<ModelCandidateStats>();
    private final List<? extends Metric> metrics;
    private final Map<Metric, Double> testScores;
    private final Map<Metric, Double> outerTrainScores;

    public TrainingStatistics(List<? extends Metric> metrics) {
        this.metrics = metrics;
        this.testScores = new HashMap<Metric, Double>();
        this.outerTrainScores = new HashMap<Metric, Double>();
    }

    @TestOnly
    public List<EvaluationScores> getTrainStats(Metric metric) {
        return this.modelCandidateStats.stream().map(stats -> stats.trainingStats().get(metric)).collect(Collectors.toList());
    }

    @TestOnly
    public List<EvaluationScores> getValidationStats(Metric metric) {
        return this.modelCandidateStats.stream().map(stats -> stats.validationStats().get(metric)).collect(Collectors.toList());
    }

    @TestOnly
    public Double getTestScore(Metric metric) {
        return this.testScores.get(metric);
    }

    public Map<String, Object> toMap() {
        return Map.of("bestParameters", this.bestParameters().toMapWithTrainerMethod(), "bestTrial", this.getBestTrialIdx() + 1, "modelCandidates", this.modelCandidateStats.stream().map(ModelCandidateStats::toMap).collect(Collectors.toList()));
    }

    public double getMainMetric(int trial) {
        return this.modelCandidateStats.get(trial).validationStats().get(this.evaluationMetric()).avg();
    }

    public Map<Metric, Double> validationMetricsAvg(int trial) {
        return this.extractAverage(this.modelCandidateStats.get(trial).validationStats());
    }

    public Map<Metric, Double> trainMetricsAvg(int trial) {
        return this.extractAverage(this.modelCandidateStats.get(trial).trainingStats());
    }

    private Map<Metric, Double> extractAverage(Map<Metric, EvaluationScores> statsMap) {
        return statsMap.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> ((EvaluationScores)entry.getValue()).avg()));
    }

    public Metric evaluationMetric() {
        return this.metrics.get(0);
    }

    public void addCandidateStats(ModelCandidateStats statistics) {
        this.modelCandidateStats.add(statistics);
    }

    public void addTestScore(Metric metric, double score) {
        this.testScores.put(metric, score);
    }

    public void addOuterTrainScore(Metric metric, double score) {
        this.outerTrainScores.put(metric, score);
    }

    public Map<Metric, Double> winningModelTestMetrics() {
        return this.testScores;
    }

    public Map<Metric, Double> winningModelOuterTrainMetrics() {
        return this.outerTrainScores;
    }

    public int getBestTrialIdx() {
        return this.modelCandidateStats.stream().map(stats -> stats.validationStats().get(this.evaluationMetric()).avg()).collect(Collectors.toList()).indexOf(this.getBestTrialScore());
    }

    public ModelCandidateStats bestCandidate() {
        return this.modelCandidateStats.get(this.getBestTrialIdx());
    }

    public double getBestTrialScore() {
        return this.modelCandidateStats.stream().map(stats -> stats.validationStats().get(this.evaluationMetric()).avg()).max(this.evaluationMetric().comparator()).orElseThrow(() -> new IllegalStateException("Empty validation stats."));
    }

    public TrainerConfig bestParameters() {
        return this.bestCandidate().trainerConfig();
    }

    public static MemoryEstimation memoryEstimationStatsMap(int numberOfMetricsSpecifications, int numberOfModelCandidates) {
        int fudgedNumberOfClasses = 1000;
        return TrainingStatistics.memoryEstimationStatsMap(numberOfMetricsSpecifications, numberOfModelCandidates, fudgedNumberOfClasses);
    }

    public static MemoryEstimation memoryEstimationStatsMap(int numberOfMetricsSpecifications, int numberOfModelCandidates, int numberOfClasses) {
        int numberOfMetrics = numberOfMetricsSpecifications * numberOfClasses;
        int numberOfModelStats = numberOfMetrics * numberOfModelCandidates;
        long sizeOfOneModelStatsInBytes = MemoryUsage.sizeOfInstance(ImmutableEvaluationScores.class);
        long sizeOfAllModelStatsInBytes = sizeOfOneModelStatsInBytes * (long)numberOfModelStats;
        return MemoryEstimations.builder((String)"StatsMap").fixed("array list", MemoryUsage.sizeOfInstance(ArrayList.class)).fixed("model stats", sizeOfAllModelStatsInBytes).build();
    }
}

