0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.mllib.regression;
0019
0020 import java.util.ArrayList;
0021 import java.util.Arrays;
0022 import java.util.List;
0023
0024 import scala.Tuple3;
0025
0026 import org.junit.Assert;
0027 import org.junit.Test;
0028
0029 import org.apache.spark.SharedSparkSession;
0030 import org.apache.spark.api.java.JavaDoubleRDD;
0031 import org.apache.spark.api.java.JavaRDD;
0032
0033 public class JavaIsotonicRegressionSuite extends SharedSparkSession {
0034
0035 private static List<Tuple3<Double, Double, Double>> generateIsotonicInput(double[] labels) {
0036 List<Tuple3<Double, Double, Double>> input = new ArrayList<>(labels.length);
0037
0038 for (int i = 1; i <= labels.length; i++) {
0039 input.add(new Tuple3<>(labels[i - 1], (double) i, 1.0));
0040 }
0041
0042 return input;
0043 }
0044
0045 private IsotonicRegressionModel runIsotonicRegression(double[] labels) {
0046 JavaRDD<Tuple3<Double, Double, Double>> trainRDD =
0047 jsc.parallelize(generateIsotonicInput(labels), 2).cache();
0048
0049 return new IsotonicRegression().run(trainRDD);
0050 }
0051
0052 @Test
0053 public void testIsotonicRegressionJavaRDD() {
0054 IsotonicRegressionModel model =
0055 runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});
0056
0057 Assert.assertArrayEquals(
0058 new double[]{1, 2, 7.0 / 3, 7.0 / 3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1.0e-14);
0059 }
0060
0061 @Test
0062 public void testIsotonicRegressionPredictionsJavaRDD() {
0063 IsotonicRegressionModel model =
0064 runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12});
0065
0066 JavaDoubleRDD testRDD = jsc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0));
0067 List<Double> predictions = model.predict(testRDD).collect();
0068
0069 Assert.assertEquals(1.0, predictions.get(0).doubleValue(), 1.0e-14);
0070 Assert.assertEquals(1.0, predictions.get(1).doubleValue(), 1.0e-14);
0071 Assert.assertEquals(10.0, predictions.get(2).doubleValue(), 1.0e-14);
0072 Assert.assertEquals(12.0, predictions.get(3).doubleValue(), 1.0e-14);
0073 Assert.assertEquals(12.0, predictions.get(4).doubleValue(), 1.0e-14);
0074 }
0075 }