/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.examples.ml;

import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.feature.VectorIndexer;
import org.apache.spark.ml.feature.VectorIndexerModel;
import org.apache.spark.ml.regression.RandomForestRegressionModel;
import org.apache.spark.ml.regression.RandomForestRegressor;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;

public class JavaRandomForestRegressorExample {
    public static void main(String[] args) {
        SparkConf conf = new SparkConf().setAppName("JavaRandomForestRegressorExample");
        JavaSparkContext jsc = new JavaSparkContext(conf);
        SQLContext sqlContext = new SQLContext(jsc);
        DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");
        VectorIndexerModel featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(data);
        DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
        DataFrame trainingData = splits[0];
        DataFrame testData = splits[1];
        RandomForestRegressor rf = (RandomForestRegressor)((RandomForestRegressor)new RandomForestRegressor().setLabelCol("label")).setFeaturesCol("indexedFeatures");
        Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{featureIndexer, rf});
        PipelineModel model = pipeline.fit(trainingData);
        DataFrame predictions = model.transform(testData);
        predictions.select("prediction", new String[]{"label", "features"}).show(5);
        RegressionEvaluator evaluator = new RegressionEvaluator().setLabelCol("label").setPredictionCol("prediction").setMetricName("rmse");
        double rmse2 = evaluator.evaluate(predictions);
        System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse2);
        RandomForestRegressionModel rfModel = (RandomForestRegressionModel)model.stages()[1];
        System.out.println("Learned regression forest model:\n" + rfModel.toDebugString());
        jsc.stop();
    }
}

