package org.openml.webapplication.evaluate;

import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
import org.json.JSONArray;
import org.openml.apiconnector.algorithms.Conversion;
import org.openml.apiconnector.algorithms.TaskInformation;
import org.openml.apiconnector.models.MetricScore;
import org.openml.apiconnector.settings.Constants;
import org.openml.apiconnector.xml.DataSetDescription;
import org.openml.apiconnector.xml.EstimationProcedure;
import org.openml.apiconnector.xml.EstimationProcedureType;
import org.openml.apiconnector.xml.EvaluationScore;
import org.openml.apiconnector.xml.Task;
import org.openml.webapplication.algorithm.InstancesHelper;
import org.openml.webapplication.io.Output;
import org.openml.webapplication.predictionCounter.FoldsPredictionCounter;
import org.openml.webapplication.predictionCounter.PredictionCounter;
import org.openml.weka.io.OpenmlWekaConnector;
import weka.classifiers.Evaluation;
import weka.core.Instance;
import weka.core.Instances;

/* loaded from: input_file:org/openml/webapplication/evaluate/EvaluateBatchPredictions.class */
public class EvaluateBatchPredictions implements PredictionEvaluator {
    private final int nrOfClasses;
    private final int ATT_PREDICTION_ROWID;
    private final int ATT_PREDICTION_FOLD;
    private final int ATT_PREDICTION_REPEAT;
    private final int ATT_PREDICTION_PREDICTION;
    private final int ATT_PREDICTION_SAMPLE;
    private final int[] ATT_PREDICTION_CONFIDENCE;
    private final Instances dataset;
    private final Instances splits;
    private final Instances predictions;
    private final PredictionCounter predictionCounter;
    private final String[] classes;
    private final EstimationProcedure estimationProcedure;
    private final TaskType taskType;
    private final JSONArray cost_matrix;
    private final Evaluation[][][][] sampleEvaluation;
    private final boolean bootstrap;
    private EvaluationScore[] evaluationScores;

    public EvaluateBatchPredictions(OpenmlWekaConnector openmlWekaConnector, Task task, TaskType taskType, int i) throws Exception {
        int intValue = TaskInformation.getSourceData(task).getData_set_id().intValue();
        this.estimationProcedure = openmlWekaConnector.estimationProcedureGet(TaskInformation.getEstimationProcedure(task).getId());
        DataSetDescription dataGet = openmlWekaConnector.dataGet(intValue);
        this.taskType = taskType;
        this.dataset = openmlWekaConnector.getDataset(dataGet);
        this.splits = openmlWekaConnector.getSplitsFromTask(task);
        this.predictions = openmlWekaConnector.getArffFromUrl(i);
        Conversion.log("OK", "EvaluateBatchPredictions", "predictions: " + i);
        this.bootstrap = TaskInformation.getEstimationProcedure(task).getType() == EstimationProcedureType.BOOTSTRAPPING;
        String target_feature = TaskInformation.getSourceData(task).getTarget_feature();
        this.cost_matrix = TaskInformation.getCostMatrix(task);
        if (this.dataset.attribute(target_feature) == null) {
            throw new RuntimeException("Class attribute (" + target_feature + ") not found");
        }
        this.dataset.setClass(this.dataset.attribute(target_feature));
        this.predictionCounter = new FoldsPredictionCounter(this.splits);
        this.sampleEvaluation = new Evaluation[this.predictionCounter.getRepeats()][this.predictionCounter.getFolds()][this.predictionCounter.getSamples()][this.bootstrap ? 2 : 1];
        this.ATT_PREDICTION_ROWID = InstancesHelper.getRowIndex("row_id", this.predictions);
        this.ATT_PREDICTION_REPEAT = InstancesHelper.getRowIndex(new String[]{"repeat", "repeat_nr"}, this.predictions);
        this.ATT_PREDICTION_FOLD = InstancesHelper.getRowIndex(new String[]{"fold", "fold_nr"}, this.predictions);
        this.ATT_PREDICTION_PREDICTION = InstancesHelper.getRowIndex(new String[]{"prediction"}, this.predictions);
        if (taskType == TaskType.LEARNINGCURVE) {
            this.ATT_PREDICTION_SAMPLE = InstancesHelper.getRowIndex(new String[]{"sample", "sample_nr"}, this.predictions);
        } else {
            this.ATT_PREDICTION_SAMPLE = -1;
        }
        this.nrOfClasses = this.dataset.classAttribute().numValues();
        this.classes = new String[this.nrOfClasses];
        this.ATT_PREDICTION_CONFIDENCE = new int[this.nrOfClasses];
        for (int i2 = 0; i2 < this.classes.length; i2++) {
            this.classes[i2] = this.dataset.classAttribute().value(i2);
            String str = "confidence." + this.classes[i2];
            if (this.predictions.attribute(str) == null) {
                throw new Exception("Attribute " + str + " not found among predictions. ");
            }
            this.ATT_PREDICTION_CONFIDENCE[i2] = this.predictions.attribute(str).index();
        }
        doEvaluation();
    }

