Back to home page

OSCL-LXR

 
 

    


0001 # -*- encoding: utf-8 -*-
0002 #
0003 # Licensed to the Apache Software Foundation (ASF) under one or more
0004 # contributor license agreements.  See the NOTICE file distributed with
0005 # this work for additional information regarding copyright ownership.
0006 # The ASF licenses this file to You under the Apache License, Version 2.0
0007 # (the "License"); you may not use this file except in compliance with
0008 # the License.  You may obtain a copy of the License at
0009 #
0010 #    http://www.apache.org/licenses/LICENSE-2.0
0011 #
0012 # Unless required by applicable law or agreed to in writing, software
0013 # distributed under the License is distributed on an "AS IS" BASIS,
0014 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0015 # See the License for the specific language governing permissions and
0016 # limitations under the License.
0017 #
0018 import os
0019 import sys
0020 import tempfile
0021 import threading
0022 import time
0023 import unittest
0024 has_resource_module = True
0025 try:
0026     import resource
0027 except ImportError:
0028     has_resource_module = False
0029 
0030 from py4j.protocol import Py4JJavaError
0031 
0032 from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest
0033 
0034 if sys.version_info[0] >= 3:
0035     xrange = range
0036 
0037 
0038 class WorkerTests(ReusedPySparkTestCase):
0039     def test_cancel_task(self):
0040         temp = tempfile.NamedTemporaryFile(delete=True)
0041         temp.close()
0042         path = temp.name
0043 
0044         def sleep(x):
0045             import os
0046             import time
0047             with open(path, 'w') as f:
0048                 f.write("%d %d" % (os.getppid(), os.getpid()))
0049             time.sleep(100)
0050 
0051         # start job in background thread
0052         def run():
0053             try:
0054                 self.sc.parallelize(range(1), 1).foreach(sleep)
0055             except Exception:
0056                 pass
0057         import threading
0058         t = threading.Thread(target=run)
0059         t.daemon = True
0060         t.start()
0061 
0062         daemon_pid, worker_pid = 0, 0
0063         while True:
0064             if os.path.exists(path):
0065                 with open(path) as f:
0066                     data = f.read().split(' ')
0067                 daemon_pid, worker_pid = map(int, data)
0068                 break
0069             time.sleep(0.1)
0070 
0071         # cancel jobs
0072         self.sc.cancelAllJobs()
0073         t.join()
0074 
0075         for i in range(50):
0076             try:
0077                 os.kill(worker_pid, 0)
0078                 time.sleep(0.1)
0079             except OSError:
0080                 break  # worker was killed
0081         else:
0082             self.fail("worker has not been killed after 5 seconds")
0083 
0084         try:
0085             os.kill(daemon_pid, 0)
0086         except OSError:
0087             self.fail("daemon had been killed")
0088 
0089         # run a normal job
0090         rdd = self.sc.parallelize(xrange(100), 1)
0091         self.assertEqual(100, rdd.map(str).count())
0092 
0093     def test_after_exception(self):
0094         def raise_exception(_):
0095             raise Exception()
0096         rdd = self.sc.parallelize(xrange(100), 1)
0097         with QuietTest(self.sc):
0098             self.assertRaises(Exception, lambda: rdd.foreach(raise_exception))
0099         self.assertEqual(100, rdd.map(str).count())
0100 
0101     def test_after_jvm_exception(self):
0102         tempFile = tempfile.NamedTemporaryFile(delete=False)
0103         tempFile.write(b"Hello World!")
0104         tempFile.close()
0105         data = self.sc.textFile(tempFile.name, 1)
0106         filtered_data = data.filter(lambda x: True)
0107         self.assertEqual(1, filtered_data.count())
0108         os.unlink(tempFile.name)
0109         with QuietTest(self.sc):
0110             self.assertRaises(Exception, lambda: filtered_data.count())
0111 
0112         rdd = self.sc.parallelize(xrange(100), 1)
0113         self.assertEqual(100, rdd.map(str).count())
0114 
0115     def test_accumulator_when_reuse_worker(self):
0116         from pyspark.accumulators import INT_ACCUMULATOR_PARAM
0117         acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
0118         self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x))
0119         self.assertEqual(sum(range(100)), acc1.value)
0120 
0121         acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
0122         self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x))
0123         self.assertEqual(sum(range(100)), acc2.value)
0124         self.assertEqual(sum(range(100)), acc1.value)
0125 
0126     def test_reuse_worker_after_take(self):
0127         rdd = self.sc.parallelize(xrange(100000), 1)
0128         self.assertEqual(0, rdd.first())
0129 
0130         def count():
0131             try:
0132                 rdd.count()
0133             except Exception:
0134                 pass
0135 
0136         t = threading.Thread(target=count)
0137         t.daemon = True
0138         t.start()
0139         t.join(5)
0140         self.assertTrue(not t.isAlive())
0141         self.assertEqual(100000, rdd.count())
0142 
0143     def test_with_different_versions_of_python(self):
0144         rdd = self.sc.parallelize(range(10))
0145         rdd.count()
0146         version = self.sc.pythonVer
0147         self.sc.pythonVer = "2.0"
0148         try:
0149             with QuietTest(self.sc):
0150                 self.assertRaises(Py4JJavaError, lambda: rdd.count())
0151         finally:
0152             self.sc.pythonVer = version
0153 
0154     def test_python_exception_non_hanging(self):
0155         # SPARK-21045: exceptions with no ascii encoding shall not hanging PySpark.
0156         try:
0157             def f():
0158                 raise Exception("exception with 中 and \xd6\xd0")
0159 
0160             self.sc.parallelize([1]).map(lambda x: f()).count()
0161         except Py4JJavaError as e:
0162             if sys.version_info.major < 3:
0163                 # we have to use unicode here to avoid UnicodeDecodeError
0164                 self.assertRegexpMatches(unicode(e).encode("utf-8"), "exception with 中")
0165             else:
0166                 self.assertRegexpMatches(str(e), "exception with 中")
0167 
0168 
0169 class WorkerReuseTest(PySparkTestCase):
0170 
0171     def test_reuse_worker_of_parallelize_xrange(self):
0172         rdd = self.sc.parallelize(xrange(20), 8)
0173         previous_pids = rdd.map(lambda x: os.getpid()).collect()
0174         current_pids = rdd.map(lambda x: os.getpid()).collect()
0175         for pid in current_pids:
0176             self.assertTrue(pid in previous_pids)
0177 
0178 
0179 @unittest.skipIf(
0180     not has_resource_module,
0181     "Memory limit feature in Python worker is dependent on "
0182     "Python's 'resource' module; however, not found.")
0183 class WorkerMemoryTest(PySparkTestCase):
0184 
0185     def test_memory_limit(self):
0186         self.sc._conf.set("spark.executor.pyspark.memory", "2g")
0187         rdd = self.sc.parallelize(xrange(1), 1)
0188 
0189         def getrlimit():
0190             import resource
0191             return resource.getrlimit(resource.RLIMIT_AS)
0192 
0193         actual = rdd.map(lambda _: getrlimit()).collect()
0194         self.assertTrue(len(actual) == 1)
0195         self.assertTrue(len(actual[0]) == 2)
0196         [(soft_limit, hard_limit)] = actual
0197         self.assertEqual(soft_limit, 2 * 1024 * 1024 * 1024)
0198         self.assertEqual(hard_limit, 2 * 1024 * 1024 * 1024)
0199 
0200 
0201 if __name__ == "__main__":
0202     import unittest
0203     from pyspark.tests.test_worker import *
0204 
0205     try:
0206         import xmlrunner
0207         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0208     except ImportError:
0209         testRunner = None
0210     unittest.main(testRunner=testRunner, verbosity=2)