Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 import os
0018 import random
0019 import stat
0020 import sys
0021 import tempfile
0022 import time
0023 import unittest
0024 
0025 from pyspark import SparkConf, SparkContext, TaskContext, BarrierTaskContext
0026 from pyspark.testing.utils import PySparkTestCase, SPARK_HOME
0027 
0028 if sys.version_info[0] >= 3:
0029     xrange = range
0030 
0031 
0032 class TaskContextTests(PySparkTestCase):
0033 
0034     def setUp(self):
0035         self._old_sys_path = list(sys.path)
0036         class_name = self.__class__.__name__
0037         # Allow retries even though they are normally disabled in local mode
0038         self.sc = SparkContext('local[4, 2]', class_name)
0039 
0040     def test_stage_id(self):
0041         """Test the stage ids are available and incrementing as expected."""
0042         rdd = self.sc.parallelize(range(10))
0043         stage1 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0]
0044         stage2 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0]
0045         # Test using the constructor directly rather than the get()
0046         stage3 = rdd.map(lambda x: TaskContext().stageId()).take(1)[0]
0047         self.assertEqual(stage1 + 1, stage2)
0048         self.assertEqual(stage1 + 2, stage3)
0049         self.assertEqual(stage2 + 1, stage3)
0050 
0051     def test_resources(self):
0052         """Test the resources are empty by default."""
0053         rdd = self.sc.parallelize(range(10))
0054         resources1 = rdd.map(lambda x: TaskContext.get().resources()).take(1)[0]
0055         # Test using the constructor directly rather than the get()
0056         resources2 = rdd.map(lambda x: TaskContext().resources()).take(1)[0]
0057         self.assertEqual(len(resources1), 0)
0058         self.assertEqual(len(resources2), 0)
0059 
0060     def test_partition_id(self):
0061         """Test the partition id."""
0062         rdd1 = self.sc.parallelize(range(10), 1)
0063         rdd2 = self.sc.parallelize(range(10), 2)
0064         pids1 = rdd1.map(lambda x: TaskContext.get().partitionId()).collect()
0065         pids2 = rdd2.map(lambda x: TaskContext.get().partitionId()).collect()
0066         self.assertEqual(0, pids1[0])
0067         self.assertEqual(0, pids1[9])
0068         self.assertEqual(0, pids2[0])
0069         self.assertEqual(1, pids2[9])
0070 
0071     def test_attempt_number(self):
0072         """Verify the attempt numbers are correctly reported."""
0073         rdd = self.sc.parallelize(range(10))
0074         # Verify a simple job with no failures
0075         attempt_numbers = rdd.map(lambda x: TaskContext.get().attemptNumber()).collect()
0076         map(lambda attempt: self.assertEqual(0, attempt), attempt_numbers)
0077 
0078         def fail_on_first(x):
0079             """Fail on the first attempt so we get a positive attempt number"""
0080             tc = TaskContext.get()
0081             attempt_number = tc.attemptNumber()
0082             partition_id = tc.partitionId()
0083             attempt_id = tc.taskAttemptId()
0084             if attempt_number == 0 and partition_id == 0:
0085                 raise Exception("Failing on first attempt")
0086             else:
0087                 return [x, partition_id, attempt_number, attempt_id]
0088         result = rdd.map(fail_on_first).collect()
0089         # We should re-submit the first partition to it but other partitions should be attempt 0
0090         self.assertEqual([0, 0, 1], result[0][0:3])
0091         self.assertEqual([9, 3, 0], result[9][0:3])
0092         first_partition = filter(lambda x: x[1] == 0, result)
0093         map(lambda x: self.assertEqual(1, x[2]), first_partition)
0094         other_partitions = filter(lambda x: x[1] != 0, result)
0095         map(lambda x: self.assertEqual(0, x[2]), other_partitions)
0096         # The task attempt id should be different
0097         self.assertTrue(result[0][3] != result[9][3])
0098 
0099     def test_tc_on_driver(self):
0100         """Verify that getting the TaskContext on the driver returns None."""
0101         tc = TaskContext.get()
0102         self.assertTrue(tc is None)
0103 
0104     def test_get_local_property(self):
0105         """Verify that local properties set on the driver are available in TaskContext."""
0106         key = "testkey"
0107         value = "testvalue"
0108         self.sc.setLocalProperty(key, value)
0109         try:
0110             rdd = self.sc.parallelize(range(1), 1)
0111             prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0]
0112             self.assertEqual(prop1, value)
0113             prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0]
0114             self.assertTrue(prop2 is None)
0115         finally:
0116             self.sc.setLocalProperty(key, None)
0117 
0118     def test_barrier(self):
0119         """
0120         Verify that BarrierTaskContext.barrier() performs global sync among all barrier tasks
0121         within a stage.
0122         """
0123         rdd = self.sc.parallelize(range(10), 4)
0124 
0125         def f(iterator):
0126             yield sum(iterator)
0127 
0128         def context_barrier(x):
0129             tc = BarrierTaskContext.get()
0130             time.sleep(random.randint(1, 10))
0131             tc.barrier()
0132             return time.time()
0133 
0134         times = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
0135         self.assertTrue(max(times) - min(times) < 1)
0136 
0137     def test_all_gather(self):
0138         """
0139         Verify that BarrierTaskContext.allGather() performs global sync among all barrier tasks
0140         within a stage and passes messages properly.
0141         """
0142         rdd = self.sc.parallelize(range(10), 4)
0143 
0144         def f(iterator):
0145             yield sum(iterator)
0146 
0147         def context_barrier(x):
0148             tc = BarrierTaskContext.get()
0149             time.sleep(random.randint(1, 10))
0150             out = tc.allGather(str(tc.partitionId()))
0151             pids = [int(e) for e in out]
0152             return pids
0153 
0154         pids = rdd.barrier().mapPartitions(f).map(context_barrier).collect()[0]
0155         self.assertEqual(pids, [0, 1, 2, 3])
0156 
0157     def test_barrier_infos(self):
0158         """
0159         Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the
0160         barrier stage.
0161         """
0162         rdd = self.sc.parallelize(range(10), 4)
0163 
0164         def f(iterator):
0165             yield sum(iterator)
0166 
0167         taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: BarrierTaskContext.get()
0168                                                        .getTaskInfos()).collect()
0169         self.assertTrue(len(taskInfos) == 4)
0170         self.assertTrue(len(taskInfos[0]) == 4)
0171 
0172     def test_context_get(self):
0173         """
0174         Verify that TaskContext.get() works both in or not in a barrier stage.
0175         """
0176         rdd = self.sc.parallelize(range(10), 4)
0177 
0178         def f(iterator):
0179             taskContext = TaskContext.get()
0180             if isinstance(taskContext, BarrierTaskContext):
0181                 yield taskContext.partitionId() + 1
0182             elif isinstance(taskContext, TaskContext):
0183                 yield taskContext.partitionId() + 2
0184             else:
0185                 yield -1
0186 
0187         # for normal stage
0188         result1 = rdd.mapPartitions(f).collect()
0189         self.assertTrue(result1 == [2, 3, 4, 5])
0190         # for barrier stage
0191         result2 = rdd.barrier().mapPartitions(f).collect()
0192         self.assertTrue(result2 == [1, 2, 3, 4])
0193 
0194     def test_barrier_context_get(self):
0195         """
0196         Verify that BarrierTaskContext.get() should only works in a barrier stage.
0197         """
0198         rdd = self.sc.parallelize(range(10), 4)
0199 
0200         def f(iterator):
0201             try:
0202                 taskContext = BarrierTaskContext.get()
0203             except Exception:
0204                 yield -1
0205             else:
0206                 yield taskContext.partitionId()
0207 
0208         # for normal stage
0209         result1 = rdd.mapPartitions(f).collect()
0210         self.assertTrue(result1 == [-1, -1, -1, -1])
0211         # for barrier stage
0212         result2 = rdd.barrier().mapPartitions(f).collect()
0213         self.assertTrue(result2 == [0, 1, 2, 3])
0214 
0215 
0216 class TaskContextTestsWithWorkerReuse(unittest.TestCase):
0217 
0218     def setUp(self):
0219         class_name = self.__class__.__name__
0220         conf = SparkConf().set("spark.python.worker.reuse", "true")
0221         self.sc = SparkContext('local[2]', class_name, conf=conf)
0222 
0223     def test_barrier_with_python_worker_reuse(self):
0224         """
0225         Regression test for SPARK-25921: verify that BarrierTaskContext.barrier() with
0226         reused python worker.
0227         """
0228         # start a normal job first to start all workers and get all worker pids
0229         worker_pids = self.sc.parallelize(range(2), 2).map(lambda x: os.getpid()).collect()
0230         # the worker will reuse in this barrier job
0231         rdd = self.sc.parallelize(range(10), 2)
0232 
0233         def f(iterator):
0234             yield sum(iterator)
0235 
0236         def context_barrier(x):
0237             tc = BarrierTaskContext.get()
0238             time.sleep(random.randint(1, 10))
0239             tc.barrier()
0240             return (time.time(), os.getpid())
0241 
0242         result = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
0243         times = list(map(lambda x: x[0], result))
0244         pids = list(map(lambda x: x[1], result))
0245         # check both barrier and worker reuse effect
0246         self.assertTrue(max(times) - min(times) < 1)
0247         for pid in pids:
0248             self.assertTrue(pid in worker_pids)
0249 
0250     def test_task_context_correct_with_python_worker_reuse(self):
0251         """Verify the task context correct when reused python worker"""
0252         # start a normal job first to start all workers and get all worker pids
0253         worker_pids = self.sc.parallelize(xrange(2), 2).map(lambda x: os.getpid()).collect()
0254         # the worker will reuse in this barrier job
0255         rdd = self.sc.parallelize(xrange(10), 2)
0256 
0257         def context(iterator):
0258             tp = TaskContext.get().partitionId()
0259             try:
0260                 bp = BarrierTaskContext.get().partitionId()
0261             except Exception:
0262                 bp = -1
0263 
0264             yield (tp, bp, os.getpid())
0265 
0266         # normal stage after normal stage
0267         normal_result = rdd.mapPartitions(context).collect()
0268         tps, bps, pids = zip(*normal_result)
0269         print(tps)
0270         self.assertTrue(tps == (0, 1))
0271         self.assertTrue(bps == (-1, -1))
0272         for pid in pids:
0273             self.assertTrue(pid in worker_pids)
0274         # barrier stage after normal stage
0275         barrier_result = rdd.barrier().mapPartitions(context).collect()
0276         tps, bps, pids = zip(*barrier_result)
0277         self.assertTrue(tps == (0, 1))
0278         self.assertTrue(bps == (0, 1))
0279         for pid in pids:
0280             self.assertTrue(pid in worker_pids)
0281         # normal stage after barrier stage
0282         normal_result2 = rdd.mapPartitions(context).collect()
0283         tps, bps, pids = zip(*normal_result2)
0284         self.assertTrue(tps == (0, 1))
0285         self.assertTrue(bps == (-1, -1))
0286         for pid in pids:
0287             self.assertTrue(pid in worker_pids)
0288 
0289     def tearDown(self):
0290         self.sc.stop()
0291 
0292 
0293 class TaskContextTestsWithResources(unittest.TestCase):
0294 
0295     def setUp(self):
0296         class_name = self.__class__.__name__
0297         self.tempFile = tempfile.NamedTemporaryFile(delete=False)
0298         self.tempFile.write(b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\"]}')
0299         self.tempFile.close()
0300         os.chmod(self.tempFile.name, stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP |
0301                  stat.S_IROTH | stat.S_IXOTH)
0302         conf = SparkConf().set("spark.test.home", SPARK_HOME)
0303         conf = conf.set("spark.worker.resource.gpu.discoveryScript", self.tempFile.name)
0304         conf = conf.set("spark.worker.resource.gpu.amount", 1)
0305         conf = conf.set("spark.task.resource.gpu.amount", "1")
0306         conf = conf.set("spark.executor.resource.gpu.amount", "1")
0307         self.sc = SparkContext('local-cluster[2,1,1024]', class_name, conf=conf)
0308 
0309     def test_resources(self):
0310         """Test the resources are available."""
0311         rdd = self.sc.parallelize(range(10))
0312         resources = rdd.map(lambda x: TaskContext.get().resources()).take(1)[0]
0313         self.assertEqual(len(resources), 1)
0314         self.assertTrue('gpu' in resources)
0315         self.assertEqual(resources['gpu'].name, 'gpu')
0316         self.assertEqual(resources['gpu'].addresses, ['0'])
0317 
0318     def tearDown(self):
0319         os.unlink(self.tempFile.name)
0320         self.sc.stop()
0321 
0322 if __name__ == "__main__":
0323     import unittest
0324     from pyspark.tests.test_taskcontext import *
0325 
0326     try:
0327         import xmlrunner
0328         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0329     except ImportError:
0330         testRunner = None
0331     unittest.main(testRunner=testRunner, verbosity=2)