Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 import random
0018 import sys
0019 import unittest
0020 
0021 from py4j.protocol import Py4JJavaError
0022 
0023 from pyspark import shuffle, PickleSerializer, SparkConf, SparkContext
0024 from pyspark.shuffle import Aggregator, ExternalMerger, ExternalSorter
0025 
0026 if sys.version_info[0] >= 3:
0027     xrange = range
0028 
0029 
0030 class MergerTests(unittest.TestCase):
0031 
0032     def setUp(self):
0033         self.N = 1 << 12
0034         self.l = [i for i in xrange(self.N)]
0035         self.data = list(zip(self.l, self.l))
0036         self.agg = Aggregator(lambda x: [x],
0037                               lambda x, y: x.append(y) or x,
0038                               lambda x, y: x.extend(y) or x)
0039 
0040     def test_small_dataset(self):
0041         m = ExternalMerger(self.agg, 1000)
0042         m.mergeValues(self.data)
0043         self.assertEqual(m.spills, 0)
0044         self.assertEqual(sum(sum(v) for k, v in m.items()),
0045                          sum(xrange(self.N)))
0046 
0047         m = ExternalMerger(self.agg, 1000)
0048         m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), self.data))
0049         self.assertEqual(m.spills, 0)
0050         self.assertEqual(sum(sum(v) for k, v in m.items()),
0051                          sum(xrange(self.N)))
0052 
0053     def test_medium_dataset(self):
0054         m = ExternalMerger(self.agg, 20)
0055         m.mergeValues(self.data)
0056         self.assertTrue(m.spills >= 1)
0057         self.assertEqual(sum(sum(v) for k, v in m.items()),
0058                          sum(xrange(self.N)))
0059 
0060         m = ExternalMerger(self.agg, 10)
0061         m.mergeCombiners(map(lambda x_y2: (x_y2[0], [x_y2[1]]), self.data * 3))
0062         self.assertTrue(m.spills >= 1)
0063         self.assertEqual(sum(sum(v) for k, v in m.items()),
0064                          sum(xrange(self.N)) * 3)
0065 
0066     def test_huge_dataset(self):
0067         m = ExternalMerger(self.agg, 5, partitions=3)
0068         m.mergeCombiners(map(lambda k_v: (k_v[0], [str(k_v[1])]), self.data * 10))
0069         self.assertTrue(m.spills >= 1)
0070         self.assertEqual(sum(len(v) for k, v in m.items()),
0071                          self.N * 10)
0072         m._cleanup()
0073 
0074     def test_group_by_key(self):
0075 
0076         def gen_data(N, step):
0077             for i in range(1, N + 1, step):
0078                 for j in range(i):
0079                     yield (i, [j])
0080 
0081         def gen_gs(N, step=1):
0082             return shuffle.GroupByKey(gen_data(N, step))
0083 
0084         self.assertEqual(1, len(list(gen_gs(1))))
0085         self.assertEqual(2, len(list(gen_gs(2))))
0086         self.assertEqual(100, len(list(gen_gs(100))))
0087         self.assertEqual(list(range(1, 101)), [k for k, _ in gen_gs(100)])
0088         self.assertTrue(all(list(range(k)) == list(vs) for k, vs in gen_gs(100)))
0089 
0090         for k, vs in gen_gs(50002, 10000):
0091             self.assertEqual(k, len(vs))
0092             self.assertEqual(list(range(k)), list(vs))
0093 
0094         ser = PickleSerializer()
0095         l = ser.loads(ser.dumps(list(gen_gs(50002, 30000))))
0096         for k, vs in l:
0097             self.assertEqual(k, len(vs))
0098             self.assertEqual(list(range(k)), list(vs))
0099 
0100     def test_stopiteration_is_raised(self):
0101 
0102         def stopit(*args, **kwargs):
0103             raise StopIteration()
0104 
0105         def legit_create_combiner(x):
0106             return [x]
0107 
0108         def legit_merge_value(x, y):
0109             return x.append(y) or x
0110 
0111         def legit_merge_combiners(x, y):
0112             return x.extend(y) or x
0113 
0114         data = [(x % 2, x) for x in range(100)]
0115 
0116         # wrong create combiner
0117         m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20)
0118         with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
0119             m.mergeValues(data)
0120 
0121         # wrong merge value
0122         m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20)
0123         with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
0124             m.mergeValues(data)
0125 
0126         # wrong merge combiners
0127         m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20)
0128         with self.assertRaises((Py4JJavaError, RuntimeError)) as cm:
0129             m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data))
0130 
0131 
0132 class SorterTests(unittest.TestCase):
0133     def test_in_memory_sort(self):
0134         l = list(range(1024))
0135         random.shuffle(l)
0136         sorter = ExternalSorter(1024)
0137         self.assertEqual(sorted(l), list(sorter.sorted(l)))
0138         self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
0139         self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
0140         self.assertEqual(sorted(l, key=lambda x: -x, reverse=True),
0141                          list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
0142 
0143     def test_external_sort(self):
0144         class CustomizedSorter(ExternalSorter):
0145             def _next_limit(self):
0146                 return self.memory_limit
0147         l = list(range(1024))
0148         random.shuffle(l)
0149         sorter = CustomizedSorter(1)
0150         self.assertEqual(sorted(l), list(sorter.sorted(l)))
0151         self.assertGreater(shuffle.DiskBytesSpilled, 0)
0152         last = shuffle.DiskBytesSpilled
0153         self.assertEqual(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
0154         self.assertGreater(shuffle.DiskBytesSpilled, last)
0155         last = shuffle.DiskBytesSpilled
0156         self.assertEqual(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
0157         self.assertGreater(shuffle.DiskBytesSpilled, last)
0158         last = shuffle.DiskBytesSpilled
0159         self.assertEqual(sorted(l, key=lambda x: -x, reverse=True),
0160                          list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
0161         self.assertGreater(shuffle.DiskBytesSpilled, last)
0162 
0163     def test_external_sort_in_rdd(self):
0164         conf = SparkConf().set("spark.python.worker.memory", "1m")
0165         sc = SparkContext(conf=conf)
0166         l = list(range(10240))
0167         random.shuffle(l)
0168         rdd = sc.parallelize(l, 4)
0169         self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect())
0170         sc.stop()
0171 
0172 
0173 if __name__ == "__main__":
0174     from pyspark.tests.test_shuffle import *
0175 
0176     try:
0177         import xmlrunner
0178         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0179     except ImportError:
0180         testRunner = None
0181     unittest.main(testRunner=testRunner, verbosity=2)