Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 from pyspark.testing.utils import ReusedPySparkTestCase
0018 
0019 
0020 class JoinTests(ReusedPySparkTestCase):
0021 
0022     def test_narrow_dependency_in_join(self):
0023         rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x))
0024         parted = rdd.partitionBy(2)
0025         self.assertEqual(2, parted.union(parted).getNumPartitions())
0026         self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions())
0027         self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions())
0028 
0029         tracker = self.sc.statusTracker()
0030 
0031         self.sc.setJobGroup("test1", "test", True)
0032         d = sorted(parted.join(parted).collect())
0033         self.assertEqual(10, len(d))
0034         self.assertEqual((0, (0, 0)), d[0])
0035         jobId = tracker.getJobIdsForGroup("test1")[0]
0036         self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds))
0037 
0038         self.sc.setJobGroup("test2", "test", True)
0039         d = sorted(parted.join(rdd).collect())
0040         self.assertEqual(10, len(d))
0041         self.assertEqual((0, (0, 0)), d[0])
0042         jobId = tracker.getJobIdsForGroup("test2")[0]
0043         self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds))
0044 
0045         self.sc.setJobGroup("test3", "test", True)
0046         d = sorted(parted.cogroup(parted).collect())
0047         self.assertEqual(10, len(d))
0048         self.assertEqual([[0], [0]], list(map(list, d[0][1])))
0049         jobId = tracker.getJobIdsForGroup("test3")[0]
0050         self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds))
0051 
0052         self.sc.setJobGroup("test4", "test", True)
0053         d = sorted(parted.cogroup(rdd).collect())
0054         self.assertEqual(10, len(d))
0055         self.assertEqual([[0], [0]], list(map(list, d[0][1])))
0056         jobId = tracker.getJobIdsForGroup("test4")[0]
0057         self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds))
0058 
0059 
0060 if __name__ == "__main__":
0061     import unittest
0062     from pyspark.tests.test_join import *
0063 
0064     try:
0065         import xmlrunner
0066         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0067     except ImportError:
0068         testRunner = None
0069     unittest.main(testRunner=testRunner, verbosity=2)