0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.mllib.regression;
0019
0020 import java.util.List;
0021
0022 import org.junit.Assert;
0023 import org.junit.Test;
0024
0025 import org.apache.spark.SharedSparkSession;
0026 import org.apache.spark.api.java.JavaRDD;
0027 import org.apache.spark.mllib.linalg.Vector;
0028 import org.apache.spark.mllib.util.LinearDataGenerator;
0029
0030 public class JavaLinearRegressionSuite extends SharedSparkSession {
0031
0032 private static int validatePrediction(
0033 List<LabeledPoint> validationData, LinearRegressionModel model) {
0034 int numAccurate = 0;
0035 for (LabeledPoint point : validationData) {
0036 double prediction = model.predict(point.features());
0037
0038 if (Math.abs(prediction - point.label()) <= 0.5) {
0039 numAccurate++;
0040 }
0041 }
0042 return numAccurate;
0043 }
0044
0045 @Test
0046 public void runLinearRegressionUsingConstructor() {
0047 int nPoints = 100;
0048 double A = 3.0;
0049 double[] weights = {10, 10};
0050
0051 JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
0052 LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
0053 List<LabeledPoint> validationData =
0054 LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
0055
0056 LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(1.0, 100, 0.0, 1.0);
0057 linSGDImpl.setIntercept(true);
0058 LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
0059
0060 int numAccurate = validatePrediction(validationData, model);
0061 Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
0062 }
0063
0064 @Test
0065 public void runLinearRegressionUsingStaticMethods() {
0066 int nPoints = 100;
0067 double A = 0.0;
0068 double[] weights = {10, 10};
0069
0070 JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
0071 LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
0072 List<LabeledPoint> validationData =
0073 LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
0074
0075 LinearRegressionModel model = new LinearRegressionWithSGD(1.0, 100, 0.0, 1.0)
0076 .run(testRDD.rdd());
0077
0078 int numAccurate = validatePrediction(validationData, model);
0079 Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
0080 }
0081
0082 @Test
0083 public void testPredictJavaRDD() {
0084 int nPoints = 100;
0085 double A = 0.0;
0086 double[] weights = {10, 10};
0087 JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
0088 LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
0089 LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(1.0, 100, 0.0, 1.0);
0090 LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
0091 JavaRDD<Vector> vectors = testRDD.map(LabeledPoint::features);
0092 JavaRDD<Double> predictions = model.predict(vectors);
0093
0094 predictions.first();
0095 }
0096 }