0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 """
0018 User-defined function related classes and functions
0019 """
0020 import functools
0021 import sys
0022
0023 from pyspark import SparkContext, since
0024 from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix
0025 from pyspark.sql.column import Column, _to_java_column, _to_seq
0026 from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string
0027 from pyspark.sql.pandas.types import to_arrow_type
0028
0029 __all__ = ["UDFRegistration"]
0030
0031
0032 def _wrap_function(sc, func, returnType):
0033 command = (func, returnType)
0034 pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
0035 return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
0036 sc.pythonVer, broadcast_vars, sc._javaAccumulator)
0037
0038
0039 def _create_udf(f, returnType, evalType):
0040
0041 udf_obj = UserDefinedFunction(
0042 f, returnType=returnType, name=None, evalType=evalType, deterministic=True)
0043 return udf_obj._wrapped()
0044
0045
0046 class UserDefinedFunction(object):
0047 """
0048 User defined function in Python
0049
0050 .. versionadded:: 1.3
0051
0052 .. note:: The constructor of this class is not supposed to be directly called.
0053 Use :meth:`pyspark.sql.functions.udf` or :meth:`pyspark.sql.functions.pandas_udf`
0054 to create this instance.
0055 """
0056 def __init__(self, func,
0057 returnType=StringType(),
0058 name=None,
0059 evalType=PythonEvalType.SQL_BATCHED_UDF,
0060 deterministic=True):
0061 if not callable(func):
0062 raise TypeError(
0063 "Invalid function: not a function or callable (__call__ is not defined): "
0064 "{0}".format(type(func)))
0065
0066 if not isinstance(returnType, (DataType, str)):
0067 raise TypeError(
0068 "Invalid return type: returnType should be DataType or str "
0069 "but is {}".format(returnType))
0070
0071 if not isinstance(evalType, int):
0072 raise TypeError(
0073 "Invalid evaluation type: evalType should be an int but is {}".format(evalType))
0074
0075 self.func = func
0076 self._returnType = returnType
0077
0078 self._returnType_placeholder = None
0079 self._judf_placeholder = None
0080 self._name = name or (
0081 func.__name__ if hasattr(func, '__name__')
0082 else func.__class__.__name__)
0083 self.evalType = evalType
0084 self.deterministic = deterministic
0085
0086 @property
0087 def returnType(self):
0088
0089
0090 if self._returnType_placeholder is None:
0091 if isinstance(self._returnType, DataType):
0092 self._returnType_placeholder = self._returnType
0093 else:
0094 self._returnType_placeholder = _parse_datatype_string(self._returnType)
0095
0096 if self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF or \
0097 self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
0098 try:
0099 to_arrow_type(self._returnType_placeholder)
0100 except TypeError:
0101 raise NotImplementedError(
0102 "Invalid return type with scalar Pandas UDFs: %s is "
0103 "not supported" % str(self._returnType_placeholder))
0104 elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
0105 if isinstance(self._returnType_placeholder, StructType):
0106 try:
0107 to_arrow_type(self._returnType_placeholder)
0108 except TypeError:
0109 raise NotImplementedError(
0110 "Invalid return type with grouped map Pandas UDFs or "
0111 "at groupby.applyInPandas: %s is not supported" % str(
0112 self._returnType_placeholder))
0113 else:
0114 raise TypeError("Invalid return type for grouped map Pandas "
0115 "UDFs or at groupby.applyInPandas: return type must be a "
0116 "StructType.")
0117 elif self.evalType == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
0118 if isinstance(self._returnType_placeholder, StructType):
0119 try:
0120 to_arrow_type(self._returnType_placeholder)
0121 except TypeError:
0122 raise NotImplementedError(
0123 "Invalid return type in mapInPandas: "
0124 "%s is not supported" % str(self._returnType_placeholder))
0125 else:
0126 raise TypeError("Invalid return type in mapInPandas: "
0127 "return type must be a StructType.")
0128 elif self.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
0129 if isinstance(self._returnType_placeholder, StructType):
0130 try:
0131 to_arrow_type(self._returnType_placeholder)
0132 except TypeError:
0133 raise NotImplementedError(
0134 "Invalid return type in cogroup.applyInPandas: "
0135 "%s is not supported" % str(self._returnType_placeholder))
0136 else:
0137 raise TypeError("Invalid return type in cogroup.applyInPandas: "
0138 "return type must be a StructType.")
0139 elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
0140 try:
0141
0142 if isinstance(self._returnType_placeholder, StructType):
0143 raise TypeError
0144 to_arrow_type(self._returnType_placeholder)
0145 except TypeError:
0146 raise NotImplementedError(
0147 "Invalid return type with grouped aggregate Pandas UDFs: "
0148 "%s is not supported" % str(self._returnType_placeholder))
0149
0150 return self._returnType_placeholder
0151
0152 @property
0153 def _judf(self):
0154
0155
0156
0157
0158 if self._judf_placeholder is None:
0159 self._judf_placeholder = self._create_judf()
0160 return self._judf_placeholder
0161
0162 def _create_judf(self):
0163 from pyspark.sql import SparkSession
0164
0165 spark = SparkSession.builder.getOrCreate()
0166 sc = spark.sparkContext
0167
0168 wrapped_func = _wrap_function(sc, self.func, self.returnType)
0169 jdt = spark._jsparkSession.parseDataType(self.returnType.json())
0170 judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
0171 self._name, wrapped_func, jdt, self.evalType, self.deterministic)
0172 return judf
0173
0174 def __call__(self, *cols):
0175 judf = self._judf
0176 sc = SparkContext._active_spark_context
0177 return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
0178
0179
0180
0181
0182 def _wrapped(self):
0183 """
0184 Wrap this udf with a function and attach docstring from func
0185 """
0186
0187
0188
0189
0190
0191
0192 assignments = tuple(
0193 a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__')
0194
0195 @functools.wraps(self.func, assigned=assignments)
0196 def wrapper(*args):
0197 return self(*args)
0198
0199 wrapper.__name__ = self._name
0200 wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__')
0201 else self.func.__class__.__module__)
0202
0203 wrapper.func = self.func
0204 wrapper.returnType = self.returnType
0205 wrapper.evalType = self.evalType
0206 wrapper.deterministic = self.deterministic
0207 wrapper.asNondeterministic = functools.wraps(
0208 self.asNondeterministic)(lambda: self.asNondeterministic()._wrapped())
0209 return wrapper
0210
0211 def asNondeterministic(self):
0212 """
0213 Updates UserDefinedFunction to nondeterministic.
0214
0215 .. versionadded:: 2.3
0216 """
0217
0218
0219 self._judf_placeholder = None
0220 self.deterministic = False
0221 return self
0222
0223
0224 class UDFRegistration(object):
0225 """
0226 Wrapper for user-defined function registration. This instance can be accessed by
0227 :attr:`spark.udf` or :attr:`sqlContext.udf`.
0228
0229 .. versionadded:: 1.3.1
0230 """
0231
0232 def __init__(self, sparkSession):
0233 self.sparkSession = sparkSession
0234
0235 @ignore_unicode_prefix
0236 @since("1.3.1")
0237 def register(self, name, f, returnType=None):
0238 """Register a Python function (including lambda function) or a user-defined function
0239 as a SQL function.
0240
0241 :param name: name of the user-defined function in SQL statements.
0242 :param f: a Python function, or a user-defined function. The user-defined function can
0243 be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and
0244 :meth:`pyspark.sql.functions.pandas_udf`.
0245 :param returnType: the return type of the registered user-defined function. The value can
0246 be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
0247 :return: a user-defined function.
0248
0249 To register a nondeterministic Python function, users need to first build
0250 a nondeterministic user-defined function for the Python function and then register it
0251 as a SQL function.
0252
0253 `returnType` can be optionally specified when `f` is a Python function but not
0254 when `f` is a user-defined function. Please see below.
0255
0256 1. When `f` is a Python function:
0257
0258 `returnType` defaults to string type and can be optionally specified. The produced
0259 object must match the specified type. In this case, this API works as if
0260 `register(name, f, returnType=StringType())`.
0261
0262 >>> strlen = spark.udf.register("stringLengthString", lambda x: len(x))
0263 >>> spark.sql("SELECT stringLengthString('test')").collect()
0264 [Row(stringLengthString(test)=u'4')]
0265
0266 >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
0267 [Row(stringLengthString(text)=u'3')]
0268
0269 >>> from pyspark.sql.types import IntegerType
0270 >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
0271 >>> spark.sql("SELECT stringLengthInt('test')").collect()
0272 [Row(stringLengthInt(test)=4)]
0273
0274 >>> from pyspark.sql.types import IntegerType
0275 >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
0276 >>> spark.sql("SELECT stringLengthInt('test')").collect()
0277 [Row(stringLengthInt(test)=4)]
0278
0279 2. When `f` is a user-defined function:
0280
0281 Spark uses the return type of the given user-defined function as the return type of
0282 the registered user-defined function. `returnType` should not be specified.
0283 In this case, this API works as if `register(name, f)`.
0284
0285 >>> from pyspark.sql.types import IntegerType
0286 >>> from pyspark.sql.functions import udf
0287 >>> slen = udf(lambda s: len(s), IntegerType())
0288 >>> _ = spark.udf.register("slen", slen)
0289 >>> spark.sql("SELECT slen('test')").collect()
0290 [Row(slen(test)=4)]
0291
0292 >>> import random
0293 >>> from pyspark.sql.functions import udf
0294 >>> from pyspark.sql.types import IntegerType
0295 >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
0296 >>> new_random_udf = spark.udf.register("random_udf", random_udf)
0297 >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP
0298 [Row(random_udf()=82)]
0299
0300 >>> import pandas as pd # doctest: +SKIP
0301 >>> from pyspark.sql.functions import pandas_udf
0302 >>> @pandas_udf("integer") # doctest: +SKIP
0303 ... def add_one(s: pd.Series) -> pd.Series:
0304 ... return s + 1
0305 ...
0306 >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP
0307 >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP
0308 [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
0309
0310 >>> @pandas_udf("integer") # doctest: +SKIP
0311 ... def sum_udf(v: pd.Series) -> int:
0312 ... return v.sum()
0313 ...
0314 >>> _ = spark.udf.register("sum_udf", sum_udf) # doctest: +SKIP
0315 >>> q = "SELECT sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2"
0316 >>> spark.sql(q).collect() # doctest: +SKIP
0317 [Row(sum_udf(v1)=1), Row(sum_udf(v1)=5)]
0318
0319 .. note:: Registration for a user-defined function (case 2.) was added from
0320 Spark 2.3.0.
0321 """
0322
0323
0324
0325 if hasattr(f, 'asNondeterministic'):
0326 if returnType is not None:
0327 raise TypeError(
0328 "Invalid return type: data type can not be specified when f is"
0329 "a user-defined function, but got %s." % returnType)
0330 if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
0331 PythonEvalType.SQL_SCALAR_PANDAS_UDF,
0332 PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
0333 PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
0334 PythonEvalType.SQL_MAP_PANDAS_ITER_UDF]:
0335 raise ValueError(
0336 "Invalid f: f must be SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, "
0337 "SQL_SCALAR_PANDAS_ITER_UDF, SQL_GROUPED_AGG_PANDAS_UDF or "
0338 "SQL_MAP_PANDAS_ITER_UDF.")
0339 register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name,
0340 evalType=f.evalType,
0341 deterministic=f.deterministic)
0342 return_udf = f
0343 else:
0344 if returnType is None:
0345 returnType = StringType()
0346 register_udf = UserDefinedFunction(f, returnType=returnType, name=name,
0347 evalType=PythonEvalType.SQL_BATCHED_UDF)
0348 return_udf = register_udf._wrapped()
0349 self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf)
0350 return return_udf
0351
0352 @ignore_unicode_prefix
0353 @since(2.3)
0354 def registerJavaFunction(self, name, javaClassName, returnType=None):
0355 """Register a Java user-defined function as a SQL function.
0356
0357 In addition to a name and the function itself, the return type can be optionally specified.
0358 When the return type is not specified we would infer it via reflection.
0359
0360 :param name: name of the user-defined function
0361 :param javaClassName: fully qualified name of java class
0362 :param returnType: the return type of the registered Java function. The value can be either
0363 a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
0364
0365 >>> from pyspark.sql.types import IntegerType
0366 >>> spark.udf.registerJavaFunction(
0367 ... "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType())
0368 >>> spark.sql("SELECT javaStringLength('test')").collect()
0369 [Row(javaStringLength(test)=4)]
0370
0371 >>> spark.udf.registerJavaFunction(
0372 ... "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength")
0373 >>> spark.sql("SELECT javaStringLength2('test')").collect()
0374 [Row(javaStringLength2(test)=4)]
0375
0376 >>> spark.udf.registerJavaFunction(
0377 ... "javaStringLength3", "test.org.apache.spark.sql.JavaStringLength", "integer")
0378 >>> spark.sql("SELECT javaStringLength3('test')").collect()
0379 [Row(javaStringLength3(test)=4)]
0380 """
0381
0382 jdt = None
0383 if returnType is not None:
0384 if not isinstance(returnType, DataType):
0385 returnType = _parse_datatype_string(returnType)
0386 jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
0387 self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)
0388
0389 @ignore_unicode_prefix
0390 @since(2.3)
0391 def registerJavaUDAF(self, name, javaClassName):
0392 """Register a Java user-defined aggregate function as a SQL function.
0393
0394 :param name: name of the user-defined aggregate function
0395 :param javaClassName: fully qualified name of java class
0396
0397 >>> spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg")
0398 >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"])
0399 >>> df.createOrReplaceTempView("df")
0400 >>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name order by name desc") \
0401 .collect()
0402 [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)]
0403 """
0404
0405 self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName)
0406
0407
0408 def _test():
0409 import doctest
0410 from pyspark.sql import SparkSession
0411 import pyspark.sql.udf
0412 globs = pyspark.sql.udf.__dict__.copy()
0413 spark = SparkSession.builder\
0414 .master("local[4]")\
0415 .appName("sql.udf tests")\
0416 .getOrCreate()
0417 globs['spark'] = spark
0418
0419
0420 del pyspark.sql.udf.UDFRegistration.register
0421 (failure_count, test_count) = doctest.testmod(
0422 pyspark.sql.udf, globs=globs,
0423 optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
0424 spark.stop()
0425 if failure_count:
0426 sys.exit(-1)
0427
0428
0429 if __name__ == "__main__":
0430 _test()