Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 """
0018 User-defined function related classes and functions
0019 """
0020 import functools
0021 import sys
0022 
0023 from pyspark import SparkContext, since
0024 from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix
0025 from pyspark.sql.column import Column, _to_java_column, _to_seq
0026 from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string
0027 from pyspark.sql.pandas.types import to_arrow_type
0028 
0029 __all__ = ["UDFRegistration"]
0030 
0031 
0032 def _wrap_function(sc, func, returnType):
0033     command = (func, returnType)
0034     pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
0035     return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
0036                                   sc.pythonVer, broadcast_vars, sc._javaAccumulator)
0037 
0038 
0039 def _create_udf(f, returnType, evalType):
0040     # Set the name of the UserDefinedFunction object to be the name of function f
0041     udf_obj = UserDefinedFunction(
0042         f, returnType=returnType, name=None, evalType=evalType, deterministic=True)
0043     return udf_obj._wrapped()
0044 
0045 
0046 class UserDefinedFunction(object):
0047     """
0048     User defined function in Python
0049 
0050     .. versionadded:: 1.3
0051 
0052     .. note:: The constructor of this class is not supposed to be directly called.
0053         Use :meth:`pyspark.sql.functions.udf` or :meth:`pyspark.sql.functions.pandas_udf`
0054         to create this instance.
0055     """
0056     def __init__(self, func,
0057                  returnType=StringType(),
0058                  name=None,
0059                  evalType=PythonEvalType.SQL_BATCHED_UDF,
0060                  deterministic=True):
0061         if not callable(func):
0062             raise TypeError(
0063                 "Invalid function: not a function or callable (__call__ is not defined): "
0064                 "{0}".format(type(func)))
0065 
0066         if not isinstance(returnType, (DataType, str)):
0067             raise TypeError(
0068                 "Invalid return type: returnType should be DataType or str "
0069                 "but is {}".format(returnType))
0070 
0071         if not isinstance(evalType, int):
0072             raise TypeError(
0073                 "Invalid evaluation type: evalType should be an int but is {}".format(evalType))
0074 
0075         self.func = func
0076         self._returnType = returnType
0077         # Stores UserDefinedPythonFunctions jobj, once initialized
0078         self._returnType_placeholder = None
0079         self._judf_placeholder = None
0080         self._name = name or (
0081             func.__name__ if hasattr(func, '__name__')
0082             else func.__class__.__name__)
0083         self.evalType = evalType
0084         self.deterministic = deterministic
0085 
0086     @property
0087     def returnType(self):
0088         # This makes sure this is called after SparkContext is initialized.
0089         # ``_parse_datatype_string`` accesses to JVM for parsing a DDL formatted string.
0090         if self._returnType_placeholder is None:
0091             if isinstance(self._returnType, DataType):
0092                 self._returnType_placeholder = self._returnType
0093             else:
0094                 self._returnType_placeholder = _parse_datatype_string(self._returnType)
0095 
0096         if self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF or \
0097                 self.evalType == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF:
0098             try:
0099                 to_arrow_type(self._returnType_placeholder)
0100             except TypeError:
0101                 raise NotImplementedError(
0102                     "Invalid return type with scalar Pandas UDFs: %s is "
0103                     "not supported" % str(self._returnType_placeholder))
0104         elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
0105             if isinstance(self._returnType_placeholder, StructType):
0106                 try:
0107                     to_arrow_type(self._returnType_placeholder)
0108                 except TypeError:
0109                     raise NotImplementedError(
0110                         "Invalid return type with grouped map Pandas UDFs or "
0111                         "at groupby.applyInPandas: %s is not supported" % str(
0112                             self._returnType_placeholder))
0113             else:
0114                 raise TypeError("Invalid return type for grouped map Pandas "
0115                                 "UDFs or at groupby.applyInPandas: return type must be a "
0116                                 "StructType.")
0117         elif self.evalType == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF:
0118             if isinstance(self._returnType_placeholder, StructType):
0119                 try:
0120                     to_arrow_type(self._returnType_placeholder)
0121                 except TypeError:
0122                     raise NotImplementedError(
0123                         "Invalid return type in mapInPandas: "
0124                         "%s is not supported" % str(self._returnType_placeholder))
0125             else:
0126                 raise TypeError("Invalid return type in mapInPandas: "
0127                                 "return type must be a StructType.")
0128         elif self.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
0129             if isinstance(self._returnType_placeholder, StructType):
0130                 try:
0131                     to_arrow_type(self._returnType_placeholder)
0132                 except TypeError:
0133                     raise NotImplementedError(
0134                         "Invalid return type in cogroup.applyInPandas: "
0135                         "%s is not supported" % str(self._returnType_placeholder))
0136             else:
0137                 raise TypeError("Invalid return type in cogroup.applyInPandas: "
0138                                 "return type must be a StructType.")
0139         elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
0140             try:
0141                 # StructType is not yet allowed as a return type, explicitly check here to fail fast
0142                 if isinstance(self._returnType_placeholder, StructType):
0143                     raise TypeError
0144                 to_arrow_type(self._returnType_placeholder)
0145             except TypeError:
0146                 raise NotImplementedError(
0147                     "Invalid  return type with grouped aggregate Pandas UDFs: "
0148                     "%s is not supported" % str(self._returnType_placeholder))
0149 
0150         return self._returnType_placeholder
0151 
0152     @property
0153     def _judf(self):
0154         # It is possible that concurrent access, to newly created UDF,
0155         # will initialize multiple UserDefinedPythonFunctions.
0156         # This is unlikely, doesn't affect correctness,
0157         # and should have a minimal performance impact.
0158         if self._judf_placeholder is None:
0159             self._judf_placeholder = self._create_judf()
0160         return self._judf_placeholder
0161 
0162     def _create_judf(self):
0163         from pyspark.sql import SparkSession
0164 
0165         spark = SparkSession.builder.getOrCreate()
0166         sc = spark.sparkContext
0167 
0168         wrapped_func = _wrap_function(sc, self.func, self.returnType)
0169         jdt = spark._jsparkSession.parseDataType(self.returnType.json())
0170         judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
0171             self._name, wrapped_func, jdt, self.evalType, self.deterministic)
0172         return judf
0173 
0174     def __call__(self, *cols):
0175         judf = self._judf
0176         sc = SparkContext._active_spark_context
0177         return Column(judf.apply(_to_seq(sc, cols, _to_java_column)))
0178 
0179     # This function is for improving the online help system in the interactive interpreter.
0180     # For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and
0181     # argument annotation. (See: SPARK-19161)
0182     def _wrapped(self):
0183         """
0184         Wrap this udf with a function and attach docstring from func
0185         """
0186 
0187         # It is possible for a callable instance without __name__ attribute or/and
0188         # __module__ attribute to be wrapped here. For example, functools.partial. In this case,
0189         # we should avoid wrapping the attributes from the wrapped function to the wrapper
0190         # function. So, we take out these attribute names from the default names to set and
0191         # then manually assign it after being wrapped.
0192         assignments = tuple(
0193             a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__')
0194 
0195         @functools.wraps(self.func, assigned=assignments)
0196         def wrapper(*args):
0197             return self(*args)
0198 
0199         wrapper.__name__ = self._name
0200         wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__')
0201                               else self.func.__class__.__module__)
0202 
0203         wrapper.func = self.func
0204         wrapper.returnType = self.returnType
0205         wrapper.evalType = self.evalType
0206         wrapper.deterministic = self.deterministic
0207         wrapper.asNondeterministic = functools.wraps(
0208             self.asNondeterministic)(lambda: self.asNondeterministic()._wrapped())
0209         return wrapper
0210 
0211     def asNondeterministic(self):
0212         """
0213         Updates UserDefinedFunction to nondeterministic.
0214 
0215         .. versionadded:: 2.3
0216         """
0217         # Here, we explicitly clean the cache to create a JVM UDF instance
0218         # with 'deterministic' updated. See SPARK-23233.
0219         self._judf_placeholder = None
0220         self.deterministic = False
0221         return self
0222 
0223 
0224 class UDFRegistration(object):
0225     """
0226     Wrapper for user-defined function registration. This instance can be accessed by
0227     :attr:`spark.udf` or :attr:`sqlContext.udf`.
0228 
0229     .. versionadded:: 1.3.1
0230     """
0231 
0232     def __init__(self, sparkSession):
0233         self.sparkSession = sparkSession
0234 
0235     @ignore_unicode_prefix
0236     @since("1.3.1")
0237     def register(self, name, f, returnType=None):
0238         """Register a Python function (including lambda function) or a user-defined function
0239         as a SQL function.
0240 
0241         :param name: name of the user-defined function in SQL statements.
0242         :param f: a Python function, or a user-defined function. The user-defined function can
0243             be either row-at-a-time or vectorized. See :meth:`pyspark.sql.functions.udf` and
0244             :meth:`pyspark.sql.functions.pandas_udf`.
0245         :param returnType: the return type of the registered user-defined function. The value can
0246             be either a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
0247         :return: a user-defined function.
0248 
0249         To register a nondeterministic Python function, users need to first build
0250         a nondeterministic user-defined function for the Python function and then register it
0251         as a SQL function.
0252 
0253         `returnType` can be optionally specified when `f` is a Python function but not
0254         when `f` is a user-defined function. Please see below.
0255 
0256         1. When `f` is a Python function:
0257 
0258             `returnType` defaults to string type and can be optionally specified. The produced
0259             object must match the specified type. In this case, this API works as if
0260             `register(name, f, returnType=StringType())`.
0261 
0262             >>> strlen = spark.udf.register("stringLengthString", lambda x: len(x))
0263             >>> spark.sql("SELECT stringLengthString('test')").collect()
0264             [Row(stringLengthString(test)=u'4')]
0265 
0266             >>> spark.sql("SELECT 'foo' AS text").select(strlen("text")).collect()
0267             [Row(stringLengthString(text)=u'3')]
0268 
0269             >>> from pyspark.sql.types import IntegerType
0270             >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
0271             >>> spark.sql("SELECT stringLengthInt('test')").collect()
0272             [Row(stringLengthInt(test)=4)]
0273 
0274             >>> from pyspark.sql.types import IntegerType
0275             >>> _ = spark.udf.register("stringLengthInt", lambda x: len(x), IntegerType())
0276             >>> spark.sql("SELECT stringLengthInt('test')").collect()
0277             [Row(stringLengthInt(test)=4)]
0278 
0279         2. When `f` is a user-defined function:
0280 
0281             Spark uses the return type of the given user-defined function as the return type of
0282             the registered user-defined function. `returnType` should not be specified.
0283             In this case, this API works as if `register(name, f)`.
0284 
0285             >>> from pyspark.sql.types import IntegerType
0286             >>> from pyspark.sql.functions import udf
0287             >>> slen = udf(lambda s: len(s), IntegerType())
0288             >>> _ = spark.udf.register("slen", slen)
0289             >>> spark.sql("SELECT slen('test')").collect()
0290             [Row(slen(test)=4)]
0291 
0292             >>> import random
0293             >>> from pyspark.sql.functions import udf
0294             >>> from pyspark.sql.types import IntegerType
0295             >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic()
0296             >>> new_random_udf = spark.udf.register("random_udf", random_udf)
0297             >>> spark.sql("SELECT random_udf()").collect()  # doctest: +SKIP
0298             [Row(random_udf()=82)]
0299 
0300             >>> import pandas as pd  # doctest: +SKIP
0301             >>> from pyspark.sql.functions import pandas_udf
0302             >>> @pandas_udf("integer")  # doctest: +SKIP
0303             ... def add_one(s: pd.Series) -> pd.Series:
0304             ...     return s + 1
0305             ...
0306             >>> _ = spark.udf.register("add_one", add_one)  # doctest: +SKIP
0307             >>> spark.sql("SELECT add_one(id) FROM range(3)").collect()  # doctest: +SKIP
0308             [Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
0309 
0310             >>> @pandas_udf("integer")  # doctest: +SKIP
0311             ... def sum_udf(v: pd.Series) -> int:
0312             ...     return v.sum()
0313             ...
0314             >>> _ = spark.udf.register("sum_udf", sum_udf)  # doctest: +SKIP
0315             >>> q = "SELECT sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2"
0316             >>> spark.sql(q).collect()  # doctest: +SKIP
0317             [Row(sum_udf(v1)=1), Row(sum_udf(v1)=5)]
0318 
0319             .. note:: Registration for a user-defined function (case 2.) was added from
0320                 Spark 2.3.0.
0321         """
0322 
0323         # This is to check whether the input function is from a user-defined function or
0324         # Python function.
0325         if hasattr(f, 'asNondeterministic'):
0326             if returnType is not None:
0327                 raise TypeError(
0328                     "Invalid return type: data type can not be specified when f is"
0329                     "a user-defined function, but got %s." % returnType)
0330             if f.evalType not in [PythonEvalType.SQL_BATCHED_UDF,
0331                                   PythonEvalType.SQL_SCALAR_PANDAS_UDF,
0332                                   PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
0333                                   PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
0334                                   PythonEvalType.SQL_MAP_PANDAS_ITER_UDF]:
0335                 raise ValueError(
0336                     "Invalid f: f must be SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, "
0337                     "SQL_SCALAR_PANDAS_ITER_UDF, SQL_GROUPED_AGG_PANDAS_UDF or "
0338                     "SQL_MAP_PANDAS_ITER_UDF.")
0339             register_udf = UserDefinedFunction(f.func, returnType=f.returnType, name=name,
0340                                                evalType=f.evalType,
0341                                                deterministic=f.deterministic)
0342             return_udf = f
0343         else:
0344             if returnType is None:
0345                 returnType = StringType()
0346             register_udf = UserDefinedFunction(f, returnType=returnType, name=name,
0347                                                evalType=PythonEvalType.SQL_BATCHED_UDF)
0348             return_udf = register_udf._wrapped()
0349         self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf)
0350         return return_udf
0351 
0352     @ignore_unicode_prefix
0353     @since(2.3)
0354     def registerJavaFunction(self, name, javaClassName, returnType=None):
0355         """Register a Java user-defined function as a SQL function.
0356 
0357         In addition to a name and the function itself, the return type can be optionally specified.
0358         When the return type is not specified we would infer it via reflection.
0359 
0360         :param name: name of the user-defined function
0361         :param javaClassName: fully qualified name of java class
0362         :param returnType: the return type of the registered Java function. The value can be either
0363             a :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
0364 
0365         >>> from pyspark.sql.types import IntegerType
0366         >>> spark.udf.registerJavaFunction(
0367         ...     "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType())
0368         >>> spark.sql("SELECT javaStringLength('test')").collect()
0369         [Row(javaStringLength(test)=4)]
0370 
0371         >>> spark.udf.registerJavaFunction(
0372         ...     "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength")
0373         >>> spark.sql("SELECT javaStringLength2('test')").collect()
0374         [Row(javaStringLength2(test)=4)]
0375 
0376         >>> spark.udf.registerJavaFunction(
0377         ...     "javaStringLength3", "test.org.apache.spark.sql.JavaStringLength", "integer")
0378         >>> spark.sql("SELECT javaStringLength3('test')").collect()
0379         [Row(javaStringLength3(test)=4)]
0380         """
0381 
0382         jdt = None
0383         if returnType is not None:
0384             if not isinstance(returnType, DataType):
0385                 returnType = _parse_datatype_string(returnType)
0386             jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
0387         self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)
0388 
0389     @ignore_unicode_prefix
0390     @since(2.3)
0391     def registerJavaUDAF(self, name, javaClassName):
0392         """Register a Java user-defined aggregate function as a SQL function.
0393 
0394         :param name: name of the user-defined aggregate function
0395         :param javaClassName: fully qualified name of java class
0396 
0397         >>> spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg")
0398         >>> df = spark.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"])
0399         >>> df.createOrReplaceTempView("df")
0400         >>> spark.sql("SELECT name, javaUDAF(id) as avg from df group by name order by name desc") \
0401                 .collect()
0402         [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)]
0403         """
0404 
0405         self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName)
0406 
0407 
0408 def _test():
0409     import doctest
0410     from pyspark.sql import SparkSession
0411     import pyspark.sql.udf
0412     globs = pyspark.sql.udf.__dict__.copy()
0413     spark = SparkSession.builder\
0414         .master("local[4]")\
0415         .appName("sql.udf tests")\
0416         .getOrCreate()
0417     globs['spark'] = spark
0418     # Hack to skip the unit tests in register. These are currently being tested in proper tests.
0419     # We should reenable this test once we completely drop Python 2.
0420     del pyspark.sql.udf.UDFRegistration.register
0421     (failure_count, test_count) = doctest.testmod(
0422         pyspark.sql.udf, globs=globs,
0423         optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
0424     spark.stop()
0425     if failure_count:
0426         sys.exit(-1)
0427 
0428 
0429 if __name__ == "__main__":
0430     _test()