/*
 * Decompiled with CFR 0.152.
 */
package eu.fbk.utils.svm;

import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import eu.fbk.utils.eval.ConfusionMatrix;
import eu.fbk.utils.svm.Vector;
import java.io.IOException;
import java.util.Iterator;
import javax.annotation.Nullable;

public abstract class LabelledVector
extends Vector {
    private static final long serialVersionUID = 2L;
    private final Vector vector;
    private final int label;

    static LabelledVector create(Vector vector, int label, @Nullable float[] probabilities) {
        if (probabilities == null || probabilities.length == 0) {
            return new LabelledVector0(vector, label);
        }
        if (probabilities.length == 2) {
            return new LabelledVector1(vector, label, probabilities[0]);
        }
        return new LabelledVectorN(vector, label, probabilities);
    }

    private LabelledVector(Vector vector, int label) {
        super(vector.getId());
        this.vector = vector;
        this.label = label;
    }

    public final int getLabel() {
        return this.label;
    }

    public final float getProbability(int label) {
        Preconditions.checkArgument(label >= 0);
        return this.doGetProbability(label);
    }

    @Override
    final int doSize() {
        return this.vector.doSize();
    }

    @Override
    final String doGetFeature(int index) {
        return this.vector.doGetFeature(index);
    }

    @Override
    final float doGetValue(int index) {
        return this.vector.doGetValue(index);
    }

    abstract float doGetProbability(int var1);

    @Override
    final LabelledVector doLabel(int label, float ... probabilities) {
        if (this.getLabel() == label) {
            if (probabilities != null && probabilities.length > 0) {
                boolean matchProbabilities = true;
                for (int i = 0; i < probabilities.length; ++i) {
                    if (this.getProbability(i) == probabilities[i]) continue;
                    matchProbabilities = false;
                    break;
                }
                if (matchProbabilities) {
                    return this;
                }
            } else if (this.getProbability(label) == 1.0f) {
                return this;
            }
        }
        return super.doLabel(label, probabilities);
    }

    @Override
    final Vector doUnlabel() {
        return this.vector.doUnlabel();
    }

    public static ConfusionMatrix evaluate(Iterable<LabelledVector> goldVectors, Iterable<LabelledVector> predictedVectors, int numLabels) {
        int predictedSize;
        int goldSize = Iterables.size(goldVectors);
        if (goldSize != (predictedSize = Iterables.size(predictedVectors))) {
            throw new IllegalArgumentException("Number of gold vectors (" + goldSize + ") different from number of predicted vectors (" + predictedSize + ")");
        }
        double[][] matrix = new double[numLabels][];
        for (int i = 0; i < numLabels; ++i) {
            matrix[i] = new double[numLabels];
        }
        Iterator<LabelledVector> goldIterator = goldVectors.iterator();
        Iterator<LabelledVector> predictedIterator = predictedVectors.iterator();
        while (goldIterator.hasNext()) {
            LabelledVector goldVector = goldIterator.next();
            LabelledVector predictedVector = predictedIterator.next();
            double[] dArray = matrix[goldVector.getLabel()];
            int n = predictedVector.getLabel();
            dArray[n] = dArray[n] + 1.0;
        }
        return new ConfusionMatrix(matrix);
    }

    private static final class LabelledVectorN
    extends LabelledVector {
        private static final long serialVersionUID = 1L;
        private final float[] probabilities;

        private LabelledVectorN(Vector vector, int label, float[] probabilities) {
            super(vector, label);
            this.probabilities = probabilities;
        }

        @Override
        float doGetProbability(int label) {
            Preconditions.checkArgument(label >= 0);
            return label < this.probabilities.length ? this.probabilities[label] : 0.0f;
        }

        @Override
        void doToString(Appendable out) throws IOException {
            out.append(Integer.toString(this.getLabel()));
            out.append(" (");
            for (int i = 0; i < this.probabilities.length; ++i) {
                out.append(i == 0 ? "" : " ").append(Integer.toString(i)).append(':').append(Float.toString(this.probabilities[i]));
            }
            out.append(") ");
            super.doToString(out);
        }
    }

    private static final class LabelledVector1
    extends LabelledVector {
        private static final long serialVersionUID = 1L;
        private final float probability0;

        private LabelledVector1(Vector vector, int label, float probability0) {
            super(vector, label);
            this.probability0 = probability0;
        }

        @Override
        float doGetProbability(int label) {
            return label == 0 ? this.probability0 : (label == 1 ? 1.0f - this.probability0 : 0.0f);
        }

        @Override
        void doToString(Appendable out) throws IOException {
            out.append(Integer.toString(this.getLabel()));
            out.append(" (0:").append(Float.toString(this.probability0)).append(" 1:").append(Float.toString(1.0f - this.probability0)).append(") ");
            super.doToString(out);
        }
    }

    private static final class LabelledVector0
    extends LabelledVector {
        private static final long serialVersionUID = 1L;

        private LabelledVector0(Vector vector, int label) {
            super(vector, label);
        }

        @Override
        float doGetProbability(int label) {
            Preconditions.checkArgument(label >= 0);
            return label == this.getLabel() ? 1.0f : 0.0f;
        }

        @Override
        void doToString(Appendable out) throws IOException {
            out.append(Integer.toString(this.getLabel())).append(' ');
            super.doToString(out);
        }
    }
}

