0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 import os
0018 import time
0019 import random
0020 import threading
0021 import unittest
0022
0023 from pyspark import SparkContext, SparkConf
0024
0025
0026 class PinThreadTests(unittest.TestCase):
0027
0028
0029
0030 @classmethod
0031 def setUpClass(cls):
0032 cls.old_pin_thread = os.environ.get("PYSPARK_PIN_THREAD")
0033 os.environ["PYSPARK_PIN_THREAD"] = "true"
0034 cls.sc = SparkContext('local[4]', cls.__name__, conf=SparkConf())
0035
0036 @classmethod
0037 def tearDownClass(cls):
0038 cls.sc.stop()
0039 if cls.old_pin_thread is not None:
0040 os.environ["PYSPARK_PIN_THREAD"] = cls.old_pin_thread
0041 else:
0042 del os.environ["PYSPARK_PIN_THREAD"]
0043
0044 def test_pinned_thread(self):
0045 threads = []
0046 exceptions = []
0047 property_name = "test_property_%s" % PinThreadTests.__name__
0048 jvm_thread_ids = []
0049
0050 for i in range(10):
0051 def test_local_property():
0052 jvm_thread_id = self.sc._jvm.java.lang.Thread.currentThread().getId()
0053 jvm_thread_ids.append(jvm_thread_id)
0054
0055
0056
0057 self.sc.setLocalProperty(property_name, str(i))
0058
0059
0060 time.sleep(i % 2)
0061
0062 try:
0063 assert self.sc.getLocalProperty(property_name) == str(i)
0064
0065
0066
0067 assert jvm_thread_id == self.sc._jvm.java.lang.Thread.currentThread().getId()
0068 except Exception as e:
0069 exceptions.append(e)
0070 threads.append(threading.Thread(target=test_local_property))
0071
0072 for t in threads:
0073 t.start()
0074
0075 for t in threads:
0076 t.join()
0077
0078 for e in exceptions:
0079 raise e
0080
0081
0082 assert len(set(jvm_thread_ids)) == 10
0083
0084 def test_multiple_group_jobs(self):
0085
0086
0087 group_a = "job_ids_to_cancel"
0088 group_b = "job_ids_to_run"
0089
0090 threads = []
0091 thread_ids = range(4)
0092 thread_ids_to_cancel = [i for i in thread_ids if i % 2 == 0]
0093 thread_ids_to_run = [i for i in thread_ids if i % 2 != 0]
0094
0095
0096
0097 is_job_cancelled = [False for _ in thread_ids]
0098
0099 def run_job(job_group, index):
0100 """
0101 Executes a job with the group ``job_group``. Each job waits for 3 seconds
0102 and then exits.
0103 """
0104 try:
0105 self.sc.setJobGroup(job_group, "test rdd collect with setting job group")
0106 self.sc.parallelize([15]).map(lambda x: time.sleep(x)).collect()
0107 is_job_cancelled[index] = False
0108 except Exception:
0109
0110 is_job_cancelled[index] = True
0111
0112
0113 run_job(group_a, 0)
0114 self.assertFalse(is_job_cancelled[0])
0115
0116
0117 for i in thread_ids_to_cancel:
0118 t = threading.Thread(target=run_job, args=(group_a, i))
0119 t.start()
0120 threads.append(t)
0121
0122 for i in thread_ids_to_run:
0123 t = threading.Thread(target=run_job, args=(group_b, i))
0124 t.start()
0125 threads.append(t)
0126
0127
0128 time.sleep(3)
0129
0130 self.sc.cancelJobGroup(group_a)
0131
0132
0133 for t in threads:
0134 t.join()
0135
0136 for i in thread_ids_to_cancel:
0137 self.assertTrue(
0138 is_job_cancelled[i],
0139 "Thread {i}: Job in group A was not cancelled.".format(i=i))
0140
0141 for i in thread_ids_to_run:
0142 self.assertFalse(
0143 is_job_cancelled[i],
0144 "Thread {i}: Job in group B did not succeeded.".format(i=i))
0145
0146
0147 if __name__ == "__main__":
0148 import unittest
0149 from pyspark.tests.test_pin_thread import *
0150
0151 try:
0152 import xmlrunner
0153 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0154 except ImportError:
0155 testRunner = None
0156 unittest.main(testRunner=testRunner, verbosity=2)