0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 from datetime import datetime, timedelta
0018 import hashlib
0019 import os
0020 import random
0021 import sys
0022 import tempfile
0023 import time
0024 from glob import glob
0025
0026 from py4j.protocol import Py4JJavaError
0027
0028 from pyspark import shuffle, RDD
0029 from pyspark.serializers import CloudPickleSerializer, BatchedSerializer, PickleSerializer,\
0030 MarshalSerializer, UTF8Deserializer, NoOpSerializer
0031 from pyspark.testing.utils import ReusedPySparkTestCase, SPARK_HOME, QuietTest
0032
0033 if sys.version_info[0] >= 3:
0034 xrange = range
0035
0036
0037 global_func = lambda: "Hi"
0038
0039
0040 class RDDTests(ReusedPySparkTestCase):
0041
0042 def test_range(self):
0043 self.assertEqual(self.sc.range(1, 1).count(), 0)
0044 self.assertEqual(self.sc.range(1, 0, -1).count(), 1)
0045 self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2)
0046
0047 def test_id(self):
0048 rdd = self.sc.parallelize(range(10))
0049 id = rdd.id()
0050 self.assertEqual(id, rdd.id())
0051 rdd2 = rdd.map(str).filter(bool)
0052 id2 = rdd2.id()
0053 self.assertEqual(id + 1, id2)
0054 self.assertEqual(id2, rdd2.id())
0055
0056 def test_empty_rdd(self):
0057 rdd = self.sc.emptyRDD()
0058 self.assertTrue(rdd.isEmpty())
0059
0060 def test_sum(self):
0061 self.assertEqual(0, self.sc.emptyRDD().sum())
0062 self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum())
0063
0064 def test_to_localiterator(self):
0065 rdd = self.sc.parallelize([1, 2, 3])
0066 it = rdd.toLocalIterator()
0067 self.assertEqual([1, 2, 3], sorted(it))
0068
0069 rdd2 = rdd.repartition(1000)
0070 it2 = rdd2.toLocalIterator()
0071 self.assertEqual([1, 2, 3], sorted(it2))
0072
0073 def test_to_localiterator_prefetch(self):
0074
0075
0076
0077
0078
0079
0080 rdd = self.sc.parallelize(range(2), 2)
0081 times1 = rdd.map(lambda x: datetime.now())
0082 times2 = rdd.map(lambda x: datetime.now())
0083 times_iter_prefetch = times1.toLocalIterator(prefetchPartitions=True)
0084 times_iter = times2.toLocalIterator(prefetchPartitions=False)
0085 times_prefetch_head = next(times_iter_prefetch)
0086 times_head = next(times_iter)
0087 time.sleep(2)
0088 times_next = next(times_iter)
0089 times_prefetch_next = next(times_iter_prefetch)
0090 self.assertTrue(times_next - times_head >= timedelta(seconds=2))
0091 self.assertTrue(times_prefetch_next - times_prefetch_head < timedelta(seconds=1))
0092
0093 def test_save_as_textfile_with_unicode(self):
0094
0095 x = u"\u00A1Hola, mundo!"
0096 data = self.sc.parallelize([x])
0097 tempFile = tempfile.NamedTemporaryFile(delete=True)
0098 tempFile.close()
0099 data.saveAsTextFile(tempFile.name)
0100 raw_contents = b''.join(open(p, 'rb').read()
0101 for p in glob(tempFile.name + "/part-0000*"))
0102 self.assertEqual(x, raw_contents.strip().decode("utf-8"))
0103
0104 def test_save_as_textfile_with_utf8(self):
0105 x = u"\u00A1Hola, mundo!"
0106 data = self.sc.parallelize([x.encode("utf-8")])
0107 tempFile = tempfile.NamedTemporaryFile(delete=True)
0108 tempFile.close()
0109 data.saveAsTextFile(tempFile.name)
0110 raw_contents = b''.join(open(p, 'rb').read()
0111 for p in glob(tempFile.name + "/part-0000*"))
0112 self.assertEqual(x, raw_contents.strip().decode('utf8'))
0113
0114 def test_transforming_cartesian_result(self):
0115
0116 rdd1 = self.sc.parallelize([1, 2])
0117 rdd2 = self.sc.parallelize([3, 4])
0118 cart = rdd1.cartesian(rdd2)
0119 result = cart.map(lambda x_y3: x_y3[0] + x_y3[1]).collect()
0120
0121 def test_transforming_pickle_file(self):
0122
0123 data = self.sc.parallelize([u"Hello", u"World!"])
0124 tempFile = tempfile.NamedTemporaryFile(delete=True)
0125 tempFile.close()
0126 data.saveAsPickleFile(tempFile.name)
0127 pickled_file = self.sc.pickleFile(tempFile.name)
0128 pickled_file.map(lambda x: x).collect()
0129
0130 def test_cartesian_on_textfile(self):
0131
0132 path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
0133 a = self.sc.textFile(path)
0134 result = a.cartesian(a).collect()
0135 (x, y) = result[0]
0136 self.assertEqual(u"Hello World!", x.strip())
0137 self.assertEqual(u"Hello World!", y.strip())
0138
0139 def test_cartesian_chaining(self):
0140
0141 rdd = self.sc.parallelize(range(10), 2)
0142 self.assertSetEqual(
0143 set(rdd.cartesian(rdd).cartesian(rdd).collect()),
0144 set([((x, y), z) for x in range(10) for y in range(10) for z in range(10)])
0145 )
0146
0147 self.assertSetEqual(
0148 set(rdd.cartesian(rdd.cartesian(rdd)).collect()),
0149 set([(x, (y, z)) for x in range(10) for y in range(10) for z in range(10)])
0150 )
0151
0152 self.assertSetEqual(
0153 set(rdd.cartesian(rdd.zip(rdd)).collect()),
0154 set([(x, (y, y)) for x in range(10) for y in range(10)])
0155 )
0156
0157 def test_zip_chaining(self):
0158
0159 rdd = self.sc.parallelize('abc', 2)
0160 self.assertSetEqual(
0161 set(rdd.zip(rdd).zip(rdd).collect()),
0162 set([((x, x), x) for x in 'abc'])
0163 )
0164 self.assertSetEqual(
0165 set(rdd.zip(rdd.zip(rdd)).collect()),
0166 set([(x, (x, x)) for x in 'abc'])
0167 )
0168
0169 def test_union_pair_rdd(self):
0170
0171 rdd = self.sc.parallelize([1, 2])
0172 pair_rdd = rdd.zip(rdd)
0173 unionRDD = self.sc.union([pair_rdd, pair_rdd])
0174 self.assertEqual(
0175 set(unionRDD.collect()),
0176 set([(1, 1), (2, 2), (1, 1), (2, 2)])
0177 )
0178 self.assertEqual(unionRDD.count(), 4)
0179
0180 def test_deleting_input_files(self):
0181
0182 tempFile = tempfile.NamedTemporaryFile(delete=False)
0183 tempFile.write(b"Hello World!")
0184 tempFile.close()
0185 data = self.sc.textFile(tempFile.name)
0186 filtered_data = data.filter(lambda x: True)
0187 self.assertEqual(1, filtered_data.count())
0188 os.unlink(tempFile.name)
0189 with QuietTest(self.sc):
0190 self.assertRaises(Exception, lambda: filtered_data.count())
0191
0192 def test_sampling_default_seed(self):
0193
0194 data = self.sc.parallelize(xrange(1000), 1)
0195 subset = data.takeSample(False, 10)
0196 self.assertEqual(len(subset), 10)
0197
0198 def test_aggregate_mutable_zero_value(self):
0199
0200
0201
0202
0203 from collections import defaultdict
0204
0205
0206 data1 = self.sc.range(10, numSlices=1)
0207 data2 = self.sc.range(10, numSlices=2)
0208
0209 def seqOp(x, y):
0210 x[y] += 1
0211 return x
0212
0213 def comboOp(x, y):
0214 for key, val in y.items():
0215 x[key] += val
0216 return x
0217
0218 counts1 = data1.aggregate(defaultdict(int), seqOp, comboOp)
0219 counts2 = data2.aggregate(defaultdict(int), seqOp, comboOp)
0220 counts3 = data1.treeAggregate(defaultdict(int), seqOp, comboOp, 2)
0221 counts4 = data2.treeAggregate(defaultdict(int), seqOp, comboOp, 2)
0222
0223 ground_truth = defaultdict(int, dict((i, 1) for i in range(10)))
0224 self.assertEqual(counts1, ground_truth)
0225 self.assertEqual(counts2, ground_truth)
0226 self.assertEqual(counts3, ground_truth)
0227 self.assertEqual(counts4, ground_truth)
0228
0229 def test_aggregate_by_key_mutable_zero_value(self):
0230
0231
0232
0233
0234
0235
0236
0237 tuples = list(zip(list(range(10))*2, [1]*20))
0238
0239 data1 = self.sc.parallelize(tuples, 1)
0240 data2 = self.sc.parallelize(tuples, 2)
0241
0242 def seqOp(x, y):
0243 x.append(y)
0244 return x
0245
0246 def comboOp(x, y):
0247 x.extend(y)
0248 return x
0249
0250 values1 = data1.aggregateByKey([], seqOp, comboOp).collect()
0251 values2 = data2.aggregateByKey([], seqOp, comboOp).collect()
0252
0253 values1.sort()
0254 values2.sort()
0255
0256 ground_truth = [(i, [1]*2) for i in range(10)]
0257 self.assertEqual(values1, ground_truth)
0258 self.assertEqual(values2, ground_truth)
0259
0260 def test_fold_mutable_zero_value(self):
0261
0262
0263
0264
0265 from collections import defaultdict
0266
0267 counts1 = defaultdict(int, dict((i, 1) for i in range(10)))
0268 counts2 = defaultdict(int, dict((i, 1) for i in range(3, 8)))
0269 counts3 = defaultdict(int, dict((i, 1) for i in range(4, 7)))
0270 counts4 = defaultdict(int, dict((i, 1) for i in range(5, 6)))
0271 all_counts = [counts1, counts2, counts3, counts4]
0272
0273 data1 = self.sc.parallelize(all_counts, 1)
0274 data2 = self.sc.parallelize(all_counts, 2)
0275
0276 def comboOp(x, y):
0277 for key, val in y.items():
0278 x[key] += val
0279 return x
0280
0281 fold1 = data1.fold(defaultdict(int), comboOp)
0282 fold2 = data2.fold(defaultdict(int), comboOp)
0283
0284 ground_truth = defaultdict(int)
0285 for counts in all_counts:
0286 for key, val in counts.items():
0287 ground_truth[key] += val
0288 self.assertEqual(fold1, ground_truth)
0289 self.assertEqual(fold2, ground_truth)
0290
0291 def test_fold_by_key_mutable_zero_value(self):
0292
0293
0294
0295 tuples = [(i, range(i)) for i in range(10)]*2
0296
0297 data1 = self.sc.parallelize(tuples, 1)
0298 data2 = self.sc.parallelize(tuples, 2)
0299
0300 def comboOp(x, y):
0301 x.extend(y)
0302 return x
0303
0304 values1 = data1.foldByKey([], comboOp).collect()
0305 values2 = data2.foldByKey([], comboOp).collect()
0306
0307 values1.sort()
0308 values2.sort()
0309
0310
0311 ground_truth = [(i, list(range(i))*2) for i in range(10)]
0312 self.assertEqual(values1, ground_truth)
0313 self.assertEqual(values2, ground_truth)
0314
0315 def test_aggregate_by_key(self):
0316 data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)
0317
0318 def seqOp(x, y):
0319 x.add(y)
0320 return x
0321
0322 def combOp(x, y):
0323 x |= y
0324 return x
0325
0326 sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect())
0327 self.assertEqual(3, len(sets))
0328 self.assertEqual(set([1]), sets[1])
0329 self.assertEqual(set([2]), sets[3])
0330 self.assertEqual(set([1, 3]), sets[5])
0331
0332 def test_itemgetter(self):
0333 rdd = self.sc.parallelize([range(10)])
0334 from operator import itemgetter
0335 self.assertEqual([1], rdd.map(itemgetter(1)).collect())
0336 self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect())
0337
0338 def test_namedtuple_in_rdd(self):
0339 from collections import namedtuple
0340 Person = namedtuple("Person", "id firstName lastName")
0341 jon = Person(1, "Jon", "Doe")
0342 jane = Person(2, "Jane", "Doe")
0343 theDoes = self.sc.parallelize([jon, jane])
0344 self.assertEqual([jon, jane], theDoes.collect())
0345
0346 def test_large_broadcast(self):
0347 N = 10000
0348 data = [[float(i) for i in range(300)] for i in range(N)]
0349 bdata = self.sc.broadcast(data)
0350 m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
0351 self.assertEqual(N, m)
0352
0353 def test_unpersist(self):
0354 N = 1000
0355 data = [[float(i) for i in range(300)] for i in range(N)]
0356 bdata = self.sc.broadcast(data)
0357 bdata.unpersist()
0358 m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
0359 self.assertEqual(N, m)
0360 bdata.destroy(blocking=True)
0361 try:
0362 self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
0363 except Exception as e:
0364 pass
0365 else:
0366 raise Exception("job should fail after destroy the broadcast")
0367
0368 def test_multiple_broadcasts(self):
0369 N = 1 << 21
0370 b1 = self.sc.broadcast(set(range(N)))
0371 r = list(range(1 << 15))
0372 random.shuffle(r)
0373 s = str(r).encode()
0374 checksum = hashlib.md5(s).hexdigest()
0375 b2 = self.sc.broadcast(s)
0376 r = list(set(self.sc.parallelize(range(10), 10).map(
0377 lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
0378 self.assertEqual(1, len(r))
0379 size, csum = r[0]
0380 self.assertEqual(N, size)
0381 self.assertEqual(checksum, csum)
0382
0383 random.shuffle(r)
0384 s = str(r).encode()
0385 checksum = hashlib.md5(s).hexdigest()
0386 b2 = self.sc.broadcast(s)
0387 r = list(set(self.sc.parallelize(range(10), 10).map(
0388 lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
0389 self.assertEqual(1, len(r))
0390 size, csum = r[0]
0391 self.assertEqual(N, size)
0392 self.assertEqual(checksum, csum)
0393
0394 def test_multithread_broadcast_pickle(self):
0395 import threading
0396
0397 b1 = self.sc.broadcast(list(range(3)))
0398 b2 = self.sc.broadcast(list(range(3)))
0399
0400 def f1():
0401 return b1.value
0402
0403 def f2():
0404 return b2.value
0405
0406 funcs_num_pickled = {f1: None, f2: None}
0407
0408 def do_pickle(f, sc):
0409 command = (f, None, sc.serializer, sc.serializer)
0410 ser = CloudPickleSerializer()
0411 ser.dumps(command)
0412
0413 def process_vars(sc):
0414 broadcast_vars = list(sc._pickled_broadcast_vars)
0415 num_pickled = len(broadcast_vars)
0416 sc._pickled_broadcast_vars.clear()
0417 return num_pickled
0418
0419 def run(f, sc):
0420 do_pickle(f, sc)
0421 funcs_num_pickled[f] = process_vars(sc)
0422
0423
0424 do_pickle(f1, self.sc)
0425
0426
0427 t = threading.Thread(target=run, args=(f2, self.sc))
0428 t.start()
0429 t.join()
0430
0431
0432 funcs_num_pickled[f1] = process_vars(self.sc)
0433
0434 self.assertEqual(funcs_num_pickled[f1], 1)
0435 self.assertEqual(funcs_num_pickled[f2], 1)
0436 self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0)
0437
0438 def test_large_closure(self):
0439 N = 200000
0440 data = [float(i) for i in xrange(N)]
0441 rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data))
0442 self.assertEqual(N, rdd.first())
0443
0444 self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count())
0445
0446 def test_zip_with_different_serializers(self):
0447 a = self.sc.parallelize(range(5))
0448 b = self.sc.parallelize(range(100, 105))
0449 self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
0450 a = a._reserialize(BatchedSerializer(PickleSerializer(), 2))
0451 b = b._reserialize(MarshalSerializer())
0452 self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
0453
0454 path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
0455 t = self.sc.textFile(path)
0456 cnt = t.count()
0457 self.assertEqual(cnt, t.zip(t).count())
0458 rdd = t.map(str)
0459 self.assertEqual(cnt, t.zip(rdd).count())
0460
0461 self.assertEqual(cnt, t.zip(rdd).count())
0462
0463 def test_zip_with_different_object_sizes(self):
0464
0465 a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i)
0466 b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i)
0467 self.assertEqual(10000, a.zip(b).count())
0468
0469 def test_zip_with_different_number_of_items(self):
0470 a = self.sc.parallelize(range(5), 2)
0471
0472 b = self.sc.parallelize(range(100, 106), 3)
0473 self.assertRaises(ValueError, lambda: a.zip(b))
0474 with QuietTest(self.sc):
0475
0476 b = self.sc.parallelize(range(100, 104), 2)
0477 self.assertRaises(Exception, lambda: a.zip(b).count())
0478
0479 b = self.sc.parallelize(range(100, 106), 2)
0480 self.assertRaises(Exception, lambda: a.zip(b).count())
0481
0482 a = self.sc.parallelize([2, 3], 2).flatMap(range)
0483 b = self.sc.parallelize([3, 2], 2).flatMap(range)
0484 self.assertEqual(a.count(), b.count())
0485 self.assertRaises(Exception, lambda: a.zip(b).count())
0486
0487 def test_count_approx_distinct(self):
0488 rdd = self.sc.parallelize(xrange(1000))
0489 self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050)
0490 self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050)
0491 self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050)
0492 self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.03) < 1050)
0493
0494 rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7)
0495 self.assertTrue(18 < rdd.countApproxDistinct() < 22)
0496 self.assertTrue(18 < rdd.map(float).countApproxDistinct() < 22)
0497 self.assertTrue(18 < rdd.map(str).countApproxDistinct() < 22)
0498 self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22)
0499
0500 self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001))
0501
0502 def test_histogram(self):
0503
0504 rdd = self.sc.parallelize([])
0505 self.assertEqual([0], rdd.histogram([0, 10])[1])
0506 self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1])
0507 self.assertRaises(ValueError, lambda: rdd.histogram(1))
0508
0509
0510 rdd = self.sc.parallelize([10.01, -0.01])
0511 self.assertEqual([0], rdd.histogram([0, 10])[1])
0512 self.assertEqual([0, 0], rdd.histogram((0, 4, 10))[1])
0513
0514
0515 rdd = self.sc.parallelize(range(1, 5))
0516 self.assertEqual([4], rdd.histogram([0, 10])[1])
0517 self.assertEqual([3, 1], rdd.histogram([0, 4, 10])[1])
0518
0519
0520 self.assertEqual([4], rdd.histogram([1, 4])[1])
0521
0522
0523 rdd = self.sc.parallelize([10.01, -0.01])
0524 self.assertEqual([0, 0], rdd.histogram([0, 5, 10])[1])
0525
0526
0527 rdd = self.sc.parallelize([10.01, -0.01])
0528 self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1])
0529
0530
0531 rdd = self.sc.parallelize([1, 2, 3, 5, 6])
0532 self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1])
0533
0534
0535 rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')])
0536 self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1])
0537
0538
0539 rdd = self.sc.parallelize([1, 2, 3, 5, 6])
0540 self.assertEqual([3, 2], rdd.histogram([0, 5, 11])[1])
0541
0542
0543 rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01])
0544 self.assertEqual([4, 3], rdd.histogram([0, 5, 11])[1])
0545
0546
0547 rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1])
0548 self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1])
0549
0550
0551 rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0,
0552 199.0, 200.0, 200.1, None, float('nan')])
0553 self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1])
0554
0555
0556 rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")])
0557 self.assertEqual([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1])
0558
0559
0560 self.assertRaises(ValueError, lambda: rdd.histogram([]))
0561 self.assertRaises(ValueError, lambda: rdd.histogram([1]))
0562 self.assertRaises(ValueError, lambda: rdd.histogram(0))
0563 self.assertRaises(TypeError, lambda: rdd.histogram({}))
0564
0565
0566 rdd = self.sc.parallelize(range(1, 5))
0567 self.assertEqual(([1, 4], [4]), rdd.histogram(1))
0568
0569
0570 rdd = self.sc.parallelize([1])
0571 self.assertEqual(([1, 1], [1]), rdd.histogram(1))
0572
0573
0574 rdd = self.sc.parallelize([1] * 4)
0575 self.assertEqual(([1, 1], [4]), rdd.histogram(1))
0576
0577
0578 rdd = self.sc.parallelize(range(1, 5))
0579 self.assertEqual(([1, 2.5, 4], [2, 2]), rdd.histogram(2))
0580
0581
0582 rdd = self.sc.parallelize([1, 2])
0583 buckets = [1 + 0.2 * i for i in range(6)]
0584 hist = [1, 0, 0, 0, 1]
0585 self.assertEqual((buckets, hist), rdd.histogram(5))
0586
0587
0588 rdd = self.sc.parallelize([1, float('inf')])
0589 self.assertRaises(ValueError, lambda: rdd.histogram(2))
0590 rdd = self.sc.parallelize([float('nan')])
0591 self.assertRaises(ValueError, lambda: rdd.histogram(2))
0592
0593
0594 rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2)
0595 self.assertEqual([2, 2], rdd.histogram(["a", "b", "c"])[1])
0596 self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1))
0597 self.assertRaises(TypeError, lambda: rdd.histogram(2))
0598
0599 def test_repartitionAndSortWithinPartitions_asc(self):
0600 rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2)
0601
0602 repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, True)
0603 partitions = repartitioned.glom().collect()
0604 self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)])
0605 self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)])
0606
0607 def test_repartitionAndSortWithinPartitions_desc(self):
0608 rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2)
0609
0610 repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, False)
0611 partitions = repartitioned.glom().collect()
0612 self.assertEqual(partitions[0], [(2, 6), (0, 5), (0, 8)])
0613 self.assertEqual(partitions[1], [(3, 8), (3, 8), (1, 3)])
0614
0615 def test_repartition_no_skewed(self):
0616 num_partitions = 20
0617 a = self.sc.parallelize(range(int(1000)), 2)
0618 l = a.repartition(num_partitions).glom().map(len).collect()
0619 zeros = len([x for x in l if x == 0])
0620 self.assertTrue(zeros == 0)
0621 l = a.coalesce(num_partitions, True).glom().map(len).collect()
0622 zeros = len([x for x in l if x == 0])
0623 self.assertTrue(zeros == 0)
0624
0625 def test_repartition_on_textfile(self):
0626 path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
0627 rdd = self.sc.textFile(path)
0628 result = rdd.repartition(1).collect()
0629 self.assertEqual(u"Hello World!", result[0])
0630
0631 def test_distinct(self):
0632 rdd = self.sc.parallelize((1, 2, 3)*10, 10)
0633 self.assertEqual(rdd.getNumPartitions(), 10)
0634 self.assertEqual(rdd.distinct().count(), 3)
0635 result = rdd.distinct(5)
0636 self.assertEqual(result.getNumPartitions(), 5)
0637 self.assertEqual(result.count(), 3)
0638
0639 def test_external_group_by_key(self):
0640 self.sc._conf.set("spark.python.worker.memory", "1m")
0641 N = 2000001
0642 kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x))
0643 gkv = kv.groupByKey().cache()
0644 self.assertEqual(3, gkv.count())
0645 filtered = gkv.filter(lambda kv: kv[0] == 1)
0646 self.assertEqual(1, filtered.count())
0647 self.assertEqual([(1, N // 3)], filtered.mapValues(len).collect())
0648 self.assertEqual([(N // 3, N // 3)],
0649 filtered.values().map(lambda x: (len(x), len(list(x)))).collect())
0650 result = filtered.collect()[0][1]
0651 self.assertEqual(N // 3, len(result))
0652 self.assertTrue(isinstance(result.data, shuffle.ExternalListOfList))
0653
0654 def test_sort_on_empty_rdd(self):
0655 self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect())
0656
0657 def test_sample(self):
0658 rdd = self.sc.parallelize(range(0, 100), 4)
0659 wo = rdd.sample(False, 0.1, 2).collect()
0660 wo_dup = rdd.sample(False, 0.1, 2).collect()
0661 self.assertSetEqual(set(wo), set(wo_dup))
0662 wr = rdd.sample(True, 0.2, 5).collect()
0663 wr_dup = rdd.sample(True, 0.2, 5).collect()
0664 self.assertSetEqual(set(wr), set(wr_dup))
0665 wo_s10 = rdd.sample(False, 0.3, 10).collect()
0666 wo_s20 = rdd.sample(False, 0.3, 20).collect()
0667 self.assertNotEqual(set(wo_s10), set(wo_s20))
0668 wr_s11 = rdd.sample(True, 0.4, 11).collect()
0669 wr_s21 = rdd.sample(True, 0.4, 21).collect()
0670 self.assertNotEqual(set(wr_s11), set(wr_s21))
0671
0672 def test_null_in_rdd(self):
0673 jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc)
0674 rdd = RDD(jrdd, self.sc, UTF8Deserializer())
0675 self.assertEqual([u"a", None, u"b"], rdd.collect())
0676 rdd = RDD(jrdd, self.sc, NoOpSerializer())
0677 self.assertEqual([b"a", None, b"b"], rdd.collect())
0678
0679 def test_multiple_python_java_RDD_conversions(self):
0680
0681 data = [
0682 (u'1', {u'director': u'David Lean'}),
0683 (u'2', {u'director': u'Andrew Dominik'})
0684 ]
0685 data_rdd = self.sc.parallelize(data)
0686 data_java_rdd = data_rdd._to_java_object_rdd()
0687 data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd)
0688 converted_rdd = RDD(data_python_rdd, self.sc)
0689 self.assertEqual(2, converted_rdd.count())
0690
0691
0692 data_java_rdd = converted_rdd._to_java_object_rdd()
0693 data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd)
0694 converted_rdd = RDD(data_python_rdd, self.sc)
0695 self.assertEqual(2, converted_rdd.count())
0696
0697
0698 def test_take_on_jrdd(self):
0699 rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x))
0700 rdd._jrdd.first()
0701
0702 def test_sortByKey_uses_all_partitions_not_only_first_and_last(self):
0703
0704 seq = [(i * 59 % 101, i) for i in range(101)]
0705 rdd = self.sc.parallelize(seq)
0706 for ascending in [True, False]:
0707 sort = rdd.sortByKey(ascending=ascending, numPartitions=5)
0708 self.assertEqual(sort.collect(), sorted(seq, reverse=not ascending))
0709 sizes = sort.glom().map(len).collect()
0710 for size in sizes:
0711 self.assertGreater(size, 0)
0712
0713 def test_pipe_functions(self):
0714 data = ['1', '2', '3']
0715 rdd = self.sc.parallelize(data)
0716 with QuietTest(self.sc):
0717 self.assertEqual([], rdd.pipe('java').collect())
0718 self.assertRaises(Py4JJavaError, rdd.pipe('java', checkCode=True).collect)
0719 result = rdd.pipe('cat').collect()
0720 result.sort()
0721 for x, y in zip(data, result):
0722 self.assertEqual(x, y)
0723 self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect)
0724 self.assertEqual([], rdd.pipe('grep 4').collect())
0725
0726 def test_pipe_unicode(self):
0727
0728 data = [u'\u6d4b\u8bd5', '1']
0729 rdd = self.sc.parallelize(data)
0730 result = rdd.pipe('cat').collect()
0731 self.assertEqual(data, result)
0732
0733 def test_stopiteration_in_user_code(self):
0734
0735 def stopit(*x):
0736 raise StopIteration()
0737
0738 seq_rdd = self.sc.parallelize(range(10))
0739 keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10))
0740 msg = "Caught StopIteration thrown from user's code; failing the task"
0741
0742 self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect)
0743 self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect)
0744 self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
0745 self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit)
0746 self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit)
0747 self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
0748 self.assertRaisesRegexp(Py4JJavaError, msg,
0749 seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
0750
0751
0752
0753
0754
0755 self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
0756 keyed_rdd.reduceByKeyLocally, stopit)
0757 self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
0758 seq_rdd.aggregate, 0, stopit, lambda *x: 1)
0759 self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
0760 seq_rdd.aggregate, 0, lambda *x: 1, stopit)
0761
0762 def test_overwritten_global_func(self):
0763
0764 global global_func
0765 self.assertEqual(self.sc.parallelize([1]).map(lambda _: global_func()).first(), "Hi")
0766 global_func = lambda: "Yeah"
0767 self.assertEqual(self.sc.parallelize([1]).map(lambda _: global_func()).first(), "Yeah")
0768
0769 def test_to_local_iterator_failure(self):
0770
0771
0772 def fail(_):
0773 raise RuntimeError("local iterator error")
0774
0775 rdd = self.sc.range(10).map(fail)
0776
0777 with self.assertRaisesRegexp(Exception, "local iterator error"):
0778 for _ in rdd.toLocalIterator():
0779 pass
0780
0781 def test_to_local_iterator_collects_single_partition(self):
0782
0783
0784 def fail_last(x):
0785 if x == 9:
0786 raise RuntimeError("This should not be hit")
0787 return x
0788
0789 rdd = self.sc.range(12, numSlices=4).map(fail_last)
0790 it = rdd.toLocalIterator()
0791
0792
0793
0794 for i in range(4):
0795 self.assertEqual(i, next(it))
0796
0797 def test_multiple_group_jobs(self):
0798 import threading
0799 group_a = "job_ids_to_cancel"
0800 group_b = "job_ids_to_run"
0801
0802 threads = []
0803 thread_ids = range(4)
0804 thread_ids_to_cancel = [i for i in thread_ids if i % 2 == 0]
0805 thread_ids_to_run = [i for i in thread_ids if i % 2 != 0]
0806
0807
0808
0809 is_job_cancelled = [False for _ in thread_ids]
0810
0811 def run_job(job_group, index):
0812 """
0813 Executes a job with the group ``job_group``. Each job waits for 3 seconds
0814 and then exits.
0815 """
0816 try:
0817 self.sc.parallelize([15]).map(lambda x: time.sleep(x)) \
0818 .collectWithJobGroup(job_group, "test rdd collect with setting job group")
0819 is_job_cancelled[index] = False
0820 except Exception:
0821
0822 is_job_cancelled[index] = True
0823
0824
0825 run_job(group_a, 0)
0826 self.assertFalse(is_job_cancelled[0])
0827
0828
0829 for i in thread_ids_to_cancel:
0830 t = threading.Thread(target=run_job, args=(group_a, i))
0831 t.start()
0832 threads.append(t)
0833
0834 for i in thread_ids_to_run:
0835 t = threading.Thread(target=run_job, args=(group_b, i))
0836 t.start()
0837 threads.append(t)
0838
0839
0840 time.sleep(3)
0841
0842 self.sc.cancelJobGroup(group_a)
0843
0844
0845 for t in threads:
0846 t.join()
0847
0848 for i in thread_ids_to_cancel:
0849 self.assertTrue(
0850 is_job_cancelled[i],
0851 "Thread {i}: Job in group A was not cancelled.".format(i=i))
0852
0853 for i in thread_ids_to_run:
0854 self.assertFalse(
0855 is_job_cancelled[i],
0856 "Thread {i}: Job in group B did not succeeded.".format(i=i))
0857
0858
0859 if __name__ == "__main__":
0860 import unittest
0861 from pyspark.tests.test_rdd import *
0862
0863 try:
0864 import xmlrunner
0865 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0866 except ImportError:
0867 testRunner = None
0868 unittest.main(testRunner=testRunner, verbosity=2)