0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.examples.ml;
0019
0020
0021 import org.apache.spark.ml.classification.LogisticRegression;
0022 import org.apache.spark.ml.classification.LogisticRegressionModel;
0023 import org.apache.spark.sql.Dataset;
0024 import org.apache.spark.sql.Row;
0025 import org.apache.spark.sql.SparkSession;
0026
0027
0028 public class JavaLogisticRegressionWithElasticNetExample {
0029 public static void main(String[] args) {
0030 SparkSession spark = SparkSession
0031 .builder()
0032 .appName("JavaLogisticRegressionWithElasticNetExample")
0033 .getOrCreate();
0034
0035
0036
0037 Dataset<Row> training = spark.read().format("libsvm")
0038 .load("data/mllib/sample_libsvm_data.txt");
0039
0040 LogisticRegression lr = new LogisticRegression()
0041 .setMaxIter(10)
0042 .setRegParam(0.3)
0043 .setElasticNetParam(0.8);
0044
0045
0046 LogisticRegressionModel lrModel = lr.fit(training);
0047
0048
0049 System.out.println("Coefficients: "
0050 + lrModel.coefficients() + " Intercept: " + lrModel.intercept());
0051
0052
0053 LogisticRegression mlr = new LogisticRegression()
0054 .setMaxIter(10)
0055 .setRegParam(0.3)
0056 .setElasticNetParam(0.8)
0057 .setFamily("multinomial");
0058
0059
0060 LogisticRegressionModel mlrModel = mlr.fit(training);
0061
0062
0063 System.out.println("Multinomial coefficients: " + lrModel.coefficientMatrix()
0064 + "\nMultinomial intercepts: " + mlrModel.interceptVector());
0065
0066
0067 spark.stop();
0068 }
0069 }