/*
 * Decompiled with CFR 0.152.
 */
package eu.fbk.utils.eval;

import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.Ordering;
import eu.fbk.utils.eval.PrecisionRecall;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Objects;
import javax.annotation.Nullable;

public final class ConfusionMatrix
implements Serializable {
    private static final long serialVersionUID = 1L;
    private final int numLabels;
    private final double[] counts;
    private transient double countTotal;
    @Nullable
    private transient PrecisionRecall[] labelPRs;
    @Nullable
    private transient PrecisionRecall microPR;
    @Nullable
    private transient PrecisionRecall macroPR;

    public ConfusionMatrix(double[][] matrix) {
        this.numLabels = matrix.length;
        this.counts = new double[this.numLabels * this.numLabels];
        for (int i = 0; i < this.numLabels; ++i) {
            double[] row = matrix[i];
            Preconditions.checkArgument(row.length == this.numLabels);
            System.arraycopy(row, 0, this.counts, i * this.numLabels, this.numLabels);
        }
    }

    private void checkLabel(int label) {
        if (label < 0 || label >= this.numLabels) {
            throw new IllegalArgumentException("Invalid label " + label + " (matrix has " + this.numLabels + " labels)");
        }
    }

    public int getNumLabels() {
        return this.numLabels;
    }

    public double getCount(int labelGold, int labelPredicted) {
        this.checkLabel(labelGold);
        this.checkLabel(labelPredicted);
        return this.counts[labelGold * this.numLabels + labelPredicted];
    }

    public double getCountGold(int label) {
        double count = 0.0;
        for (int i = label * this.numLabels; i < (label + 1) * this.numLabels; ++i) {
            count += this.counts[i];
        }
        return count;
    }

    public double getCountPredicted(int label) {
        this.checkLabel(label);
        double count = 0.0;
        for (int i = 0; i < this.numLabels; ++i) {
            count += this.counts[i * this.numLabels + label];
        }
        return count;
    }

    public double getCountTotal() {
        if (this.countTotal == 0.0) {
            double count = 0.0;
            for (int i = 0; i < this.counts.length; ++i) {
                count += this.counts[i];
            }
            this.countTotal = count;
        }
        return this.countTotal;
    }

    public synchronized PrecisionRecall getLabelPR(int label) {
        if (this.labelPRs == null) {
            this.labelPRs = new PrecisionRecall[this.numLabels];
        }
        if (this.labelPRs[label] == null) {
            double tp = this.counts[label * this.numLabels + label];
            double fp = 0.0;
            double fn = 0.0;
            for (int i = 0; i < this.numLabels; ++i) {
                if (i == label) continue;
                fp += this.counts[i * this.numLabels + label];
                fn += this.counts[label * this.numLabels + i];
            }
            double tn = this.getCountTotal() - tp - fp - fn;
            this.labelPRs[label] = PrecisionRecall.forCounts(tp, fp, fn, tn);
        }
        return this.labelPRs[label];
    }

    public synchronized PrecisionRecall getMicroPR() {
        if (this.microPR == null) {
            double fp;
            double tp = 0.0;
            for (int i = 0; i < this.numLabels; ++i) {
                tp += this.counts[i * this.numLabels + i];
            }
            double total = this.getCountTotal();
            double fn = fp = total - tp;
            double tn = total * (double)this.numLabels - tp - fp - fn;
            this.microPR = PrecisionRecall.forCounts(tp, fp, fn, tn);
        }
        return this.microPR;
    }

    public synchronized PrecisionRecall getMacroPR() {
        if (this.macroPR == null) {
            double p = 0.0;
            double r = 0.0;
            double a = 0.0;
            for (int i = 0; i < this.numLabels; ++i) {
                PrecisionRecall pr = this.getLabelPR(i);
                p += pr.getPrecision();
                r += pr.getRecall();
                a += pr.getAccuracy();
            }
            this.macroPR = PrecisionRecall.forMeasures(p /= (double)this.numLabels, r /= (double)this.numLabels, a /= (double)this.numLabels, this.getCountTotal());
        }
        return this.macroPR;
    }

    public boolean equals(Object object) {
        if (object == this) {
            return true;
        }
        if (!(object instanceof ConfusionMatrix)) {
            return false;
        }
        ConfusionMatrix other = (ConfusionMatrix)object;
        return this.numLabels == other.numLabels && Arrays.equals(this.counts, other.counts);
    }

    public int hashCode() {
        return Objects.hash(this.numLabels, Arrays.hashCode(this.counts));
    }

    public String toString() {
        return this.toString(null);
    }

    public String toString(String ... labelStrings) {
        int i;
        double total = this.getCountTotal();
        PrecisionRecall micro = this.getMicroPR();
        PrecisionRecall macro = this.getMacroPR();
        String delim = Strings.repeat("-", 10 + this.numLabels * 10 + 2 + 10 + 10) + '\n';
        StringBuilder builder = new StringBuilder("pred->   |");
        for (i = 0; i < this.numLabels; ++i) {
            String str = labelStrings == null || i >= labelStrings.length ? Integer.toString(i) : labelStrings[i];
            builder.append(String.format("%10s", str));
        }
        builder.append(" |       sum         %\n");
        builder.append(delim);
        for (int j = 0; j < this.numLabels; ++j) {
            double sum = this.getCountGold(j);
            String str = labelStrings == null || j >= labelStrings.length ? Integer.toString(j) : labelStrings[j];
            builder.append(String.format("%8s |", str));
            for (int i2 = 0; i2 < this.numLabels; ++i2) {
                builder.append(String.format("%10.2f", this.getCount(j, i2)));
            }
            builder.append(String.format(" |%10.2f%10.2f\n", sum, sum / total * 100.0));
        }
        builder.append(delim);
        builder.append("     sum |");
        for (i = 0; i < this.numLabels; ++i) {
            builder.append(String.format("%10.2f", this.getCountPredicted(i)));
        }
        builder.append(String.format(" |%10.2f%10.2f\n", total, 100.0));
        builder.append("       % |");
        for (i = 0; i < this.numLabels; ++i) {
            builder.append(String.format("%10.2f", this.getCountPredicted(i) / total * 100.0));
        }
        builder.append(" |     macro     micro\n");
        builder.append(delim);
        builder.append("     acc |");
        for (i = 0; i < this.numLabels; ++i) {
            builder.append(String.format("%10.2f", this.getLabelPR(i).getAccuracy() * 100.0));
        }
        builder.append(String.format(" |%10.2f%10.2f\n", macro.getAccuracy() * 100.0, micro.getAccuracy() * 100.0));
        builder.append("    prec |");
        for (i = 0; i < this.numLabels; ++i) {
            builder.append(String.format("%10.2f", this.getLabelPR(i).getPrecision() * 100.0));
        }
        builder.append(String.format(" |%10.2f%10.2f\n", macro.getPrecision() * 100.0, micro.getPrecision() * 100.0));
        builder.append("     rec |");
        for (i = 0; i < this.numLabels; ++i) {
            builder.append(String.format("%10.2f", this.getLabelPR(i).getRecall() * 100.0));
        }
        builder.append(String.format(" |%10.2f%10.2f\n", macro.getRecall() * 100.0, micro.getRecall() * 100.0));
        builder.append("      F1 |");
        for (i = 0; i < this.numLabels; ++i) {
            builder.append(String.format("%10.2f", this.getLabelPR(i).getF1() * 100.0));
        }
        builder.append(String.format(" |%10.2f%10.2f\n", macro.getF1() * 100.0, micro.getF1() * 100.0));
        return builder.toString();
    }

    public static Ordering<ConfusionMatrix> labelComparator(final PrecisionRecall.Measure measure, final int label, final boolean higherIsBetter) {
        return new Ordering<ConfusionMatrix>(){

            @Override
            public int compare(ConfusionMatrix left, ConfusionMatrix right) {
                double leftValue = left.getLabelPR(label).get(measure);
                double rightValue = right.getLabelPR(label).get(measure);
                if (Double.isNaN(leftValue)) {
                    return Double.isNaN(rightValue) ? 0 : 1;
                }
                return Double.isNaN(rightValue) ? -1 : Double.compare(leftValue, rightValue) * (higherIsBetter ? -1 : 1);
            }
        };
    }

    public static Ordering<ConfusionMatrix> microComparator(final PrecisionRecall.Measure measure, final boolean higherIsBetter) {
        return new Ordering<ConfusionMatrix>(){

            @Override
            public int compare(ConfusionMatrix left, ConfusionMatrix right) {
                double leftValue = left.getMicroPR().get(measure);
                double rightValue = right.getMicroPR().get(measure);
                int result = Double.compare(leftValue, rightValue);
                return higherIsBetter ? -result : result;
            }
        };
    }

    public static Ordering<ConfusionMatrix> macroComparator(final PrecisionRecall.Measure measure, final boolean higherIsBetter) {
        return new Ordering<ConfusionMatrix>(){

            @Override
            public int compare(ConfusionMatrix left, ConfusionMatrix right) {
                double leftValue = left.getMacroPR().get(measure);
                double rightValue = right.getMacroPR().get(measure);
                int result = Double.compare(leftValue, rightValue);
                return higherIsBetter ? -result : result;
            }
        };
    }

    @Nullable
    public static ConfusionMatrix sum(Iterable<ConfusionMatrix> matrixes) {
        int numMatrixes = 0;
        int numLabels = 0;
        for (ConfusionMatrix matrix : matrixes) {
            ++numMatrixes;
            numLabels = Math.max(numLabels, matrix.numLabels);
        }
        if (numMatrixes == 0) {
            return null;
        }
        if (numMatrixes == 1) {
            return matrixes.iterator().next();
        }
        double[][] counts = new double[numLabels][];
        for (int i = 0; i < numLabels; ++i) {
            counts[i] = new double[numLabels];
            for (int j = 0; j < numLabels; ++j) {
                for (ConfusionMatrix matrix : matrixes) {
                    if (i >= matrix.getNumLabels() || j >= matrix.getNumLabels()) continue;
                    double[] dArray = counts[i];
                    int n = j;
                    dArray[n] = dArray[n] + matrix.getCount(i, j);
                }
            }
        }
        return new ConfusionMatrix(counts);
    }

    public static Evaluator evaluator(int numLabels) {
        return new Evaluator(numLabels);
    }

    public static final class Evaluator {
        private final double[][] counts;
        @Nullable
        private ConfusionMatrix cachedResult;

        private Evaluator(int numLabels) {
            this.counts = new double[numLabels][];
            for (int i = 0; i < numLabels; ++i) {
                this.counts[i] = new double[numLabels];
            }
            this.cachedResult = null;
        }

        public synchronized Evaluator add(int labelGold, int labelPredicted, double count) {
            this.cachedResult = null;
            double[] dArray = this.counts[labelGold];
            int n = labelPredicted;
            dArray[n] = dArray[n] + count;
            return this;
        }

        public synchronized Evaluator add(ConfusionMatrix matrix) {
            this.cachedResult = null;
            int numLabels = Math.min(this.counts.length, matrix.getNumLabels());
            for (int i = 0; i < numLabels; ++i) {
                for (int j = 0; j < numLabels; ++j) {
                    double[] dArray = this.counts[i];
                    int n = j;
                    dArray[n] = dArray[n] + matrix.getCount(i, j);
                }
            }
            return this;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public synchronized Evaluator add(Evaluator evaluator) {
            this.cachedResult = null;
            int numLabels = Math.min(this.counts.length, evaluator.counts.length);
            Evaluator evaluator2 = evaluator;
            synchronized (evaluator2) {
                for (int i = 0; i < numLabels; ++i) {
                    for (int j = 0; j < numLabels; ++j) {
                        double[] dArray = this.counts[i];
                        int n = j;
                        dArray[n] = dArray[n] + evaluator.counts[i][j];
                    }
                }
            }
            return this;
        }

        public synchronized ConfusionMatrix getResult() {
            if (this.cachedResult == null) {
                this.cachedResult = new ConfusionMatrix(this.counts);
            }
            return this.cachedResult;
        }
    }
}

