0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.ml;
0019
0020
0021 import org.apache.spark.ml.evaluation.RegressionEvaluator;
0022 import org.apache.spark.ml.param.ParamMap;
0023 import org.apache.spark.ml.regression.LinearRegression;
0024 import org.apache.spark.ml.tuning.ParamGridBuilder;
0025 import org.apache.spark.ml.tuning.TrainValidationSplit;
0026 import org.apache.spark.ml.tuning.TrainValidationSplitModel;
0027 import org.apache.spark.sql.Dataset;
0028 import org.apache.spark.sql.Row;
0029
0030 import org.apache.spark.sql.SparkSession;
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040 public class JavaModelSelectionViaTrainValidationSplitExample {
0041 public static void main(String[] args) {
0042 SparkSession spark = SparkSession
0043 .builder()
0044 .appName("JavaModelSelectionViaTrainValidationSplitExample")
0045 .getOrCreate();
0046
0047
0048 Dataset<Row> data = spark.read().format("libsvm")
0049 .load("data/mllib/sample_linear_regression_data.txt");
0050
0051
0052 Dataset<Row>[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345);
0053 Dataset<Row> training = splits[0];
0054 Dataset<Row> test = splits[1];
0055
0056 LinearRegression lr = new LinearRegression();
0057
0058
0059
0060
0061 ParamMap[] paramGrid = new ParamGridBuilder()
0062 .addGrid(lr.regParam(), new double[] {0.1, 0.01})
0063 .addGrid(lr.fitIntercept())
0064 .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0})
0065 .build();
0066
0067
0068
0069 TrainValidationSplit trainValidationSplit = new TrainValidationSplit()
0070 .setEstimator(lr)
0071 .setEvaluator(new RegressionEvaluator())
0072 .setEstimatorParamMaps(paramGrid)
0073 .setTrainRatio(0.8)
0074 .setParallelism(2);
0075
0076
0077 TrainValidationSplitModel model = trainValidationSplit.fit(training);
0078
0079
0080
0081 model.transform(test)
0082 .select("features", "label", "prediction")
0083 .show();
0084
0085
0086 spark.stop();
0087 }
0088 }