0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.mllib.regression;
0019
0020 import java.util.ArrayList;
0021 import java.util.List;
0022 import java.util.Random;
0023
0024 import org.junit.Assert;
0025 import org.junit.Test;
0026
0027 import org.apache.spark.SharedSparkSession;
0028 import org.apache.spark.api.java.JavaRDD;
0029 import org.apache.spark.mllib.util.LinearDataGenerator;
0030
0031 public class JavaRidgeRegressionSuite extends SharedSparkSession {
0032
0033 private static double predictionError(List<LabeledPoint> validationData,
0034 RidgeRegressionModel model) {
0035 double errorSum = 0;
0036 for (LabeledPoint point : validationData) {
0037 double prediction = model.predict(point.features());
0038 errorSum += (prediction - point.label()) * (prediction - point.label());
0039 }
0040 return errorSum / validationData.size();
0041 }
0042
0043 private static List<LabeledPoint> generateRidgeData(int numPoints, int numFeatures, double std) {
0044
0045 Random random = new Random(42);
0046 double[] w = new double[numFeatures];
0047 for (int i = 0; i < w.length; i++) {
0048 w[i] = random.nextDouble() - 0.5;
0049 }
0050 return LinearDataGenerator.generateLinearInputAsList(0.0, w, numPoints, 42, std);
0051 }
0052
0053 @Test
0054 public void runRidgeRegressionUsingConstructor() {
0055 int numExamples = 50;
0056 int numFeatures = 20;
0057 List<LabeledPoint> data = generateRidgeData(2 * numExamples, numFeatures, 10.0);
0058
0059 JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
0060 new ArrayList<>(data.subList(0, numExamples)));
0061 List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
0062
0063 RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD(1.0, 200, 0.0, 1.0);
0064 RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd());
0065 double unRegularizedErr = predictionError(validationData, model);
0066
0067 ridgeSGDImpl.optimizer().setRegParam(0.1);
0068 model = ridgeSGDImpl.run(testRDD.rdd());
0069 double regularizedErr = predictionError(validationData, model);
0070
0071 Assert.assertTrue(regularizedErr < unRegularizedErr);
0072 }
0073
0074 @Test
0075 public void runRidgeRegressionUsingStaticMethods() {
0076 int numExamples = 50;
0077 int numFeatures = 20;
0078 List<LabeledPoint> data = generateRidgeData(2 * numExamples, numFeatures, 10.0);
0079
0080 JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
0081 new ArrayList<>(data.subList(0, numExamples)));
0082 List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
0083
0084 RidgeRegressionModel model = new RidgeRegressionWithSGD(1.0, 200, 0.0, 1.0)
0085 .run(testRDD.rdd());
0086 double unRegularizedErr = predictionError(validationData, model);
0087
0088 model = new RidgeRegressionWithSGD(1.0, 200, 0.1, 1.0)
0089 .run(testRDD.rdd());
0090 double regularizedErr = predictionError(validationData, model);
0091
0092 Assert.assertTrue(regularizedErr < unRegularizedErr);
0093 }
0094 }