0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import functools
0019 import sys
0020 import warnings
0021
0022 from pyspark import since
0023 from pyspark.rdd import PythonEvalType
0024 from pyspark.sql.pandas.typehints import infer_eval_type
0025 from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
0026 from pyspark.sql.types import DataType
0027 from pyspark.sql.udf import _create_udf
0028 from pyspark.util import _get_argspec
0029
0030
0031 class PandasUDFType(object):
0032 """Pandas UDF Types. See :meth:`pyspark.sql.functions.pandas_udf`.
0033 """
0034 SCALAR = PythonEvalType.SQL_SCALAR_PANDAS_UDF
0035
0036 SCALAR_ITER = PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
0037
0038 GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
0039
0040 GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF
0041
0042
0043 @since(2.3)
0044 def pandas_udf(f=None, returnType=None, functionType=None):
0045 """
0046 Creates a pandas user defined function (a.k.a. vectorized user defined function).
0047
0048 Pandas UDFs are user defined functions that are executed by Spark using Arrow to transfer
0049 data and Pandas to work with the data, which allows vectorized operations. A Pandas UDF
0050 is defined using the `pandas_udf` as a decorator or to wrap the function, and no
0051 additional configuration is required. A Pandas UDF behaves as a regular PySpark function
0052 API in general.
0053
0054 :param f: user-defined function. A python function if used as a standalone function
0055 :param returnType: the return type of the user-defined function. The value can be either a
0056 :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
0057 :param functionType: an enum value in :class:`pyspark.sql.functions.PandasUDFType`.
0058 Default: SCALAR.
0059
0060 .. note:: This parameter exists for compatibility. Using Python type hints is encouraged.
0061
0062 In order to use this API, customarily the below are imported:
0063
0064 >>> import pandas as pd
0065 >>> from pyspark.sql.functions import pandas_udf
0066
0067 From Spark 3.0 with Python 3.6+, `Python type hints <https://www.python.org/dev/peps/pep-0484>`_
0068 detect the function types as below:
0069
0070 >>> @pandas_udf(IntegerType())
0071 ... def slen(s: pd.Series) -> pd.Series:
0072 ... return s.str.len()
0073
0074 Prior to Spark 3.0, the pandas UDF used `functionType` to decide the execution type as below:
0075
0076 >>> from pyspark.sql.functions import PandasUDFType
0077 >>> from pyspark.sql.types import IntegerType
0078 >>> @pandas_udf(IntegerType(), PandasUDFType.SCALAR)
0079 ... def slen(s):
0080 ... return s.str.len()
0081
0082 It is preferred to specify type hints for the pandas UDF instead of specifying pandas UDF
0083 type via `functionType` which will be deprecated in the future releases.
0084
0085 Note that the type hint should use `pandas.Series` in all cases but there is one variant
0086 that `pandas.DataFrame` should be used for its input or output type hint instead when the input
0087 or output column is of :class:`pyspark.sql.types.StructType`. The following example shows
0088 a Pandas UDF which takes long column, string column and struct column, and outputs a struct
0089 column. It requires the function to specify the type hints of `pandas.Series` and
0090 `pandas.DataFrame` as below:
0091
0092 >>> @pandas_udf("col1 string, col2 long")
0093 >>> def func(s1: pd.Series, s2: pd.Series, s3: pd.DataFrame) -> pd.DataFrame:
0094 ... s3['col2'] = s1 + s2.str.len()
0095 ... return s3
0096 ...
0097 >>> # Create a Spark DataFrame that has three columns including a sturct column.
0098 ... df = spark.createDataFrame(
0099 ... [[1, "a string", ("a nested string",)]],
0100 ... "long_col long, string_col string, struct_col struct<col1:string>")
0101 >>> df.printSchema()
0102 root
0103 |-- long_column: long (nullable = true)
0104 |-- string_column: string (nullable = true)
0105 |-- struct_column: struct (nullable = true)
0106 | |-- col1: string (nullable = true)
0107 >>> df.select(func("long_col", "string_col", "struct_col")).printSchema()
0108 |-- func(long_col, string_col, struct_col): struct (nullable = true)
0109 | |-- col1: string (nullable = true)
0110 | |-- col2: long (nullable = true)
0111
0112 In the following sections, it describes the cominations of the supported type hints. For
0113 simplicity, `pandas.DataFrame` variant is omitted.
0114
0115 * Series to Series
0116 `pandas.Series`, ... -> `pandas.Series`
0117
0118 The function takes one or more `pandas.Series` and outputs one `pandas.Series`.
0119 The output of the function should always be of the same length as the input.
0120
0121 >>> @pandas_udf("string")
0122 ... def to_upper(s: pd.Series) -> pd.Series:
0123 ... return s.str.upper()
0124 ...
0125 >>> df = spark.createDataFrame([("John Doe",)], ("name",))
0126 >>> df.select(to_upper("name")).show()
0127 +--------------+
0128 |to_upper(name)|
0129 +--------------+
0130 | JOHN DOE|
0131 +--------------+
0132
0133 >>> @pandas_udf("first string, last string")
0134 ... def split_expand(s: pd.Series) -> pd.DataFrame:
0135 ... return s.str.split(expand=True)
0136 ...
0137 >>> df = spark.createDataFrame([("John Doe",)], ("name",))
0138 >>> df.select(split_expand("name")).show()
0139 +------------------+
0140 |split_expand(name)|
0141 +------------------+
0142 | [John, Doe]|
0143 +------------------+
0144
0145 .. note:: The length of the input is not that of the whole input column, but is the
0146 length of an internal batch used for each call to the function.
0147
0148 * Iterator of Series to Iterator of Series
0149 `Iterator[pandas.Series]` -> `Iterator[pandas.Series]`
0150
0151 The function takes an iterator of `pandas.Series` and outputs an iterator of
0152 `pandas.Series`. In this case, the created pandas UDF instance requires one input
0153 column when this is called as a PySpark column. The length of the entire output from
0154 the function should be the same length of the entire input; therefore, it can
0155 prefetch the data from the input iterator as long as the lengths are the same.
0156
0157 It is also useful when the UDF execution
0158 requires initializing some states although internally it works identically as
0159 Series to Series case. The pseudocode below illustrates the example.
0160
0161 .. highlight:: python
0162 .. code-block:: python
0163
0164 @pandas_udf("long")
0165 def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
0166 # Do some expensive initialization with a state
0167 state = very_expensive_initialization()
0168 for x in iterator:
0169 # Use that state for whole iterator.
0170 yield calculate_with_state(x, state)
0171
0172 df.select(calculate("value")).show()
0173
0174 >>> from typing import Iterator
0175 >>> @pandas_udf("long")
0176 ... def plus_one(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
0177 ... for s in iterator:
0178 ... yield s + 1
0179 ...
0180 >>> df = spark.createDataFrame(pd.DataFrame([1, 2, 3], columns=["v"]))
0181 >>> df.select(plus_one(df.v)).show()
0182 +-----------+
0183 |plus_one(v)|
0184 +-----------+
0185 | 2|
0186 | 3|
0187 | 4|
0188 +-----------+
0189
0190 .. note:: The length of each series is the length of a batch internally used.
0191
0192 * Iterator of Multiple Series to Iterator of Series
0193 `Iterator[Tuple[pandas.Series, ...]]` -> `Iterator[pandas.Series]`
0194
0195 The function takes an iterator of a tuple of multiple `pandas.Series` and outputs an
0196 iterator of `pandas.Series`. In this case, the created pandas UDF instance requires
0197 input columns as many as the series when this is called as a PySpark column.
0198 Otherwise, it has the same characteristics and restrictions as Iterator of Series
0199 to Iterator of Series case.
0200
0201 >>> from typing import Iterator, Tuple
0202 >>> from pyspark.sql.functions import struct, col
0203 >>> @pandas_udf("long")
0204 ... def multiply(iterator: Iterator[Tuple[pd.Series, pd.DataFrame]]) -> Iterator[pd.Series]:
0205 ... for s1, df in iterator:
0206 ... yield s1 * df.v
0207 ...
0208 >>> df = spark.createDataFrame(pd.DataFrame([1, 2, 3], columns=["v"]))
0209 >>> df.withColumn('output', multiply(col("v"), struct(col("v")))).show()
0210 +---+------+
0211 | v|output|
0212 +---+------+
0213 | 1| 1|
0214 | 2| 4|
0215 | 3| 9|
0216 +---+------+
0217
0218 .. note:: The length of each series is the length of a batch internally used.
0219
0220 * Series to Scalar
0221 `pandas.Series`, ... -> `Any`
0222
0223 The function takes `pandas.Series` and returns a scalar value. The `returnType`
0224 should be a primitive data type, and the returned scalar can be either a python primitive
0225 type, e.g., int or float or a numpy data type, e.g., numpy.int64 or numpy.float64.
0226 `Any` should ideally be a specific scalar type accordingly.
0227
0228 >>> @pandas_udf("double")
0229 ... def mean_udf(v: pd.Series) -> float:
0230 ... return v.mean()
0231 ...
0232 >>> df = spark.createDataFrame(
0233 ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))
0234 >>> df.groupby("id").agg(mean_udf(df['v'])).show()
0235 +---+-----------+
0236 | id|mean_udf(v)|
0237 +---+-----------+
0238 | 1| 1.5|
0239 | 2| 6.0|
0240 +---+-----------+
0241
0242 This UDF can also be used as window functions as below:
0243
0244 >>> from pyspark.sql import Window
0245 >>> @pandas_udf("double")
0246 ... def mean_udf(v: pd.Series) -> float:
0247 ... return v.mean()
0248 ...
0249 >>> df = spark.createDataFrame(
0250 ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v"))
0251 >>> w = Window.partitionBy('id').orderBy('v').rowsBetween(-1, 0)
0252 >>> df.withColumn('mean_v', mean_udf("v").over(w)).show()
0253 +---+----+------+
0254 | id| v|mean_v|
0255 +---+----+------+
0256 | 1| 1.0| 1.0|
0257 | 1| 2.0| 1.5|
0258 | 2| 3.0| 3.0|
0259 | 2| 5.0| 4.0|
0260 | 2|10.0| 7.5|
0261 +---+----+------+
0262
0263 .. note:: For performance reasons, the input series to window functions are not copied.
0264 Therefore, mutating the input series is not allowed and will cause incorrect results.
0265 For the same reason, users should also not rely on the index of the input series.
0266
0267 .. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window`
0268
0269 .. note:: The user-defined functions do not support conditional expressions or short circuiting
0270 in boolean expressions and it ends up with being executed all internally. If the functions
0271 can fail on special rows, the workaround is to incorporate the condition into the functions.
0272
0273 .. note:: The user-defined functions do not take keyword arguments on the calling side.
0274
0275 .. note:: The data type of returned `pandas.Series` from the user-defined functions should be
0276 matched with defined `returnType` (see :meth:`types.to_arrow_type` and
0277 :meth:`types.from_arrow_type`). When there is mismatch between them, Spark might do
0278 conversion on returned data. The conversion is not guaranteed to be correct and results
0279 should be checked for accuracy by users.
0280
0281 .. note:: Currently,
0282 :class:`pyspark.sql.types.MapType`,
0283 :class:`pyspark.sql.types.ArrayType` of :class:`pyspark.sql.types.TimestampType` and
0284 nested :class:`pyspark.sql.types.StructType`
0285 are currently not supported as output types.
0286
0287 .. seealso:: :meth:`pyspark.sql.DataFrame.mapInPandas`
0288 .. seealso:: :meth:`pyspark.sql.GroupedData.applyInPandas`
0289 .. seealso:: :meth:`pyspark.sql.PandasCogroupedOps.applyInPandas`
0290 .. seealso:: :meth:`pyspark.sql.UDFRegistration.register`
0291 """
0292
0293
0294
0295
0296
0297
0298
0299
0300
0301
0302
0303
0304
0305
0306
0307
0308
0309
0310
0311
0312
0313
0314
0315
0316
0317
0318
0319
0320
0321
0322
0323
0324 require_minimum_pandas_version()
0325 require_minimum_pyarrow_version()
0326
0327
0328 is_decorator = f is None or isinstance(f, (str, DataType))
0329
0330 if is_decorator:
0331
0332
0333 return_type = f or returnType
0334
0335 if functionType is not None:
0336
0337
0338 eval_type = functionType
0339 elif returnType is not None and isinstance(returnType, int):
0340
0341 eval_type = returnType
0342 else:
0343
0344 eval_type = None
0345 else:
0346 return_type = returnType
0347
0348 if functionType is not None:
0349 eval_type = functionType
0350 else:
0351 eval_type = None
0352
0353 if return_type is None:
0354 raise ValueError("Invalid return type: returnType can not be None")
0355
0356 if eval_type not in [PythonEvalType.SQL_SCALAR_PANDAS_UDF,
0357 PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
0358 PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
0359 PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
0360 PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
0361 PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
0362 None]:
0363
0364 raise ValueError("Invalid function type: "
0365 "functionType must be one the values from PandasUDFType")
0366
0367 if is_decorator:
0368 return functools.partial(_create_pandas_udf, returnType=return_type, evalType=eval_type)
0369 else:
0370 return _create_pandas_udf(f=f, returnType=return_type, evalType=eval_type)
0371
0372
0373 def _create_pandas_udf(f, returnType, evalType):
0374 argspec = _get_argspec(f)
0375
0376
0377 if sys.version_info >= (3, 6):
0378 from inspect import signature
0379
0380 if evalType in [PythonEvalType.SQL_SCALAR_PANDAS_UDF,
0381 PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
0382 PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF]:
0383 warnings.warn(
0384 "In Python 3.6+ and Spark 3.0+, it is preferred to specify type hints for "
0385 "pandas UDF instead of specifying pandas UDF type which will be deprecated "
0386 "in the future releases. See SPARK-28264 for more details.", UserWarning)
0387 elif evalType in [PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
0388 PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
0389 PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF]:
0390
0391
0392
0393
0394 pass
0395 elif len(argspec.annotations) > 0:
0396 evalType = infer_eval_type(signature(f))
0397 assert evalType is not None
0398
0399 if evalType is None:
0400
0401 evalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF
0402
0403 if (evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF or
0404 evalType == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF) and \
0405 len(argspec.args) == 0 and \
0406 argspec.varargs is None:
0407 raise ValueError(
0408 "Invalid function: 0-arg pandas_udfs are not supported. "
0409 "Instead, create a 1-arg pandas_udf and ignore the arg in your function."
0410 )
0411
0412 if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \
0413 and len(argspec.args) not in (1, 2):
0414 raise ValueError(
0415 "Invalid function: pandas_udf with function type GROUPED_MAP or "
0416 "the function in groupby.applyInPandas "
0417 "must take either one argument (data) or two arguments (key, data).")
0418
0419 if evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF \
0420 and len(argspec.args) not in (2, 3):
0421 raise ValueError(
0422 "Invalid function: the function in cogroup.applyInPandas "
0423 "must take either two arguments (left, right) "
0424 "or three arguments (key, left, right).")
0425
0426 return _create_udf(f, returnType, evalType)