Back to home page

OSCL-LXR

 
 

    


0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one or more
0003  * contributor license agreements.  See the NOTICE file distributed with
0004  * this work for additional information regarding copyright ownership.
0005  * The ASF licenses this file to You under the Apache License, Version 2.0
0006  * (the "License"); you may not use this file except in compliance with
0007  * the License.  You may obtain a copy of the License at
0008  *
0009  *    http://www.apache.org/licenses/LICENSE-2.0
0010  *
0011  * Unless required by applicable law or agreed to in writing, software
0012  * distributed under the License is distributed on an "AS IS" BASIS,
0013  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014  * See the License for the specific language governing permissions and
0015  * limitations under the License.
0016  */
0017 
0018 package org.apache.spark.mllib.random;
0019 
0020 import java.io.Serializable;
0021 import java.util.Arrays;
0022 
0023 import org.junit.Assert;
0024 import org.junit.Test;
0025 
0026 import org.apache.spark.SharedSparkSession;
0027 import org.apache.spark.api.java.JavaDoubleRDD;
0028 import org.apache.spark.api.java.JavaRDD;
0029 import org.apache.spark.mllib.linalg.Vector;
0030 import static org.apache.spark.mllib.random.RandomRDDs.*;
0031 
0032 public class JavaRandomRDDsSuite extends SharedSparkSession {
0033 
0034   @Test
0035   public void testUniformRDD() {
0036     long m = 1000L;
0037     int p = 2;
0038     long seed = 1L;
0039     JavaDoubleRDD rdd1 = uniformJavaRDD(jsc, m);
0040     JavaDoubleRDD rdd2 = uniformJavaRDD(jsc, m, p);
0041     JavaDoubleRDD rdd3 = uniformJavaRDD(jsc, m, p, seed);
0042     for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0043       Assert.assertEquals(m, rdd.count());
0044     }
0045   }
0046 
0047   @Test
0048   public void testNormalRDD() {
0049     long m = 1000L;
0050     int p = 2;
0051     long seed = 1L;
0052     JavaDoubleRDD rdd1 = normalJavaRDD(jsc, m);
0053     JavaDoubleRDD rdd2 = normalJavaRDD(jsc, m, p);
0054     JavaDoubleRDD rdd3 = normalJavaRDD(jsc, m, p, seed);
0055     for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0056       Assert.assertEquals(m, rdd.count());
0057     }
0058   }
0059 
0060   @Test
0061   public void testLNormalRDD() {
0062     double mean = 4.0;
0063     double std = 2.0;
0064     long m = 1000L;
0065     int p = 2;
0066     long seed = 1L;
0067     JavaDoubleRDD rdd1 = logNormalJavaRDD(jsc, mean, std, m);
0068     JavaDoubleRDD rdd2 = logNormalJavaRDD(jsc, mean, std, m, p);
0069     JavaDoubleRDD rdd3 = logNormalJavaRDD(jsc, mean, std, m, p, seed);
0070     for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0071       Assert.assertEquals(m, rdd.count());
0072     }
0073   }
0074 
0075   @Test
0076   public void testPoissonRDD() {
0077     double mean = 2.0;
0078     long m = 1000L;
0079     int p = 2;
0080     long seed = 1L;
0081     JavaDoubleRDD rdd1 = poissonJavaRDD(jsc, mean, m);
0082     JavaDoubleRDD rdd2 = poissonJavaRDD(jsc, mean, m, p);
0083     JavaDoubleRDD rdd3 = poissonJavaRDD(jsc, mean, m, p, seed);
0084     for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0085       Assert.assertEquals(m, rdd.count());
0086     }
0087   }
0088 
0089   @Test
0090   public void testExponentialRDD() {
0091     double mean = 2.0;
0092     long m = 1000L;
0093     int p = 2;
0094     long seed = 1L;
0095     JavaDoubleRDD rdd1 = exponentialJavaRDD(jsc, mean, m);
0096     JavaDoubleRDD rdd2 = exponentialJavaRDD(jsc, mean, m, p);
0097     JavaDoubleRDD rdd3 = exponentialJavaRDD(jsc, mean, m, p, seed);
0098     for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0099       Assert.assertEquals(m, rdd.count());
0100     }
0101   }
0102 
0103   @Test
0104   public void testGammaRDD() {
0105     double shape = 1.0;
0106     double jscale = 2.0;
0107     long m = 1000L;
0108     int p = 2;
0109     long seed = 1L;
0110     JavaDoubleRDD rdd1 = gammaJavaRDD(jsc, shape, jscale, m);
0111     JavaDoubleRDD rdd2 = gammaJavaRDD(jsc, shape, jscale, m, p);
0112     JavaDoubleRDD rdd3 = gammaJavaRDD(jsc, shape, jscale, m, p, seed);
0113     for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0114       Assert.assertEquals(m, rdd.count());
0115     }
0116   }
0117 
0118 
0119   @Test
0120   @SuppressWarnings("unchecked")
0121   public void testUniformVectorRDD() {
0122     long m = 100L;
0123     int n = 10;
0124     int p = 2;
0125     long seed = 1L;
0126     JavaRDD<Vector> rdd1 = uniformJavaVectorRDD(jsc, m, n);
0127     JavaRDD<Vector> rdd2 = uniformJavaVectorRDD(jsc, m, n, p);
0128     JavaRDD<Vector> rdd3 = uniformJavaVectorRDD(jsc, m, n, p, seed);
0129     for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0130       Assert.assertEquals(m, rdd.count());
0131       Assert.assertEquals(n, rdd.first().size());
0132     }
0133   }
0134 
0135   @Test
0136   @SuppressWarnings("unchecked")
0137   public void testNormalVectorRDD() {
0138     long m = 100L;
0139     int n = 10;
0140     int p = 2;
0141     long seed = 1L;
0142     JavaRDD<Vector> rdd1 = normalJavaVectorRDD(jsc, m, n);
0143     JavaRDD<Vector> rdd2 = normalJavaVectorRDD(jsc, m, n, p);
0144     JavaRDD<Vector> rdd3 = normalJavaVectorRDD(jsc, m, n, p, seed);
0145     for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0146       Assert.assertEquals(m, rdd.count());
0147       Assert.assertEquals(n, rdd.first().size());
0148     }
0149   }
0150 
0151   @Test
0152   @SuppressWarnings("unchecked")
0153   public void testLogNormalVectorRDD() {
0154     double mean = 4.0;
0155     double std = 2.0;
0156     long m = 100L;
0157     int n = 10;
0158     int p = 2;
0159     long seed = 1L;
0160     JavaRDD<Vector> rdd1 = logNormalJavaVectorRDD(jsc, mean, std, m, n);
0161     JavaRDD<Vector> rdd2 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p);
0162     JavaRDD<Vector> rdd3 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p, seed);
0163     for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0164       Assert.assertEquals(m, rdd.count());
0165       Assert.assertEquals(n, rdd.first().size());
0166     }
0167   }
0168 
0169   @Test
0170   @SuppressWarnings("unchecked")
0171   public void testPoissonVectorRDD() {
0172     double mean = 2.0;
0173     long m = 100L;
0174     int n = 10;
0175     int p = 2;
0176     long seed = 1L;
0177     JavaRDD<Vector> rdd1 = poissonJavaVectorRDD(jsc, mean, m, n);
0178     JavaRDD<Vector> rdd2 = poissonJavaVectorRDD(jsc, mean, m, n, p);
0179     JavaRDD<Vector> rdd3 = poissonJavaVectorRDD(jsc, mean, m, n, p, seed);
0180     for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0181       Assert.assertEquals(m, rdd.count());
0182       Assert.assertEquals(n, rdd.first().size());
0183     }
0184   }
0185 
0186   @Test
0187   @SuppressWarnings("unchecked")
0188   public void testExponentialVectorRDD() {
0189     double mean = 2.0;
0190     long m = 100L;
0191     int n = 10;
0192     int p = 2;
0193     long seed = 1L;
0194     JavaRDD<Vector> rdd1 = exponentialJavaVectorRDD(jsc, mean, m, n);
0195     JavaRDD<Vector> rdd2 = exponentialJavaVectorRDD(jsc, mean, m, n, p);
0196     JavaRDD<Vector> rdd3 = exponentialJavaVectorRDD(jsc, mean, m, n, p, seed);
0197     for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0198       Assert.assertEquals(m, rdd.count());
0199       Assert.assertEquals(n, rdd.first().size());
0200     }
0201   }
0202 
0203   @Test
0204   @SuppressWarnings("unchecked")
0205   public void testGammaVectorRDD() {
0206     double shape = 1.0;
0207     double jscale = 2.0;
0208     long m = 100L;
0209     int n = 10;
0210     int p = 2;
0211     long seed = 1L;
0212     JavaRDD<Vector> rdd1 = gammaJavaVectorRDD(jsc, shape, jscale, m, n);
0213     JavaRDD<Vector> rdd2 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p);
0214     JavaRDD<Vector> rdd3 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p, seed);
0215     for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0216       Assert.assertEquals(m, rdd.count());
0217       Assert.assertEquals(n, rdd.first().size());
0218     }
0219   }
0220 
0221   @Test
0222   public void testArbitrary() {
0223     long size = 10;
0224     long seed = 1L;
0225     int numPartitions = 0;
0226     StringGenerator gen = new StringGenerator();
0227     JavaRDD<String> rdd1 = randomJavaRDD(jsc, gen, size);
0228     JavaRDD<String> rdd2 = randomJavaRDD(jsc, gen, size, numPartitions);
0229     JavaRDD<String> rdd3 = randomJavaRDD(jsc, gen, size, numPartitions, seed);
0230     for (JavaRDD<String> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0231       Assert.assertEquals(size, rdd.count());
0232       Assert.assertEquals(2, rdd.first().length());
0233     }
0234   }
0235 
0236   @Test
0237   @SuppressWarnings("unchecked")
0238   public void testRandomVectorRDD() {
0239     UniformGenerator generator = new UniformGenerator();
0240     long m = 100L;
0241     int n = 10;
0242     int p = 2;
0243     long seed = 1L;
0244     JavaRDD<Vector> rdd1 = randomJavaVectorRDD(jsc, generator, m, n);
0245     JavaRDD<Vector> rdd2 = randomJavaVectorRDD(jsc, generator, m, n, p);
0246     JavaRDD<Vector> rdd3 = randomJavaVectorRDD(jsc, generator, m, n, p, seed);
0247     for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0248       Assert.assertEquals(m, rdd.count());
0249       Assert.assertEquals(n, rdd.first().size());
0250     }
0251   }
0252 }
0253 
0254 // This is just a test generator, it always returns a string of 42
0255 class StringGenerator implements RandomDataGenerator<String>, Serializable {
0256   @Override
0257   public String nextValue() {
0258     return "42";
0259   }
0260 
0261   @Override
0262   public StringGenerator copy() {
0263     return new StringGenerator();
0264   }
0265 
0266   @Override
0267   public void setSeed(long seed) {
0268   }
0269 }