0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 from pyspark.testing.utils import ReusedPySparkTestCase
0018
0019
0020 class JoinTests(ReusedPySparkTestCase):
0021
0022 def test_narrow_dependency_in_join(self):
0023 rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x))
0024 parted = rdd.partitionBy(2)
0025 self.assertEqual(2, parted.union(parted).getNumPartitions())
0026 self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions())
0027 self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions())
0028
0029 tracker = self.sc.statusTracker()
0030
0031 self.sc.setJobGroup("test1", "test", True)
0032 d = sorted(parted.join(parted).collect())
0033 self.assertEqual(10, len(d))
0034 self.assertEqual((0, (0, 0)), d[0])
0035 jobId = tracker.getJobIdsForGroup("test1")[0]
0036 self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds))
0037
0038 self.sc.setJobGroup("test2", "test", True)
0039 d = sorted(parted.join(rdd).collect())
0040 self.assertEqual(10, len(d))
0041 self.assertEqual((0, (0, 0)), d[0])
0042 jobId = tracker.getJobIdsForGroup("test2")[0]
0043 self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds))
0044
0045 self.sc.setJobGroup("test3", "test", True)
0046 d = sorted(parted.cogroup(parted).collect())
0047 self.assertEqual(10, len(d))
0048 self.assertEqual([[0], [0]], list(map(list, d[0][1])))
0049 jobId = tracker.getJobIdsForGroup("test3")[0]
0050 self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds))
0051
0052 self.sc.setJobGroup("test4", "test", True)
0053 d = sorted(parted.cogroup(rdd).collect())
0054 self.assertEqual(10, len(d))
0055 self.assertEqual([[0], [0]], list(map(list, d[0][1])))
0056 jobId = tracker.getJobIdsForGroup("test4")[0]
0057 self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds))
0058
0059
0060 if __name__ == "__main__":
0061 import unittest
0062 from pyspark.tests.test_join import *
0063
0064 try:
0065 import xmlrunner
0066 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0067 except ImportError:
0068 testRunner = None
0069 unittest.main(testRunner=testRunner, verbosity=2)