/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.decisiontree;

import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.utils.paged.HugeSerialIndirectMergeSort;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.decisiontree.DecisionTreeTrainer;
import org.neo4j.gds.ml.decisiontree.FeatureBagger;
import org.neo4j.gds.ml.decisiontree.Group;
import org.neo4j.gds.ml.decisiontree.ImmutableGroup;
import org.neo4j.gds.ml.decisiontree.ImmutableGroups;
import org.neo4j.gds.ml.decisiontree.ImmutableSplit;
import org.neo4j.gds.ml.decisiontree.ImpurityCriterion;
import org.neo4j.gds.ml.models.Features;

public class Splitter {
    private final ImpurityCriterion impurityCriterion;
    private final Features features;
    private final FeatureBagger featureBagger;
    private final int minLeafSize;
    private final HugeLongArray sortCache;
    private final ImpurityCriterion.ImpurityData rightImpurityData;

    Splitter(long trainSetSize, ImpurityCriterion impurityCriterion, FeatureBagger featureBagger, Features features, int minLeafSize) {
        this.featureBagger = featureBagger;
        this.impurityCriterion = impurityCriterion;
        this.features = features;
        this.minLeafSize = minLeafSize;
        this.sortCache = HugeLongArray.newArray((long)trainSetSize);
        this.rightImpurityData = impurityCriterion.groupImpurity(HugeLongArray.of((long[])new long[0]), 0L, 0L);
    }

    static long memoryEstimation(long numberOfTrainingSamples, long sizeOfImpurityData) {
        return MemoryUsage.sizeOfInstance(Splitter.class) + HugeLongArray.memoryEstimation((long)numberOfTrainingSamples) + 4L * sizeOfImpurityData + 4L * HugeLongArray.memoryEstimation((long)numberOfTrainingSamples);
    }

    DecisionTreeTrainer.Split findBestSplit(Group group) {
        int[] featureBag;
        int bestIdx = -1;
        double bestValue = Double.MAX_VALUE;
        double bestImpurity = Double.MAX_VALUE;
        long bestLeftGroupSize = -1L;
        HugeLongArray leftChildArray = HugeLongArray.newArray((long)group.size());
        HugeLongArray rightChildArray = HugeLongArray.newArray((long)group.size());
        HugeLongArray bestLeftChildArray = HugeLongArray.newArray((long)group.size());
        HugeLongArray bestRightChildArray = HugeLongArray.newArray((long)group.size());
        ImpurityCriterion.ImpurityData bestLeftImpurityData = this.impurityCriterion.groupImpurity(HugeLongArray.of((long[])new long[0]), 0L, 0L);
        ImpurityCriterion.ImpurityData bestRightImpurityData = this.impurityCriterion.groupImpurity(HugeLongArray.of((long[])new long[0]), 0L, 0L);
        rightChildArray.setAll(idx -> group.array().get(group.startIdx() + idx));
        rightChildArray.copyTo(bestRightChildArray, group.size());
        for (int featureIdx : featureBag = this.featureBagger.sample()) {
            HugeSerialIndirectMergeSort.sort((HugeLongArray)rightChildArray, (long)group.size(), l -> this.features.get(l)[featureIdx], (HugeLongArray)this.sortCache);
            group.impurityData().copyTo(this.rightImpurityData);
            for (long leftGroupSize = 1L; leftGroupSize < (long)this.minLeafSize; ++leftGroupSize) {
                long splittingFeatureVectorIdx = rightChildArray.get(leftGroupSize - 1L);
                leftChildArray.set(leftGroupSize - 1L, splittingFeatureVectorIdx);
                this.impurityCriterion.decrementalImpurity(splittingFeatureVectorIdx, this.rightImpurityData);
            }
            ImpurityCriterion.ImpurityData leftImpurityData = this.impurityCriterion.groupImpurity(leftChildArray, 0L, (long)this.minLeafSize - 1L);
            boolean foundImprovementWithIdx = false;
            for (long leftGroupSize = (long)this.minLeafSize; leftGroupSize <= group.size() - (long)this.minLeafSize; ++leftGroupSize) {
                long splittingFeatureVectorIdx = rightChildArray.get(leftGroupSize - 1L);
                leftChildArray.set(leftGroupSize - 1L, splittingFeatureVectorIdx);
                this.impurityCriterion.incrementalImpurity(splittingFeatureVectorIdx, leftImpurityData);
                this.impurityCriterion.decrementalImpurity(splittingFeatureVectorIdx, this.rightImpurityData);
                double combinedImpurity = this.impurityCriterion.combinedImpurity(leftImpurityData, this.rightImpurityData);
                if (!(combinedImpurity < bestImpurity)) continue;
                foundImprovementWithIdx = true;
                bestIdx = featureIdx;
                bestValue = this.features.get(splittingFeatureVectorIdx)[featureIdx];
                bestImpurity = combinedImpurity;
                bestLeftGroupSize = leftGroupSize;
                leftImpurityData.copyTo(bestLeftImpurityData);
                this.rightImpurityData.copyTo(bestRightImpurityData);
            }
            if (!foundImprovementWithIdx) continue;
            HugeLongArray tmpChildArray = bestRightChildArray;
            bestRightChildArray = rightChildArray;
            rightChildArray = tmpChildArray;
            tmpChildArray = bestLeftChildArray;
            bestLeftChildArray = leftChildArray;
            leftChildArray = tmpChildArray;
        }
        return ImmutableSplit.of(bestIdx, bestValue, ImmutableGroups.of(ImmutableGroup.of(bestLeftChildArray, 0L, bestLeftGroupSize, bestLeftImpurityData), ImmutableGroup.of(bestRightChildArray, bestLeftGroupSize, group.size() - bestLeftGroupSize, bestRightImpurityData)));
    }
}

