0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import os
0019 import sys
0020 import tempfile
0021 import unittest
0022
0023 from pyspark import SparkConf, SparkContext, BasicProfiler
0024 from pyspark.testing.utils import PySparkTestCase
0025
0026 if sys.version >= "3":
0027 from io import StringIO
0028 else:
0029 from StringIO import StringIO
0030
0031
0032 class ProfilerTests(PySparkTestCase):
0033
0034 def setUp(self):
0035 self._old_sys_path = list(sys.path)
0036 class_name = self.__class__.__name__
0037 conf = SparkConf().set("spark.python.profile", "true")
0038 self.sc = SparkContext('local[4]', class_name, conf=conf)
0039
0040 def test_profiler(self):
0041 self.do_computation()
0042
0043 profilers = self.sc.profiler_collector.profilers
0044 self.assertEqual(1, len(profilers))
0045 id, profiler, _ = profilers[0]
0046 stats = profiler.stats()
0047 self.assertTrue(stats is not None)
0048 width, stat_list = stats.get_print_list([])
0049 func_names = [func_name for fname, n, func_name in stat_list]
0050 self.assertTrue("heavy_foo" in func_names)
0051
0052 old_stdout = sys.stdout
0053 sys.stdout = io = StringIO()
0054 self.sc.show_profiles()
0055 self.assertTrue("heavy_foo" in io.getvalue())
0056 sys.stdout = old_stdout
0057
0058 d = tempfile.gettempdir()
0059 self.sc.dump_profiles(d)
0060 self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
0061
0062 def test_custom_profiler(self):
0063 class TestCustomProfiler(BasicProfiler):
0064 def show(self, id):
0065 self.result = "Custom formatting"
0066
0067 self.sc.profiler_collector.profiler_cls = TestCustomProfiler
0068
0069 self.do_computation()
0070
0071 profilers = self.sc.profiler_collector.profilers
0072 self.assertEqual(1, len(profilers))
0073 _, profiler, _ = profilers[0]
0074 self.assertTrue(isinstance(profiler, TestCustomProfiler))
0075
0076 self.sc.show_profiles()
0077 self.assertEqual("Custom formatting", profiler.result)
0078
0079 def do_computation(self):
0080 def heavy_foo(x):
0081 for i in range(1 << 18):
0082 x = 1
0083
0084 rdd = self.sc.parallelize(range(100))
0085 rdd.foreach(heavy_foo)
0086
0087
0088 class ProfilerTests2(unittest.TestCase):
0089 def test_profiler_disabled(self):
0090 sc = SparkContext(conf=SparkConf().set("spark.python.profile", "false"))
0091 try:
0092 self.assertRaisesRegexp(
0093 RuntimeError,
0094 "'spark.python.profile' configuration must be set",
0095 lambda: sc.show_profiles())
0096 self.assertRaisesRegexp(
0097 RuntimeError,
0098 "'spark.python.profile' configuration must be set",
0099 lambda: sc.dump_profiles("/tmp/abc"))
0100 finally:
0101 sc.stop()
0102
0103
0104 if __name__ == "__main__":
0105 from pyspark.tests.test_profiler import *
0106
0107 try:
0108 import xmlrunner
0109 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0110 except ImportError:
0111 testRunner = None
0112 unittest.main(testRunner=testRunner, verbosity=2)