Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.mllib.regression;
0019 
0020 import java.util.List;
0021 
0022 import org.junit.Assert;
0023 import org.junit.Test;
0024 
0025 import org.apache.spark.SharedSparkSession;
0026 import org.apache.spark.api.java.JavaRDD;
0027 import org.apache.spark.mllib.linalg.Vector;
0028 import org.apache.spark.mllib.util.LinearDataGenerator;
0029 
0030 public class JavaLinearRegressionSuite extends SharedSparkSession {
0031 
0032   private static int validatePrediction(
0033       List<LabeledPoint> validationData, LinearRegressionModel model) {
0034     int numAccurate = 0;
0035     for (LabeledPoint point : validationData) {
0036       double prediction = model.predict(point.features());
0037       // A prediction is off if the prediction is more than 0.5 away from expected value.
0038       if (Math.abs(prediction - point.label()) <= 0.5) {
0039         numAccurate++;
0040       }
0041     }
0042     return numAccurate;
0043   }
0044 
0045   @Test
0046   public void runLinearRegressionUsingConstructor() {
0047     int nPoints = 100;
0048     double A = 3.0;
0049     double[] weights = {10, 10};
0050 
0051     JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
0052       LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
0053     List<LabeledPoint> validationData =
0054       LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
0055 
0056     LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(1.0, 100, 0.0, 1.0);
0057     linSGDImpl.setIntercept(true);
0058     LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
0059 
0060     int numAccurate = validatePrediction(validationData, model);
0061     Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
0062   }
0063 
0064   @Test
0065   public void runLinearRegressionUsingStaticMethods() {
0066     int nPoints = 100;
0067     double A = 0.0;
0068     double[] weights = {10, 10};
0069 
0070     JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
0071       LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
0072     List<LabeledPoint> validationData =
0073       LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 17, 0.1);
0074 
0075     LinearRegressionModel model = new LinearRegressionWithSGD(1.0, 100, 0.0, 1.0)
0076         .run(testRDD.rdd());
0077 
0078     int numAccurate = validatePrediction(validationData, model);
0079     Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
0080   }
0081 
0082   @Test
0083   public void testPredictJavaRDD() {
0084     int nPoints = 100;
0085     double A = 0.0;
0086     double[] weights = {10, 10};
0087     JavaRDD<LabeledPoint> testRDD = jsc.parallelize(
0088       LinearDataGenerator.generateLinearInputAsList(A, weights, nPoints, 42, 0.1), 2).cache();
0089     LinearRegressionWithSGD linSGDImpl = new LinearRegressionWithSGD(1.0, 100, 0.0, 1.0);
0090     LinearRegressionModel model = linSGDImpl.run(testRDD.rdd());
0091     JavaRDD<Vector> vectors = testRDD.map(LabeledPoint::features);
0092     JavaRDD<Double> predictions = model.predict(vectors);
0093     // Should be able to get the first prediction.
0094     predictions.first();
0095   }
0096 }