package com.whylogs.core.metrics;

import com.shaded.whylabs.com.google.common.base.Preconditions;
import com.shaded.whylabs.com.google.common.collect.Lists;
import com.shaded.whylabs.com.google.common.collect.Sets;
import com.whylogs.core.message.ScoreMatrixMessage;
import com.whylogs.core.statistics.NumberTracker;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import javax.annotation.Nullable;
import lombok.NonNull;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/whylogs/core/metrics/ClassificationMetrics.class */
public class ClassificationMetrics {
    private static final Logger log = LoggerFactory.getLogger(ClassificationMetrics.class);
    private List<String> labels;
    private final String predictionField;
    private final String targetField;
    private final String scoreField;
    private NumberTracker[][] values;

    public ClassificationMetrics(String str, String str2, String str3) {
        this(Lists.newArrayList(), str, str2, str3, newMatrix(0));
    }

    public List<String> getLabels() {
        return Collections.unmodifiableList(this.labels);
    }

    public long[][] getConfusionMatrix() {
        int size = this.labels.size();
        long[][] jArr = new long[size][size];
        for (int i = 0; i < size; i++) {
            for (int i2 = 0; i2 < size; i2++) {
                jArr[i][i2] = this.values[i][i2].getDoubles().getCount();
            }
        }
        return jArr;
    }

