0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 import random
0018 import sys
0019 import unittest
0020
0021 from py4j.protocol import Py4JJavaError
0022
0023 from pyspark import shuffle, PickleSerializer, SparkConf, SparkContext
0024 from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter
0025
0026 if sys.version_info[0] >= 3:
0027 xrange = range
0028
0029
0030 class MergerTests(unittest.TestCase):
0031
0032 def setUp(self):
0033 self.N = 1 << 12
0034 self.l = [i for i in xrange(self.N)]
0035 self.data = list(zip(self.l, self.l))
0036 self.agg = Aggregator(lambda x: [x],
0037 lambda x, y: x.append(y) or x,
0038 lambda x, y: x.extend(y) or x)
0039
0040 def test_small_dataset(self):
0041 m = ExternalMerger(self.agg, 1000)
0042 m.mergeValues(self.data)
0043 self.assertEqual(m.spills, 0)
0044 self.assertEqual(sum(sum(v) for k, v in m.items()),
0045 sum(xrange(self.N)))
0046
0047 m = ExternalMerger(self.agg, 1000)
0048 m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data))
0049 self.assertEqual(m.spills, 0)
0050 self.assertEqual(sum(sum(v) for k, v in m.items()),
0051 sum(xrange(self.N)))
0052
0053 def test_medium_dataset(self):
0054 m = ExternalMerger(self.agg, 20)
0055 m.mergeValues(self.data)
0056 self.assertTrue(m.spills >= 1)
0057 self.assertEqual(sum(sum(v) for k, v in m.items()),
0058 sum(xrange(self.N)))
0059
0060 m = ExternalMerger(self.agg, 10)
0061 m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3))
0062 self.assertTrue(m.spills >= 1)
0063 self.assertEqual(sum(sum(v) for k, v in m.items()),
0064 sum(xrange(self.N)) * 3)
0065
0066 def test_huge_dataset(self):
0067 m = ExternalMerger(self.agg, 5, partitions=3)
0068 m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10))
0069 self.assertTrue(m.spills >= 1)
0070 self.assertEqual(sum(len(v) for k, v in m.items()),
0071 self.N * 10)
0072 m._cleanup()
0073
0074 def test_group_by_key(self):
0075
0076 def gen_data(N, step):
0077 for i in range(1, N + 1, step):
0078 for j in range(i):
0079 yield (i, [j])
0080
0081 def gen_gs(N, step=1):
0082 return shuffle.GroupByKey(gen_data(N, step))
0083
0084 self.assertEqual(1, len(list(gen_gs(1))))
0085 self.assertEqual(2, len(list(gen_gs(2))))
0086 self.assertEqual(100, len(list(gen_gs(100))))
0087 self.assertEqual(list(range(1, 101)), [k for k, _ in gen_gs(100)])
0088 self.assertTrue(all(list(range(k)) == list(vs) for k, vs in gen_gs(100)))
0089
0090 for k, vs in gen_gs(50002, 10000):
0091 self.assertEqual(k, len(vs))
0092 self.assertEqual(list(range(k)), list(vs))
0093
0094 ser = PickleSerializer()
0095 l = ser.loads(ser.dumps(list(gen_gs(50002, 30000))))
0096 for k, vs in l:
0097 self.assertEqual(k, len(vs))
0098 self.assertEqual(list(range(k)), list(vs))
0099
0100 def test_stopiteration_is_raised(self):
0101
0102 def stopit(*args, **kwargs):
0103 raise StopIteration()
0104
0105 def legit_create_combiner(x):
0106 return [x]
0107
0108 def legit_merge_value(x, y):
0109 return x.append(y) or x
0110
0111 def legit_merge_combiners(x, y):
0112 return x.extend(y) or x
0113
0114 data = [(x % 2, x) for x in range(100)]
0115
0116
0117 m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20)
0118 with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
0119 m.mergeValues(data)
0120
0121
0122 m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20)
0123 with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
0124 m.mergeValues(data)
0125
0126
0127 m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20)
0128 with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
0129 m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data))
0130
0131
0132 class SorterTests(unittest.TestCase):
0133 def test_in_memory_sort(self):
0134 l = list(range(1024))
0135 random.shuffle(l)
0136 sorter = ExternalSorter(1024)
0137 self.assertEqual(sorted(l), list(sorter.sorted(l)))
0138 self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
0139 self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
0140 self.assertEqual(sorted(l, key=lambda x: -x, reverse=True),
0141 list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
0142
0143 def test_external_sort(self):
0144 class CustomizedSorter(ExternalSorter):
0145 def _next_limit(self):
0146 return self.memory_limit
0147 l = list(range(1024))
0148 random.shuffle(l)
0149 sorter = CustomizedSorter(1)
0150 self.assertEqual(sorted(l), list(sorter.sorted(l)))
0151 self.assertGreater(shuffle.DiskBytesSpilled, 0)
0152 last = shuffle.DiskBytesSpilled
0153 self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
0154 self.assertGreater(shuffle.DiskBytesSpilled, last)
0155 last = shuffle.DiskBytesSpilled
0156 self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
0157 self.assertGreater(shuffle.DiskBytesSpilled, last)
0158 last = shuffle.DiskBytesSpilled
0159 self.assertEqual(sorted(l, key=lambda x: -x, reverse=True),
0160 list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
0161 self.assertGreater(shuffle.DiskBytesSpilled, last)
0162
0163 def test_external_sort_in_rdd(self):
0164 conf = SparkConf().set("spark.python.worker.memory", "1m")
0165 sc = SparkContext(conf=conf)
0166 l = list(range(10240))
0167 random.shuffle(l)
0168 rdd = sc.parallelize(l, 4)
0169 self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect())
0170 sc.stop()
0171
0172
0173 if __name__ == "__main__":
0174 from pyspark.tests.test_shuffle import *
0175
0176 try:
0177 import xmlrunner
0178 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0179 except ImportError:
0180 testRunner = None
0181 unittest.main(testRunner=testRunner, verbosity=2)