package weka.classifiers.meta;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.RandomizableSingleClassifierEnhancer;
import weka.classifiers.evaluation.EvaluationUtils;
import weka.classifiers.evaluation.ThresholdCurve;
import weka.classifiers.functions.Logistic;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.core.Capabilities;
import weka.core.Drawable;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;

/* loaded from: input_file:WEB-INF/lib/weka-stable-3.6.10.jar:weka/classifiers/meta/ThresholdSelector.class */
public class ThresholdSelector extends RandomizableSingleClassifierEnhancer implements OptionHandler, Drawable {
    static final long serialVersionUID = -1795038053239867444L;
    public static final int RANGE_NONE = 0;
    public static final int RANGE_BOUNDS = 1;
    public static final int EVAL_TRAINING_SET = 2;
    public static final int EVAL_TUNED_SPLIT = 1;
    public static final int EVAL_CROSS_VALIDATION = 0;
    public static final int OPTIMIZE_0 = 0;
    public static final int OPTIMIZE_1 = 1;
    public static final int OPTIMIZE_LFREQ = 2;
    public static final int OPTIMIZE_MFREQ = 3;
    public static final int OPTIMIZE_POS_NAME = 4;
    public static final int FMEASURE = 1;
    public static final int ACCURACY = 2;
    public static final int TRUE_POS = 3;
    public static final int TRUE_NEG = 4;
    public static final int TP_RATE = 5;
    public static final int PRECISION = 6;
    public static final int RECALL = 7;
    protected double m_HighThreshold = 1.0d;
    protected double m_LowThreshold = KStarConstants.FLOOR;
    protected double m_BestThreshold = -1.7976931348623157E308d;
    protected double m_BestValue = -1.7976931348623157E308d;
    protected int m_NumXValFolds = 3;
    protected int m_DesignatedClass = 0;
    protected int m_ClassMode = 4;
    protected int m_EvalMode = 1;
    protected int m_RangeMode = 0;
    int m_nMeasure = 1;
    protected boolean m_manualThreshold = false;
    protected double m_manualThresholdValue = -1.0d;
    protected static final double MIN_VALUE = 0.05d;
    public static final Tag[] TAGS_RANGE = {new Tag(0, "No range correction"), new Tag(1, "Correct based on min/max observed")};
    public static final Tag[] TAGS_EVAL = {new Tag(2, "Entire training set"), new Tag(1, "Single tuned fold"), new Tag(0, "N-Fold cross validation")};
    public static final Tag[] TAGS_OPTIMIZE = {new Tag(0, "First class value"), new Tag(1, "Second class value"), new Tag(2, "Least frequent class value"), new Tag(3, "Most frequent class value"), new Tag(4, "Class value named: \"yes\", \"pos(itive)\",\"1\"")};
    public static final Tag[] TAGS_MEASURE = {new Tag(1, "FMEASURE"), new Tag(2, "ACCURACY"), new Tag(3, "TRUE_POS"), new Tag(4, "TRUE_NEG"), new Tag(5, "TP_RATE"), new Tag(6, "PRECISION"), new Tag(7, "RECALL")};

    public ThresholdSelector() {
        this.m_Classifier = new Logistic();
    }

    @Override // weka.classifiers.SingleClassifierEnhancer
    protected String defaultClassifierString() {
        return "weka.classifiers.functions.Logistic";
    }

    protected FastVector getPredictions(Instances instances, int i, int i2) throws Exception {
        EvaluationUtils evaluationUtils = new EvaluationUtils();
        evaluationUtils.setSeed(this.m_Seed);
        switch (i) {
            case 0:
                return evaluationUtils.getCVPredictions(this.m_Classifier, instances, i2);
            case 1:
                Instances instances2 = null;
                Instances instances3 = null;
                Instances instances4 = new Instances(instances);
                Random random = new Random(this.m_Seed);
                instances4.randomize(random);
                instances4.stratify(i2);
                for (int i3 = 0; i3 < i2; i3++) {
                    instances2 = instances4.trainCV(i2, i3, random);
                    instances3 = instances4.testCV(i2, i3);
                    if (checkForInstance(instances2) && checkForInstance(instances3)) {
                        return evaluationUtils.getTrainTestPredictions(this.m_Classifier, instances2, instances3);
                    }
                }
                return evaluationUtils.getTrainTestPredictions(this.m_Classifier, instances2, instances3);
            case 2:
                return evaluationUtils.getTrainTestPredictions(this.m_Classifier, instances, instances);
            default:
                throw new RuntimeException("Unrecognized evaluation mode");
        }
    }

