/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.examples.ml;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.examples.ml.Document;
import org.apache.spark.examples.ml.LabeledDocument;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.ml.evaluation.Evaluator;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;

public class JavaCrossValidatorExample {
    public static void main(String[] args) {
        SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample");
        JavaSparkContext jsc = new JavaSparkContext(conf);
        SQLContext jsql = new SQLContext(jsc);
        ArrayList localTraining = Lists.newArrayList((Object[])new LabeledDocument[]{new LabeledDocument(0L, "a b c d e spark", 1.0), new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), new LabeledDocument(3L, "hadoop mapreduce", 0.0), new LabeledDocument(4L, "b spark who", 1.0), new LabeledDocument(5L, "g d a y", 0.0), new LabeledDocument(6L, "spark fly", 1.0), new LabeledDocument(7L, "was mapreduce", 0.0), new LabeledDocument(8L, "e spark program", 1.0), new LabeledDocument(9L, "a e c l", 0.0), new LabeledDocument(10L, "spark compile", 1.0), new LabeledDocument(11L, "hadoop software", 0.0)});
        DataFrame training = jsql.createDataFrame(jsc.parallelize((List)localTraining), LabeledDocument.class);
        Tokenizer tokenizer = (Tokenizer)((Tokenizer)new Tokenizer().setInputCol("text")).setOutputCol("words");
        HashingTF hashingTF = new HashingTF().setNumFeatures(1000).setInputCol(tokenizer.getOutputCol()).setOutputCol("features");
        LogisticRegression lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01);
        Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{tokenizer, hashingTF, lr});
        CrossValidator crossval = new CrossValidator().setEstimator((Estimator)pipeline).setEvaluator((Evaluator)new BinaryClassificationEvaluator());
        ParamMap[] paramGrid = new ParamGridBuilder().addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000}).addGrid(lr.regParam(), new double[]{0.1, 0.01}).build();
        crossval.setEstimatorParamMaps(paramGrid);
        crossval.setNumFolds(2);
        CrossValidatorModel cvModel = crossval.fit(training);
        ArrayList localTest = Lists.newArrayList((Object[])new Document[]{new Document(4L, "spark i j k"), new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")});
        DataFrame test = jsql.createDataFrame(jsc.parallelize((List)localTest), Document.class);
        DataFrame predictions = cvModel.transform(test);
        for (Row r : predictions.select("id", new String[]{"text", "probability", "prediction"}).collect()) {
            System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3));
        }
        jsc.stop();
    }
}

