0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.ml;
0019
0020
0021 import java.util.Arrays;
0022
0023
0024
0025 import org.apache.spark.ml.Pipeline;
0026 import org.apache.spark.ml.PipelineStage;
0027 import org.apache.spark.ml.classification.LogisticRegression;
0028 import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
0029 import org.apache.spark.ml.feature.HashingTF;
0030 import org.apache.spark.ml.feature.Tokenizer;
0031 import org.apache.spark.ml.param.ParamMap;
0032 import org.apache.spark.ml.tuning.CrossValidator;
0033 import org.apache.spark.ml.tuning.CrossValidatorModel;
0034 import org.apache.spark.ml.tuning.ParamGridBuilder;
0035 import org.apache.spark.sql.Dataset;
0036 import org.apache.spark.sql.Row;
0037
0038 import org.apache.spark.sql.SparkSession;
0039
0040
0041
0042
0043 public class JavaModelSelectionViaCrossValidationExample {
0044 public static void main(String[] args) {
0045 SparkSession spark = SparkSession
0046 .builder()
0047 .appName("JavaModelSelectionViaCrossValidationExample")
0048 .getOrCreate();
0049
0050
0051
0052 Dataset<Row> training = spark.createDataFrame(Arrays.asList(
0053 new JavaLabeledDocument(0L, "a b c d e spark", 1.0),
0054 new JavaLabeledDocument(1L, "b d", 0.0),
0055 new JavaLabeledDocument(2L,"spark f g h", 1.0),
0056 new JavaLabeledDocument(3L, "hadoop mapreduce", 0.0),
0057 new JavaLabeledDocument(4L, "b spark who", 1.0),
0058 new JavaLabeledDocument(5L, "g d a y", 0.0),
0059 new JavaLabeledDocument(6L, "spark fly", 1.0),
0060 new JavaLabeledDocument(7L, "was mapreduce", 0.0),
0061 new JavaLabeledDocument(8L, "e spark program", 1.0),
0062 new JavaLabeledDocument(9L, "a e c l", 0.0),
0063 new JavaLabeledDocument(10L, "spark compile", 1.0),
0064 new JavaLabeledDocument(11L, "hadoop software", 0.0)
0065 ), JavaLabeledDocument.class);
0066
0067
0068 Tokenizer tokenizer = new Tokenizer()
0069 .setInputCol("text")
0070 .setOutputCol("words");
0071 HashingTF hashingTF = new HashingTF()
0072 .setNumFeatures(1000)
0073 .setInputCol(tokenizer.getOutputCol())
0074 .setOutputCol("features");
0075 LogisticRegression lr = new LogisticRegression()
0076 .setMaxIter(10)
0077 .setRegParam(0.01);
0078 Pipeline pipeline = new Pipeline()
0079 .setStages(new PipelineStage[] {tokenizer, hashingTF, lr});
0080
0081
0082
0083
0084 ParamMap[] paramGrid = new ParamGridBuilder()
0085 .addGrid(hashingTF.numFeatures(), new int[] {10, 100, 1000})
0086 .addGrid(lr.regParam(), new double[] {0.1, 0.01})
0087 .build();
0088
0089
0090
0091
0092
0093
0094 CrossValidator cv = new CrossValidator()
0095 .setEstimator(pipeline)
0096 .setEvaluator(new BinaryClassificationEvaluator())
0097 .setEstimatorParamMaps(paramGrid)
0098 .setNumFolds(2)
0099 .setParallelism(2);
0100
0101
0102 CrossValidatorModel cvModel = cv.fit(training);
0103
0104
0105 Dataset<Row> test = spark.createDataFrame(Arrays.asList(
0106 new JavaDocument(4L, "spark i j k"),
0107 new JavaDocument(5L, "l m n"),
0108 new JavaDocument(6L, "mapreduce spark"),
0109 new JavaDocument(7L, "apache hadoop")
0110 ), JavaDocument.class);
0111
0112
0113 Dataset<Row> predictions = cvModel.transform(test);
0114 for (Row r : predictions.select("id", "text", "probability", "prediction").collectAsList()) {
0115 System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
0116 + ", prediction=" + r.get(3));
0117 }
0118
0119
0120 spark.stop();
0121 }
0122 }