Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 from pyspark.testing.utils import ReusedPySparkTestCase
0018 
0019 
0020 class RDDBarrierTests(ReusedPySparkTestCase):
0021     def test_map_partitions(self):
0022         """Test RDDBarrier.mapPartitions"""
0023         rdd = self.sc.parallelize(range(12), 4)
0024         self.assertFalse(rdd._is_barrier())
0025 
0026         rdd1 = rdd.barrier().mapPartitions(lambda it: it)
0027         self.assertTrue(rdd1._is_barrier())
0028 
0029     def test_map_partitions_with_index(self):
0030         """Test RDDBarrier.mapPartitionsWithIndex"""
0031         rdd = self.sc.parallelize(range(12), 4)
0032         self.assertFalse(rdd._is_barrier())
0033 
0034         def f(index, iterator):
0035             yield index
0036         rdd1 = rdd.barrier().mapPartitionsWithIndex(f)
0037         self.assertTrue(rdd1._is_barrier())
0038         self.assertEqual(rdd1.collect(), [0, 1, 2, 3])
0039 
0040 
0041 if __name__ == "__main__":
0042     import unittest
0043     from pyspark.tests.test_rddbarrier import *
0044 
0045     try:
0046         import xmlrunner
0047         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0048     except ImportError:
0049         testRunner = None
0050     unittest.main(testRunner=testRunner, verbosity=2)