Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import os
0019 import sys
0020 import tempfile
0021 import unittest
0022 
0023 from pyspark import SparkConf, SparkContext, BasicProfiler
0024 from pyspark.testing.utils import PySparkTestCase
0025 
0026 if sys.version >= "3":
0027     from io import StringIO
0028 else:
0029     from StringIO import StringIO
0030 
0031 
0032 class ProfilerTests(PySparkTestCase):
0033 
0034     def setUp(self):
0035         self._old_sys_path = list(sys.path)
0036         class_name = self.__class__.__name__
0037         conf = SparkConf().set("spark.python.profile", "true")
0038         self.sc = SparkContext('local[4]', class_name, conf=conf)
0039 
0040     def test_profiler(self):
0041         self.do_computation()
0042 
0043         profilers = self.sc.profiler_collector.profilers
0044         self.assertEqual(1, len(profilers))
0045         id, profiler, _ = profilers[0]
0046         stats = profiler.stats()
0047         self.assertTrue(stats is not None)
0048         width, stat_list = stats.get_print_list([])
0049         func_names = [func_name for fname, n, func_name in stat_list]
0050         self.assertTrue("heavy_foo" in func_names)
0051 
0052         old_stdout = sys.stdout
0053         sys.stdout = io = StringIO()
0054         self.sc.show_profiles()
0055         self.assertTrue("heavy_foo" in io.getvalue())
0056         sys.stdout = old_stdout
0057 
0058         d = tempfile.gettempdir()
0059         self.sc.dump_profiles(d)
0060         self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
0061 
0062     def test_custom_profiler(self):
0063         class TestCustomProfiler(BasicProfiler):
0064             def show(self, id):
0065                 self.result = "Custom formatting"
0066 
0067         self.sc.profiler_collector.profiler_cls = TestCustomProfiler
0068 
0069         self.do_computation()
0070 
0071         profilers = self.sc.profiler_collector.profilers
0072         self.assertEqual(1, len(profilers))
0073         _, profiler, _ = profilers[0]
0074         self.assertTrue(isinstance(profiler, TestCustomProfiler))
0075 
0076         self.sc.show_profiles()
0077         self.assertEqual("Custom formatting", profiler.result)
0078 
0079     def do_computation(self):
0080         def heavy_foo(x):
0081             for i in range(1 << 18):
0082                 x = 1
0083 
0084         rdd = self.sc.parallelize(range(100))
0085         rdd.foreach(heavy_foo)
0086 
0087 
0088 class ProfilerTests2(unittest.TestCase):
0089     def test_profiler_disabled(self):
0090         sc = SparkContext(conf=SparkConf().set("spark.python.profile", "false"))
0091         try:
0092             self.assertRaisesRegexp(
0093                 RuntimeError,
0094                 "'spark.python.profile' configuration must be set",
0095                 lambda: sc.show_profiles())
0096             self.assertRaisesRegexp(
0097                 RuntimeError,
0098                 "'spark.python.profile' configuration must be set",
0099                 lambda: sc.dump_profiles("/tmp/abc"))
0100         finally:
0101             sc.stop()
0102 
0103 
0104 if __name__ == "__main__":
0105     from pyspark.tests.test_profiler import *
0106 
0107     try:
0108         import xmlrunner
0109         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0110     except ImportError:
0111         testRunner = None
0112     unittest.main(testRunner=testRunner, verbosity=2)