/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.decisiontree;

import java.util.ArrayDeque;
import java.util.Deque;
import org.neo4j.gds.annotation.ValueClass;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.core.utils.paged.ReadOnlyHugeLongArray;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.decisiontree.DecisionTreePredictor;
import org.neo4j.gds.ml.decisiontree.DecisionTreeTrainerConfig;
import org.neo4j.gds.ml.decisiontree.FeatureBagger;
import org.neo4j.gds.ml.decisiontree.Group;
import org.neo4j.gds.ml.decisiontree.Groups;
import org.neo4j.gds.ml.decisiontree.ImmutableGroup;
import org.neo4j.gds.ml.decisiontree.ImmutableStackRecord;
import org.neo4j.gds.ml.decisiontree.ImpurityCriterion;
import org.neo4j.gds.ml.decisiontree.Splitter;
import org.neo4j.gds.ml.decisiontree.TreeNode;
import org.neo4j.gds.ml.models.Features;

public abstract class DecisionTreeTrainer<PREDICTION extends Number> {
    private final ImpurityCriterion impurityCriterion;
    private final Features features;
    private final DecisionTreeTrainerConfig config;
    private final FeatureBagger featureBagger;
    private Splitter splitter;

    DecisionTreeTrainer(Features features, DecisionTreeTrainerConfig config, ImpurityCriterion impurityCriterion, FeatureBagger featureBagger) {
        this.impurityCriterion = impurityCriterion;
        this.features = features;
        this.config = config;
        this.featureBagger = featureBagger;
    }

    static MemoryRange estimateTree(DecisionTreeTrainerConfig config, long numberOfTrainingSamples, long leafNodeSizeInBytes, long sizeOfImpurityData) {
        MemoryRange predictorEstimation = DecisionTreeTrainer.estimateTree(config, numberOfTrainingSamples, leafNodeSizeInBytes);
        long normalizedMaxDepth = Math.min((long)config.maxDepth(), Math.max(1L, numberOfTrainingSamples - (long)config.minSplitSize() + 2L));
        long maxItemsOnStack = 2L * normalizedMaxDepth;
        MemoryRange maxStackSize = MemoryRange.of((long)MemoryUsage.sizeOfInstance(ArrayDeque.class)).add(MemoryRange.of((long)1L, (long)maxItemsOnStack).times(MemoryUsage.sizeOfInstance(ImmutableStackRecord.class))).add(MemoryRange.of((long)0L, (long)(HugeLongArray.memoryEstimation((long)(numberOfTrainingSamples / maxItemsOnStack)) * maxItemsOnStack)));
        long splitterEstimation = Splitter.memoryEstimation(numberOfTrainingSamples, sizeOfImpurityData);
        return predictorEstimation.add(maxStackSize).add(splitterEstimation);
    }

    public static MemoryRange estimateTree(DecisionTreeTrainerConfig config, long numberOfTrainingSamples, long leafNodeSizeInBytes) {
        if (numberOfTrainingSamples == 0L) {
            return MemoryRange.empty();
        }
        long maxNumLeafNodes = (long)Math.ceil(Math.min(Math.pow(2.0, config.maxDepth()), Math.min((double)numberOfTrainingSamples / (double)config.minLeafSize(), 2.0 * (double)numberOfTrainingSamples / (double)config.minSplitSize())));
        return MemoryRange.of((long)MemoryUsage.sizeOfInstance(DecisionTreePredictor.class)).add(MemoryRange.of((long)1L, (long)maxNumLeafNodes).times(leafNodeSizeInBytes)).add(MemoryRange.of((long)0L, (long)(maxNumLeafNodes - 1L)).times(TreeNode.splitMemoryEstimation()));
    }

    public DecisionTreePredictor<PREDICTION> train(ReadOnlyHugeLongArray trainSetIndices) {
        this.splitter = new Splitter(trainSetIndices.size(), this.impurityCriterion, this.featureBagger, this.features, this.config.minLeafSize());
        ArrayDeque<StackRecord<PREDICTION>> stack = new ArrayDeque<StackRecord<PREDICTION>>();
        HugeLongArray mutableTrainSetIndices = HugeLongArray.newArray((long)trainSetIndices.size());
        mutableTrainSetIndices.setAll(arg_0 -> ((ReadOnlyHugeLongArray)trainSetIndices).get(arg_0));
        ImpurityCriterion.ImpurityData impurityData = this.impurityCriterion.groupImpurity(mutableTrainSetIndices, 0L, mutableTrainSetIndices.size());
        TreeNode<PREDICTION> root = this.splitAndPush(stack, ImmutableGroup.of(mutableTrainSetIndices, 0L, mutableTrainSetIndices.size(), impurityData), 1);
        int maxDepth = this.config.maxDepth();
        int minSplitSize = this.config.minSplitSize();
        while (!stack.isEmpty()) {
            StackRecord<PREDICTION> record = stack.pop();
            Split split = record.split();
            if (record.depth() >= maxDepth || split.groups().left().size() < (long)minSplitSize) {
                record.node().setLeftChild(new TreeNode<PREDICTION>(this.toTerminal(split.groups().left())));
            } else {
                record.node().setLeftChild(this.splitAndPush(stack, split.groups().left(), record.depth() + 1));
            }
            if (record.depth() >= maxDepth || split.groups().right().size() < (long)minSplitSize) {
                record.node().setRightChild(new TreeNode<PREDICTION>(this.toTerminal(split.groups().right())));
                continue;
            }
            record.node().setRightChild(this.splitAndPush(stack, split.groups().right(), record.depth() + 1));
        }
        return new DecisionTreePredictor<PREDICTION>(root);
    }

    protected abstract PREDICTION toTerminal(Group var1);

    private TreeNode<PREDICTION> splitAndPush(Deque<StackRecord<PREDICTION>> stack, Group group, int depth) {
        assert (group.size() > 0L);
        assert (depth >= 1);
        if (group.size() < (long)this.config.minSplitSize()) {
            return new TreeNode<PREDICTION>(this.toTerminal(group));
        }
        Split split = this.splitter.findBestSplit(group);
        if (split.groups().right().size() == 0L) {
            return new TreeNode<PREDICTION>(this.toTerminal(split.groups().left()));
        }
        if (split.groups().left().size() == 0L) {
            return new TreeNode<PREDICTION>(this.toTerminal(split.groups().right()));
        }
        TreeNode node = new TreeNode(split.index(), split.value());
        stack.push(ImmutableStackRecord.of(node, split, depth));
        return node;
    }

    @ValueClass
    static interface StackRecord<PREDICTION extends Number> {
        public TreeNode<PREDICTION> node();

        public Split split();

        public int depth();
    }

    @ValueClass
    static interface Split {
        public int index();

        public double value();

        public Groups groups();
    }
}

