/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.models.randomforest;

import com.carrotsearch.hppc.BitSet;
import java.util.List;
import java.util.Optional;
import java.util.SplittableRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.LongUnaryOperator;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.collections.haa.HugeAtomicLongArray;
import org.neo4j.gds.collections.haa.PageCreator;
import org.neo4j.gds.core.concurrency.RunWithConcurrency;
import org.neo4j.gds.core.utils.mem.MemoryEstimation;
import org.neo4j.gds.core.utils.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.ParalleLongPageCreator;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.core.utils.progress.tasks.LogLevel;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.decisiontree.ClassifierImpurityCriterionType;
import org.neo4j.gds.ml.decisiontree.DecisionTreeClassifierTrainer;
import org.neo4j.gds.ml.decisiontree.DecisionTreePredictor;
import org.neo4j.gds.ml.decisiontree.DecisionTreeTrainerConfig;
import org.neo4j.gds.ml.decisiontree.DecisionTreeTrainerConfigImpl;
import org.neo4j.gds.ml.decisiontree.Entropy;
import org.neo4j.gds.ml.decisiontree.FeatureBagger;
import org.neo4j.gds.ml.decisiontree.GiniIndex;
import org.neo4j.gds.ml.decisiontree.ImpurityCriterion;
import org.neo4j.gds.ml.metrics.ModelSpecificMetricsHandler;
import org.neo4j.gds.ml.metrics.classification.OutOfBagError;
import org.neo4j.gds.ml.models.ClassifierTrainer;
import org.neo4j.gds.ml.models.Features;
import org.neo4j.gds.ml.models.randomforest.DatasetBootstrapper;
import org.neo4j.gds.ml.models.randomforest.ImmutableBootstrappedDataset;
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifier;
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierData;
import org.neo4j.gds.ml.models.randomforest.RandomForestClassifierTrainerConfig;
import org.neo4j.gds.ml.models.randomforest.RandomForestTrainerConfig;
import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.gds.utils.StringFormatting;