    private void doEvaluation() throws Exception {
        Evaluation[] evaluationArr = new Evaluation[this.bootstrap ? 2 : 1];
        for (int i = 0; i < evaluationArr.length; i++) {
            evaluationArr[i] = new Evaluation(this.dataset);
            if (this.cost_matrix != null) {
                evaluationArr[i] = new Evaluation(this.dataset, InstancesHelper.doubleToCostMatrix(this.cost_matrix));
            } else {
                evaluationArr[i] = new Evaluation(this.dataset);
            }
        }
        for (int i2 = 0; i2 < this.sampleEvaluation.length; i2++) {
            for (int i3 = 0; i3 < this.sampleEvaluation[i2].length; i3++) {
                for (int i4 = 0; i4 < this.sampleEvaluation[i2][i3].length; i4++) {
                    int i5 = 0;
                    while (true) {
                        if (i5 < (this.bootstrap ? 2 : 1)) {
                            if (this.cost_matrix != null) {
                                this.sampleEvaluation[i2][i3][i4][i5] = new Evaluation(this.dataset, InstancesHelper.doubleToCostMatrix(this.cost_matrix));
                            } else {
                                this.sampleEvaluation[i2][i3][i4][i5] = new Evaluation(this.dataset);
                            }
                            i5++;
                        }
                    }
                }
            }
        }
        for (int i6 = 0; i6 < this.predictions.numInstances(); i6++) {
            Instance instance = this.predictions.instance(i6);
            int value = this.ATT_PREDICTION_REPEAT < 0 ? 0 : (int) instance.value(this.ATT_PREDICTION_REPEAT);
            int value2 = this.ATT_PREDICTION_FOLD < 0 ? 0 : (int) instance.value(this.ATT_PREDICTION_FOLD);
            int value3 = this.ATT_PREDICTION_SAMPLE < 0 ? 0 : (int) instance.value(this.ATT_PREDICTION_SAMPLE);
            int value4 = (int) instance.value(this.ATT_PREDICTION_ROWID);
            this.predictionCounter.addPrediction(value, value2, value3, value4);
            if (this.dataset.numInstances() <= value4) {
                throw new RuntimeException("Making a prediction for row_id" + value4 + " (0-based) while dataset has only " + this.dataset.numInstances() + " instances. ");
            }
            boolean z = true;
            if (this.taskType == TaskType.LEARNINGCURVE && value3 != this.predictionCounter.getSamples() - 1) {
                z = false;
            }
            if (this.taskType == TaskType.REGRESSION) {
                if (z) {
                    evaluationArr[0].evaluateModelOnce(instance.value(this.ATT_PREDICTION_PREDICTION), this.dataset.instance(value4));
                }
                this.sampleEvaluation[value][value2][value3][0].evaluateModelOnce(instance.value(this.ATT_PREDICTION_PREDICTION), this.dataset.instance(value4));
            } else {
                double[] predictionToConfidences = InstancesHelper.predictionToConfidences(this.dataset, instance, this.ATT_PREDICTION_CONFIDENCE, this.ATT_PREDICTION_PREDICTION);
                if (z) {
                    evaluationArr[0].evaluateModelOnceAndRecordPrediction(predictionToConfidences, this.dataset.instance(value4));
                }
                this.sampleEvaluation[value][value2][value3][0].evaluateModelOnceAndRecordPrediction(predictionToConfidences, this.dataset.instance(value4));
            }
        }
        if (!this.predictionCounter.check()) {
            throw new RuntimeException("Prediction count does not match: " + this.predictionCounter.getErrorMessage());
        }
        ArrayList arrayList = new ArrayList();
        HashMap hashMap = new HashMap();
        for (int i7 = 0; i7 < this.sampleEvaluation.length; i7++) {
            for (int i8 = 0; i8 < this.sampleEvaluation[i7].length; i8++) {
                for (int i9 = 0; i9 < this.sampleEvaluation[i7][i8].length; i9++) {
                    Map<String, MetricScore> evaluatorToMap = Output.evaluatorToMap(this.sampleEvaluation[i7][i8][i9], this.nrOfClasses, this.taskType, this.bootstrap);
                    for (String str : evaluatorToMap.keySet()) {
                        MetricScore metricScore = evaluatorToMap.get(str);
                        if (metricScore.getScore() != null && !metricScore.getScore().isNaN() && !metricScore.getScore().isInfinite()) {
                            DecimalFormat decimalFormat = Constants.defaultDecimalFormat;
                            Double score = metricScore.getScore() == null ? null : metricScore.getScore();
                            EvaluationScore evaluationScore = this.taskType == TaskType.LEARNINGCURVE ? new EvaluationScore(str, score, metricScore.getArrayAsString(decimalFormat), Integer.valueOf(i7), Integer.valueOf(i8), Integer.valueOf(i9), Integer.valueOf(this.predictionCounter.getShadowTypeSize(i7, i8, i9))) : new EvaluationScore(str, score, metricScore.getArrayAsString(decimalFormat), Integer.valueOf(i7), Integer.valueOf(i8));
                            if (this.estimationProcedure.getType() != EstimationProcedureType.LEAVEONEOUT && this.estimationProcedure.getType() != EstimationProcedureType.TESTONTRAININGDATA) {
                                arrayList.add(evaluationScore);
                            }
                            if (score != null && i9 == this.sampleEvaluation[i7][i8].length - 1) {
                                if (!hashMap.containsKey(str)) {
                                    hashMap.put(str, new ArrayList());
                                }
                                ((List) hashMap.get(str)).add(score);
                            }
                        }
                    }
                }
            }
        }
        StandardDeviation standardDeviation = new StandardDeviation();
        Map<String, MetricScore> evaluatorToMap2 = Output.evaluatorToMap(evaluationArr, this.nrOfClasses, this.taskType, this.bootstrap);
        for (String str2 : evaluatorToMap2.keySet()) {
            MetricScore metricScore2 = evaluatorToMap2.get(str2);
            if (metricScore2.getScore() != null && !metricScore2.getScore().isNaN() && !metricScore2.getScore().isInfinite()) {
                DecimalFormat decimalFormat2 = Constants.defaultDecimalFormat;
                arrayList.add(new EvaluationScore(str2, metricScore2.getScore() == null ? null : metricScore2.getScore(), hashMap.containsKey(str2) ? Double.valueOf(standardDeviation.evaluate(ArrayUtils.toPrimitive((Double[]) ((List) hashMap.get(str2)).toArray(new Double[((List) hashMap.get(str2)).size()])))) : null, metricScore2.getArrayAsString(decimalFormat2)));
            }
        }
        this.evaluationScores = (EvaluationScore[]) arrayList.toArray(new EvaluationScore[arrayList.size()]);
    }

    @Override // org.openml.webapplication.evaluate.PredictionEvaluator
    public EvaluationScore[] getEvaluationScores() {
        return this.evaluationScores;
    }

    @Override // org.openml.webapplication.evaluate.PredictionEvaluator
    public PredictionCounter getPredictionCounter() {
        return this.predictionCounter;
    }
}