    private static NumberTracker[][] newMatrix(int i) {
        NumberTracker[][] numberTrackerArr = new NumberTracker[i][i];
        if (i == 0) {
            return numberTrackerArr;
        }
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < i; i3++) {
                numberTrackerArr[i2][i3] = new NumberTracker();
            }
        }
        return numberTrackerArr;
    }

    public void track(Map<String, ?> map) {
        Preconditions.checkState(this.predictionField != null);
        Preconditions.checkState(this.targetField != null);
        Object obj = map.get(this.predictionField);
        Object obj2 = map.get(this.targetField);
        Object obj3 = map.get(this.scoreField);
        double d = 0.0d;
        if (obj3 instanceof Number) {
            d = ((Number) obj3).doubleValue();
        } else if (obj3 != null) {
            try {
                d = Double.parseDouble(obj3.toString());
            } catch (NumberFormatException e) {
                log.warn("Failed to parse score: {}", obj3, e);
            }
        }
        update(obj, obj2, d);
    }

    public <T> void update(T t, T t2, double d) {
        String textValue = textValue(t);
        String textValue2 = textValue(t2);
        int indexOf = this.labels.indexOf(textValue);
        int indexOf2 = this.labels.indexOf(textValue2);
        if (indexOf >= 0 && indexOf2 >= 0) {
            this.values[indexOf][indexOf2].track(Double.valueOf(d));
            return;
        }
        HashSet newHashSet = Sets.newHashSet(this.labels);
        if (indexOf < 0) {
            newHashSet.add(textValue);
        }
        if (indexOf2 < 0) {
            newHashSet.add(textValue2);
        }
        ArrayList newArrayList = Lists.newArrayList(newHashSet);
        Collections.sort(newArrayList);
        NumberTracker[][] newMatrix = newMatrix(newHashSet.size());
        addMatrix(this.labels, this.values, newArrayList, newMatrix);
        int indexOf3 = newArrayList.indexOf(textValue);
        newMatrix[indexOf3][newArrayList.indexOf(textValue2)].track(Double.valueOf(d));
        this.labels = newArrayList;
        this.values = newMatrix;
    }

    private static String textValue(Object obj) {
        if (obj == null) {
            return null;
        }
        return obj instanceof Boolean ? ((Boolean) obj).booleanValue() ? "1" : "0" : obj.toString();
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Labels: ");
        this.labels.forEach(str -> {
            sb.append(str);
            sb.append(", ");
        });
        sb.append('\n');
        int size = this.labels.size();
        for (int i = 0; i < size; i++) {
            sb.append('[');
            for (int i2 = 0; i2 < size; i2++) {
                sb.append(this.values[i][i2]);
                if (i2 + 1 < size) {
                    sb.append(", ");
                }
            }
            sb.append("]\n");
        }
        return sb.toString();
    }

    public ClassificationMetrics merge(ClassificationMetrics classificationMetrics) {
        if (classificationMetrics == null) {
            return copy();
        }
        HashSet newHashSet = Sets.newHashSet(this.labels);
        newHashSet.addAll(classificationMetrics.labels);
        ArrayList newArrayList = Lists.newArrayList(newHashSet);
        Collections.sort(newArrayList);
        NumberTracker[][] newMatrix = newMatrix(newArrayList.size());
        addMatrix(this.labels, this.values, newArrayList, newMatrix);
        addMatrix(classificationMetrics.labels, classificationMetrics.values, newArrayList, newMatrix);
        return new ClassificationMetrics(newArrayList, this.targetField, this.predictionField, this.scoreField, newMatrix);
    }

    private void addMatrix(List<String> list, NumberTracker[][] numberTrackerArr, List<String> list2, NumberTracker[][] numberTrackerArr2) {
        for (int i = 0; i < list.size(); i++) {
            int indexOf = list2.indexOf(list.get(i));
            for (int i2 = 0; i2 < list.size(); i2++) {
                int indexOf2 = list2.indexOf(list.get(i2));
                numberTrackerArr2[indexOf][indexOf2] = numberTrackerArr2[indexOf][indexOf2].merge(numberTrackerArr[i][i2]);
            }
        }
    }

    @NonNull
    public ClassificationMetrics copy() {
        int size = this.labels.size();
        NumberTracker[][] newMatrix = newMatrix(size);
        for (int i = 0; i < size; i++) {
            for (int i2 = 0; i2 < size; i2++) {
                newMatrix[i][i2] = newMatrix[i][i2].merge(this.values[i][i2]);
            }
        }
        return new ClassificationMetrics(Lists.newArrayList(this.labels), this.predictionField, this.targetField, this.scoreField, newMatrix);
    }

    @NonNull
    public ScoreMatrixMessage.Builder toProtobuf() {
        ScoreMatrixMessage.Builder newBuilder = ScoreMatrixMessage.newBuilder();
        Stream<R> map = this.labels.stream().map((v0) -> {
            return v0.toString();
        });
        newBuilder.getClass();
        map.forEach(newBuilder::addLabels);
        int size = this.labels.size();
        for (int i = 0; i < size; i++) {
            for (int i2 = 0; i2 < size; i2++) {
                newBuilder.addScores(this.values[i][i2].toProtobuf());
            }
        }
        newBuilder.setPredictionField(this.predictionField);
        newBuilder.setTargetField(this.targetField);
        newBuilder.setScoreField(this.scoreField);
        return newBuilder;
    }

    @Nullable
    public static ClassificationMetrics fromProtobuf(@Nullable ScoreMatrixMessage scoreMatrixMessage) {
        if (scoreMatrixMessage == null || scoreMatrixMessage.getSerializedSize() == 0) {
            return null;
        }
        ArrayList newArrayList = Lists.newArrayList();
        for (int i = 0; i < scoreMatrixMessage.getLabelsCount(); i++) {
            newArrayList.add(scoreMatrixMessage.getLabels(i));
        }
        if (scoreMatrixMessage.getLabelsCount() == 0 && scoreMatrixMessage.getScoresCount() > 0) {
            log.warn("Skipping classification ScoreMatrix: has scores but no labels");
            return null;
        }
        int size = newArrayList.size();
        NumberTracker[][] newMatrix = newMatrix(size);
        for (int i2 = 0; i2 < scoreMatrixMessage.getScoresCount(); i2++) {
            newMatrix[i2 / size][i2 % size] = NumberTracker.fromProtobuf(scoreMatrixMessage.getScores(i2));
        }
        return new ClassificationMetrics(newArrayList, scoreMatrixMessage.getPredictionField(), scoreMatrixMessage.getTargetField(), scoreMatrixMessage.getScoreField(), newMatrix);
    }

    private ClassificationMetrics(List<String> list, String str, String str2, String str3, NumberTracker[][] numberTrackerArr) {
        this.labels = list;
        this.predictionField = str;
        this.targetField = str2;
        this.scoreField = str3;
        this.values = numberTrackerArr;
    }

    public String getPredictionField() {
        return this.predictionField;
    }

    public String getTargetField() {
        return this.targetField;
    }

    public String getScoreField() {
        return this.scoreField;
    }

    public NumberTracker[][] getValues() {
        return this.values;
    }
}
