0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import py4j
0019 import sys
0020
0021 from pyspark import SparkContext
0022
0023 if sys.version_info.major >= 3:
0024 unicode = str
0025
0026
0027 exec("""
0028 def raise_from(e):
0029 raise e from None
0030 """)
0031 else:
0032 def raise_from(e):
0033 raise e
0034
0035
0036 class CapturedException(Exception):
0037 def __init__(self, desc, stackTrace, cause=None):
0038 self.desc = desc
0039 self.stackTrace = stackTrace
0040 self.cause = convert_exception(cause) if cause is not None else None
0041
0042 def __str__(self):
0043 sql_conf = SparkContext._jvm.org.apache.spark.sql.internal.SQLConf.get()
0044 debug_enabled = sql_conf.pysparkJVMStacktraceEnabled()
0045 desc = self.desc
0046 if debug_enabled:
0047 desc = desc + "\n\nJVM stacktrace:\n%s" % self.stackTrace
0048
0049 if sys.version_info.major < 3 and isinstance(desc, unicode):
0050 return str(desc.encode('utf-8'))
0051 else:
0052 return str(desc)
0053
0054
0055 class AnalysisException(CapturedException):
0056 """
0057 Failed to analyze a SQL query plan.
0058 """
0059
0060
0061 class ParseException(CapturedException):
0062 """
0063 Failed to parse a SQL command.
0064 """
0065
0066
0067 class IllegalArgumentException(CapturedException):
0068 """
0069 Passed an illegal or inappropriate argument.
0070 """
0071
0072
0073 class StreamingQueryException(CapturedException):
0074 """
0075 Exception that stopped a :class:`StreamingQuery`.
0076 """
0077
0078
0079 class QueryExecutionException(CapturedException):
0080 """
0081 Failed to execute a query.
0082 """
0083
0084
0085 class PythonException(CapturedException):
0086 """
0087 Exceptions thrown from Python workers.
0088 """
0089
0090
0091 class UnknownException(CapturedException):
0092 """
0093 None of the above exceptions.
0094 """
0095
0096
0097 def convert_exception(e):
0098 s = e.toString()
0099 c = e.getCause()
0100
0101 jvm = SparkContext._jvm
0102 jwriter = jvm.java.io.StringWriter()
0103 e.printStackTrace(jvm.java.io.PrintWriter(jwriter))
0104 stacktrace = jwriter.toString()
0105 if s.startswith('org.apache.spark.sql.AnalysisException: '):
0106 return AnalysisException(s.split(': ', 1)[1], stacktrace, c)
0107 if s.startswith('org.apache.spark.sql.catalyst.analysis'):
0108 return AnalysisException(s.split(': ', 1)[1], stacktrace, c)
0109 if s.startswith('org.apache.spark.sql.catalyst.parser.ParseException: '):
0110 return ParseException(s.split(': ', 1)[1], stacktrace, c)
0111 if s.startswith('org.apache.spark.sql.streaming.StreamingQueryException: '):
0112 return StreamingQueryException(s.split(': ', 1)[1], stacktrace, c)
0113 if s.startswith('org.apache.spark.sql.execution.QueryExecutionException: '):
0114 return QueryExecutionException(s.split(': ', 1)[1], stacktrace, c)
0115 if s.startswith('java.lang.IllegalArgumentException: '):
0116 return IllegalArgumentException(s.split(': ', 1)[1], stacktrace, c)
0117 if c is not None and (
0118 c.toString().startswith('org.apache.spark.api.python.PythonException: ')
0119
0120 and any(map(lambda v: "org.apache.spark.sql.execution.python" in v.toString(),
0121 c.getStackTrace()))):
0122 msg = ("\n An exception was thrown from Python worker in the executor. "
0123 "The below is the Python worker stacktrace.\n%s" % c.getMessage())
0124 return PythonException(msg, stacktrace)
0125 return UnknownException(s, stacktrace, c)
0126
0127
0128 def capture_sql_exception(f):
0129 def deco(*a, **kw):
0130 try:
0131 return f(*a, **kw)
0132 except py4j.protocol.Py4JJavaError as e:
0133 converted = convert_exception(e.java_exception)
0134 if not isinstance(converted, UnknownException):
0135
0136
0137 raise_from(converted)
0138 else:
0139 raise
0140 return deco
0141
0142
0143 def install_exception_handler():
0144 """
0145 Hook an exception handler into Py4j, which could capture some SQL exceptions in Java.
0146
0147 When calling Java API, it will call `get_return_value` to parse the returned object.
0148 If any exception happened in JVM, the result will be Java exception object, it raise
0149 py4j.protocol.Py4JJavaError. We replace the original `get_return_value` with one that
0150 could capture the Java exception and throw a Python one (with the same error message).
0151
0152 It's idempotent, could be called multiple times.
0153 """
0154 original = py4j.protocol.get_return_value
0155
0156 patched = capture_sql_exception(original)
0157
0158 py4j.java_gateway.get_return_value = patched
0159
0160
0161 def toJArray(gateway, jtype, arr):
0162 """
0163 Convert python list to java type array
0164 :param gateway: Py4j Gateway
0165 :param jtype: java type of element in array
0166 :param arr: python type list
0167 """
0168 jarr = gateway.new_array(jtype, len(arr))
0169 for i in range(0, len(arr)):
0170 jarr[i] = arr[i]
0171 return jarr
0172
0173
0174 def require_test_compiled():
0175 """ Raise Exception if test classes are not compiled
0176 """
0177 import os
0178 import glob
0179 try:
0180 spark_home = os.environ['SPARK_HOME']
0181 except KeyError:
0182 raise RuntimeError('SPARK_HOME is not defined in environment')
0183
0184 test_class_path = os.path.join(
0185 spark_home, 'sql', 'core', 'target', '*', 'test-classes')
0186 paths = glob.glob(test_class_path)
0187
0188 if len(paths) == 0:
0189 raise RuntimeError(
0190 "%s doesn't exist. Spark sql test classes are not compiled." % test_class_path)
0191
0192
0193 class ForeachBatchFunction(object):
0194 """
0195 This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps
0196 the user-defined 'foreachBatch' function such that it can be called from the JVM when
0197 the query is active.
0198 """
0199
0200 def __init__(self, sql_ctx, func):
0201 self.sql_ctx = sql_ctx
0202 self.func = func
0203
0204 def call(self, jdf, batch_id):
0205 from pyspark.sql.dataframe import DataFrame
0206 try:
0207 self.func(DataFrame(jdf, self.sql_ctx), batch_id)
0208 except Exception as e:
0209 self.error = e
0210 raise e
0211
0212 class Java:
0213 implements = ['org.apache.spark.sql.execution.streaming.sources.PythonForeachBatchFunction']
0214
0215
0216 def to_str(value):
0217 """
0218 A wrapper over str(), but converts bool values to lower case strings.
0219 If None is given, just returns None, instead of converting it to string "None".
0220 """
0221 if isinstance(value, bool):
0222 return str(value).lower()
0223 elif value is None:
0224 return value
0225 else:
0226 return str(value)