Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.mllib.stat;
0019 
0020 import java.util.Arrays;
0021 import java.util.List;
0022 
0023 import org.junit.After;
0024 import org.junit.Before;
0025 import org.junit.Test;
0026 import static org.junit.Assert.assertEquals;
0027 
0028 import org.apache.spark.SparkConf;
0029 import org.apache.spark.api.java.JavaDoubleRDD;
0030 import org.apache.spark.api.java.JavaRDD;
0031 import org.apache.spark.api.java.JavaSparkContext;
0032 import org.apache.spark.mllib.linalg.Vectors;
0033 import org.apache.spark.mllib.regression.LabeledPoint;
0034 import org.apache.spark.mllib.stat.test.BinarySample;
0035 import org.apache.spark.mllib.stat.test.ChiSqTestResult;
0036 import org.apache.spark.mllib.stat.test.KolmogorovSmirnovTestResult;
0037 import org.apache.spark.mllib.stat.test.StreamingTest;
0038 import org.apache.spark.sql.SparkSession;
0039 import org.apache.spark.streaming.Duration;
0040 import org.apache.spark.streaming.api.java.JavaDStream;
0041 import org.apache.spark.streaming.api.java.JavaStreamingContext;
0042 import static org.apache.spark.streaming.JavaTestUtils.*;
0043 
0044 public class JavaStatisticsSuite {
0045   private transient SparkSession spark;
0046   private transient JavaSparkContext jsc;
0047   private transient JavaStreamingContext ssc;
0048 
0049   @Before
0050   public void setUp() {
0051     SparkConf conf = new SparkConf()
0052       .set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
0053     spark = SparkSession.builder()
0054       .master("local[2]")
0055       .appName("JavaStatistics")
0056       .config(conf)
0057       .getOrCreate();
0058     jsc = new JavaSparkContext(spark.sparkContext());
0059     ssc = new JavaStreamingContext(jsc, new Duration(1000));
0060     ssc.checkpoint("checkpoint");
0061   }
0062 
0063   @After
0064   public void tearDown() {
0065     spark.stop();
0066     ssc.stop();
0067     spark = null;
0068   }
0069 
0070   @Test
0071   public void testCorr() {
0072     JavaRDD<Double> x = jsc.parallelize(Arrays.asList(1.0, 2.0, 3.0, 4.0));
0073     JavaRDD<Double> y = jsc.parallelize(Arrays.asList(1.1, 2.2, 3.1, 4.3));
0074 
0075     Double corr1 = Statistics.corr(x, y);
0076     Double corr2 = Statistics.corr(x, y, "pearson");
0077     // Check default method
0078     assertEquals(corr1, corr2, 1e-5);
0079   }
0080 
0081   @Test
0082   public void kolmogorovSmirnovTest() {
0083     JavaDoubleRDD data = jsc.parallelizeDoubles(Arrays.asList(0.2, 1.0, -1.0, 2.0));
0084     KolmogorovSmirnovTestResult testResult1 = Statistics.kolmogorovSmirnovTest(data, "norm");
0085     KolmogorovSmirnovTestResult testResult2 = Statistics.kolmogorovSmirnovTest(
0086       data, "norm", 0.0, 1.0);
0087   }
0088 
0089   @Test
0090   public void chiSqTest() {
0091     JavaRDD<LabeledPoint> data = jsc.parallelize(Arrays.asList(
0092       new LabeledPoint(0.0, Vectors.dense(0.1, 2.3)),
0093       new LabeledPoint(1.0, Vectors.dense(1.5, 5.1)),
0094       new LabeledPoint(0.0, Vectors.dense(2.4, 8.1))));
0095     ChiSqTestResult[] testResults = Statistics.chiSqTest(data);
0096   }
0097 
0098   @Test
0099   public void streamingTest() {
0100     List<BinarySample> trainingBatch = Arrays.asList(
0101       new BinarySample(true, 1.0),
0102       new BinarySample(false, 2.0));
0103     JavaDStream<BinarySample> training =
0104       attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2);
0105     int numBatches = 2;
0106     StreamingTest model = new StreamingTest()
0107       .setWindowSize(0)
0108       .setPeacePeriod(0)
0109       .setTestMethod("welch");
0110     model.registerStream(training);
0111     attachTestOutputStream(training);
0112     runStreams(ssc, numBatches, numBatches);
0113   }
0114 }