Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.mllib.regression;
0019 
0020 import java.util.ArrayList;
0021 import java.util.Arrays;
0022 import java.util.List;
0023 
0024 import scala.Tuple3;
0025 
0026 import org.junit.Assert;
0027 import org.junit.Test;
0028 
0029 import org.apache.spark.SharedSparkSession;
0030 import org.apache.spark.api.java.JavaDoubleRDD;
0031 import org.apache.spark.api.java.JavaRDD;
0032 
0033 public class JavaIsotonicRegressionSuite extends SharedSparkSession {
0034 
0035   private static List<Tuple3<Double, Double, Double>> generateIsotonicInput(double[] labels) {
0036     List<Tuple3<Double, Double, Double>> input = new ArrayList<>(labels.length);
0037 
0038     for (int i = 1; i <= labels.length; i++) {
0039       input.add(new Tuple3<>(labels[i - 1], (double) i, 1.0));
0040     }
0041 
0042     return input;
0043   }
0044 
0045   private IsotonicRegressionModel runIsotonicRegression(double[] labels) {
0046     JavaRDD<Tuple3<Double, Double, Double>> trainRDD =
0047       jsc.parallelize(generateIsotonicInput(labels), 2).cache();
0048 
0049     return new IsotonicRegression().run(trainRDD);
0050   }
0051 
0052   @Test
0053   public void testIsotonicRegressionJavaRDD() {
0054     IsotonicRegressionModel model =
0055       runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});
0056 
0057     Assert.assertArrayEquals(
0058       new double[]{1, 2, 7.0 / 3, 7.0 / 3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1.0e-14);
0059   }
0060 
0061   @Test
0062   public void testIsotonicRegressionPredictionsJavaRDD() {
0063     IsotonicRegressionModel model =
0064       runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});
0065 
0066     JavaDoubleRDD testRDD = jsc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0));
0067     List<Double> predictions = model.predict(testRDD).collect();
0068 
0069     Assert.assertEquals(1.0, predictions.get(0).doubleValue(), 1.0e-14);
0070     Assert.assertEquals(1.0, predictions.get(1).doubleValue(), 1.0e-14);
0071     Assert.assertEquals(10.0, predictions.get(2).doubleValue(), 1.0e-14);
0072     Assert.assertEquals(12.0, predictions.get(3).doubleValue(), 1.0e-14);
0073     Assert.assertEquals(12.0, predictions.get(4).doubleValue(), 1.0e-14);
0074   }
0075 }