0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 from pyspark.testing.utils import ReusedPySparkTestCase
0018
0019
0020 class RDDBarrierTests(ReusedPySparkTestCase):
0021 def test_map_partitions(self):
0022 """Test RDDBarrier.mapPartitions"""
0023 rdd = self.sc.parallelize(range(12), 4)
0024 self.assertFalse(rdd._is_barrier())
0025
0026 rdd1 = rdd.barrier().mapPartitions(lambda it: it)
0027 self.assertTrue(rdd1._is_barrier())
0028
0029 def test_map_partitions_with_index(self):
0030 """Test RDDBarrier.mapPartitionsWithIndex"""
0031 rdd = self.sc.parallelize(range(12), 4)
0032 self.assertFalse(rdd._is_barrier())
0033
0034 def f(index, iterator):
0035 yield index
0036 rdd1 = rdd.barrier().mapPartitionsWithIndex(f)
0037 self.assertTrue(rdd1._is_barrier())
0038 self.assertEqual(rdd1.collect(), [0, 1, 2, 3])
0039
0040
0041 if __name__ == "__main__":
0042 import unittest
0043 from pyspark.tests.test_rddbarrier import *
0044
0045 try:
0046 import xmlrunner
0047 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0048 except ImportError:
0049 testRunner = None
0050 unittest.main(testRunner=testRunner, verbosity=2)