Back to home page




0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0018 from collections import namedtuple
0019 import os
0020 import traceback
0023 CallSite = namedtuple("CallSite", "function file linenum")
0026 def first_spark_call():
0027     """
0028     Return a CallSite representing the first Spark call in the current call stack.
0029     """
0030     tb = traceback.extract_stack()
0031     if len(tb) == 0:
0032         return None
0033     file, line, module, what = tb[len(tb) - 1]
0034     sparkpath = os.path.dirname(file)
0035     first_spark_frame = len(tb) - 1
0036     for i in range(0, len(tb)):
0037         file, line, fun, what = tb[i]
0038         if file.startswith(sparkpath):
0039             first_spark_frame = i
0040             break
0041     if first_spark_frame == 0:
0042         file, line, fun, what = tb[0]
0043         return CallSite(function=fun, file=file, linenum=line)
0044     sfile, sline, sfun, swhat = tb[first_spark_frame]
0045     ufile, uline, ufun, uwhat = tb[first_spark_frame - 1]
0046     return CallSite(function=sfun, file=ufile, linenum=uline)
0049 class SCCallSiteSync(object):
0050     """
0051     Helper for setting the spark context call site.
0053     Example usage:
0054     from pyspark.context import SCCallSiteSync
0055     with SCCallSiteSync(<relevant SparkContext>) as css:
0056         <a Spark call>
0057     """
0059     _spark_stack_depth = 0
0061     def __init__(self, sc):
0062         call_site = first_spark_call()
0063         if call_site is not None:
0064             self._call_site = "%s at %s:%s" % (
0065                 call_site.function, call_site.file, call_site.linenum)
0066         else:
0067             self._call_site = "Error! Could not extract traceback info"
0068         self._context = sc
0070     def __enter__(self):
0071         if SCCallSiteSync._spark_stack_depth == 0:
0072             self._context._jsc.setCallSite(self._call_site)
0073         SCCallSiteSync._spark_stack_depth += 1
0075     def __exit__(self, type, value, tb):
0076         SCCallSiteSync._spark_stack_depth -= 1
0077         if SCCallSiteSync._spark_stack_depth == 0:
0078             self._context._jsc.setCallSite(None)