/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.df.split;

import java.util.Arrays;
import org.apache.commons.lang.ArrayUtils;
import org.apache.mahout.df.data.Data;
import org.apache.mahout.df.data.DataUtils;
import org.apache.mahout.df.data.Instance;
import org.apache.mahout.df.split.IgSplit;
import org.apache.mahout.df.split.Split;

public class OptIgSplit
extends IgSplit {
    private int[][] counts;
    private int[] countAll;
    private int[] countLess;

    @Override
    public Split computeSplit(Data data, int attr) {
        if (data.getDataset().isNumerical(attr)) {
            return this.numericalSplit(data, attr);
        }
        return OptIgSplit.categoricalSplit(data, attr);
    }

    private static Split categoricalSplit(Data data, int attr) {
        double[] values = data.values(attr);
        int[][] counts = new int[values.length][data.getDataset().nblabels()];
        int[] countAll = new int[data.getDataset().nblabels()];
        for (int index = 0; index < data.size(); ++index) {
            Instance instance = data.get(index);
            int[] nArray = counts[ArrayUtils.indexOf((double[])values, (double)instance.get(attr))];
            int n = instance.getLabel();
            nArray[n] = nArray[n] + 1;
            int n2 = instance.getLabel();
            countAll[n2] = countAll[n2] + 1;
        }
        int size = data.size();
        double hy = OptIgSplit.entropy(countAll, size);
        double hyx = 0.0;
        double invDataSize = 1.0 / (double)size;
        for (int index = 0; index < values.length; ++index) {
            size = DataUtils.sum(counts[index]);
            hyx += (double)size * invDataSize * OptIgSplit.entropy(counts[index], size);
        }
        double ig = hy - hyx;
        return new Split(attr, ig);
    }

    private static double[] sortedValues(Data data, int attr) {
        double[] values = data.values(attr);
        Arrays.sort(values);
        return values;
    }

    protected void initCounts(Data data, double[] values) {
        this.counts = new int[values.length][data.getDataset().nblabels()];
        this.countAll = new int[data.getDataset().nblabels()];
        this.countLess = new int[data.getDataset().nblabels()];
    }

    protected void computeFrequencies(Data data, int attr, double[] values) {
        for (int index = 0; index < data.size(); ++index) {
            Instance instance = data.get(index);
            int[] nArray = this.counts[ArrayUtils.indexOf((double[])values, (double)instance.get(attr))];
            int n = instance.getLabel();
            nArray[n] = nArray[n] + 1;
            int n2 = instance.getLabel();
            this.countAll[n2] = this.countAll[n2] + 1;
        }
    }

    protected Split numericalSplit(Data data, int attr) {
        double[] values = OptIgSplit.sortedValues(data, attr);
        this.initCounts(data, values);
        this.computeFrequencies(data, attr, values);
        int size = data.size();
        double hy = OptIgSplit.entropy(this.countAll, size);
        double invDataSize = 1.0 / (double)size;
        int best = -1;
        double bestIg = -1.0;
        for (int index = 0; index < values.length; ++index) {
            double ig = hy;
            size = DataUtils.sum(this.countLess);
            ig -= (double)size * invDataSize * OptIgSplit.entropy(this.countLess, size);
            size = DataUtils.sum(this.countAll);
            if ((ig -= (double)size * invDataSize * OptIgSplit.entropy(this.countAll, size)) > bestIg) {
                bestIg = ig;
                best = index;
            }
            DataUtils.add(this.countLess, this.counts[index]);
            DataUtils.dec(this.countAll, this.counts[index]);
        }
        if (best == -1) {
            throw new IllegalStateException("no best split found !");
        }
        return new Split(attr, bestIg, values[best]);
    }

    private static double entropy(int[] counts, int dataSize) {
        if (dataSize == 0) {
            return 0.0;
        }
        double entropy = 0.0;
        double invDataSize = 1.0 / (double)dataSize;
        for (int count : counts) {
            if (count == 0) continue;
            double p = (double)count * invDataSize;
            entropy += -p * Math.log(p) / LOG2;
        }
        return entropy;
    }
}

