/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.gds.ml.decisiontree;

import org.neo4j.gds.collections.ha.HugeDoubleArray;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.core.utils.mem.MemoryRange;
import org.neo4j.gds.mem.MemoryUsage;
import org.neo4j.gds.ml.decisiontree.ImpurityCriterion;

public class SplitMeanSquaredError
implements ImpurityCriterion {
    private final HugeDoubleArray targets;

    public SplitMeanSquaredError(HugeDoubleArray targets) {
        this.targets = targets;
    }

    public static MemoryRange memoryEstimation() {
        return MemoryRange.of((long)MemoryUsage.sizeOfInstance(SplitMeanSquaredError.class));
    }

    @Override
    public MSEImpurityData groupImpurity(HugeLongArray group, long startIdx, long size) {
        if (size <= 0L) {
            return new MSEImpurityData(0.0, 0.0, 0.0, 0L);
        }
        double sum = 0.0;
        double sumOfSquares = 0.0;
        for (long i = startIdx; i < size; ++i) {
            double value = this.targets.get(group.get(i));
            sum += value;
            sumOfSquares += value * value;
        }
        double mean = sum / (double)size;
        double mse = sumOfSquares / (double)size - mean * mean;
        return new MSEImpurityData(mse, sumOfSquares, sum, size);
    }

    @Override
    public void incrementalImpurity(long featureVectorIdx, ImpurityCriterion.ImpurityData impurityData) {
        MSEImpurityData mseImpurityData = (MSEImpurityData)impurityData;
        double value = this.targets.get(featureVectorIdx);
        double sum = mseImpurityData.sum() + value;
        double sumOfSquares = mseImpurityData.sumOfSquares + value * value;
        long groupSize = mseImpurityData.groupSize + 1L;
        SplitMeanSquaredError.updateImpurityData(sum, sumOfSquares, groupSize, mseImpurityData);
    }

    @Override
    public void decrementalImpurity(long featureVectorIdx, ImpurityCriterion.ImpurityData impurityData) {
        MSEImpurityData mseImpurityData = (MSEImpurityData)impurityData;
        double value = this.targets.get(featureVectorIdx);
        double sum = mseImpurityData.sum() - value;
        double sumOfSquares = mseImpurityData.sumOfSquares - value * value;
        long groupSize = mseImpurityData.groupSize - 1L;
        SplitMeanSquaredError.updateImpurityData(sum, sumOfSquares, groupSize, mseImpurityData);
    }

    private static void updateImpurityData(double sum, double sumOfSquares, long groupSize, MSEImpurityData mseImpurityData) {
        double mean = sum / (double)groupSize;
        double mse = sumOfSquares / (double)groupSize - mean * mean;
        mseImpurityData.setImpurity(mse);
        mseImpurityData.setSum(sum);
        mseImpurityData.setSumOfSquares(sumOfSquares);
        mseImpurityData.setGroupSize(groupSize);
    }

    static class MSEImpurityData
    implements ImpurityCriterion.ImpurityData {
        private double impurity;
        private double sumOfSquares;
        private double sum;
        private long groupSize;

        MSEImpurityData(double impurity, double sumOfSquares, double sum, long groupSize) {
            this.impurity = impurity;
            this.sumOfSquares = sumOfSquares;
            this.sum = sum;
            this.groupSize = groupSize;
        }

        public static long memoryEstimation() {
            return MemoryUsage.sizeOfInstance(MSEImpurityData.class);
        }

        @Override
        public double impurity() {
            return this.impurity;
        }

        @Override
        public long groupSize() {
            return this.groupSize;
        }

        @Override
        public void copyTo(ImpurityCriterion.ImpurityData impurityData) {
            MSEImpurityData mseImpurityData = (MSEImpurityData)impurityData;
            mseImpurityData.setImpurity(this.impurity());
            mseImpurityData.setSumOfSquares(this.sumOfSquares());
            mseImpurityData.setSum(this.sum());
            mseImpurityData.setGroupSize(this.groupSize());
        }

        public void setGroupSize(long groupSize) {
            this.groupSize = groupSize;
        }

        public void setSum(double sum) {
            this.sum = sum;
        }

        public void setSumOfSquares(double sumOfSquares) {
            this.sumOfSquares = sumOfSquares;
        }

        public double sum() {
            return this.sum;
        }

        public double sumOfSquares() {
            return this.sumOfSquares;
        }

        public void setImpurity(double impurity) {
            this.impurity = impurity;
        }
    }
}

