Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 import os
0018 import time
0019 import random
0020 import threading
0021 import unittest
0022 
0023 from pyspark import SparkContext, SparkConf
0024 
0025 
0026 class PinThreadTests(unittest.TestCase):
0027     # These tests are in a separate class because it uses
0028     # 'PYSPARK_PIN_THREAD' environment variable to test thread pin feature.
0029 
0030     @classmethod
0031     def setUpClass(cls):
0032         cls.old_pin_thread = os.environ.get("PYSPARK_PIN_THREAD")
0033         os.environ["PYSPARK_PIN_THREAD"] = "true"
0034         cls.sc = SparkContext('local[4]', cls.__name__, conf=SparkConf())
0035 
0036     @classmethod
0037     def tearDownClass(cls):
0038         cls.sc.stop()
0039         if cls.old_pin_thread is not None:
0040             os.environ["PYSPARK_PIN_THREAD"] = cls.old_pin_thread
0041         else:
0042             del os.environ["PYSPARK_PIN_THREAD"]
0043 
0044     def test_pinned_thread(self):
0045         threads = []
0046         exceptions = []
0047         property_name = "test_property_%s" % PinThreadTests.__name__
0048         jvm_thread_ids = []
0049 
0050         for i in range(10):
0051             def test_local_property():
0052                 jvm_thread_id = self.sc._jvm.java.lang.Thread.currentThread().getId()
0053                 jvm_thread_ids.append(jvm_thread_id)
0054 
0055                 # If a property is set in this thread, later it should get the same property
0056                 # within this thread.
0057                 self.sc.setLocalProperty(property_name, str(i))
0058 
0059                 # 5 threads, 1 second sleep. 5 threads without a sleep.
0060                 time.sleep(i % 2)
0061 
0062                 try:
0063                     assert self.sc.getLocalProperty(property_name) == str(i)
0064 
0065                     # Each command might create a thread in multi-threading mode in Py4J.
0066                     # This assert makes sure that the created thread is being reused.
0067                     assert jvm_thread_id == self.sc._jvm.java.lang.Thread.currentThread().getId()
0068                 except Exception as e:
0069                     exceptions.append(e)
0070             threads.append(threading.Thread(target=test_local_property))
0071 
0072         for t in threads:
0073             t.start()
0074 
0075         for t in threads:
0076             t.join()
0077 
0078         for e in exceptions:
0079             raise e
0080 
0081         # Created JVM threads should be 10 because Python thread are 10.
0082         assert len(set(jvm_thread_ids)) == 10
0083 
0084     def test_multiple_group_jobs(self):
0085         # SPARK-22340 Add a mode to pin Python thread into JVM's
0086 
0087         group_a = "job_ids_to_cancel"
0088         group_b = "job_ids_to_run"
0089 
0090         threads = []
0091         thread_ids = range(4)
0092         thread_ids_to_cancel = [i for i in thread_ids if i % 2 == 0]
0093         thread_ids_to_run = [i for i in thread_ids if i % 2 != 0]
0094 
0095         # A list which records whether job is cancelled.
0096         # The index of the array is the thread index which job run in.
0097         is_job_cancelled = [False for _ in thread_ids]
0098 
0099         def run_job(job_group, index):
0100             """
0101             Executes a job with the group ``job_group``. Each job waits for 3 seconds
0102             and then exits.
0103             """
0104             try:
0105                 self.sc.setJobGroup(job_group, "test rdd collect with setting job group")
0106                 self.sc.parallelize([15]).map(lambda x: time.sleep(x)).collect()
0107                 is_job_cancelled[index] = False
0108             except Exception:
0109                 # Assume that exception means job cancellation.
0110                 is_job_cancelled[index] = True
0111 
0112         # Test if job succeeded when not cancelled.
0113         run_job(group_a, 0)
0114         self.assertFalse(is_job_cancelled[0])
0115 
0116         # Run jobs
0117         for i in thread_ids_to_cancel:
0118             t = threading.Thread(target=run_job, args=(group_a, i))
0119             t.start()
0120             threads.append(t)
0121 
0122         for i in thread_ids_to_run:
0123             t = threading.Thread(target=run_job, args=(group_b, i))
0124             t.start()
0125             threads.append(t)
0126 
0127         # Wait to make sure all jobs are executed.
0128         time.sleep(3)
0129         # And then, cancel one job group.
0130         self.sc.cancelJobGroup(group_a)
0131 
0132         # Wait until all threads launching jobs are finished.
0133         for t in threads:
0134             t.join()
0135 
0136         for i in thread_ids_to_cancel:
0137             self.assertTrue(
0138                 is_job_cancelled[i],
0139                 "Thread {i}: Job in group A was not cancelled.".format(i=i))
0140 
0141         for i in thread_ids_to_run:
0142             self.assertFalse(
0143                 is_job_cancelled[i],
0144                 "Thread {i}: Job in group B did not succeeded.".format(i=i))
0145 
0146 
0147 if __name__ == "__main__":
0148     import unittest
0149     from pyspark.tests.test_pin_thread import *
0150 
0151     try:
0152         import xmlrunner
0153         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0154     except ImportError:
0155         testRunner = None
0156     unittest.main(testRunner=testRunner, verbosity=2)