0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.mllib.stat;
0019
0020 import java.util.Arrays;
0021 import java.util.List;
0022
0023 import org.junit.After;
0024 import org.junit.Before;
0025 import org.junit.Test;
0026 import static org.junit.Assert.assertEquals;
0027
0028 import org.apache.spark.SparkConf;
0029 import org.apache.spark.api.java.JavaDoubleRDD;
0030 import org.apache.spark.api.java.JavaRDD;
0031 import org.apache.spark.api.java.JavaSparkContext;
0032 import org.apache.spark.mllib.linalg.Vectors;
0033 import org.apache.spark.mllib.regression.LabeledPoint;
0034 import org.apache.spark.mllib.stat.test.BinarySample;
0035 import org.apache.spark.mllib.stat.test.ChiSqTestResult;
0036 import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult;
0037 import org.apache.spark.mllib.stat.test.StreamingTest;
0038 import org.apache.spark.sql.SparkSession;
0039 import org.apache.spark.streaming.Duration;
0040 import org.apache.spark.streaming.api.java.JavaDStream;
0041 import org.apache.spark.streaming.api.java.JavaStreamingContext;
0042 import static org.apache.spark.streaming.JavaTestUtils.*;
0043
0044 public class JavaStatisticsSuite {
0045 private transient SparkSession spark;
0046 private transient JavaSparkContext jsc;
0047 private transient JavaStreamingContext ssc;
0048
0049 @Before
0050 public void setUp() {
0051 SparkConf conf = new SparkConf()
0052 .set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
0053 spark = SparkSession.builder()
0054 .master("local[2]")
0055 .appName("JavaStatistics")
0056 .config(conf)
0057 .getOrCreate();
0058 jsc = new JavaSparkContext(spark.sparkContext());
0059 ssc = new JavaStreamingContext(jsc, new Duration(1000));
0060 ssc.checkpoint("checkpoint");
0061 }
0062
0063 @After
0064 public void tearDown() {
0065 spark.stop();
0066 ssc.stop();
0067 spark = null;
0068 }
0069
0070 @Test
0071 public void testCorr() {
0072 JavaRDD<Double> x = jsc.parallelize(Arrays.asList(1.0, 2.0, 3.0, 4.0));
0073 JavaRDD<Double> y = jsc.parallelize(Arrays.asList(1.1, 2.2, 3.1, 4.3));
0074
0075 Double corr1 = Statistics.corr(x, y);
0076 Double corr2 = Statistics.corr(x, y, "pearson");
0077
0078 assertEquals(corr1, corr2, 1e-5);
0079 }
0080
0081 @Test
0082 public void kolmogorovSmirnovTest() {
0083 JavaDoubleRDD data = jsc.parallelizeDoubles(Arrays.asList(0.2, 1.0, -1.0, 2.0));
0084 KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm");
0085 KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest(
0086 data, "norm", 0.0, 1.0);
0087 }
0088
0089 @Test
0090 public void chiSqTest() {
0091 JavaRDD<LabeledPoint> data = jsc.parallelize(Arrays.asList(
0092 new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)),
0093 new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)),
0094 new LabeledPoint(0.0, Vectors.dense(2.4, 8.1))));
0095 ChiSqTestResult[] testResults = Statistics.chiSqTest(data);
0096 }
0097
0098 @Test
0099 public void streamingTest() {
0100 List<BinarySample> trainingBatch = Arrays.asList(
0101 new BinarySample(true, 1.0),
0102 new BinarySample(false, 2.0));
0103 JavaDStream<BinarySample> training =
0104 attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2);
0105 int numBatches = 2;
0106 StreamingTest model = new StreamingTest()
0107 .setWindowSize(0)
0108 .setPeacePeriod(0)
0109 .setTestMethod("welch");
0110 model.registerStream(training);
0111 attachTestOutputStream(training);
0112 runStreams(ssc, numBatches, numBatches);
0113 }
0114 }