/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.models.randomforest;

import com.carrotsearch.hppc.BitSet;
import java.util.List;
import java.util.Optional;
import java.util.SplittableRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.LongUnaryOperator;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.neo4j.gds.collections.ha.HugeDoubleArray;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.LogLevel;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.decisiontree.DecisionTreePredictor;
import org.neo4j.gds.ml.decisiontree.DecisionTreeRegressorTrainer;
import org.neo4j.gds.ml.decisiontree.DecisionTreeTrainerConfig;
import org.neo4j.gds.ml.decisiontree.DecisionTreeTrainerConfigImpl;
import org.neo4j.gds.ml.decisiontree.FeatureBagger;
import org.neo4j.gds.ml.decisiontree.ImpurityCriterion;
import org.neo4j.gds.ml.decisiontree.SplitMeanSquaredError;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.RegressorTrainer;
import org.neo4j.gds.ml.models.randomforest.DatasetBootstrapper;
import org.neo4j.gds.ml.models.randomforest.RandomForestRegressor;
import org.neo4j.gds.ml.models.randomforest.RandomForestRegressorData;
import org.neo4j.gds.ml.models.randomforest.RandomForestRegressorTrainerConfig;
import org.neo4j.gds.ml.models.randomforest.RandomForestTrainerConfig;
import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.gds.utils.StringFormatting;

