0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 import sys
0018
0019 from pyspark import since
0020 from pyspark.rdd import PythonEvalType
0021
0022
0023 class PandasMapOpsMixin(object):
0024 """
0025 Min-in for pandas map operations. Currently, only :class:`DataFrame`
0026 can use this class.
0027 """
0028
0029 @since(3.0)
0030 def mapInPandas(self, func, schema):
0031 """
0032 Maps an iterator of batches in the current :class:`DataFrame` using a Python native
0033 function that takes and outputs a pandas DataFrame, and returns the result as a
0034 :class:`DataFrame`.
0035
0036 The function should take an iterator of `pandas.DataFrame`\\s and return
0037 another iterator of `pandas.DataFrame`\\s. All columns are passed
0038 together as an iterator of `pandas.DataFrame`\\s to the function and the
0039 returned iterator of `pandas.DataFrame`\\s are combined as a :class:`DataFrame`.
0040 Each `pandas.DataFrame` size can be controlled by
0041 `spark.sql.execution.arrow.maxRecordsPerBatch`.
0042
0043 :param func: a Python native function that takes an iterator of `pandas.DataFrame`\\s, and
0044 outputs an iterator of `pandas.DataFrame`\\s.
0045 :param schema: the return type of the `func` in PySpark. The value can be either a
0046 :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
0047
0048 >>> from pyspark.sql.functions import pandas_udf
0049 >>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
0050 >>> def filter_func(iterator):
0051 ... for pdf in iterator:
0052 ... yield pdf[pdf.id == 1]
0053 >>> df.mapInPandas(filter_func, df.schema).show() # doctest: +SKIP
0054 +---+---+
0055 | id|age|
0056 +---+---+
0057 | 1| 21|
0058 +---+---+
0059
0060 .. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
0061
0062 .. note:: Experimental
0063 """
0064 from pyspark.sql import DataFrame
0065 from pyspark.sql.pandas.functions import pandas_udf
0066
0067 assert isinstance(self, DataFrame)
0068
0069 udf = pandas_udf(
0070 func, returnType=schema, functionType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF)
0071 udf_column = udf(*[self[col] for col in self.columns])
0072 jdf = self._jdf.mapInPandas(udf_column._jc.expr())
0073 return DataFrame(jdf, self.sql_ctx)
0074
0075
0076 def _test():
0077 import doctest
0078 from pyspark.sql import SparkSession
0079 import pyspark.sql.pandas.map_ops
0080 globs = pyspark.sql.pandas.map_ops.__dict__.copy()
0081 spark = SparkSession.builder\
0082 .master("local[4]")\
0083 .appName("sql.pandas.map_ops tests")\
0084 .getOrCreate()
0085 globs['spark'] = spark
0086 (failure_count, test_count) = doctest.testmod(
0087 pyspark.sql.pandas.map_ops, globs=globs,
0088 optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
0089 spark.stop()
0090 if failure_count:
0091 sys.exit(-1)
0092
0093
0094 if __name__ == "__main__":
0095 _test()