0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 package org.apache.spark.mllib.random;
0019
0020 import java.io.Serializable;
0021 import java.util.Arrays;
0022
0023 import org.junit.Assert;
0024 import org.junit.Test;
0025
0026 import org.apache.spark.SharedSparkSession;
0027 import org.apache.spark.api.java.JavaDoubleRDD;
0028 import org.apache.spark.api.java.JavaRDD;
0029 import org.apache.spark.mllib.linalg.Vector;
0030 import static org.apache.spark.mllib.random.RandomRDDs.*;
0031
0032 public class JavaRandomRDDsSuite extends SharedSparkSession {
0033
0034 @Test
0035 public void testUniformRDD() {
0036 long m = 1000L;
0037 int p = 2;
0038 long seed = 1L;
0039 JavaDoubleRDD rdd1 = uniformJavaRDD(jsc, m);
0040 JavaDoubleRDD rdd2 = uniformJavaRDD(jsc, m, p);
0041 JavaDoubleRDD rdd3 = uniformJavaRDD(jsc, m, p, seed);
0042 for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0043 Assert.assertEquals(m, rdd.count());
0044 }
0045 }
0046
0047 @Test
0048 public void testNormalRDD() {
0049 long m = 1000L;
0050 int p = 2;
0051 long seed = 1L;
0052 JavaDoubleRDD rdd1 = normalJavaRDD(jsc, m);
0053 JavaDoubleRDD rdd2 = normalJavaRDD(jsc, m, p);
0054 JavaDoubleRDD rdd3 = normalJavaRDD(jsc, m, p, seed);
0055 for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0056 Assert.assertEquals(m, rdd.count());
0057 }
0058 }
0059
0060 @Test
0061 public void testLNormalRDD() {
0062 double mean = 4.0;
0063 double std = 2.0;
0064 long m = 1000L;
0065 int p = 2;
0066 long seed = 1L;
0067 JavaDoubleRDD rdd1 = logNormalJavaRDD(jsc, mean, std, m);
0068 JavaDoubleRDD rdd2 = logNormalJavaRDD(jsc, mean, std, m, p);
0069 JavaDoubleRDD rdd3 = logNormalJavaRDD(jsc, mean, std, m, p, seed);
0070 for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0071 Assert.assertEquals(m, rdd.count());
0072 }
0073 }
0074
0075 @Test
0076 public void testPoissonRDD() {
0077 double mean = 2.0;
0078 long m = 1000L;
0079 int p = 2;
0080 long seed = 1L;
0081 JavaDoubleRDD rdd1 = poissonJavaRDD(jsc, mean, m);
0082 JavaDoubleRDD rdd2 = poissonJavaRDD(jsc, mean, m, p);
0083 JavaDoubleRDD rdd3 = poissonJavaRDD(jsc, mean, m, p, seed);
0084 for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0085 Assert.assertEquals(m, rdd.count());
0086 }
0087 }
0088
0089 @Test
0090 public void testExponentialRDD() {
0091 double mean = 2.0;
0092 long m = 1000L;
0093 int p = 2;
0094 long seed = 1L;
0095 JavaDoubleRDD rdd1 = exponentialJavaRDD(jsc, mean, m);
0096 JavaDoubleRDD rdd2 = exponentialJavaRDD(jsc, mean, m, p);
0097 JavaDoubleRDD rdd3 = exponentialJavaRDD(jsc, mean, m, p, seed);
0098 for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0099 Assert.assertEquals(m, rdd.count());
0100 }
0101 }
0102
0103 @Test
0104 public void testGammaRDD() {
0105 double shape = 1.0;
0106 double jscale = 2.0;
0107 long m = 1000L;
0108 int p = 2;
0109 long seed = 1L;
0110 JavaDoubleRDD rdd1 = gammaJavaRDD(jsc, shape, jscale, m);
0111 JavaDoubleRDD rdd2 = gammaJavaRDD(jsc, shape, jscale, m, p);
0112 JavaDoubleRDD rdd3 = gammaJavaRDD(jsc, shape, jscale, m, p, seed);
0113 for (JavaDoubleRDD rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0114 Assert.assertEquals(m, rdd.count());
0115 }
0116 }
0117
0118
0119 @Test
0120 @SuppressWarnings("unchecked")
0121 public void testUniformVectorRDD() {
0122 long m = 100L;
0123 int n = 10;
0124 int p = 2;
0125 long seed = 1L;
0126 JavaRDD<Vector> rdd1 = uniformJavaVectorRDD(jsc, m, n);
0127 JavaRDD<Vector> rdd2 = uniformJavaVectorRDD(jsc, m, n, p);
0128 JavaRDD<Vector> rdd3 = uniformJavaVectorRDD(jsc, m, n, p, seed);
0129 for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0130 Assert.assertEquals(m, rdd.count());
0131 Assert.assertEquals(n, rdd.first().size());
0132 }
0133 }
0134
0135 @Test
0136 @SuppressWarnings("unchecked")
0137 public void testNormalVectorRDD() {
0138 long m = 100L;
0139 int n = 10;
0140 int p = 2;
0141 long seed = 1L;
0142 JavaRDD<Vector> rdd1 = normalJavaVectorRDD(jsc, m, n);
0143 JavaRDD<Vector> rdd2 = normalJavaVectorRDD(jsc, m, n, p);
0144 JavaRDD<Vector> rdd3 = normalJavaVectorRDD(jsc, m, n, p, seed);
0145 for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0146 Assert.assertEquals(m, rdd.count());
0147 Assert.assertEquals(n, rdd.first().size());
0148 }
0149 }
0150
0151 @Test
0152 @SuppressWarnings("unchecked")
0153 public void testLogNormalVectorRDD() {
0154 double mean = 4.0;
0155 double std = 2.0;
0156 long m = 100L;
0157 int n = 10;
0158 int p = 2;
0159 long seed = 1L;
0160 JavaRDD<Vector> rdd1 = logNormalJavaVectorRDD(jsc, mean, std, m, n);
0161 JavaRDD<Vector> rdd2 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p);
0162 JavaRDD<Vector> rdd3 = logNormalJavaVectorRDD(jsc, mean, std, m, n, p, seed);
0163 for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0164 Assert.assertEquals(m, rdd.count());
0165 Assert.assertEquals(n, rdd.first().size());
0166 }
0167 }
0168
0169 @Test
0170 @SuppressWarnings("unchecked")
0171 public void testPoissonVectorRDD() {
0172 double mean = 2.0;
0173 long m = 100L;
0174 int n = 10;
0175 int p = 2;
0176 long seed = 1L;
0177 JavaRDD<Vector> rdd1 = poissonJavaVectorRDD(jsc, mean, m, n);
0178 JavaRDD<Vector> rdd2 = poissonJavaVectorRDD(jsc, mean, m, n, p);
0179 JavaRDD<Vector> rdd3 = poissonJavaVectorRDD(jsc, mean, m, n, p, seed);
0180 for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0181 Assert.assertEquals(m, rdd.count());
0182 Assert.assertEquals(n, rdd.first().size());
0183 }
0184 }
0185
0186 @Test
0187 @SuppressWarnings("unchecked")
0188 public void testExponentialVectorRDD() {
0189 double mean = 2.0;
0190 long m = 100L;
0191 int n = 10;
0192 int p = 2;
0193 long seed = 1L;
0194 JavaRDD<Vector> rdd1 = exponentialJavaVectorRDD(jsc, mean, m, n);
0195 JavaRDD<Vector> rdd2 = exponentialJavaVectorRDD(jsc, mean, m, n, p);
0196 JavaRDD<Vector> rdd3 = exponentialJavaVectorRDD(jsc, mean, m, n, p, seed);
0197 for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0198 Assert.assertEquals(m, rdd.count());
0199 Assert.assertEquals(n, rdd.first().size());
0200 }
0201 }
0202
0203 @Test
0204 @SuppressWarnings("unchecked")
0205 public void testGammaVectorRDD() {
0206 double shape = 1.0;
0207 double jscale = 2.0;
0208 long m = 100L;
0209 int n = 10;
0210 int p = 2;
0211 long seed = 1L;
0212 JavaRDD<Vector> rdd1 = gammaJavaVectorRDD(jsc, shape, jscale, m, n);
0213 JavaRDD<Vector> rdd2 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p);
0214 JavaRDD<Vector> rdd3 = gammaJavaVectorRDD(jsc, shape, jscale, m, n, p, seed);
0215 for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0216 Assert.assertEquals(m, rdd.count());
0217 Assert.assertEquals(n, rdd.first().size());
0218 }
0219 }
0220
0221 @Test
0222 public void testArbitrary() {
0223 long size = 10;
0224 long seed = 1L;
0225 int numPartitions = 0;
0226 StringGenerator gen = new StringGenerator();
0227 JavaRDD<String> rdd1 = randomJavaRDD(jsc, gen, size);
0228 JavaRDD<String> rdd2 = randomJavaRDD(jsc, gen, size, numPartitions);
0229 JavaRDD<String> rdd3 = randomJavaRDD(jsc, gen, size, numPartitions, seed);
0230 for (JavaRDD<String> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0231 Assert.assertEquals(size, rdd.count());
0232 Assert.assertEquals(2, rdd.first().length());
0233 }
0234 }
0235
0236 @Test
0237 @SuppressWarnings("unchecked")
0238 public void testRandomVectorRDD() {
0239 UniformGenerator generator = new UniformGenerator();
0240 long m = 100L;
0241 int n = 10;
0242 int p = 2;
0243 long seed = 1L;
0244 JavaRDD<Vector> rdd1 = randomJavaVectorRDD(jsc, generator, m, n);
0245 JavaRDD<Vector> rdd2 = randomJavaVectorRDD(jsc, generator, m, n, p);
0246 JavaRDD<Vector> rdd3 = randomJavaVectorRDD(jsc, generator, m, n, p, seed);
0247 for (JavaRDD<Vector> rdd : Arrays.asList(rdd1, rdd2, rdd3)) {
0248 Assert.assertEquals(m, rdd.count());
0249 Assert.assertEquals(n, rdd.first().size());
0250 }
0251 }
0252 }
0253
0254
0255 class StringGenerator implements RandomDataGenerator<String>, Serializable {
0256 @Override
0257 public String nextValue() {
0258 return "42";
0259 }
0260
0261 @Override
0262 public StringGenerator copy() {
0263 return new StringGenerator();
0264 }
0265
0266 @Override
0267 public void setSeed(long seed) {
0268 }
0269 }