/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.examples.mllib;

import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.GradientBoostedTrees;
import org.apache.spark.mllib.tree.configuration.BoostingStrategy;
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel;
import org.apache.spark.mllib.util.MLUtils;
import scala.Tuple2;

public final class JavaGradientBoostedTreesRunner {
    private static void usage() {
        System.err.println("Usage: JavaGradientBoostedTreesRunner <libsvm format data file> <Classification/Regression>");
        System.exit(-1);
    }

    public static void main(String[] args) {
        String datapath = "data/mllib/sample_libsvm_data.txt";
        String algo = "Classification";
        if (args.length >= 1) {
            datapath = args[0];
        }
        if (args.length >= 2) {
            algo = args[1];
        }
        if (args.length > 2) {
            JavaGradientBoostedTreesRunner.usage();
        }
        SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner");
        JavaSparkContext sc = new JavaSparkContext(sparkConf);
        JavaRDD data = MLUtils.loadLibSVMFile((SparkContext)sc.sc(), (String)datapath).toJavaRDD().cache();
        BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams((String)algo);
        boostingStrategy.setNumIterations(10);
        boostingStrategy.treeStrategy().setMaxDepth(5);
        if (algo.equals("Classification")) {
            Integer numClasses = data.map((Function)new Function<LabeledPoint, Double>(){

                public Double call(LabeledPoint p) {
                    return p.label();
                }
            }).countByValue().size();
            boostingStrategy.treeStrategy().setNumClasses(numClasses.intValue());
            final GradientBoostedTreesModel model = GradientBoostedTrees.train((JavaRDD)data, (BoostingStrategy)boostingStrategy);
            JavaPairRDD predictionAndLabel = data.mapToPair((PairFunction)new PairFunction<LabeledPoint, Double, Double>(){

                public Tuple2<Double, Double> call(LabeledPoint p) {
                    return new Tuple2((Object)model.predict(p.features()), (Object)p.label());
                }
            });
            Double trainErr = 1.0 * (double)predictionAndLabel.filter((Function)new Function<Tuple2<Double, Double>, Boolean>(){

                public Boolean call(Tuple2<Double, Double> pl) {
                    return !((Double)pl._1()).equals(pl._2());
                }
            }).count() / (double)data.count();
            System.out.println("Training error: " + trainErr);
            System.out.println("Learned classification tree model:\n" + model);
        } else if (algo.equals("Regression")) {
            final GradientBoostedTreesModel model = GradientBoostedTrees.train((JavaRDD)data, (BoostingStrategy)boostingStrategy);
            JavaPairRDD predictionAndLabel = data.mapToPair((PairFunction)new PairFunction<LabeledPoint, Double, Double>(){

                public Tuple2<Double, Double> call(LabeledPoint p) {
                    return new Tuple2((Object)model.predict(p.features()), (Object)p.label());
                }
            });
            Double trainMSE = (Double)predictionAndLabel.map((Function)new Function<Tuple2<Double, Double>, Double>(){

                public Double call(Tuple2<Double, Double> pl) {
                    Double diff = (Double)pl._1() - (Double)pl._2();
                    return diff * diff;
                }
            }).reduce((Function2)new Function2<Double, Double, Double>(){

                public Double call(Double a, Double b) {
                    return a + b;
                }
            }) / (double)data.count();
            System.out.println("Training Mean Squared Error: " + trainMSE);
            System.out.println("Learned regression tree model:\n" + model);
        } else {
            JavaGradientBoostedTreesRunner.usage();
        }
        sc.stop();
    }
}

