/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.decisiontree;

import org.neo4j.gds.collections.ha.HugeIntArray;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.decisiontree.ImpurityCriterion;

public class Entropy
implements ImpurityCriterion {
    private static final double LN_2 = Math.log(2.0);
    private final HugeIntArray expectedMappedLabels;
    private final int numberOfClasses;

    public Entropy(HugeIntArray expectedMappedLabels, int numberOfClasses) {
        this.expectedMappedLabels = expectedMappedLabels;
        this.numberOfClasses = numberOfClasses;
    }

    public static MemoryRange memoryEstimation(long numberOfTrainingSamples) {
        return MemoryRange.of((long)HugeIntArray.memoryEstimation((long)numberOfTrainingSamples)).add(MemoryRange.of((long)MemoryUsage.sizeOfInstance(Entropy.class)));
    }

    @Override
    public EntropyImpurityData groupImpurity(HugeLongArray group, long startIndex, long size) {
        if (size == 0L) {
            return new EntropyImpurityData(0.0, new long[this.numberOfClasses], size);
        }
        long[] groupClassCounts = new long[this.numberOfClasses];
        for (long i = startIndex; i < size; ++i) {
            int expectedLabel;
            int n = expectedLabel = this.expectedMappedLabels.get(group.get(i));
            groupClassCounts[n] = groupClassCounts[n] + 1L;
        }
        double impurity = 0.0;
        for (long count : groupClassCounts) {
            if (count == 0L) continue;
            double p = (double)count / (double)size;
            impurity -= p * Math.log(p);
        }
        return new EntropyImpurityData(impurity /= LN_2, groupClassCounts, size);
    }

    @Override
    public void incrementalImpurity(long featureVectorIdx, ImpurityCriterion.ImpurityData impurityData) {
        EntropyImpurityData entropyImpurityData = (EntropyImpurityData)impurityData;
        int label = this.expectedMappedLabels.get(featureVectorIdx);
        long newClassCount = entropyImpurityData.classCounts[label] + 1L;
        long newGroupSize = entropyImpurityData.groupSize() + 1L;
        Entropy.updateImpurityData(label, newGroupSize, newClassCount, entropyImpurityData);
    }

    @Override
    public void decrementalImpurity(long featureVectorIdx, ImpurityCriterion.ImpurityData impurityData) {
        EntropyImpurityData entropyImpurityData = (EntropyImpurityData)impurityData;
        int label = this.expectedMappedLabels.get(featureVectorIdx);
        long newClassCount = entropyImpurityData.classCounts[label] - 1L;
        long newGroupSize = entropyImpurityData.groupSize() - 1L;
        Entropy.updateImpurityData(label, newGroupSize, newClassCount, entropyImpurityData);
    }

    private static void updateImpurityData(int label, long newGroupSize, long newClassCount, EntropyImpurityData impurityData) {
        long prevClassCount = impurityData.classCounts()[label];
        double newImpurity = 0.0;
        if (newGroupSize > 0L) {
            newImpurity = impurityData.impurity() * LN_2;
            if (impurityData.groupSize() > 0L) {
                newImpurity -= Math.log(impurityData.groupSize());
                newImpurity *= (double)impurityData.groupSize();
            }
            if (prevClassCount > 0L) {
                newImpurity += (double)prevClassCount * Math.log(prevClassCount);
            }
            if (newClassCount > 0L) {
                newImpurity -= (double)newClassCount * Math.log(newClassCount);
            }
            newImpurity /= (double)newGroupSize;
            newImpurity += Math.log(newGroupSize);
            newImpurity /= LN_2;
        }
        impurityData.classCounts()[label] = newClassCount;
        impurityData.setGroupSize(newGroupSize);
        impurityData.setImpurity(newImpurity);
    }

    static class EntropyImpurityData
    implements ImpurityCriterion.ImpurityData {
        private double impurity;
        private final long[] classCounts;
        private long groupSize;

        EntropyImpurityData(double impurity, long[] classCounts, long groupSize) {
            this.impurity = impurity;
            this.classCounts = classCounts;
            this.groupSize = groupSize;
        }

        public static long memoryEstimation(int numberOfClasses) {
            return MemoryUsage.sizeOfInstance(EntropyImpurityData.class) + MemoryUsage.sizeOfLongArray((long)numberOfClasses);
        }

        @Override
        public void copyTo(ImpurityCriterion.ImpurityData impurityData) {
            EntropyImpurityData entropyImpurityData = (EntropyImpurityData)impurityData;
            entropyImpurityData.setImpurity(this.impurity());
            entropyImpurityData.setGroupSize(this.groupSize());
            System.arraycopy(this.classCounts(), 0, entropyImpurityData.classCounts(), 0, this.classCounts().length);
        }

        @Override
        public double impurity() {
            return this.impurity;
        }

        public void setImpurity(double impurity) {
            this.impurity = impurity;
        }

        public long[] classCounts() {
            return this.classCounts;
        }

        @Override
        public long groupSize() {
            return this.groupSize;
        }

        public void setGroupSize(long groupSize) {
            this.groupSize = groupSize;
        }
    }
}