public class RandomForestClassifierTrainer
implements ClassifierTrainer {
    private final int numberOfClasses;
    private final RandomForestClassifierTrainerConfig config;
    private final int concurrency;
    private final SplittableRandom random;
    private final ProgressTracker progressTracker;
    private final LogLevel messageLogLevel;
    private final TerminationFlag terminationFlag;
    private Optional<Double> outOfBagError = Optional.empty();
    private final ModelSpecificMetricsHandler metricsHandler;

    public RandomForestClassifierTrainer(int concurrency, int numberOfClasses, RandomForestClassifierTrainerConfig config, Optional<Long> randomSeed, ProgressTracker progressTracker, LogLevel messageLogLevel, TerminationFlag terminationFlag, ModelSpecificMetricsHandler metricsHandler) {
        this.numberOfClasses = numberOfClasses;
        this.config = config;
        this.concurrency = concurrency;
        this.random = new SplittableRandom(randomSeed.orElseGet(() -> new SplittableRandom().nextLong()));
        this.progressTracker = progressTracker;
        this.messageLogLevel = messageLogLevel;
        this.terminationFlag = terminationFlag;
        this.metricsHandler = metricsHandler;
    }

    public static MemoryEstimation memoryEstimation(LongUnaryOperator numberOfTrainingSamples, int numberOfClasses, MemoryRange featureDimension, RandomForestClassifierTrainerConfig config) {
        int minNumberOfBaggedFeatures = (int)Math.ceil(config.maxFeaturesRatio((int)featureDimension.min) * (double)featureDimension.min);
        int maxNumberOfBaggedFeatures = (int)Math.ceil(config.maxFeaturesRatio((int)featureDimension.max) * (double)featureDimension.max);
        return MemoryEstimations.builder((String)"Training").add(RandomForestClassifierData.memoryEstimation(numberOfTrainingSamples, config)).rangePerNode("Impurity computation data", nodeCount -> config.criterion() == ClassifierImpurityCriterionType.GINI ? GiniIndex.memoryEstimation(numberOfTrainingSamples.applyAsLong(nodeCount)) : Entropy.memoryEstimation(numberOfTrainingSamples.applyAsLong(nodeCount))).perGraphDimension("Decision tree training", (dim, concurrency) -> TrainDecisionTreeTask.memoryEstimation(config, numberOfTrainingSamples.applyAsLong(dim.nodeCount()), numberOfClasses, minNumberOfBaggedFeatures, config.numberOfSamplesRatio()).union(TrainDecisionTreeTask.memoryEstimation(config, numberOfTrainingSamples.applyAsLong(dim.nodeCount()), numberOfClasses, maxNumberOfBaggedFeatures, config.numberOfSamplesRatio())).times((long)concurrency.intValue())).build();
    }

    @Override
    public RandomForestClassifier train(Features allFeatureVectors, HugeIntArray allLabels, ReadOnlyHugeLongArray trainSet) {
        Optional<HugeAtomicLongArray> maybePredictions = this.metricsHandler.isRequested(OutOfBagError.OUT_OF_BAG_ERROR) ? Optional.of(HugeAtomicLongArray.of((long)((long)this.numberOfClasses * trainSet.size()), (PageCreator.LongPageCreator)ParalleLongPageCreator.passThrough((int)this.concurrency))) : Optional.empty();
        DecisionTreeTrainerConfig decisionTreeTrainConfig = DecisionTreeTrainerConfigImpl.builder().maxDepth(this.config.maxDepth()).minSplitSize(this.config.minSplitSize()).build();
        int numberOfDecisionTrees = this.config.numberOfDecisionTrees();
        ImpurityCriterion impurityCriterion = this.initializeImpurityCriterion(allLabels);
        AtomicInteger numberOfTreesTrained = new AtomicInteger(0);
        List tasks = IntStream.range(0, numberOfDecisionTrees).mapToObj(unused -> new TrainDecisionTreeTask(maybePredictions, decisionTreeTrainConfig, this.config, this.random.split(), allFeatureVectors, allLabels, this.numberOfClasses, impurityCriterion, trainSet, this.progressTracker, this.messageLogLevel, numberOfTreesTrained)).collect(Collectors.toList());
        RunWithConcurrency.builder().concurrency(this.concurrency).tasks(tasks).terminationFlag(this.terminationFlag).run();
        maybePredictions.ifPresent(predictions -> {
            this.outOfBagError = Optional.of(OutOfBagError.evaluate(trainSet, this.numberOfClasses, allLabels, this.concurrency, predictions));
            this.metricsHandler.handle(OutOfBagError.OUT_OF_BAG_ERROR, this.outOfBagError());
        });
        List<DecisionTreePredictor<Integer>> decisionTrees = tasks.stream().map(TrainDecisionTreeTask::trainedTree).collect(Collectors.toList());
        return new RandomForestClassifier(decisionTrees, this.numberOfClasses, allFeatureVectors.featureDimension());
    }

    double outOfBagError() {
        return this.outOfBagError.orElseThrow(() -> new IllegalAccessError("Out of bag error has not been computed."));
    }

    private ImpurityCriterion initializeImpurityCriterion(HugeIntArray allLabels) {
        switch (this.config.criterion()) {
            case GINI: {
                return new GiniIndex(allLabels, this.numberOfClasses);
            }
            case ENTROPY: {
                return new Entropy(allLabels, this.numberOfClasses);
            }
        }
        throw new IllegalStateException("Invalid decision tree classifier impurity criterion.");
    }

    static class TrainDecisionTreeTask
    implements Runnable {
        private final int numberOfClasses;
        private DecisionTreePredictor<Integer> trainedTree;
        private final Optional<HugeAtomicLongArray> maybePredictions;
        private final DecisionTreeTrainerConfig decisionTreeTrainConfig;
        private final RandomForestTrainerConfig randomForestTrainConfig;
        private final SplittableRandom random;
        private final Features allFeatureVectors;
        private final HugeIntArray allLabels;
        private final ImpurityCriterion impurityCriterion;
        private final ReadOnlyHugeLongArray trainSet;
        private final ProgressTracker progressTracker;
        private final LogLevel messageLogLevel;
        private final AtomicInteger numberOfTreesTrained;

        TrainDecisionTreeTask(Optional<HugeAtomicLongArray> maybePredictions, DecisionTreeTrainerConfig decisionTreeTrainConfig, RandomForestTrainerConfig randomForestTrainConfig, SplittableRandom random, Features allFeatureVectors, HugeIntArray allLabels, int numberOfClasses, ImpurityCriterion impurityCriterion, ReadOnlyHugeLongArray trainSet, ProgressTracker progressTracker, LogLevel messageLogLevel, AtomicInteger numberOfTreesTrained) {
            this.maybePredictions = maybePredictions;
            this.decisionTreeTrainConfig = decisionTreeTrainConfig;
            this.randomForestTrainConfig = randomForestTrainConfig;
            this.random = random;
            this.allFeatureVectors = allFeatureVectors;
            this.allLabels = allLabels;
            this.numberOfClasses = numberOfClasses;
            this.impurityCriterion = impurityCriterion;
            this.trainSet = trainSet;
            this.progressTracker = progressTracker;
            this.messageLogLevel = messageLogLevel;
            this.numberOfTreesTrained = numberOfTreesTrained;
        }

        public static MemoryRange memoryEstimation(DecisionTreeTrainerConfig decisionTreeTrainConfig, long numberOfTrainingSamples, int numberOfClasses, int numberOfBaggedFeatures, double numberOfSamplesRatio) {
            long usedNumberOfTrainingSamples = (long)Math.ceil(numberOfSamplesRatio * (double)numberOfTrainingSamples);
            MemoryRange bootstrappedDatasetEstimation = MemoryRange.of((long)HugeLongArray.memoryEstimation((long)usedNumberOfTrainingSamples)).add(MemoryUsage.sizeOfBitset((long)usedNumberOfTrainingSamples));
            return MemoryRange.of((long)MemoryUsage.sizeOfInstance(TrainDecisionTreeTask.class)).add(FeatureBagger.memoryEstimation(numberOfBaggedFeatures)).add(DecisionTreeClassifierTrainer.memoryEstimation(decisionTreeTrainConfig, usedNumberOfTrainingSamples, numberOfClasses)).add(bootstrappedDatasetEstimation);
        }

        public DecisionTreePredictor<Integer> trainedTree() {
            return this.trainedTree;
        }

        @Override
        public void run() {
            FeatureBagger featureBagger = new FeatureBagger(this.random, this.allFeatureVectors.featureDimension(), this.randomForestTrainConfig.maxFeaturesRatio(this.allFeatureVectors.featureDimension()));
            DecisionTreeClassifierTrainer decisionTree = new DecisionTreeClassifierTrainer(this.impurityCriterion, this.allFeatureVectors, this.allLabels, this.numberOfClasses, this.decisionTreeTrainConfig, featureBagger);
            BootstrappedDataset bootstrappedDataset = this.bootstrappedDataset();
            this.trainedTree = decisionTree.train(bootstrappedDataset.allVectorsIndices());
            this.maybePredictions.ifPresent(predictionsCache -> OutOfBagError.addPredictionsForTree(this.trainedTree, this.numberOfClasses, this.allFeatureVectors, this.trainSet, bootstrappedDataset.trainSetIndices(), predictionsCache));
            this.progressTracker.logMessage(this.messageLogLevel, StringFormatting.formatWithLocale((String)"Trained decision tree %d out of %d", (Object[])new Object[]{this.numberOfTreesTrained.incrementAndGet(), this.randomForestTrainConfig.numberOfDecisionTrees()}));
        }

        private BootstrappedDataset bootstrappedDataset() {
            ReadOnlyHugeLongArray allVectorsIndices;
            BitSet trainSetIndices = new BitSet(this.trainSet.size());
            if (Double.compare(this.randomForestTrainConfig.numberOfSamplesRatio(), 0.0) == 0) {
                allVectorsIndices = this.trainSet;
                trainSetIndices.set(1L, this.trainSet.size());
            } else {
                allVectorsIndices = DatasetBootstrapper.bootstrap(this.random, this.randomForestTrainConfig.numberOfSamplesRatio(), this.trainSet, trainSetIndices);
            }
            return ImmutableBootstrappedDataset.of(trainSetIndices, allVectorsIndices);
        }

        @ValueClass
        static interface BootstrappedDataset {
            public BitSet trainSetIndices();

            public ReadOnlyHugeLongArray allVectorsIndices();
        }
    }
}

