Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.mllib.regression;
0019 
0020 import java.util.ArrayList;
0021 import java.util.List;
0022 import java.util.Random;
0023 
0024 import org.junit.Assert;
0025 import org.junit.Test;
0026 
0027 import org.apache.spark.SharedSparkSession;
0028 import org.apache.spark.api.java.JavaRDD;
0029 import org.apache.spark.mllib.util.LinearDataGenerator;
0030 
0031 public class JavaRidgeRegressionSuite extends SharedSparkSession {
0032 
0033   private static double predictionError(List<LabeledPoint> validationData,
0034                                         RidgeRegressionModel model) {
0035     double errorSum = 0;
0036     for (LabeledPoint point : validationData) {
0037       double prediction = model.predict(point.features());
0038       errorSum += (prediction - point.label()) * (prediction - point.label());
0039     }
0040     return errorSum / validationData.size();
0041   }
0042 
0043   private static List<LabeledPoint> generateRidgeData(int numPoints, int numFeatures, double std) {
0044     // Pick weights as random values distributed uniformly in [-0.5, 0.5]
0045     Random random = new Random(42);
0046     double[] w = new double[numFeatures];
0047     for (int i = 0; i < w.length; i++) {
0048       w[i] = random.nextDouble() - 0.5;
0049     }
0050     return LinearDataGenerator.generateLinearInputAsList(0.0, w, numPoints, 42, std);
0051   }
0052 
0053   @Test
0054   public void runRidgeRegressionUsingConstructor() {
0055     int numExamples = 50;
0056     int numFeatures = 20;
0057     List<LabeledPoint> data = generateRidgeData(2 * numExamples, numFeatures, 10.0);
0058 
0059     JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
0060             new ArrayList<>(data.subList(0, numExamples)));
0061     List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
0062 
0063     RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD(1.0, 200, 0.0, 1.0);
0064     RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd());
0065     double unRegularizedErr = predictionError(validationData, model);
0066 
0067     ridgeSGDImpl.optimizer().setRegParam(0.1);
0068     model = ridgeSGDImpl.run(testRDD.rdd());
0069     double regularizedErr = predictionError(validationData, model);
0070 
0071     Assert.assertTrue(regularizedErr < unRegularizedErr);
0072   }
0073 
0074   @Test
0075   public void runRidgeRegressionUsingStaticMethods() {
0076     int numExamples = 50;
0077     int numFeatures = 20;
0078     List<LabeledPoint> data = generateRidgeData(2 * numExamples, numFeatures, 10.0);
0079 
0080     JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
0081             new ArrayList<>(data.subList(0, numExamples)));
0082     List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
0083 
0084     RidgeRegressionModel model = new RidgeRegressionWithSGD(1.0, 200, 0.0, 1.0)
0085         .run(testRDD.rdd());
0086     double unRegularizedErr = predictionError(validationData, model);
0087 
0088     model = new RidgeRegressionWithSGD(1.0, 200, 0.1, 1.0)
0089         .run(testRDD.rdd());
0090     double regularizedErr = predictionError(validationData, model);
0091 
0092     Assert.assertTrue(regularizedErr < unRegularizedErr);
0093   }
0094 }