/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.examples.ml;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.List;
import org.apache.spark.examples.ml.Document;
import org.apache.spark.examples.ml.LabeledDocument;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;

public class JavaSimpleTextClassificationPipeline {
    public static void main(String[] args) {
        SparkSession spark = SparkSession.builder().appName("JavaSimpleTextClassificationPipeline").getOrCreate();
        ArrayList localTraining = Lists.newArrayList((Object[])new LabeledDocument[]{new LabeledDocument(0L, "a b c d e spark", 1.0), new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), new LabeledDocument(3L, "hadoop mapreduce", 0.0)});
        Dataset training = spark.createDataFrame((List)localTraining, LabeledDocument.class);
        Tokenizer tokenizer = (Tokenizer)((Tokenizer)new Tokenizer().setInputCol("text")).setOutputCol("words");
        HashingTF hashingTF = new HashingTF().setNumFeatures(1000).setInputCol(tokenizer.getOutputCol()).setOutputCol("features");
        LogisticRegression lr = new LogisticRegression().setMaxIter(10).setRegParam(0.001);
        Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{tokenizer, hashingTF, lr});
        PipelineModel model = pipeline.fit(training);
        ArrayList localTest = Lists.newArrayList((Object[])new Document[]{new Document(4L, "spark i j k"), new Document(5L, "l m n"), new Document(6L, "spark hadoop spark"), new Document(7L, "apache hadoop")});
        Dataset test = spark.createDataFrame((List)localTest, Document.class);
        Dataset predictions = model.transform(test);
        for (Row r : predictions.select("id", new String[]{"text", "probability", "prediction"}).collectAsList()) {
            System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3));
        }
        spark.stop();
    }
}

