/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.examples.ml;

import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.OptionBuilder;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.commons.cli.PosixParser;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.Classifier;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.OneVsRest;
import org.apache.spark.ml.classification.OneVsRestModel;
import org.apache.spark.ml.util.MetadataUtils;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.StructField;

public class JavaOneVsRestExample {
    public static void main(String[] args) {
        RDD test;
        RDD train;
        Params params = JavaOneVsRestExample.parse(args);
        SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample");
        JavaSparkContext jsc = new JavaSparkContext(conf);
        SQLContext jsql = new SQLContext(jsc);
        LogisticRegression classifier = new LogisticRegression().setMaxIter(params.maxIter.intValue()).setTol(params.tol).setFitIntercept(params.fitIntercept);
        if (params.regParam != null) {
            classifier.setRegParam(params.regParam.doubleValue());
        }
        if (params.elasticNetParam != null) {
            classifier.setElasticNetParam(params.elasticNetParam.doubleValue());
        }
        OneVsRest ovr = new OneVsRest().setClassifier((Classifier)classifier);
        String input = params.input;
        RDD inputData = MLUtils.loadLibSVMFile((SparkContext)jsc.sc(), (String)input);
        String testInput = params.testInput;
        if (testInput != null) {
            train = inputData;
            int numFeatures = ((LabeledPoint)inputData.first()).features().size();
            test = MLUtils.loadLibSVMFile((SparkContext)jsc.sc(), (String)testInput, (int)numFeatures);
        } else {
            double f = params.fracTest;
            RDD[] tmp = inputData.randomSplit(new double[]{1.0 - f, f}, 12345L);
            train = tmp[0];
            test = tmp[1];
        }
        DataFrame trainingDataFrame = jsql.createDataFrame(train, LabeledPoint.class);
        OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache());
        DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class);
        DataFrame predictions = ovrModel.transform(testDataFrame.cache()).select("prediction", new String[]{"label"});
        MulticlassMetrics metrics = new MulticlassMetrics(predictions);
        StructField predictionColSchema = predictions.schema().apply("prediction");
        Integer numClasses = (Integer)MetadataUtils.getNumClasses((StructField)predictionColSchema).get();
        StringBuilder results = new StringBuilder();
        results.append("label\tfpr\n");
        for (int label = 0; label < numClasses; ++label) {
            results.append(label);
            results.append("\t");
            results.append(metrics.falsePositiveRate((double)label));
            results.append("\n");
        }
        Matrix confusionMatrix = metrics.confusionMatrix();
        System.out.println("Confusion Matrix");
        System.out.println(confusionMatrix);
        System.out.println();
        System.out.println(results);
        jsc.stop();
    }

    private static Params parse(String[] args) {
        Options options = JavaOneVsRestExample.generateCommandlineOptions();
        PosixParser parser = new PosixParser();
        Params params = new Params();
        try {
            String value;
            CommandLine cmd = parser.parse(options, args);
            if (cmd.hasOption("input")) {
                params.input = cmd.getOptionValue("input");
            }
            if (cmd.hasOption("maxIter")) {
                value = cmd.getOptionValue("maxIter");
                params.maxIter = Integer.parseInt(value);
            }
            if (cmd.hasOption("tol")) {
                value = cmd.getOptionValue("tol");
                params.tol = Double.parseDouble(value);
            }
            if (cmd.hasOption("fitIntercept")) {
                value = cmd.getOptionValue("fitIntercept");
                params.fitIntercept = Boolean.parseBoolean(value);
            }
            if (cmd.hasOption("regParam")) {
                value = cmd.getOptionValue("regParam");
                params.regParam = Double.parseDouble(value);
            }
            if (cmd.hasOption("elasticNetParam")) {
                value = cmd.getOptionValue("elasticNetParam");
                params.elasticNetParam = Double.parseDouble(value);
            }
            if (cmd.hasOption("testInput")) {
                params.testInput = value = cmd.getOptionValue("testInput");
            }
            if (cmd.hasOption("fracTest")) {
                value = cmd.getOptionValue("fracTest");
                params.fracTest = Double.parseDouble(value);
            }
        }
        catch (ParseException e) {
            JavaOneVsRestExample.printHelpAndQuit(options);
        }
        return params;
    }

    private static Options generateCommandlineOptions() {
        OptionBuilder.withArgName((String)"input");
        OptionBuilder.hasArg();
        OptionBuilder.isRequired();
        OptionBuilder.withDescription((String)"input path to labeled examples. This path must be specified");
        Option input = OptionBuilder.create((String)"input");
        OptionBuilder.withArgName((String)"testInput");
        OptionBuilder.hasArg();
        OptionBuilder.withDescription((String)"input path to test examples");
        Option testInput = OptionBuilder.create((String)"testInput");
        OptionBuilder.withArgName((String)"testInput");
        OptionBuilder.hasArg();
        OptionBuilder.withDescription((String)"fraction of data to hold out for testing. If given option testInput, this option is ignored. default: 0.2");
        Option fracTest = OptionBuilder.create((String)"fracTest");
        OptionBuilder.withArgName((String)"maxIter");
        OptionBuilder.hasArg();
        OptionBuilder.withDescription((String)"maximum number of iterations for Logistic Regression. default:100");
        Option maxIter = OptionBuilder.create((String)"maxIter");
        OptionBuilder.withArgName((String)"tol");
        OptionBuilder.hasArg();
        OptionBuilder.withDescription((String)"the convergence tolerance of iterations for Logistic Regression. default: 1E-6");
        Option tol = OptionBuilder.create((String)"tol");
        OptionBuilder.withArgName((String)"fitIntercept");
        OptionBuilder.hasArg();
        OptionBuilder.withDescription((String)"fit intercept for logistic regression. default true");
        Option fitIntercept = OptionBuilder.create((String)"fitIntercept");
        OptionBuilder.withArgName((String)"regParam");
        OptionBuilder.hasArg();
        OptionBuilder.withDescription((String)"the regularization parameter for Logistic Regression.");
        Option regParam = OptionBuilder.create((String)"regParam");
        OptionBuilder.withArgName((String)"elasticNetParam");
        OptionBuilder.hasArg();
        OptionBuilder.withDescription((String)"the ElasticNet mixing parameter for Logistic Regression.");
        Option elasticNetParam = OptionBuilder.create((String)"elasticNetParam");
        Options options = new Options().addOption(input).addOption(testInput).addOption(fracTest).addOption(maxIter).addOption(tol).addOption(fitIntercept).addOption(regParam).addOption(elasticNetParam);
        return options;
    }

    private static void printHelpAndQuit(Options options) {
        HelpFormatter formatter = new HelpFormatter();
        formatter.printHelp("JavaOneVsRestExample", options);
        System.exit(-1);
    }

    private static class Params {
        String input;
        String testInput = null;
        Integer maxIter = 100;
        double tol = 1.0E-6;
        boolean fitIntercept = true;
        Double regParam = null;
        Double elasticNetParam = null;
        double fracTest = 0.2;

        private Params() {
        }
    }
}