public class RandomForestRegressorTrainer
implements RegressorTrainer {
    private final RandomForestRegressorTrainerConfig config;
    private final int concurrency;
    private final SplittableRandom random;
    private final TerminationFlag terminationFlag;
    private final ProgressTracker progressTracker;
    private final LogLevel messageLogLevel;

    public RandomForestRegressorTrainer(int concurrency, RandomForestRegressorTrainerConfig config, Optional<Long> randomSeed, TerminationFlag terminationFlag, ProgressTracker progressTracker, LogLevel messageLogLevel) {
        this.config = config;
        this.concurrency = concurrency;
        this.random = new SplittableRandom(randomSeed.orElseGet(() -> new SplittableRandom().nextLong()));
        this.terminationFlag = terminationFlag;
        this.progressTracker = progressTracker;
        this.messageLogLevel = messageLogLevel;
    }

    public static MemoryEstimation memoryEstimation(LongUnaryOperator numberOfTrainingSamples, MemoryRange featureDimension, RandomForestRegressorTrainerConfig config) {
        int minNumberOfBaggedFeatures = (int)Math.ceil(config.maxFeaturesRatio((int)featureDimension.min) * (double)featureDimension.min);
        int maxNumberOfBaggedFeatures = (int)Math.ceil(config.maxFeaturesRatio((int)featureDimension.max) * (double)featureDimension.max);
        return MemoryEstimations.builder((String)"Training").add(RandomForestRegressorData.memoryEstimation(numberOfTrainingSamples, config)).rangePerNode("Mean Squared Error Loss", nodeCount -> SplitMeanSquaredError.memoryEstimation()).perGraphDimension("Decision tree training", (dim, concurrency) -> TrainDecisionTreeTask.memoryEstimation(config, numberOfTrainingSamples.applyAsLong(dim.nodeCount()), minNumberOfBaggedFeatures, config.numberOfSamplesRatio()).union(TrainDecisionTreeTask.memoryEstimation(config, numberOfTrainingSamples.applyAsLong(dim.nodeCount()), maxNumberOfBaggedFeatures, config.numberOfSamplesRatio())).times((long)concurrency.intValue())).build();
    }

    @Override
    public RandomForestRegressor train(Features allFeatureVectors, HugeDoubleArray targets, ReadOnlyHugeLongArray trainSet) {
        DecisionTreeTrainerConfig decisionTreeTrainConfig = DecisionTreeTrainerConfigImpl.builder().maxDepth(this.config.maxDepth()).minSplitSize(this.config.minSplitSize()).build();
        int numberOfDecisionTrees = this.config.numberOfDecisionTrees();
        SplitMeanSquaredError impurityCriterion = new SplitMeanSquaredError(targets);
        AtomicInteger numberOfTreesTrained = new AtomicInteger(0);
        List tasks = IntStream.range(0, numberOfDecisionTrees).mapToObj(unused -> new TrainDecisionTreeTask(decisionTreeTrainConfig, this.config, this.random.split(), allFeatureVectors, targets, impurityCriterion, trainSet, this.progressTracker, this.messageLogLevel, numberOfTreesTrained)).collect(Collectors.toList());
        RunWithConcurrency.builder().concurrency(this.concurrency).tasks(tasks).terminationFlag(this.terminationFlag).run();
        List<DecisionTreePredictor<Double>> decisionTrees = tasks.stream().map(TrainDecisionTreeTask::trainedTree).collect(Collectors.toList());
        return new RandomForestRegressor(decisionTrees, allFeatureVectors.featureDimension());
    }

    static class TrainDecisionTreeTask
    implements Runnable {
        private DecisionTreePredictor<Double> trainedTree;
        private final DecisionTreeTrainerConfig decisionTreeTrainConfig;
        private final RandomForestTrainerConfig randomForestTrainConfig;
        private final SplittableRandom random;
        private final Features allFeatureVectors;
        private final HugeDoubleArray targets;
        private final ImpurityCriterion impurityCriterion;
        private final ReadOnlyHugeLongArray trainSet;
        private final ProgressTracker progressTracker;
        private final LogLevel messageLogLevel;
        private final AtomicInteger numberOfTreesTrained;

        TrainDecisionTreeTask(DecisionTreeTrainerConfig decisionTreeTrainConfig, RandomForestTrainerConfig randomForestTrainConfig, SplittableRandom random, Features allFeatureVectors, HugeDoubleArray targets, ImpurityCriterion impurityCriterion, ReadOnlyHugeLongArray trainSet, ProgressTracker progressTracker, LogLevel messageLogLevel, AtomicInteger numberOfTreesTrained) {
            this.decisionTreeTrainConfig = decisionTreeTrainConfig;
            this.randomForestTrainConfig = randomForestTrainConfig;
            this.random = random;
            this.allFeatureVectors = allFeatureVectors;
            this.targets = targets;
            this.impurityCriterion = impurityCriterion;
            this.trainSet = trainSet;
            this.progressTracker = progressTracker;
            this.messageLogLevel = messageLogLevel;
            this.numberOfTreesTrained = numberOfTreesTrained;
        }

        public static MemoryRange memoryEstimation(DecisionTreeTrainerConfig config, long numberOfTrainingSamples, int numberOfBaggedFeatures, double numberOfSamplesRatio) {
            long usedNumberOfTrainingSamples = (long)Math.ceil(numberOfSamplesRatio * (double)numberOfTrainingSamples);
            MemoryRange bootstrappedDatasetEstimation = MemoryRange.of((long)HugeLongArray.memoryEstimation((long)usedNumberOfTrainingSamples)).add(MemoryUsage.sizeOfBitset((long)usedNumberOfTrainingSamples));
            return MemoryRange.of((long)MemoryUsage.sizeOfInstance(TrainDecisionTreeTask.class)).add(FeatureBagger.memoryEstimation(numberOfBaggedFeatures)).add(DecisionTreeRegressorTrainer.memoryEstimation(config, usedNumberOfTrainingSamples)).add(bootstrappedDatasetEstimation);
        }

        public DecisionTreePredictor<Double> trainedTree() {
            return this.trainedTree;
        }

        @Override
        public void run() {
            FeatureBagger featureBagger = new FeatureBagger(this.random, this.allFeatureVectors.featureDimension(), this.randomForestTrainConfig.maxFeaturesRatio(this.allFeatureVectors.featureDimension()));
            DecisionTreeRegressorTrainer decisionTree = new DecisionTreeRegressorTrainer(this.impurityCriterion, this.allFeatureVectors, this.targets, this.decisionTreeTrainConfig, featureBagger);
            this.trainedTree = decisionTree.train(this.bootstrappedDataset());
            this.progressTracker.logMessage(this.messageLogLevel, StringFormatting.formatWithLocale((String)"trained decision tree %d out of %d", (Object[])new Object[]{this.numberOfTreesTrained.incrementAndGet(), this.randomForestTrainConfig.numberOfDecisionTrees()}));
        }

        private ReadOnlyHugeLongArray bootstrappedDataset() {
            ReadOnlyHugeLongArray allVectorsIndices;
            BitSet trainSetIndices = new BitSet(this.trainSet.size());
            if (Double.compare(this.randomForestTrainConfig.numberOfSamplesRatio(), 0.0) == 0) {
                allVectorsIndices = this.trainSet;
                trainSetIndices.set(1L, this.trainSet.size());
            } else {
                allVectorsIndices = DatasetBootstrapper.bootstrap(this.random, this.randomForestTrainConfig.numberOfSamplesRatio(), this.trainSet, trainSetIndices);
            }
            return allVectorsIndices;
        }
    }
}

