0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import os
0019 import sys
0020 import tempfile
0021 import threading
0022 import time
0023 import unittest
0024 has_resource_module = True
0025 try:
0026 import resource
0027 except ImportError:
0028 has_resource_module = False
0029
0030 from py4j.protocol import Py4JJavaError
0031
0032 from pyspark.testing.utils import ReusedPySparkTestCase, PySparkTestCase, QuietTest
0033
0034 if sys.version_info[0] >= 3:
0035 xrange = range
0036
0037
0038 class WorkerTests(ReusedPySparkTestCase):
0039 def test_cancel_task(self):
0040 temp = tempfile.NamedTemporaryFile(delete=True)
0041 temp.close()
0042 path = temp.name
0043
0044 def sleep(x):
0045 import os
0046 import time
0047 with open(path, 'w') as f:
0048 f.write("%d %d" % (os.getppid(), os.getpid()))
0049 time.sleep(100)
0050
0051
0052 def run():
0053 try:
0054 self.sc.parallelize(range(1), 1).foreach(sleep)
0055 except Exception:
0056 pass
0057 import threading
0058 t = threading.Thread(target=run)
0059 t.daemon = True
0060 t.start()
0061
0062 daemon_pid, worker_pid = 0, 0
0063 while True:
0064 if os.path.exists(path):
0065 with open(path) as f:
0066 data = f.read().split(' ')
0067 daemon_pid, worker_pid = map(int, data)
0068 break
0069 time.sleep(0.1)
0070
0071
0072 self.sc.cancelAllJobs()
0073 t.join()
0074
0075 for i in range(50):
0076 try:
0077 os.kill(worker_pid, 0)
0078 time.sleep(0.1)
0079 except OSError:
0080 break
0081 else:
0082 self.fail("worker has not been killed after 5 seconds")
0083
0084 try:
0085 os.kill(daemon_pid, 0)
0086 except OSError:
0087 self.fail("daemon had been killed")
0088
0089
0090 rdd = self.sc.parallelize(xrange(100), 1)
0091 self.assertEqual(100, rdd.map(str).count())
0092
0093 def test_after_exception(self):
0094 def raise_exception(_):
0095 raise Exception()
0096 rdd = self.sc.parallelize(xrange(100), 1)
0097 with QuietTest(self.sc):
0098 self.assertRaises(Exception, lambda: rdd.foreach(raise_exception))
0099 self.assertEqual(100, rdd.map(str).count())
0100
0101 def test_after_jvm_exception(self):
0102 tempFile = tempfile.NamedTemporaryFile(delete=False)
0103 tempFile.write(b"Hello World!")
0104 tempFile.close()
0105 data = self.sc.textFile(tempFile.name, 1)
0106 filtered_data = data.filter(lambda x: True)
0107 self.assertEqual(1, filtered_data.count())
0108 os.unlink(tempFile.name)
0109 with QuietTest(self.sc):
0110 self.assertRaises(Exception, lambda: filtered_data.count())
0111
0112 rdd = self.sc.parallelize(xrange(100), 1)
0113 self.assertEqual(100, rdd.map(str).count())
0114
0115 def test_accumulator_when_reuse_worker(self):
0116 from pyspark.accumulators import INT_ACCUMULATOR_PARAM
0117 acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
0118 self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc1.add(x))
0119 self.assertEqual(sum(range(100)), acc1.value)
0120
0121 acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM)
0122 self.sc.parallelize(xrange(100), 20).foreach(lambda x: acc2.add(x))
0123 self.assertEqual(sum(range(100)), acc2.value)
0124 self.assertEqual(sum(range(100)), acc1.value)
0125
0126 def test_reuse_worker_after_take(self):
0127 rdd = self.sc.parallelize(xrange(100000), 1)
0128 self.assertEqual(0, rdd.first())
0129
0130 def count():
0131 try:
0132 rdd.count()
0133 except Exception:
0134 pass
0135
0136 t = threading.Thread(target=count)
0137 t.daemon = True
0138 t.start()
0139 t.join(5)
0140 self.assertTrue(not t.isAlive())
0141 self.assertEqual(100000, rdd.count())
0142
0143 def test_with_different_versions_of_python(self):
0144 rdd = self.sc.parallelize(range(10))
0145 rdd.count()
0146 version = self.sc.pythonVer
0147 self.sc.pythonVer = "2.0"
0148 try:
0149 with QuietTest(self.sc):
0150 self.assertRaises(Py4JJavaError, lambda: rdd.count())
0151 finally:
0152 self.sc.pythonVer = version
0153
0154 def test_python_exception_non_hanging(self):
0155
0156 try:
0157 def f():
0158 raise Exception("exception with 中 and \xd6\xd0")
0159
0160 self.sc.parallelize([1]).map(lambda x: f()).count()
0161 except Py4JJavaError as e:
0162 if sys.version_info.major < 3:
0163
0164 self.assertRegexpMatches(unicode(e).encode("utf-8"), "exception with 中")
0165 else:
0166 self.assertRegexpMatches(str(e), "exception with 中")
0167
0168
0169 class WorkerReuseTest(PySparkTestCase):
0170
0171 def test_reuse_worker_of_parallelize_xrange(self):
0172 rdd = self.sc.parallelize(xrange(20), 8)
0173 previous_pids = rdd.map(lambda x: os.getpid()).collect()
0174 current_pids = rdd.map(lambda x: os.getpid()).collect()
0175 for pid in current_pids:
0176 self.assertTrue(pid in previous_pids)
0177
0178
0179 @unittest.skipIf(
0180 not has_resource_module,
0181 "Memory limit feature in Python worker is dependent on "
0182 "Python's 'resource' module; however, not found.")
0183 class WorkerMemoryTest(PySparkTestCase):
0184
0185 def test_memory_limit(self):
0186 self.sc._conf.set("spark.executor.pyspark.memory", "2g")
0187 rdd = self.sc.parallelize(xrange(1), 1)
0188
0189 def getrlimit():
0190 import resource
0191 return resource.getrlimit(resource.RLIMIT_AS)
0192
0193 actual = rdd.map(lambda _: getrlimit()).collect()
0194 self.assertTrue(len(actual) == 1)
0195 self.assertTrue(len(actual[0]) == 2)
0196 [(soft_limit, hard_limit)] = actual
0197 self.assertEqual(soft_limit, 2 * 1024 * 1024 * 1024)
0198 self.assertEqual(hard_limit, 2 * 1024 * 1024 * 1024)
0199
0200
0201 if __name__ == "__main__":
0202 import unittest
0203 from pyspark.tests.test_worker import *
0204
0205 try:
0206 import xmlrunner
0207 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0208 except ImportError:
0209 testRunner = None
0210 unittest.main(testRunner=testRunner, verbosity=2)