Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import py4j
0019 import sys
0020 
0021 from pyspark import SparkContext
0022 
0023 if sys.version_info.major >= 3:
0024     unicode = str
0025     # Disable exception chaining (PEP 3134) in captured exceptions
0026     # in order to hide JVM stacktace.
0027     exec("""
0028 def raise_from(e):
0029     raise e from None
0030 """)
0031 else:
0032     def raise_from(e):
0033         raise e
0034 
0035 
0036 class CapturedException(Exception):
0037     def __init__(self, desc, stackTrace, cause=None):
0038         self.desc = desc
0039         self.stackTrace = stackTrace
0040         self.cause = convert_exception(cause) if cause is not None else None
0041 
0042     def __str__(self):
0043         sql_conf = SparkContext._jvm.org.apache.spark.sql.internal.SQLConf.get()
0044         debug_enabled = sql_conf.pysparkJVMStacktraceEnabled()
0045         desc = self.desc
0046         if debug_enabled:
0047             desc = desc + "\n\nJVM stacktrace:\n%s" % self.stackTrace
0048         # encode unicode instance for python2 for human readable description
0049         if sys.version_info.major < 3 and isinstance(desc, unicode):
0050             return str(desc.encode('utf-8'))
0051         else:
0052             return str(desc)
0053 
0054 
0055 class AnalysisException(CapturedException):
0056     """
0057     Failed to analyze a SQL query plan.
0058     """
0059 
0060 
0061 class ParseException(CapturedException):
0062     """
0063     Failed to parse a SQL command.
0064     """
0065 
0066 
0067 class IllegalArgumentException(CapturedException):
0068     """
0069     Passed an illegal or inappropriate argument.
0070     """
0071 
0072 
0073 class StreamingQueryException(CapturedException):
0074     """
0075     Exception that stopped a :class:`StreamingQuery`.
0076     """
0077 
0078 
0079 class QueryExecutionException(CapturedException):
0080     """
0081     Failed to execute a query.
0082     """
0083 
0084 
0085 class PythonException(CapturedException):
0086     """
0087     Exceptions thrown from Python workers.
0088     """
0089 
0090 
0091 class UnknownException(CapturedException):
0092     """
0093     None of the above exceptions.
0094     """
0095 
0096 
0097 def convert_exception(e):
0098     s = e.toString()
0099     c = e.getCause()
0100 
0101     jvm = SparkContext._jvm
0102     jwriter = jvm.java.io.StringWriter()
0103     e.printStackTrace(jvm.java.io.PrintWriter(jwriter))
0104     stacktrace = jwriter.toString()
0105     if s.startswith('org.apache.spark.sql.AnalysisException: '):
0106         return AnalysisException(s.split(': ', 1)[1], stacktrace, c)
0107     if s.startswith('org.apache.spark.sql.catalyst.analysis'):
0108         return AnalysisException(s.split(': ', 1)[1], stacktrace, c)
0109     if s.startswith('org.apache.spark.sql.catalyst.parser.ParseException: '):
0110         return ParseException(s.split(': ', 1)[1], stacktrace, c)
0111     if s.startswith('org.apache.spark.sql.streaming.StreamingQueryException: '):
0112         return StreamingQueryException(s.split(': ', 1)[1], stacktrace, c)
0113     if s.startswith('org.apache.spark.sql.execution.QueryExecutionException: '):
0114         return QueryExecutionException(s.split(': ', 1)[1], stacktrace, c)
0115     if s.startswith('java.lang.IllegalArgumentException: '):
0116         return IllegalArgumentException(s.split(': ', 1)[1], stacktrace, c)
0117     if c is not None and (
0118             c.toString().startswith('org.apache.spark.api.python.PythonException: ')
0119             # To make sure this only catches Python UDFs.
0120             and any(map(lambda v: "org.apache.spark.sql.execution.python" in v.toString(),
0121                         c.getStackTrace()))):
0122         msg = ("\n  An exception was thrown from Python worker in the executor. "
0123                "The below is the Python worker stacktrace.\n%s" % c.getMessage())
0124         return PythonException(msg, stacktrace)
0125     return UnknownException(s, stacktrace, c)
0126 
0127 
0128 def capture_sql_exception(f):
0129     def deco(*a, **kw):
0130         try:
0131             return f(*a, **kw)
0132         except py4j.protocol.Py4JJavaError as e:
0133             converted = convert_exception(e.java_exception)
0134             if not isinstance(converted, UnknownException):
0135                 # Hide where the exception came from that shows a non-Pythonic
0136                 # JVM exception message.
0137                 raise_from(converted)
0138             else:
0139                 raise
0140     return deco
0141 
0142 
0143 def install_exception_handler():
0144     """
0145     Hook an exception handler into Py4j, which could capture some SQL exceptions in Java.
0146 
0147     When calling Java API, it will call `get_return_value` to parse the returned object.
0148     If any exception happened in JVM, the result will be Java exception object, it raise
0149     py4j.protocol.Py4JJavaError. We replace the original `get_return_value` with one that
0150     could capture the Java exception and throw a Python one (with the same error message).
0151 
0152     It's idempotent, could be called multiple times.
0153     """
0154     original = py4j.protocol.get_return_value
0155     # The original `get_return_value` is not patched, it's idempotent.
0156     patched = capture_sql_exception(original)
0157     # only patch the one used in py4j.java_gateway (call Java API)
0158     py4j.java_gateway.get_return_value = patched
0159 
0160 
0161 def toJArray(gateway, jtype, arr):
0162     """
0163     Convert python list to java type array
0164     :param gateway: Py4j Gateway
0165     :param jtype: java type of element in array
0166     :param arr: python type list
0167     """
0168     jarr = gateway.new_array(jtype, len(arr))
0169     for i in range(0, len(arr)):
0170         jarr[i] = arr[i]
0171     return jarr
0172 
0173 
0174 def require_test_compiled():
0175     """ Raise Exception if test classes are not compiled
0176     """
0177     import os
0178     import glob
0179     try:
0180         spark_home = os.environ['SPARK_HOME']
0181     except KeyError:
0182         raise RuntimeError('SPARK_HOME is not defined in environment')
0183 
0184     test_class_path = os.path.join(
0185         spark_home, 'sql', 'core', 'target', '*', 'test-classes')
0186     paths = glob.glob(test_class_path)
0187 
0188     if len(paths) == 0:
0189         raise RuntimeError(
0190             "%s doesn't exist. Spark sql test classes are not compiled." % test_class_path)
0191 
0192 
0193 class ForeachBatchFunction(object):
0194     """
0195     This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps
0196     the user-defined 'foreachBatch' function such that it can be called from the JVM when
0197     the query is active.
0198     """
0199 
0200     def __init__(self, sql_ctx, func):
0201         self.sql_ctx = sql_ctx
0202         self.func = func
0203 
0204     def call(self, jdf, batch_id):
0205         from pyspark.sql.dataframe import DataFrame
0206         try:
0207             self.func(DataFrame(jdf, self.sql_ctx), batch_id)
0208         except Exception as e:
0209             self.error = e
0210             raise e
0211 
0212     class Java:
0213         implements = ['org.apache.spark.sql.execution.streaming.sources.PythonForeachBatchFunction']
0214 
0215 
0216 def to_str(value):
0217     """
0218     A wrapper over str(), but converts bool values to lower case strings.
0219     If None is given, just returns None, instead of converting it to string "None".
0220     """
0221     if isinstance(value, bool):
0222         return str(value).lower()
0223     elif value is None:
0224         return value
0225     else:
0226         return str(value)