    public String measureTipText() {
        return "Sets the measure for determining the threshold.";
    }

    public void setMeasure(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_MEASURE) {
            this.m_nMeasure = selectedTag.getSelectedTag().getID();
        }
    }

    public SelectedTag getMeasure() {
        return new SelectedTag(this.m_nMeasure, TAGS_MEASURE);
    }

    protected void findThreshold(FastVector fastVector) {
        Instances curve = new ThresholdCurve().getCurve(fastVector, this.m_DesignatedClass);
        double d = 1.0d;
        double d2 = 0.0d;
        if (curve.numInstances() > 0) {
            Instance instance = curve.instance(0);
            double d3 = 0.0d;
            int i = 0;
            int i2 = 0;
            switch (this.m_nMeasure) {
                case 1:
                    i = curve.attribute(ThresholdCurve.FMEASURE_NAME).index();
                    d3 = instance.value(i);
                    break;
                case 2:
                    i = curve.attribute(ThresholdCurve.TRUE_POS_NAME).index();
                    i2 = curve.attribute(ThresholdCurve.TRUE_NEG_NAME).index();
                    d3 = instance.value(i) + instance.value(i2);
                    break;
                case 3:
                    i = curve.attribute(ThresholdCurve.TRUE_POS_NAME).index();
                    d3 = instance.value(i);
                    break;
                case 4:
                    i = curve.attribute(ThresholdCurve.TRUE_NEG_NAME).index();
                    d3 = instance.value(i);
                    break;
                case 5:
                    i = curve.attribute(ThresholdCurve.TP_RATE_NAME).index();
                    d3 = instance.value(i);
                    break;
                case 6:
                    i = curve.attribute(ThresholdCurve.PRECISION_NAME).index();
                    d3 = instance.value(i);
                    break;
                case 7:
                    i = curve.attribute(ThresholdCurve.RECALL_NAME).index();
                    d3 = instance.value(i);
                    break;
            }
            int index = curve.attribute("Threshold").index();
            for (int i3 = 1; i3 < curve.numInstances(); i3++) {
                Instance instance2 = curve.instance(i3);
                double value = this.m_nMeasure == 2 ? instance2.value(i) + instance2.value(i2) : instance2.value(i);
                if (value > d3) {
                    instance = instance2;
                    d3 = value;
                }
                if (this.m_RangeMode == 1) {
                    double value2 = instance2.value(index);
                    if (value2 < d) {
                        d = value2;
                    }
                    if (value2 > d2) {
                        d2 = value2;
                    }
                }
            }
            if (d3 > 0.05d) {
                this.m_BestThreshold = instance.value(index);
                this.m_BestValue = d3;
            }
            if (this.m_RangeMode == 1) {
                this.m_LowThreshold = d;
                this.m_HighThreshold = d2;
            }
        }
    }

    @Override // weka.classifiers.RandomizableSingleClassifierEnhancer, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.Classifier, weka.core.OptionHandler
    public Enumeration listOptions() {
        Vector vector = new Vector(5);
        vector.addElement(new Option("\tThe class for which threshold is determined. Valid values are:\n\t1, 2 (for first and second classes, respectively), 3 (for whichever\n\tclass is least frequent), and 4 (for whichever class value is most\n\tfrequent), and 5 (for the first class named any of \"yes\",\"pos(itive)\"\n\t\"1\", or method 3 if no matches). (default 5).", "C", 1, "-C <integer>"));
        vector.addElement(new Option("\tNumber of folds used for cross validation. If just a\n\thold-out set is used, this determines the size of the hold-out set\n\t(default 3).", "X", 1, "-X <number of folds>"));
        vector.addElement(new Option("\tSets whether confidence range correction is applied. This\n\tcan be used to ensure the confidences range from 0 to 1.\n\tUse 0 for no range correction, 1 for correction based on\n\tthe min/max values seen during threshold selection\n\t(default 0).", "R", 1, "-R <integer>"));
        vector.addElement(new Option("\tSets the evaluation mode. Use 0 for\n\tevaluation using cross-validation,\n\t1 for evaluation using hold-out set,\n\tand 2 for evaluation on the\n\ttraining data (default 1).", "E", 1, "-E <integer>"));
        vector.addElement(new Option("\tMeasure used for evaluation (default is FMEASURE).\n", "M", 1, "-M [FMEASURE|ACCURACY|TRUE_POS|TRUE_NEG|TP_RATE|PRECISION|RECALL]"));
        vector.addElement(new Option("\tSet a manual threshold to use. This option overrides\n\tautomatic selection and options pertaining to\n\tautomatic selection will be ignored.\n\t(default -1, i.e. do not use a manual threshold).", "manual", 1, "-manual <real>"));
        Enumeration listOptions = super.listOptions();
        while (listOptions.hasMoreElements()) {
            vector.addElement(listOptions.nextElement());
        }
        return vector.elements();
    }

    @Override // weka.classifiers.RandomizableSingleClassifierEnhancer, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.Classifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption("manual", strArr);
        if (option.length() > 0) {
            double parseDouble = Double.parseDouble(option);
            if (parseDouble >= KStarConstants.FLOOR) {
                setManualThresholdValue(parseDouble);
            }
        }
        String option2 = Utils.getOption('C', strArr);
        if (option2.length() != 0) {
            setDesignatedClass(new SelectedTag(Integer.parseInt(option2) - 1, TAGS_OPTIMIZE));
        } else {
            setDesignatedClass(new SelectedTag(4, TAGS_OPTIMIZE));
        }
        String option3 = Utils.getOption('E', strArr);
        if (option3.length() != 0) {
            setEvaluationMode(new SelectedTag(Integer.parseInt(option3), TAGS_EVAL));
        } else {
            setEvaluationMode(new SelectedTag(1, TAGS_EVAL));
        }
        String option4 = Utils.getOption('R', strArr);
        if (option4.length() != 0) {
            setRangeCorrection(new SelectedTag(Integer.parseInt(option4), TAGS_RANGE));
        } else {
            setRangeCorrection(new SelectedTag(0, TAGS_RANGE));
        }
        String option5 = Utils.getOption('M', strArr);
        if (option5.length() != 0) {
            setMeasure(new SelectedTag(option5, TAGS_MEASURE));
        } else {
            setMeasure(new SelectedTag(1, TAGS_MEASURE));
        }
        String option6 = Utils.getOption('X', strArr);
        if (option6.length() != 0) {
            setNumXValFolds(Integer.parseInt(option6));
        } else {
            setNumXValFolds(3);
        }
        super.setOptions(strArr);
    }

    @Override // weka.classifiers.RandomizableSingleClassifierEnhancer, weka.classifiers.SingleClassifierEnhancer, weka.classifiers.Classifier, weka.core.OptionHandler
    public String[] getOptions() {
        String[] options = super.getOptions();
        String[] strArr = new String[options.length + 12];
        int i = 0;
        if (this.m_manualThreshold) {
            int i2 = 0 + 1;
            strArr[0] = "-manual";
            i = i2 + 1;
            strArr[i2] = "" + getManualThresholdValue();
        }
        int i3 = i;
        int i4 = i + 1;
        strArr[i3] = "-C";
        int i5 = i4 + 1;
        strArr[i4] = "" + (this.m_ClassMode + 1);
        int i6 = i5 + 1;
        strArr[i5] = "-X";
        int i7 = i6 + 1;
        strArr[i6] = "" + getNumXValFolds();
        int i8 = i7 + 1;
        strArr[i7] = "-E";
        int i9 = i8 + 1;
        strArr[i8] = "" + this.m_EvalMode;
        int i10 = i9 + 1;
        strArr[i9] = "-R";
        int i11 = i10 + 1;
        strArr[i10] = "" + this.m_RangeMode;
        int i12 = i11 + 1;
        strArr[i11] = "-M";
        int i13 = i12 + 1;
        strArr[i12] = "" + getMeasure().getSelectedTag().getReadable();
        System.arraycopy(options, 0, strArr, i13, options.length);
        int length = i13 + options.length;
        while (length < strArr.length) {
            int i14 = length;
            length++;
            strArr[i14] = "";
        }
        return strArr;
    }

    @Override // weka.classifiers.SingleClassifierEnhancer, weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAllClasses();
        capabilities.disableAllClassDependencies();
        capabilities.enable(Capabilities.Capability.BINARY_CLASS);
        return capabilities;
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Code restructure failed: missing block: B:43:0x00eb, code lost:
    
        if (r10 != false) goto L41;
     */
    /* JADX WARN: Failed to find 'out' block for switch in B:10:0x0067. Please report as an issue. */
    /* JADX WARN: Removed duplicated region for block: B:14:0x0138  */
    /* JADX WARN: Removed duplicated region for block: B:16:0x0141  */
    @Override // weka.classifiers.Classifier
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public void buildClassifier(weka.core.Instances r7) throws java.lang.Exception {
        /*
            Method dump skipped, instructions count: 404
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: weka.classifiers.meta.ThresholdSelector.buildClassifier(weka.core.Instances):void");
    }

    private boolean checkForInstance(Instances instances) throws Exception {
        for (int i = 0; i < instances.numInstances(); i++) {
            if (((int) instances.instance(i).classValue()) == this.m_DesignatedClass) {
                return true;
            }
        }
        return false;
    }

    @Override // weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] distributionForInstance = this.m_Classifier.distributionForInstance(instance);
        double d = distributionForInstance[this.m_DesignatedClass];
        double d2 = d > this.m_BestThreshold ? 0.5d + ((d - this.m_BestThreshold) / ((this.m_HighThreshold - this.m_BestThreshold) * 2.0d)) : (d - this.m_LowThreshold) / ((this.m_BestThreshold - this.m_LowThreshold) * 2.0d);
        if (d2 < KStarConstants.FLOOR) {
            d2 = 0.0d;
        } else if (d2 > 1.0d) {
            d2 = 1.0d;
        }
        distributionForInstance[this.m_DesignatedClass] = d2;
        if (distributionForInstance.length == 2) {
            distributionForInstance[(this.m_DesignatedClass + 1) % 2] = 1.0d - d2;
        }
        return distributionForInstance;
    }

    public String globalInfo() {
        return "A metaclassifier that selecting a mid-point threshold on the probability output by a Classifier. The midpoint threshold is set so that a given performance measure is optimized. Currently this is the F-measure. Performance is measured either on the training data, a hold-out set or using cross-validation. In addition, the probabilities returned by the base learner can have their range expanded so that the output probabilities will reside between 0 and 1 (this is useful if the scheme normally produces probabilities in a very narrow range).";
    }

    public String designatedClassTipText() {
        return "Sets the class value for which the optimization is performed. The options are: pick the first class value; pick the second class value; pick whichever class is least frequent; pick whichever class value is most frequent; pick the first class named any of \"yes\",\"pos(itive)\", \"1\", or the least frequent if no matches).";
    }

    public SelectedTag getDesignatedClass() {
        return new SelectedTag(this.m_ClassMode, TAGS_OPTIMIZE);
    }

    public void setDesignatedClass(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_OPTIMIZE) {
            this.m_ClassMode = selectedTag.getSelectedTag().getID();
        }
    }

    public String evaluationModeTipText() {
        return "Sets the method used to determine the threshold/performance curve. The options are: perform optimization based on the entire training set (may result in overfitting); perform an n-fold cross-validation (may be time consuming); perform one fold of an n-fold cross-validation (faster but likely less accurate).";
    }

    public void setEvaluationMode(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_EVAL) {
            this.m_EvalMode = selectedTag.getSelectedTag().getID();
        }
    }

    public SelectedTag getEvaluationMode() {
        return new SelectedTag(this.m_EvalMode, TAGS_EVAL);
    }

    public String rangeCorrectionTipText() {
        return "Sets the type of prediction range correction performed. The options are: do not do any range correction; expand predicted probabilities so that the minimum probability observed during the optimization maps to 0, and the maximum maps to 1 (values outside this range are clipped to 0 and 1).";
    }

    public void setRangeCorrection(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_RANGE) {
            this.m_RangeMode = selectedTag.getSelectedTag().getID();
        }
    }

    public SelectedTag getRangeCorrection() {
        return new SelectedTag(this.m_RangeMode, TAGS_RANGE);
    }

    public String numXValFoldsTipText() {
        return "Sets the number of folds used during full cross-validation and tuned fold evaluation. This number will be automatically reduced if there are insufficient positive examples.";
    }

    public int getNumXValFolds() {
        return this.m_NumXValFolds;
    }

    public void setNumXValFolds(int i) {
        if (i < 2) {
            throw new IllegalArgumentException("Number of folds must be greater than 1");
        }
        this.m_NumXValFolds = i;
    }

    @Override // weka.core.Drawable
    public int graphType() {
        if (this.m_Classifier instanceof Drawable) {
            return ((Drawable) this.m_Classifier).graphType();
        }
        return 0;
    }

    @Override // weka.core.Drawable
    public String graph() throws Exception {
        if (this.m_Classifier instanceof Drawable) {
            return ((Drawable) this.m_Classifier).graph();
        }
        throw new Exception("Classifier: " + getClassifierSpec() + " cannot be graphed");
    }

    public String manualThresholdValueTipText() {
        return "Sets a manual threshold value to use. If this is set (non-negative value between 0 and 1), then all options pertaining to automatic threshold selection are ignored. ";
    }

    public void setManualThresholdValue(double d) throws Exception {
        this.m_manualThresholdValue = d;
        if (d >= KStarConstants.FLOOR && d <= 1.0d) {
            this.m_manualThreshold = true;
            return;
        }
        this.m_manualThreshold = false;
        if (d >= KStarConstants.FLOOR) {
            throw new IllegalArgumentException("Threshold must be in the range 0..1.");
        }
    }

    public double getManualThresholdValue() {
        return this.m_manualThresholdValue;
    }

    public String toString() {
        String str;
        String str2;
        if (this.m_BestValue == -1.7976931348623157E308d) {
            return "ThresholdSelector: No model built yet.";
        }
        String str3 = ("Threshold Selector.\nClassifier: " + this.m_Classifier.getClass().getName() + "\n") + "Index of designated class: " + this.m_DesignatedClass + "\n";
        if (this.m_manualThreshold) {
            str2 = str3 + "User supplied threshold: " + this.m_BestThreshold + "\n";
        } else {
            String str4 = str3 + "Evaluation mode: ";
            switch (this.m_EvalMode) {
                case 0:
                    str = str4 + this.m_NumXValFolds + "-fold cross-validation";
                    break;
                case 1:
                    str = str4 + "tuning on 1/" + this.m_NumXValFolds + " of the data";
                    break;
                case 2:
                default:
                    str = str4 + "tuning on the training data";
                    break;
            }
            String str5 = ((str + "\n") + "Threshold: " + this.m_BestThreshold + "\n") + "Best value: " + this.m_BestValue + "\n";
            if (this.m_RangeMode == 1) {
                str5 = str5 + "Expanding range [" + this.m_LowThreshold + "," + this.m_HighThreshold + "] to [0, 1]\n";
            }
            str2 = str5 + "Measure: " + getMeasure().getSelectedTag().getReadable() + "\n";
        }
        return str2 + this.m_Classifier.toString();
    }

    @Override // weka.classifiers.Classifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 1.43 $");
    }

    public static void main(String[] strArr) {
        runClassifier(new ThresholdSelector(), strArr);
    }
}
