/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.decisiontree;

import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.decisiontree.ImpurityCriterion;

public class GiniIndex
implements ImpurityCriterion {
    private final HugeIntArray expectedMappedLabels;
    private final int numberOfClasses;

    public GiniIndex(HugeIntArray expectedMappedLabels, int numberOfClasses) {
        this.expectedMappedLabels = expectedMappedLabels;
        this.numberOfClasses = numberOfClasses;
    }

    public static MemoryRange memoryEstimation(long numberOfTrainingSamples) {
        return MemoryRange.of((long)HugeIntArray.memoryEstimation((long)numberOfTrainingSamples)).add(MemoryRange.of((long)MemoryUsage.sizeOfInstance(GiniIndex.class)));
    }

    @Override
    public GiniImpurityData groupImpurity(HugeLongArray group, long startIndex, long size) {
        if (size == 0L) {
            return new GiniImpurityData(0.0, new long[this.numberOfClasses], size);
        }
        long[] groupClassCounts = new long[this.numberOfClasses];
        for (long i = startIndex; i < size; ++i) {
            int expectedLabel;
            int n = expectedLabel = this.expectedMappedLabels.get(group.get(i));
            groupClassCounts[n] = groupClassCounts[n] + 1L;
        }
        long sumOfSquares = 0L;
        for (long count : groupClassCounts) {
            sumOfSquares += count * count;
        }
        double impurity = 1.0 - (double)sumOfSquares / (double)(size * size);
        return new GiniImpurityData(impurity, groupClassCounts, size);
    }

    @Override
    public void incrementalImpurity(long featureVectorIdx, ImpurityCriterion.ImpurityData impurityData) {
        GiniImpurityData giniImpurityData = (GiniImpurityData)impurityData;
        int label = this.expectedMappedLabels.get(featureVectorIdx);
        long newClassCount = giniImpurityData.classCounts[label] + 1L;
        long newGroupSize = giniImpurityData.groupSize() + 1L;
        GiniIndex.updateImpurityData(label, newGroupSize, newClassCount, giniImpurityData);
    }

    @Override
    public void decrementalImpurity(long featureVectorIdx, ImpurityCriterion.ImpurityData impurityData) {
        GiniImpurityData giniImpurityData = (GiniImpurityData)impurityData;
        int label = this.expectedMappedLabels.get(featureVectorIdx);
        long newClassCount = giniImpurityData.classCounts[label] - 1L;
        long newGroupSize = giniImpurityData.groupSize() - 1L;
        GiniIndex.updateImpurityData(label, newGroupSize, newClassCount, giniImpurityData);
    }

    private static void updateImpurityData(int label, long newGroupSize, long newClassCount, GiniImpurityData impurityData) {
        long groupSizeSquared = impurityData.groupSize() * impurityData.groupSize();
        long newGroupSizeSquared = newGroupSize * newGroupSize;
        long prevClassCount = impurityData.classCounts()[label];
        double newImpurity = impurityData.impurity();
        newImpurity *= (double)groupSizeSquared / (double)newGroupSizeSquared;
        newImpurity += 1.0 - (double)groupSizeSquared / (double)newGroupSizeSquared;
        newImpurity += (double)(prevClassCount * prevClassCount) / (double)newGroupSizeSquared;
        impurityData.classCounts()[label] = newClassCount;
        impurityData.setGroupSize(newGroupSize);
        impurityData.setImpurity(newImpurity -= (double)(newClassCount * newClassCount) / (double)newGroupSizeSquared);
    }

    static class GiniImpurityData
    implements ImpurityCriterion.ImpurityData {
        private double impurity;
        private final long[] classCounts;
        private long groupSize;

        GiniImpurityData(double impurity, long[] classCounts, long groupSize) {
            this.impurity = impurity;
            this.classCounts = classCounts;
            this.groupSize = groupSize;
        }

        public static long memoryEstimation(int numberOfClasses) {
            return MemoryUsage.sizeOfInstance(GiniImpurityData.class) + MemoryUsage.sizeOfLongArray((long)numberOfClasses);
        }

        @Override
        public void copyTo(ImpurityCriterion.ImpurityData impurityData) {
            GiniImpurityData giniImpurityData = (GiniImpurityData)impurityData;
            giniImpurityData.setImpurity(this.impurity());
            giniImpurityData.setGroupSize(this.groupSize());
            System.arraycopy(this.classCounts(), 0, giniImpurityData.classCounts(), 0, this.classCounts().length);
        }

        @Override
        public double impurity() {
            return this.impurity;
        }

        public void setImpurity(double impurity) {
            this.impurity = impurity;
        }

        public long[] classCounts() {
            return this.classCounts;
        }

        @Override
        public long groupSize() {
            return this.groupSize;
        }

        public void setGroupSize(long groupSize) {
            this.groupSize = groupSize;
        }
    }
}

