/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.training;

import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.SortedSet;
import org.eclipse.collections.api.block.function.primitive.LongToLongFunction;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.LogLevel;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.core.utils.progress.tasks.Task;
import org.neo4j.gds.core.utils.progress.tasks.Tasks;
import org.neo4j.gds.ml.metrics.Metric;
import org.neo4j.gds.ml.metrics.MetricConsumer;
import org.neo4j.gds.ml.metrics.ModelCandidateStats;
import org.neo4j.gds.ml.metrics.ModelSpecificMetricsHandler;
import org.neo4j.gds.ml.metrics.ModelStatsBuilder;
import org.neo4j.gds.ml.models.TrainerConfig;
import org.neo4j.gds.ml.splitting.StratifiedKFoldSplitter;
import org.neo4j.gds.ml.splitting.TrainingExamplesSplit;
import org.neo4j.gds.ml.training.TrainingStatistics;
import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.gds.utils.StringFormatting;

public class CrossValidation<MODEL_TYPE> {
    private final ProgressTracker progressTracker;
    private final TerminationFlag terminationFlag;
    private final List<? extends Metric> metrics;
    private final int validationFolds;
    private final Optional<Long> randomSeed;
    private final ModelTrainer<MODEL_TYPE> modelTrainer;
    private final ModelEvaluator<MODEL_TYPE> modelEvaluator;

    public static List<Task> progressTasks(int validationFolds, int numberOfModelSelectionTrials, long trainSetSize) {
        return List.of(Tasks.leaf((String)"Create validation folds", (long)Math.max((long)(0.5 * (double)trainSetSize), 1L)), Tasks.iterativeFixed((String)"Select best model", () -> List.of(Tasks.leaf((String)"Trial", (long)(5L * (long)validationFolds * trainSetSize))), (int)numberOfModelSelectionTrials));
    }

    public CrossValidation(ProgressTracker progressTracker, TerminationFlag terminationFlag, List<? extends Metric> metrics, int validationFolds, Optional<Long> randomSeed, ModelTrainer<MODEL_TYPE> modelTrainer, ModelEvaluator<MODEL_TYPE> modelEvaluator) {
        this.progressTracker = progressTracker;
        this.terminationFlag = terminationFlag;
        this.metrics = metrics;
        this.validationFolds = validationFolds;
        this.randomSeed = randomSeed;
        this.modelTrainer = modelTrainer;
        this.modelEvaluator = modelEvaluator;
    }

    public void selectModel(ReadOnlyHugeLongArray outerTrainSet, LongToLongFunction targets, SortedSet<Long> distinctInternalTargets, TrainingStatistics trainingStatistics, Iterator<TrainerConfig> modelCandidates) {
        this.progressTracker.beginSubTask("Create validation folds");
        List<TrainingExamplesSplit> validationSplits = new StratifiedKFoldSplitter(this.validationFolds, outerTrainSet, targets, this.randomSeed, distinctInternalTargets).splits();
        this.progressTracker.endSubTask("Create validation folds");
        this.progressTracker.beginSubTask("Select best model");
        int trial = 0;
        while (modelCandidates.hasNext()) {
            this.progressTracker.beginSubTask("Trial");
            this.progressTracker.setSteps((long)validationSplits.size());
            this.terminationFlag.assertRunning();
            TrainerConfig modelParams = modelCandidates.next();
            this.progressTracker.logInfo(StringFormatting.formatWithLocale((String)"Method: %s, Parameters: %s", (Object[])new Object[]{modelParams.method(), modelParams.toMap()}));
            ModelStatsBuilder validationStatsBuilder = new ModelStatsBuilder(validationSplits.size());
            ModelStatsBuilder trainStatsBuilder = new ModelStatsBuilder(validationSplits.size());
            ModelSpecificMetricsHandler metricsHandler = ModelSpecificMetricsHandler.of(this.metrics, validationStatsBuilder);
            int fold = 1;
            for (TrainingExamplesSplit split : validationSplits) {
                ReadOnlyHugeLongArray trainSet = split.trainSet();
                ReadOnlyHugeLongArray validationSet = split.testSet();
                this.progressTracker.logDebug("Starting fold " + fold + " training");
                MODEL_TYPE trainedModel = this.modelTrainer.train(trainSet, modelParams, metricsHandler, LogLevel.DEBUG);
                this.progressTracker.logDebug("Finished fold " + fold + " training");
                this.modelEvaluator.evaluate(validationSet, trainedModel, validationStatsBuilder::update);
                this.modelEvaluator.evaluate(trainSet, trainedModel, trainStatsBuilder::update);
                this.progressTracker.logSteps(1L);
                ++fold;
            }
            ModelCandidateStats candidateStats = ModelCandidateStats.of(modelParams, trainStatsBuilder.build(), validationStatsBuilder.build());
            trainingStatistics.addCandidateStats(candidateStats);
            Map<Metric, Double> validationStats = trainingStatistics.validationMetricsAvg(trial);
            Map<Metric, Double> trainStats = trainingStatistics.trainMetricsAvg(trial);
            double mainMetric = trainingStatistics.getMainMetric(trial);
            this.progressTracker.logInfo(StringFormatting.formatWithLocale((String)"Main validation metric (%s): %.4f", (Object[])new Object[]{trainingStatistics.evaluationMetric(), mainMetric}));
            this.progressTracker.logInfo(StringFormatting.formatWithLocale((String)"Validation metrics: %s", (Object[])new Object[]{validationStats}));
            this.progressTracker.logInfo(StringFormatting.formatWithLocale((String)"Training metrics: %s", (Object[])new Object[]{trainStats}));
            ++trial;
            this.progressTracker.endSubTask("Trial");
        }
        int bestTrial = trainingStatistics.getBestTrialIdx() + 1;
        double bestTrialScore = trainingStatistics.getBestTrialScore();
        this.progressTracker.logInfo(StringFormatting.formatWithLocale((String)"Best trial was Trial %d with main validation metric %.4f", (Object[])new Object[]{bestTrial, bestTrialScore}));
        this.progressTracker.endSubTask("Select best model");
    }

    @FunctionalInterface
    public static interface ModelTrainer<MODEL_TYPE> {
        public MODEL_TYPE train(ReadOnlyHugeLongArray var1, TrainerConfig var2, ModelSpecificMetricsHandler var3, LogLevel var4);
    }

    @FunctionalInterface
    public static interface ModelEvaluator<MODEL_TYPE> {
        public void evaluate(ReadOnlyHugeLongArray var1, MODEL_TYPE var2, MetricConsumer var3);
    }
}

