/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.decisiontree;

import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.decisiontree.DecisionTreeTrainer;
import org.neo4j.gds.ml.decisiontree.DecisionTreeTrainerConfig;
import org.neo4j.gds.ml.decisiontree.FeatureBagger;
import org.neo4j.gds.ml.decisiontree.GiniIndex;
import org.neo4j.gds.ml.decisiontree.Group;
import org.neo4j.gds.ml.decisiontree.ImpurityCriterion;
import org.neo4j.gds.ml.decisiontree.TreeNode;
import org.neo4j.gds.ml.models.Features;

public class DecisionTreeClassifierTrainer
extends DecisionTreeTrainer<Integer> {
    private final HugeIntArray allLabels;
    private final int numberOfClasses;

    public DecisionTreeClassifierTrainer(ImpurityCriterion impurityCriterion, Features features, HugeIntArray labels, int numberOfClasses, DecisionTreeTrainerConfig config, FeatureBagger featureBagger) {
        super(features, config, impurityCriterion, featureBagger);
        this.numberOfClasses = numberOfClasses;
        assert (labels.size() == features.size());
        this.allLabels = labels;
    }

    public static MemoryRange memoryEstimation(DecisionTreeTrainerConfig config, long numberOfTrainingSamples, int numberOfClasses) {
        return MemoryRange.of((long)MemoryUsage.sizeOfInstance(DecisionTreeClassifierTrainer.class)).add(DecisionTreeTrainer.estimateTree(config, numberOfTrainingSamples, TreeNode.leafMemoryEstimation(Integer.class), GiniIndex.GiniImpurityData.memoryEstimation(numberOfClasses))).add(MemoryUsage.sizeOfLongArray((long)numberOfClasses));
    }

    @Override
    protected Integer toTerminal(Group group) {
        long[] classesInGroup = new long[this.numberOfClasses];
        HugeLongArray array = group.array();
        for (long i = group.startIdx(); i < group.startIdx() + group.size(); ++i) {
            int n = this.allLabels.get(array.get(i));
            classesInGroup[n] = classesInGroup[n] + 1L;
        }
        long maxClassCountInGroup = -1L;
        int maxMappedClass = 0;
        for (int i = 0; i < classesInGroup.length; ++i) {
            if (classesInGroup[i] <= maxClassCountInGroup) continue;
            maxClassCountInGroup = classesInGroup[i];
            maxMappedClass = i;
        }
        return maxMappedClass;
    }
}

