0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.ml.regression;
0019
0020 import java.io.IOException;
0021 import java.util.List;
0022
0023 import org.junit.Test;
0024 import static org.junit.Assert.assertEquals;
0025
0026 import org.apache.spark.SharedSparkSession;
0027 import org.apache.spark.api.java.JavaRDD;
0028 import static org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInputAsList;
0029 import org.apache.spark.ml.feature.LabeledPoint;
0030 import org.apache.spark.sql.Dataset;
0031 import org.apache.spark.sql.Row;
0032
0033 public class JavaLinearRegressionSuite extends SharedSparkSession {
0034 private transient Dataset<Row> dataset;
0035 private transient JavaRDD<LabeledPoint> datasetRDD;
0036
0037 @Override
0038 public void setUp() throws IOException {
0039 super.setUp();
0040 List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
0041 datasetRDD = jsc.parallelize(points, 2);
0042 dataset = spark.createDataFrame(datasetRDD, LabeledPoint.class);
0043 dataset.createOrReplaceTempView("dataset");
0044 }
0045
0046 @Test
0047 public void linearRegressionDefaultParams() {
0048 LinearRegression lr = new LinearRegression();
0049 assertEquals("label", lr.getLabelCol());
0050 assertEquals("auto", lr.getSolver());
0051 LinearRegressionModel model = lr.fit(dataset);
0052 model.transform(dataset).createOrReplaceTempView("prediction");
0053 Dataset<Row> predictions = spark.sql("SELECT label, prediction FROM prediction");
0054 predictions.collect();
0055
0056 assertEquals("features", model.getFeaturesCol());
0057 assertEquals("prediction", model.getPredictionCol());
0058 }
0059
0060 @Test
0061 public void linearRegressionWithSetters() {
0062
0063 LinearRegression lr = new LinearRegression()
0064 .setMaxIter(10)
0065 .setRegParam(1.0).setSolver("l-bfgs");
0066 LinearRegressionModel model = lr.fit(dataset);
0067 LinearRegression parent = (LinearRegression) model.parent();
0068 assertEquals(10, parent.getMaxIter());
0069 assertEquals(1.0, parent.getRegParam(), 0.0);
0070
0071
0072 LinearRegressionModel model2 =
0073 lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
0074 LinearRegression parent2 = (LinearRegression) model2.parent();
0075 assertEquals(5, parent2.getMaxIter());
0076 assertEquals(0.1, parent2.getRegParam(), 0.0);
0077 assertEquals("thePred", model2.getPredictionCol());
0078 }
0079 }