0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 import os
0018 import random
0019 import stat
0020 import sys
0021 import tempfile
0022 import time
0023 import unittest
0024
0025 from pyspark import SparkConf, SparkContext, TaskContext, BarrierTaskContext
0026 from pyspark.testing.utils import PySparkTestCase, SPARK_HOME
0027
0028 if sys.version_info[0] >= 3:
0029 xrange = range
0030
0031
0032 class TaskContextTests(PySparkTestCase):
0033
0034 def setUp(self):
0035 self._old_sys_path = list(sys.path)
0036 class_name = self.__class__.__name__
0037
0038 self.sc = SparkContext('local[4, 2]', class_name)
0039
0040 def test_stage_id(self):
0041 """Test the stage ids are available and incrementing as expected."""
0042 rdd = self.sc.parallelize(range(10))
0043 stage1 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0]
0044 stage2 = rdd.map(lambda x: TaskContext.get().stageId()).take(1)[0]
0045
0046 stage3 = rdd.map(lambda x: TaskContext().stageId()).take(1)[0]
0047 self.assertEqual(stage1 + 1, stage2)
0048 self.assertEqual(stage1 + 2, stage3)
0049 self.assertEqual(stage2 + 1, stage3)
0050
0051 def test_resources(self):
0052 """Test the resources are empty by default."""
0053 rdd = self.sc.parallelize(range(10))
0054 resources1 = rdd.map(lambda x: TaskContext.get().resources()).take(1)[0]
0055
0056 resources2 = rdd.map(lambda x: TaskContext().resources()).take(1)[0]
0057 self.assertEqual(len(resources1), 0)
0058 self.assertEqual(len(resources2), 0)
0059
0060 def test_partition_id(self):
0061 """Test the partition id."""
0062 rdd1 = self.sc.parallelize(range(10), 1)
0063 rdd2 = self.sc.parallelize(range(10), 2)
0064 pids1 = rdd1.map(lambda x: TaskContext.get().partitionId()).collect()
0065 pids2 = rdd2.map(lambda x: TaskContext.get().partitionId()).collect()
0066 self.assertEqual(0, pids1[0])
0067 self.assertEqual(0, pids1[9])
0068 self.assertEqual(0, pids2[0])
0069 self.assertEqual(1, pids2[9])
0070
0071 def test_attempt_number(self):
0072 """Verify the attempt numbers are correctly reported."""
0073 rdd = self.sc.parallelize(range(10))
0074
0075 attempt_numbers = rdd.map(lambda x: TaskContext.get().attemptNumber()).collect()
0076 map(lambda attempt: self.assertEqual(0, attempt), attempt_numbers)
0077
0078 def fail_on_first(x):
0079 """Fail on the first attempt so we get a positive attempt number"""
0080 tc = TaskContext.get()
0081 attempt_number = tc.attemptNumber()
0082 partition_id = tc.partitionId()
0083 attempt_id = tc.taskAttemptId()
0084 if attempt_number == 0 and partition_id == 0:
0085 raise Exception("Failing on first attempt")
0086 else:
0087 return [x, partition_id, attempt_number, attempt_id]
0088 result = rdd.map(fail_on_first).collect()
0089
0090 self.assertEqual([0, 0, 1], result[0][0:3])
0091 self.assertEqual([9, 3, 0], result[9][0:3])
0092 first_partition = filter(lambda x: x[1] == 0, result)
0093 map(lambda x: self.assertEqual(1, x[2]), first_partition)
0094 other_partitions = filter(lambda x: x[1] != 0, result)
0095 map(lambda x: self.assertEqual(0, x[2]), other_partitions)
0096
0097 self.assertTrue(result[0][3] != result[9][3])
0098
0099 def test_tc_on_driver(self):
0100 """Verify that getting the TaskContext on the driver returns None."""
0101 tc = TaskContext.get()
0102 self.assertTrue(tc is None)
0103
0104 def test_get_local_property(self):
0105 """Verify that local properties set on the driver are available in TaskContext."""
0106 key = "testkey"
0107 value = "testvalue"
0108 self.sc.setLocalProperty(key, value)
0109 try:
0110 rdd = self.sc.parallelize(range(1), 1)
0111 prop1 = rdd.map(lambda _: TaskContext.get().getLocalProperty(key)).collect()[0]
0112 self.assertEqual(prop1, value)
0113 prop2 = rdd.map(lambda _: TaskContext.get().getLocalProperty("otherkey")).collect()[0]
0114 self.assertTrue(prop2 is None)
0115 finally:
0116 self.sc.setLocalProperty(key, None)
0117
0118 def test_barrier(self):
0119 """
0120 Verify that BarrierTaskContext.barrier() performs global sync among all barrier tasks
0121 within a stage.
0122 """
0123 rdd = self.sc.parallelize(range(10), 4)
0124
0125 def f(iterator):
0126 yield sum(iterator)
0127
0128 def context_barrier(x):
0129 tc = BarrierTaskContext.get()
0130 time.sleep(random.randint(1, 10))
0131 tc.barrier()
0132 return time.time()
0133
0134 times = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
0135 self.assertTrue(max(times) - min(times) < 1)
0136
0137 def test_all_gather(self):
0138 """
0139 Verify that BarrierTaskContext.allGather() performs global sync among all barrier tasks
0140 within a stage and passes messages properly.
0141 """
0142 rdd = self.sc.parallelize(range(10), 4)
0143
0144 def f(iterator):
0145 yield sum(iterator)
0146
0147 def context_barrier(x):
0148 tc = BarrierTaskContext.get()
0149 time.sleep(random.randint(1, 10))
0150 out = tc.allGather(str(tc.partitionId()))
0151 pids = [int(e) for e in out]
0152 return pids
0153
0154 pids = rdd.barrier().mapPartitions(f).map(context_barrier).collect()[0]
0155 self.assertEqual(pids, [0, 1, 2, 3])
0156
0157 def test_barrier_infos(self):
0158 """
0159 Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the
0160 barrier stage.
0161 """
0162 rdd = self.sc.parallelize(range(10), 4)
0163
0164 def f(iterator):
0165 yield sum(iterator)
0166
0167 taskInfos = rdd.barrier().mapPartitions(f).map(lambda x: BarrierTaskContext.get()
0168 .getTaskInfos()).collect()
0169 self.assertTrue(len(taskInfos) == 4)
0170 self.assertTrue(len(taskInfos[0]) == 4)
0171
0172 def test_context_get(self):
0173 """
0174 Verify that TaskContext.get() works both in or not in a barrier stage.
0175 """
0176 rdd = self.sc.parallelize(range(10), 4)
0177
0178 def f(iterator):
0179 taskContext = TaskContext.get()
0180 if isinstance(taskContext, BarrierTaskContext):
0181 yield taskContext.partitionId() + 1
0182 elif isinstance(taskContext, TaskContext):
0183 yield taskContext.partitionId() + 2
0184 else:
0185 yield -1
0186
0187
0188 result1 = rdd.mapPartitions(f).collect()
0189 self.assertTrue(result1 == [2, 3, 4, 5])
0190
0191 result2 = rdd.barrier().mapPartitions(f).collect()
0192 self.assertTrue(result2 == [1, 2, 3, 4])
0193
0194 def test_barrier_context_get(self):
0195 """
0196 Verify that BarrierTaskContext.get() should only works in a barrier stage.
0197 """
0198 rdd = self.sc.parallelize(range(10), 4)
0199
0200 def f(iterator):
0201 try:
0202 taskContext = BarrierTaskContext.get()
0203 except Exception:
0204 yield -1
0205 else:
0206 yield taskContext.partitionId()
0207
0208
0209 result1 = rdd.mapPartitions(f).collect()
0210 self.assertTrue(result1 == [-1, -1, -1, -1])
0211
0212 result2 = rdd.barrier().mapPartitions(f).collect()
0213 self.assertTrue(result2 == [0, 1, 2, 3])
0214
0215
0216 class TaskContextTestsWithWorkerReuse(unittest.TestCase):
0217
0218 def setUp(self):
0219 class_name = self.__class__.__name__
0220 conf = SparkConf().set("spark.python.worker.reuse", "true")
0221 self.sc = SparkContext('local[2]', class_name, conf=conf)
0222
0223 def test_barrier_with_python_worker_reuse(self):
0224 """
0225 Regression test for SPARK-25921: verify that BarrierTaskContext.barrier() with
0226 reused python worker.
0227 """
0228
0229 worker_pids = self.sc.parallelize(range(2), 2).map(lambda x: os.getpid()).collect()
0230
0231 rdd = self.sc.parallelize(range(10), 2)
0232
0233 def f(iterator):
0234 yield sum(iterator)
0235
0236 def context_barrier(x):
0237 tc = BarrierTaskContext.get()
0238 time.sleep(random.randint(1, 10))
0239 tc.barrier()
0240 return (time.time(), os.getpid())
0241
0242 result = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
0243 times = list(map(lambda x: x[0], result))
0244 pids = list(map(lambda x: x[1], result))
0245
0246 self.assertTrue(max(times) - min(times) < 1)
0247 for pid in pids:
0248 self.assertTrue(pid in worker_pids)
0249
0250 def test_task_context_correct_with_python_worker_reuse(self):
0251 """Verify the task context correct when reused python worker"""
0252
0253 worker_pids = self.sc.parallelize(xrange(2), 2).map(lambda x: os.getpid()).collect()
0254
0255 rdd = self.sc.parallelize(xrange(10), 2)
0256
0257 def context(iterator):
0258 tp = TaskContext.get().partitionId()
0259 try:
0260 bp = BarrierTaskContext.get().partitionId()
0261 except Exception:
0262 bp = -1
0263
0264 yield (tp, bp, os.getpid())
0265
0266
0267 normal_result = rdd.mapPartitions(context).collect()
0268 tps, bps, pids = zip(*normal_result)
0269 print(tps)
0270 self.assertTrue(tps == (0, 1))
0271 self.assertTrue(bps == (-1, -1))
0272 for pid in pids:
0273 self.assertTrue(pid in worker_pids)
0274
0275 barrier_result = rdd.barrier().mapPartitions(context).collect()
0276 tps, bps, pids = zip(*barrier_result)
0277 self.assertTrue(tps == (0, 1))
0278 self.assertTrue(bps == (0, 1))
0279 for pid in pids:
0280 self.assertTrue(pid in worker_pids)
0281
0282 normal_result2 = rdd.mapPartitions(context).collect()
0283 tps, bps, pids = zip(*normal_result2)
0284 self.assertTrue(tps == (0, 1))
0285 self.assertTrue(bps == (-1, -1))
0286 for pid in pids:
0287 self.assertTrue(pid in worker_pids)
0288
0289 def tearDown(self):
0290 self.sc.stop()
0291
0292
0293 class TaskContextTestsWithResources(unittest.TestCase):
0294
0295 def setUp(self):
0296 class_name = self.__class__.__name__
0297 self.tempFile = tempfile.NamedTemporaryFile(delete=False)
0298 self.tempFile.write(b'echo {\\"name\\": \\"gpu\\", \\"addresses\\": [\\"0\\"]}')
0299 self.tempFile.close()
0300 os.chmod(self.tempFile.name, stat.S_IRWXU | stat.S_IXGRP | stat.S_IRGRP |
0301 stat.S_IROTH | stat.S_IXOTH)
0302 conf = SparkConf().set("spark.test.home", SPARK_HOME)
0303 conf = conf.set("spark.worker.resource.gpu.discoveryScript", self.tempFile.name)
0304 conf = conf.set("spark.worker.resource.gpu.amount", 1)
0305 conf = conf.set("spark.task.resource.gpu.amount", "1")
0306 conf = conf.set("spark.executor.resource.gpu.amount", "1")
0307 self.sc = SparkContext('local-cluster[2,1,1024]', class_name, conf=conf)
0308
0309 def test_resources(self):
0310 """Test the resources are available."""
0311 rdd = self.sc.parallelize(range(10))
0312 resources = rdd.map(lambda x: TaskContext.get().resources()).take(1)[0]
0313 self.assertEqual(len(resources), 1)
0314 self.assertTrue('gpu' in resources)
0315 self.assertEqual(resources['gpu'].name, 'gpu')
0316 self.assertEqual(resources['gpu'].addresses, ['0'])
0317
0318 def tearDown(self):
0319 os.unlink(self.tempFile.name)
0320 self.sc.stop()
0321
0322 if __name__ == "__main__":
0323 import unittest
0324 from pyspark.tests.test_taskcontext import *
0325
0326 try:
0327 import xmlrunner
0328 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0329 except ImportError:
0330 testRunner = None
0331 unittest.main(testRunner=testRunner, verbosity=2)