0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.ml;
0019
0020
0021 import java.util.Arrays;
0022
0023 import org.apache.spark.ml.Pipeline;
0024 import org.apache.spark.ml.PipelineModel;
0025 import org.apache.spark.ml.PipelineStage;
0026 import org.apache.spark.ml.classification.LogisticRegression;
0027 import org.apache.spark.ml.feature.HashingTF;
0028 import org.apache.spark.ml.feature.Tokenizer;
0029 import org.apache.spark.sql.Dataset;
0030 import org.apache.spark.sql.Row;
0031
0032 import org.apache.spark.sql.SparkSession;
0033
0034
0035
0036
0037 public class JavaPipelineExample {
0038 public static void main(String[] args) {
0039 SparkSession spark = SparkSession
0040 .builder()
0041 .appName("JavaPipelineExample")
0042 .getOrCreate();
0043
0044
0045
0046 Dataset<Row> training = spark.createDataFrame(Arrays.asList(
0047 new JavaLabeledDocument(0L, "a b c d e spark", 1.0),
0048 new JavaLabeledDocument(1L, "b d", 0.0),
0049 new JavaLabeledDocument(2L, "spark f g h", 1.0),
0050 new JavaLabeledDocument(3L, "hadoop mapreduce", 0.0)
0051 ), JavaLabeledDocument.class);
0052
0053
0054 Tokenizer tokenizer = new Tokenizer()
0055 .setInputCol("text")
0056 .setOutputCol("words");
0057 HashingTF hashingTF = new HashingTF()
0058 .setNumFeatures(1000)
0059 .setInputCol(tokenizer.getOutputCol())
0060 .setOutputCol("features");
0061 LogisticRegression lr = new LogisticRegression()
0062 .setMaxIter(10)
0063 .setRegParam(0.001);
0064 Pipeline pipeline = new Pipeline()
0065 .setStages(new PipelineStage[] {tokenizer, hashingTF, lr});
0066
0067
0068 PipelineModel model = pipeline.fit(training);
0069
0070
0071 Dataset<Row> test = spark.createDataFrame(Arrays.asList(
0072 new JavaDocument(4L, "spark i j k"),
0073 new JavaDocument(5L, "l m n"),
0074 new JavaDocument(6L, "spark hadoop spark"),
0075 new JavaDocument(7L, "apache hadoop")
0076 ), JavaDocument.class);
0077
0078
0079 Dataset<Row> predictions = model.transform(test);
0080 for (Row r : predictions.select("id", "text", "probability", "prediction").collectAsList()) {
0081 System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
0082 + ", prediction=" + r.get(3));
0083 }
0084
0085
0086 spark.stop();
0087 }
0088 }