0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.ml;
0019
0020
0021 import java.util.Arrays;
0022 import java.util.List;
0023
0024 import org.apache.spark.ml.classification.LogisticRegression;
0025 import org.apache.spark.ml.classification.LogisticRegressionModel;
0026 import org.apache.spark.ml.linalg.VectorUDT;
0027 import org.apache.spark.ml.linalg.Vectors;
0028 import org.apache.spark.ml.param.ParamMap;
0029 import org.apache.spark.sql.Dataset;
0030 import org.apache.spark.sql.Row;
0031 import org.apache.spark.sql.RowFactory;
0032 import org.apache.spark.sql.types.DataTypes;
0033 import org.apache.spark.sql.types.Metadata;
0034 import org.apache.spark.sql.types.StructField;
0035 import org.apache.spark.sql.types.StructType;
0036
0037 import org.apache.spark.sql.SparkSession;
0038
0039
0040
0041
0042 public class JavaEstimatorTransformerParamExample {
0043 public static void main(String[] args) {
0044 SparkSession spark = SparkSession
0045 .builder()
0046 .appName("JavaEstimatorTransformerParamExample")
0047 .getOrCreate();
0048
0049
0050
0051 List<Row> dataTraining = Arrays.asList(
0052 RowFactory.create(1.0, Vectors.dense(0.0, 1.1, 0.1)),
0053 RowFactory.create(0.0, Vectors.dense(2.0, 1.0, -1.0)),
0054 RowFactory.create(0.0, Vectors.dense(2.0, 1.3, 1.0)),
0055 RowFactory.create(1.0, Vectors.dense(0.0, 1.2, -0.5))
0056 );
0057 StructType schema = new StructType(new StructField[]{
0058 new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
0059 new StructField("features", new VectorUDT(), false, Metadata.empty())
0060 });
0061 Dataset<Row> training = spark.createDataFrame(dataTraining, schema);
0062
0063
0064 LogisticRegression lr = new LogisticRegression();
0065
0066 System.out.println("LogisticRegression parameters:\n" + lr.explainParams() + "\n");
0067
0068
0069 lr.setMaxIter(10).setRegParam(0.01);
0070
0071
0072 LogisticRegressionModel model1 = lr.fit(training);
0073
0074
0075
0076
0077 System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap());
0078
0079
0080 ParamMap paramMap = new ParamMap()
0081 .put(lr.maxIter().w(20))
0082 .put(lr.maxIter(), 30)
0083 .put(lr.regParam().w(0.1), lr.threshold().w(0.55));
0084
0085
0086 ParamMap paramMap2 = new ParamMap()
0087 .put(lr.probabilityCol().w("myProbability"));
0088 ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2);
0089
0090
0091
0092 LogisticRegressionModel model2 = lr.fit(training, paramMapCombined);
0093 System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap());
0094
0095
0096 List<Row> dataTest = Arrays.asList(
0097 RowFactory.create(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
0098 RowFactory.create(0.0, Vectors.dense(3.0, 2.0, -0.1)),
0099 RowFactory.create(1.0, Vectors.dense(0.0, 2.2, -1.5))
0100 );
0101 Dataset<Row> test = spark.createDataFrame(dataTest, schema);
0102
0103
0104
0105
0106
0107 Dataset<Row> results = model2.transform(test);
0108 Dataset<Row> rows = results.select("features", "label", "myProbability", "prediction");
0109 for (Row r: rows.collectAsList()) {
0110 System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2)
0111 + ", prediction=" + r.get(3));
0112 }
0113
0114
0115 spark.stop();
0116 }
0117 }