0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.ml;
0019
0020
0021 import org.apache.spark.ml.regression.LinearRegression;
0022 import org.apache.spark.ml.regression.LinearRegressionModel;
0023 import org.apache.spark.ml.regression.LinearRegressionTrainingSummary;
0024 import org.apache.spark.ml.linalg.Vectors;
0025 import org.apache.spark.sql.Dataset;
0026 import org.apache.spark.sql.Row;
0027 import org.apache.spark.sql.SparkSession;
0028
0029
0030 public class JavaLinearRegressionWithElasticNetExample {
0031 public static void main(String[] args) {
0032 SparkSession spark = SparkSession
0033 .builder()
0034 .appName("JavaLinearRegressionWithElasticNetExample")
0035 .getOrCreate();
0036
0037
0038
0039 Dataset<Row> training = spark.read().format("libsvm")
0040 .load("data/mllib/sample_linear_regression_data.txt");
0041
0042 LinearRegression lr = new LinearRegression()
0043 .setMaxIter(10)
0044 .setRegParam(0.3)
0045 .setElasticNetParam(0.8);
0046
0047
0048 LinearRegressionModel lrModel = lr.fit(training);
0049
0050
0051 System.out.println("Coefficients: "
0052 + lrModel.coefficients() + " Intercept: " + lrModel.intercept());
0053
0054
0055 LinearRegressionTrainingSummary trainingSummary = lrModel.summary();
0056 System.out.println("numIterations: " + trainingSummary.totalIterations());
0057 System.out.println("objectiveHistory: " + Vectors.dense(trainingSummary.objectiveHistory()));
0058 trainingSummary.residuals().show();
0059 System.out.println("RMSE: " + trainingSummary.rootMeanSquaredError());
0060 System.out.println("r2: " + trainingSummary.r2());
0061
0062
0063 spark.stop();
0064 }
0065 }