Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 from datetime import datetime, timedelta
0018 import hashlib
0019 import os
0020 import random
0021 import sys
0022 import tempfile
0023 import time
0024 from glob import glob
0025 
0026 from py4j.protocol import Py4JJavaError
0027 
0028 from pyspark import shuffle, RDD
0029 from pyspark.serializers import CloudPickleSerializer, BatchedSerializer, PickleSerializer,\
0030     MarshalSerializer, UTF8Deserializer, NoOpSerializer
0031 from pyspark.testing.utils import ReusedPySparkTestCase, SPARK_HOME, QuietTest
0032 
0033 if sys.version_info[0] >= 3:
0034     xrange = range
0035 
0036 
0037 global_func = lambda: "Hi"
0038 
0039 
0040 class RDDTests(ReusedPySparkTestCase):
0041 
0042     def test_range(self):
0043         self.assertEqual(self.sc.range(1, 1).count(), 0)
0044         self.assertEqual(self.sc.range(1, 0, -1).count(), 1)
0045         self.assertEqual(self.sc.range(0, 1 << 40, 1 << 39).count(), 2)
0046 
0047     def test_id(self):
0048         rdd = self.sc.parallelize(range(10))
0049         id = rdd.id()
0050         self.assertEqual(id, rdd.id())
0051         rdd2 = rdd.map(str).filter(bool)
0052         id2 = rdd2.id()
0053         self.assertEqual(id + 1, id2)
0054         self.assertEqual(id2, rdd2.id())
0055 
0056     def test_empty_rdd(self):
0057         rdd = self.sc.emptyRDD()
0058         self.assertTrue(rdd.isEmpty())
0059 
0060     def test_sum(self):
0061         self.assertEqual(0, self.sc.emptyRDD().sum())
0062         self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum())
0063 
0064     def test_to_localiterator(self):
0065         rdd = self.sc.parallelize([1, 2, 3])
0066         it = rdd.toLocalIterator()
0067         self.assertEqual([1, 2, 3], sorted(it))
0068 
0069         rdd2 = rdd.repartition(1000)
0070         it2 = rdd2.toLocalIterator()
0071         self.assertEqual([1, 2, 3], sorted(it2))
0072 
0073     def test_to_localiterator_prefetch(self):
0074         # Test that we fetch the next partition in parallel
0075         # We do this by returning the current time and:
0076         # reading the first elem, waiting, and reading the second elem
0077         # If not in parallel then these would be at different times
0078         # But since they are being computed in parallel we see the time
0079         # is "close enough" to the same.
0080         rdd = self.sc.parallelize(range(2), 2)
0081         times1 = rdd.map(lambda x: datetime.now())
0082         times2 = rdd.map(lambda x: datetime.now())
0083         times_iter_prefetch = times1.toLocalIterator(prefetchPartitions=True)
0084         times_iter = times2.toLocalIterator(prefetchPartitions=False)
0085         times_prefetch_head = next(times_iter_prefetch)
0086         times_head = next(times_iter)
0087         time.sleep(2)
0088         times_next = next(times_iter)
0089         times_prefetch_next = next(times_iter_prefetch)
0090         self.assertTrue(times_next - times_head >= timedelta(seconds=2))
0091         self.assertTrue(times_prefetch_next - times_prefetch_head < timedelta(seconds=1))
0092 
0093     def test_save_as_textfile_with_unicode(self):
0094         # Regression test for SPARK-970
0095         x = u"\u00A1Hola, mundo!"
0096         data = self.sc.parallelize([x])
0097         tempFile = tempfile.NamedTemporaryFile(delete=True)
0098         tempFile.close()
0099         data.saveAsTextFile(tempFile.name)
0100         raw_contents = b''.join(open(p, 'rb').read()
0101                                 for p in glob(tempFile.name + "/part-0000*"))
0102         self.assertEqual(x, raw_contents.strip().decode("utf-8"))
0103 
0104     def test_save_as_textfile_with_utf8(self):
0105         x = u"\u00A1Hola, mundo!"
0106         data = self.sc.parallelize([x.encode("utf-8")])
0107         tempFile = tempfile.NamedTemporaryFile(delete=True)
0108         tempFile.close()
0109         data.saveAsTextFile(tempFile.name)
0110         raw_contents = b''.join(open(p, 'rb').read()
0111                                 for p in glob(tempFile.name + "/part-0000*"))
0112         self.assertEqual(x, raw_contents.strip().decode('utf8'))
0113 
0114     def test_transforming_cartesian_result(self):
0115         # Regression test for SPARK-1034
0116         rdd1 = self.sc.parallelize([1, 2])
0117         rdd2 = self.sc.parallelize([3, 4])
0118         cart = rdd1.cartesian(rdd2)
0119         result = cart.map(lambda x_y3: x_y3[0] + x_y3[1]).collect()
0120 
0121     def test_transforming_pickle_file(self):
0122         # Regression test for SPARK-2601
0123         data = self.sc.parallelize([u"Hello", u"World!"])
0124         tempFile = tempfile.NamedTemporaryFile(delete=True)
0125         tempFile.close()
0126         data.saveAsPickleFile(tempFile.name)
0127         pickled_file = self.sc.pickleFile(tempFile.name)
0128         pickled_file.map(lambda x: x).collect()
0129 
0130     def test_cartesian_on_textfile(self):
0131         # Regression test for
0132         path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
0133         a = self.sc.textFile(path)
0134         result = a.cartesian(a).collect()
0135         (x, y) = result[0]
0136         self.assertEqual(u"Hello World!", x.strip())
0137         self.assertEqual(u"Hello World!", y.strip())
0138 
0139     def test_cartesian_chaining(self):
0140         # Tests for SPARK-16589
0141         rdd = self.sc.parallelize(range(10), 2)
0142         self.assertSetEqual(
0143             set(rdd.cartesian(rdd).cartesian(rdd).collect()),
0144             set([((x, y), z) for x in range(10) for y in range(10) for z in range(10)])
0145         )
0146 
0147         self.assertSetEqual(
0148             set(rdd.cartesian(rdd.cartesian(rdd)).collect()),
0149             set([(x, (y, z)) for x in range(10) for y in range(10) for z in range(10)])
0150         )
0151 
0152         self.assertSetEqual(
0153             set(rdd.cartesian(rdd.zip(rdd)).collect()),
0154             set([(x, (y, y)) for x in range(10) for y in range(10)])
0155         )
0156 
0157     def test_zip_chaining(self):
0158         # Tests for SPARK-21985
0159         rdd = self.sc.parallelize('abc', 2)
0160         self.assertSetEqual(
0161             set(rdd.zip(rdd).zip(rdd).collect()),
0162             set([((x, x), x) for x in 'abc'])
0163         )
0164         self.assertSetEqual(
0165             set(rdd.zip(rdd.zip(rdd)).collect()),
0166             set([(x, (x, x)) for x in 'abc'])
0167         )
0168 
0169     def test_union_pair_rdd(self):
0170         # SPARK-31788: test if pair RDDs can be combined by union.
0171         rdd = self.sc.parallelize([1, 2])
0172         pair_rdd = rdd.zip(rdd)
0173         unionRDD = self.sc.union([pair_rdd, pair_rdd])
0174         self.assertEqual(
0175             set(unionRDD.collect()),
0176             set([(1, 1), (2, 2), (1, 1), (2, 2)])
0177         )
0178         self.assertEqual(unionRDD.count(), 4)
0179 
0180     def test_deleting_input_files(self):
0181         # Regression test for SPARK-1025
0182         tempFile = tempfile.NamedTemporaryFile(delete=False)
0183         tempFile.write(b"Hello World!")
0184         tempFile.close()
0185         data = self.sc.textFile(tempFile.name)
0186         filtered_data = data.filter(lambda x: True)
0187         self.assertEqual(1, filtered_data.count())
0188         os.unlink(tempFile.name)
0189         with QuietTest(self.sc):
0190             self.assertRaises(Exception, lambda: filtered_data.count())
0191 
0192     def test_sampling_default_seed(self):
0193         # Test for SPARK-3995 (default seed setting)
0194         data = self.sc.parallelize(xrange(1000), 1)
0195         subset = data.takeSample(False, 10)
0196         self.assertEqual(len(subset), 10)
0197 
0198     def test_aggregate_mutable_zero_value(self):
0199         # Test for SPARK-9021; uses aggregate and treeAggregate to build dict
0200         # representing a counter of ints
0201         # NOTE: dict is used instead of collections.Counter for Python 2.6
0202         # compatibility
0203         from collections import defaultdict
0204 
0205         # Show that single or multiple partitions work
0206         data1 = self.sc.range(10, numSlices=1)
0207         data2 = self.sc.range(10, numSlices=2)
0208 
0209         def seqOp(x, y):
0210             x[y] += 1
0211             return x
0212 
0213         def comboOp(x, y):
0214             for key, val in y.items():
0215                 x[key] += val
0216             return x
0217 
0218         counts1 = data1.aggregate(defaultdict(int), seqOp, comboOp)
0219         counts2 = data2.aggregate(defaultdict(int), seqOp, comboOp)
0220         counts3 = data1.treeAggregate(defaultdict(int), seqOp, comboOp, 2)
0221         counts4 = data2.treeAggregate(defaultdict(int), seqOp, comboOp, 2)
0222 
0223         ground_truth = defaultdict(int, dict((i, 1) for i in range(10)))
0224         self.assertEqual(counts1, ground_truth)
0225         self.assertEqual(counts2, ground_truth)
0226         self.assertEqual(counts3, ground_truth)
0227         self.assertEqual(counts4, ground_truth)
0228 
0229     def test_aggregate_by_key_mutable_zero_value(self):
0230         # Test for SPARK-9021; uses aggregateByKey to make a pair RDD that
0231         # contains lists of all values for each key in the original RDD
0232 
0233         # list(range(...)) for Python 3.x compatibility (can't use * operator
0234         # on a range object)
0235         # list(zip(...)) for Python 3.x compatibility (want to parallelize a
0236         # collection, not a zip object)
0237         tuples = list(zip(list(range(10))*2, [1]*20))
0238         # Show that single or multiple partitions work
0239         data1 = self.sc.parallelize(tuples, 1)
0240         data2 = self.sc.parallelize(tuples, 2)
0241 
0242         def seqOp(x, y):
0243             x.append(y)
0244             return x
0245 
0246         def comboOp(x, y):
0247             x.extend(y)
0248             return x
0249 
0250         values1 = data1.aggregateByKey([], seqOp, comboOp).collect()
0251         values2 = data2.aggregateByKey([], seqOp, comboOp).collect()
0252         # Sort lists to ensure clean comparison with ground_truth
0253         values1.sort()
0254         values2.sort()
0255 
0256         ground_truth = [(i, [1]*2) for i in range(10)]
0257         self.assertEqual(values1, ground_truth)
0258         self.assertEqual(values2, ground_truth)
0259 
0260     def test_fold_mutable_zero_value(self):
0261         # Test for SPARK-9021; uses fold to merge an RDD of dict counters into
0262         # a single dict
0263         # NOTE: dict is used instead of collections.Counter for Python 2.6
0264         # compatibility
0265         from collections import defaultdict
0266 
0267         counts1 = defaultdict(int, dict((i, 1) for i in range(10)))
0268         counts2 = defaultdict(int, dict((i, 1) for i in range(3, 8)))
0269         counts3 = defaultdict(int, dict((i, 1) for i in range(4, 7)))
0270         counts4 = defaultdict(int, dict((i, 1) for i in range(5, 6)))
0271         all_counts = [counts1, counts2, counts3, counts4]
0272         # Show that single or multiple partitions work
0273         data1 = self.sc.parallelize(all_counts, 1)
0274         data2 = self.sc.parallelize(all_counts, 2)
0275 
0276         def comboOp(x, y):
0277             for key, val in y.items():
0278                 x[key] += val
0279             return x
0280 
0281         fold1 = data1.fold(defaultdict(int), comboOp)
0282         fold2 = data2.fold(defaultdict(int), comboOp)
0283 
0284         ground_truth = defaultdict(int)
0285         for counts in all_counts:
0286             for key, val in counts.items():
0287                 ground_truth[key] += val
0288         self.assertEqual(fold1, ground_truth)
0289         self.assertEqual(fold2, ground_truth)
0290 
0291     def test_fold_by_key_mutable_zero_value(self):
0292         # Test for SPARK-9021; uses foldByKey to make a pair RDD that contains
0293         # lists of all values for each key in the original RDD
0294 
0295         tuples = [(i, range(i)) for i in range(10)]*2
0296         # Show that single or multiple partitions work
0297         data1 = self.sc.parallelize(tuples, 1)
0298         data2 = self.sc.parallelize(tuples, 2)
0299 
0300         def comboOp(x, y):
0301             x.extend(y)
0302             return x
0303 
0304         values1 = data1.foldByKey([], comboOp).collect()
0305         values2 = data2.foldByKey([], comboOp).collect()
0306         # Sort lists to ensure clean comparison with ground_truth
0307         values1.sort()
0308         values2.sort()
0309 
0310         # list(range(...)) for Python 3.x compatibility
0311         ground_truth = [(i, list(range(i))*2) for i in range(10)]
0312         self.assertEqual(values1, ground_truth)
0313         self.assertEqual(values2, ground_truth)
0314 
0315     def test_aggregate_by_key(self):
0316         data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)
0317 
0318         def seqOp(x, y):
0319             x.add(y)
0320             return x
0321 
0322         def combOp(x, y):
0323             x |= y
0324             return x
0325 
0326         sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect())
0327         self.assertEqual(3, len(sets))
0328         self.assertEqual(set([1]), sets[1])
0329         self.assertEqual(set([2]), sets[3])
0330         self.assertEqual(set([1, 3]), sets[5])
0331 
0332     def test_itemgetter(self):
0333         rdd = self.sc.parallelize([range(10)])
0334         from operator import itemgetter
0335         self.assertEqual([1], rdd.map(itemgetter(1)).collect())
0336         self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect())
0337 
0338     def test_namedtuple_in_rdd(self):
0339         from collections import namedtuple
0340         Person = namedtuple("Person", "id firstName lastName")
0341         jon = Person(1, "Jon", "Doe")
0342         jane = Person(2, "Jane", "Doe")
0343         theDoes = self.sc.parallelize([jon, jane])
0344         self.assertEqual([jon, jane], theDoes.collect())
0345 
0346     def test_large_broadcast(self):
0347         N = 10000
0348         data = [[float(i) for i in range(300)] for i in range(N)]
0349         bdata = self.sc.broadcast(data)  # 27MB
0350         m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
0351         self.assertEqual(N, m)
0352 
0353     def test_unpersist(self):
0354         N = 1000
0355         data = [[float(i) for i in range(300)] for i in range(N)]
0356         bdata = self.sc.broadcast(data)  # 3MB
0357         bdata.unpersist()
0358         m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
0359         self.assertEqual(N, m)
0360         bdata.destroy(blocking=True)
0361         try:
0362             self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
0363         except Exception as e:
0364             pass
0365         else:
0366             raise Exception("job should fail after destroy the broadcast")
0367 
0368     def test_multiple_broadcasts(self):
0369         N = 1 << 21
0370         b1 = self.sc.broadcast(set(range(N)))  # multiple blocks in JVM
0371         r = list(range(1 << 15))
0372         random.shuffle(r)
0373         s = str(r).encode()
0374         checksum = hashlib.md5(s).hexdigest()
0375         b2 = self.sc.broadcast(s)
0376         r = list(set(self.sc.parallelize(range(10), 10).map(
0377             lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
0378         self.assertEqual(1, len(r))
0379         size, csum = r[0]
0380         self.assertEqual(N, size)
0381         self.assertEqual(checksum, csum)
0382 
0383         random.shuffle(r)
0384         s = str(r).encode()
0385         checksum = hashlib.md5(s).hexdigest()
0386         b2 = self.sc.broadcast(s)
0387         r = list(set(self.sc.parallelize(range(10), 10).map(
0388             lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
0389         self.assertEqual(1, len(r))
0390         size, csum = r[0]
0391         self.assertEqual(N, size)
0392         self.assertEqual(checksum, csum)
0393 
0394     def test_multithread_broadcast_pickle(self):
0395         import threading
0396 
0397         b1 = self.sc.broadcast(list(range(3)))
0398         b2 = self.sc.broadcast(list(range(3)))
0399 
0400         def f1():
0401             return b1.value
0402 
0403         def f2():
0404             return b2.value
0405 
0406         funcs_num_pickled = {f1: None, f2: None}
0407 
0408         def do_pickle(f, sc):
0409             command = (f, None, sc.serializer, sc.serializer)
0410             ser = CloudPickleSerializer()
0411             ser.dumps(command)
0412 
0413         def process_vars(sc):
0414             broadcast_vars = list(sc._pickled_broadcast_vars)
0415             num_pickled = len(broadcast_vars)
0416             sc._pickled_broadcast_vars.clear()
0417             return num_pickled
0418 
0419         def run(f, sc):
0420             do_pickle(f, sc)
0421             funcs_num_pickled[f] = process_vars(sc)
0422 
0423         # pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage
0424         do_pickle(f1, self.sc)
0425 
0426         # run all for f2, should only add/count/clear b2 from worker thread local storage
0427         t = threading.Thread(target=run, args=(f2, self.sc))
0428         t.start()
0429         t.join()
0430 
0431         # count number of vars pickled in main thread, only b1 should be counted and cleared
0432         funcs_num_pickled[f1] = process_vars(self.sc)
0433 
0434         self.assertEqual(funcs_num_pickled[f1], 1)
0435         self.assertEqual(funcs_num_pickled[f2], 1)
0436         self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0)
0437 
0438     def test_large_closure(self):
0439         N = 200000
0440         data = [float(i) for i in xrange(N)]
0441         rdd = self.sc.parallelize(range(1), 1).map(lambda x: len(data))
0442         self.assertEqual(N, rdd.first())
0443         # regression test for SPARK-6886
0444         self.assertEqual(1, rdd.map(lambda x: (x, 1)).groupByKey().count())
0445 
0446     def test_zip_with_different_serializers(self):
0447         a = self.sc.parallelize(range(5))
0448         b = self.sc.parallelize(range(100, 105))
0449         self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
0450         a = a._reserialize(BatchedSerializer(PickleSerializer(), 2))
0451         b = b._reserialize(MarshalSerializer())
0452         self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
0453         # regression test for SPARK-4841
0454         path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
0455         t = self.sc.textFile(path)
0456         cnt = t.count()
0457         self.assertEqual(cnt, t.zip(t).count())
0458         rdd = t.map(str)
0459         self.assertEqual(cnt, t.zip(rdd).count())
0460         # regression test for bug in _reserializer()
0461         self.assertEqual(cnt, t.zip(rdd).count())
0462 
0463     def test_zip_with_different_object_sizes(self):
0464         # regress test for SPARK-5973
0465         a = self.sc.parallelize(xrange(10000)).map(lambda i: '*' * i)
0466         b = self.sc.parallelize(xrange(10000, 20000)).map(lambda i: '*' * i)
0467         self.assertEqual(10000, a.zip(b).count())
0468 
0469     def test_zip_with_different_number_of_items(self):
0470         a = self.sc.parallelize(range(5), 2)
0471         # different number of partitions
0472         b = self.sc.parallelize(range(100, 106), 3)
0473         self.assertRaises(ValueError, lambda: a.zip(b))
0474         with QuietTest(self.sc):
0475             # different number of batched items in JVM
0476             b = self.sc.parallelize(range(100, 104), 2)
0477             self.assertRaises(Exception, lambda: a.zip(b).count())
0478             # different number of items in one pair
0479             b = self.sc.parallelize(range(100, 106), 2)
0480             self.assertRaises(Exception, lambda: a.zip(b).count())
0481             # same total number of items, but different distributions
0482             a = self.sc.parallelize([2, 3], 2).flatMap(range)
0483             b = self.sc.parallelize([3, 2], 2).flatMap(range)
0484             self.assertEqual(a.count(), b.count())
0485             self.assertRaises(Exception, lambda: a.zip(b).count())
0486 
0487     def test_count_approx_distinct(self):
0488         rdd = self.sc.parallelize(xrange(1000))
0489         self.assertTrue(950 < rdd.countApproxDistinct(0.03) < 1050)
0490         self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.03) < 1050)
0491         self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.03) < 1050)
0492         self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.03) < 1050)
0493 
0494         rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7)
0495         self.assertTrue(18 < rdd.countApproxDistinct() < 22)
0496         self.assertTrue(18 < rdd.map(float).countApproxDistinct() < 22)
0497         self.assertTrue(18 < rdd.map(str).countApproxDistinct() < 22)
0498         self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22)
0499 
0500         self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001))
0501 
0502     def test_histogram(self):
0503         # empty
0504         rdd = self.sc.parallelize([])
0505         self.assertEqual([0], rdd.histogram([0, 10])[1])
0506         self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1])
0507         self.assertRaises(ValueError, lambda: rdd.histogram(1))
0508 
0509         # out of range
0510         rdd = self.sc.parallelize([10.01, -0.01])
0511         self.assertEqual([0], rdd.histogram([0, 10])[1])
0512         self.assertEqual([0, 0], rdd.histogram((0, 4, 10))[1])
0513 
0514         # in range with one bucket
0515         rdd = self.sc.parallelize(range(1, 5))
0516         self.assertEqual([4], rdd.histogram([0, 10])[1])
0517         self.assertEqual([3, 1], rdd.histogram([0, 4, 10])[1])
0518 
0519         # in range with one bucket exact match
0520         self.assertEqual([4], rdd.histogram([1, 4])[1])
0521 
0522         # out of range with two buckets
0523         rdd = self.sc.parallelize([10.01, -0.01])
0524         self.assertEqual([0, 0], rdd.histogram([0, 5, 10])[1])
0525 
0526         # out of range with two uneven buckets
0527         rdd = self.sc.parallelize([10.01, -0.01])
0528         self.assertEqual([0, 0], rdd.histogram([0, 4, 10])[1])
0529 
0530         # in range with two buckets
0531         rdd = self.sc.parallelize([1, 2, 3, 5, 6])
0532         self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1])
0533 
0534         # in range with two bucket and None
0535         rdd = self.sc.parallelize([1, 2, 3, 5, 6, None, float('nan')])
0536         self.assertEqual([3, 2], rdd.histogram([0, 5, 10])[1])
0537 
0538         # in range with two uneven buckets
0539         rdd = self.sc.parallelize([1, 2, 3, 5, 6])
0540         self.assertEqual([3, 2], rdd.histogram([0, 5, 11])[1])
0541 
0542         # mixed range with two uneven buckets
0543         rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01])
0544         self.assertEqual([4, 3], rdd.histogram([0, 5, 11])[1])
0545 
0546         # mixed range with four uneven buckets
0547         rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1])
0548         self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1])
0549 
0550         # mixed range with uneven buckets and NaN
0551         rdd = self.sc.parallelize([-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0,
0552                                    199.0, 200.0, 200.1, None, float('nan')])
0553         self.assertEqual([4, 2, 1, 3], rdd.histogram([0.0, 5.0, 11.0, 12.0, 200.0])[1])
0554 
0555         # out of range with infinite buckets
0556         rdd = self.sc.parallelize([10.01, -0.01, float('nan'), float("inf")])
0557         self.assertEqual([1, 2], rdd.histogram([float('-inf'), 0, float('inf')])[1])
0558 
0559         # invalid buckets
0560         self.assertRaises(ValueError, lambda: rdd.histogram([]))
0561         self.assertRaises(ValueError, lambda: rdd.histogram([1]))
0562         self.assertRaises(ValueError, lambda: rdd.histogram(0))
0563         self.assertRaises(TypeError, lambda: rdd.histogram({}))
0564 
0565         # without buckets
0566         rdd = self.sc.parallelize(range(1, 5))
0567         self.assertEqual(([1, 4], [4]), rdd.histogram(1))
0568 
0569         # without buckets single element
0570         rdd = self.sc.parallelize([1])
0571         self.assertEqual(([1, 1], [1]), rdd.histogram(1))
0572 
0573         # without bucket no range
0574         rdd = self.sc.parallelize([1] * 4)
0575         self.assertEqual(([1, 1], [4]), rdd.histogram(1))
0576 
0577         # without buckets basic two
0578         rdd = self.sc.parallelize(range(1, 5))
0579         self.assertEqual(([1, 2.5, 4], [2, 2]), rdd.histogram(2))
0580 
0581         # without buckets with more requested than elements
0582         rdd = self.sc.parallelize([1, 2])
0583         buckets = [1 + 0.2 * i for i in range(6)]
0584         hist = [1, 0, 0, 0, 1]
0585         self.assertEqual((buckets, hist), rdd.histogram(5))
0586 
0587         # invalid RDDs
0588         rdd = self.sc.parallelize([1, float('inf')])
0589         self.assertRaises(ValueError, lambda: rdd.histogram(2))
0590         rdd = self.sc.parallelize([float('nan')])
0591         self.assertRaises(ValueError, lambda: rdd.histogram(2))
0592 
0593         # string
0594         rdd = self.sc.parallelize(["ab", "ac", "b", "bd", "ef"], 2)
0595         self.assertEqual([2, 2], rdd.histogram(["a", "b", "c"])[1])
0596         self.assertEqual((["ab", "ef"], [5]), rdd.histogram(1))
0597         self.assertRaises(TypeError, lambda: rdd.histogram(2))
0598 
0599     def test_repartitionAndSortWithinPartitions_asc(self):
0600         rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2)
0601 
0602         repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, True)
0603         partitions = repartitioned.glom().collect()
0604         self.assertEqual(partitions[0], [(0, 5), (0, 8), (2, 6)])
0605         self.assertEqual(partitions[1], [(1, 3), (3, 8), (3, 8)])
0606 
0607     def test_repartitionAndSortWithinPartitions_desc(self):
0608         rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2)
0609 
0610         repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2, False)
0611         partitions = repartitioned.glom().collect()
0612         self.assertEqual(partitions[0], [(2, 6), (0, 5), (0, 8)])
0613         self.assertEqual(partitions[1], [(3, 8), (3, 8), (1, 3)])
0614 
0615     def test_repartition_no_skewed(self):
0616         num_partitions = 20
0617         a = self.sc.parallelize(range(int(1000)), 2)
0618         l = a.repartition(num_partitions).glom().map(len).collect()
0619         zeros = len([x for x in l if x == 0])
0620         self.assertTrue(zeros == 0)
0621         l = a.coalesce(num_partitions, True).glom().map(len).collect()
0622         zeros = len([x for x in l if x == 0])
0623         self.assertTrue(zeros == 0)
0624 
0625     def test_repartition_on_textfile(self):
0626         path = os.path.join(SPARK_HOME, "python/test_support/hello/hello.txt")
0627         rdd = self.sc.textFile(path)
0628         result = rdd.repartition(1).collect()
0629         self.assertEqual(u"Hello World!", result[0])
0630 
0631     def test_distinct(self):
0632         rdd = self.sc.parallelize((1, 2, 3)*10, 10)
0633         self.assertEqual(rdd.getNumPartitions(), 10)
0634         self.assertEqual(rdd.distinct().count(), 3)
0635         result = rdd.distinct(5)
0636         self.assertEqual(result.getNumPartitions(), 5)
0637         self.assertEqual(result.count(), 3)
0638 
0639     def test_external_group_by_key(self):
0640         self.sc._conf.set("spark.python.worker.memory", "1m")
0641         N = 2000001
0642         kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x))
0643         gkv = kv.groupByKey().cache()
0644         self.assertEqual(3, gkv.count())
0645         filtered = gkv.filter(lambda kv: kv[0] == 1)
0646         self.assertEqual(1, filtered.count())
0647         self.assertEqual([(1, N // 3)], filtered.mapValues(len).collect())
0648         self.assertEqual([(N // 3, N // 3)],
0649                          filtered.values().map(lambda x: (len(x), len(list(x)))).collect())
0650         result = filtered.collect()[0][1]
0651         self.assertEqual(N // 3, len(result))
0652         self.assertTrue(isinstance(result.data, shuffle.ExternalListOfList))
0653 
0654     def test_sort_on_empty_rdd(self):
0655         self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect())
0656 
0657     def test_sample(self):
0658         rdd = self.sc.parallelize(range(0, 100), 4)
0659         wo = rdd.sample(False, 0.1, 2).collect()
0660         wo_dup = rdd.sample(False, 0.1, 2).collect()
0661         self.assertSetEqual(set(wo), set(wo_dup))
0662         wr = rdd.sample(True, 0.2, 5).collect()
0663         wr_dup = rdd.sample(True, 0.2, 5).collect()
0664         self.assertSetEqual(set(wr), set(wr_dup))
0665         wo_s10 = rdd.sample(False, 0.3, 10).collect()
0666         wo_s20 = rdd.sample(False, 0.3, 20).collect()
0667         self.assertNotEqual(set(wo_s10), set(wo_s20))
0668         wr_s11 = rdd.sample(True, 0.4, 11).collect()
0669         wr_s21 = rdd.sample(True, 0.4, 21).collect()
0670         self.assertNotEqual(set(wr_s11), set(wr_s21))
0671 
0672     def test_null_in_rdd(self):
0673         jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc)
0674         rdd = RDD(jrdd, self.sc, UTF8Deserializer())
0675         self.assertEqual([u"a", None, u"b"], rdd.collect())
0676         rdd = RDD(jrdd, self.sc, NoOpSerializer())
0677         self.assertEqual([b"a", None, b"b"], rdd.collect())
0678 
0679     def test_multiple_python_java_RDD_conversions(self):
0680         # Regression test for SPARK-5361
0681         data = [
0682             (u'1', {u'director': u'David Lean'}),
0683             (u'2', {u'director': u'Andrew Dominik'})
0684         ]
0685         data_rdd = self.sc.parallelize(data)
0686         data_java_rdd = data_rdd._to_java_object_rdd()
0687         data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd)
0688         converted_rdd = RDD(data_python_rdd, self.sc)
0689         self.assertEqual(2, converted_rdd.count())
0690 
0691         # conversion between python and java RDD threw exceptions
0692         data_java_rdd = converted_rdd._to_java_object_rdd()
0693         data_python_rdd = self.sc._jvm.SerDeUtil.javaToPython(data_java_rdd)
0694         converted_rdd = RDD(data_python_rdd, self.sc)
0695         self.assertEqual(2, converted_rdd.count())
0696 
0697     # Regression test for SPARK-6294
0698     def test_take_on_jrdd(self):
0699         rdd = self.sc.parallelize(xrange(1 << 20)).map(lambda x: str(x))
0700         rdd._jrdd.first()
0701 
0702     def test_sortByKey_uses_all_partitions_not_only_first_and_last(self):
0703         # Regression test for SPARK-5969
0704         seq = [(i * 59 % 101, i) for i in range(101)]  # unsorted sequence
0705         rdd = self.sc.parallelize(seq)
0706         for ascending in [True, False]:
0707             sort = rdd.sortByKey(ascending=ascending, numPartitions=5)
0708             self.assertEqual(sort.collect(), sorted(seq, reverse=not ascending))
0709             sizes = sort.glom().map(len).collect()
0710             for size in sizes:
0711                 self.assertGreater(size, 0)
0712 
0713     def test_pipe_functions(self):
0714         data = ['1', '2', '3']
0715         rdd = self.sc.parallelize(data)
0716         with QuietTest(self.sc):
0717             self.assertEqual([], rdd.pipe('java').collect())
0718             self.assertRaises(Py4JJavaError, rdd.pipe('java', checkCode=True).collect)
0719         result = rdd.pipe('cat').collect()
0720         result.sort()
0721         for x, y in zip(data, result):
0722             self.assertEqual(x, y)
0723         self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect)
0724         self.assertEqual([], rdd.pipe('grep 4').collect())
0725 
0726     def test_pipe_unicode(self):
0727         # Regression test for SPARK-20947
0728         data = [u'\u6d4b\u8bd5', '1']
0729         rdd = self.sc.parallelize(data)
0730         result = rdd.pipe('cat').collect()
0731         self.assertEqual(data, result)
0732 
0733     def test_stopiteration_in_user_code(self):
0734 
0735         def stopit(*x):
0736             raise StopIteration()
0737 
0738         seq_rdd = self.sc.parallelize(range(10))
0739         keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10))
0740         msg = "Caught StopIteration thrown from user's code; failing the task"
0741 
0742         self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect)
0743         self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect)
0744         self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
0745         self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit)
0746         self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit)
0747         self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
0748         self.assertRaisesRegexp(Py4JJavaError, msg,
0749                                 seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
0750 
0751         # these methods call the user function both in the driver and in the executor
0752         # the exception raised is different according to where the StopIteration happens
0753         # RuntimeError is raised if in the driver
0754         # Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker)
0755         self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
0756                                 keyed_rdd.reduceByKeyLocally, stopit)
0757         self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
0758                                 seq_rdd.aggregate, 0, stopit, lambda *x: 1)
0759         self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
0760                                 seq_rdd.aggregate, 0, lambda *x: 1, stopit)
0761 
0762     def test_overwritten_global_func(self):
0763         # Regression test for SPARK-27000
0764         global global_func
0765         self.assertEqual(self.sc.parallelize([1]).map(lambda _: global_func()).first(), "Hi")
0766         global_func = lambda: "Yeah"
0767         self.assertEqual(self.sc.parallelize([1]).map(lambda _: global_func()).first(), "Yeah")
0768 
0769     def test_to_local_iterator_failure(self):
0770         # SPARK-27548 toLocalIterator task failure not propagated to Python driver
0771 
0772         def fail(_):
0773             raise RuntimeError("local iterator error")
0774 
0775         rdd = self.sc.range(10).map(fail)
0776 
0777         with self.assertRaisesRegexp(Exception, "local iterator error"):
0778             for _ in rdd.toLocalIterator():
0779                 pass
0780 
0781     def test_to_local_iterator_collects_single_partition(self):
0782         # Test that partitions are not computed until requested by iteration
0783 
0784         def fail_last(x):
0785             if x == 9:
0786                 raise RuntimeError("This should not be hit")
0787             return x
0788 
0789         rdd = self.sc.range(12, numSlices=4).map(fail_last)
0790         it = rdd.toLocalIterator()
0791 
0792         # Only consume first 4 elements from partitions 1 and 2, this should not collect the last
0793         # partition which would trigger the error
0794         for i in range(4):
0795             self.assertEqual(i, next(it))
0796 
0797     def test_multiple_group_jobs(self):
0798         import threading
0799         group_a = "job_ids_to_cancel"
0800         group_b = "job_ids_to_run"
0801 
0802         threads = []
0803         thread_ids = range(4)
0804         thread_ids_to_cancel = [i for i in thread_ids if i % 2 == 0]
0805         thread_ids_to_run = [i for i in thread_ids if i % 2 != 0]
0806 
0807         # A list which records whether job is cancelled.
0808         # The index of the array is the thread index which job run in.
0809         is_job_cancelled = [False for _ in thread_ids]
0810 
0811         def run_job(job_group, index):
0812             """
0813             Executes a job with the group ``job_group``. Each job waits for 3 seconds
0814             and then exits.
0815             """
0816             try:
0817                 self.sc.parallelize([15]).map(lambda x: time.sleep(x)) \
0818                     .collectWithJobGroup(job_group, "test rdd collect with setting job group")
0819                 is_job_cancelled[index] = False
0820             except Exception:
0821                 # Assume that exception means job cancellation.
0822                 is_job_cancelled[index] = True
0823 
0824         # Test if job succeeded when not cancelled.
0825         run_job(group_a, 0)
0826         self.assertFalse(is_job_cancelled[0])
0827 
0828         # Run jobs
0829         for i in thread_ids_to_cancel:
0830             t = threading.Thread(target=run_job, args=(group_a, i))
0831             t.start()
0832             threads.append(t)
0833 
0834         for i in thread_ids_to_run:
0835             t = threading.Thread(target=run_job, args=(group_b, i))
0836             t.start()
0837             threads.append(t)
0838 
0839         # Wait to make sure all jobs are executed.
0840         time.sleep(3)
0841         # And then, cancel one job group.
0842         self.sc.cancelJobGroup(group_a)
0843 
0844         # Wait until all threads launching jobs are finished.
0845         for t in threads:
0846             t.join()
0847 
0848         for i in thread_ids_to_cancel:
0849             self.assertTrue(
0850                 is_job_cancelled[i],
0851                 "Thread {i}: Job in group A was not cancelled.".format(i=i))
0852 
0853         for i in thread_ids_to_run:
0854             self.assertFalse(
0855                 is_job_cancelled[i],
0856                 "Thread {i}: Job in group B did not succeeded.".format(i=i))
0857 
0858 
0859 if __name__ == "__main__":
0860     import unittest
0861     from pyspark.tests.test_rdd import *
0862 
0863     try:
0864         import xmlrunner
0865         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0866     except ImportError:
0867         testRunner = None
0868     unittest.main(testRunner=testRunner, verbosity=2)