0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 import sys
0018 import warnings
0019
0020 from pyspark import since
0021 from pyspark.rdd import PythonEvalType
0022 from pyspark.sql.column import Column
0023 from pyspark.sql.dataframe import DataFrame
0024
0025
0026 class PandasGroupedOpsMixin(object):
0027 """
0028 Min-in for pandas grouped operations. Currently, only :class:`GroupedData`
0029 can use this class.
0030 """
0031
0032 @since(2.3)
0033 def apply(self, udf):
0034 """
0035 It is an alias of :meth:`pyspark.sql.GroupedData.applyInPandas`; however, it takes a
0036 :meth:`pyspark.sql.functions.pandas_udf` whereas
0037 :meth:`pyspark.sql.GroupedData.applyInPandas` takes a Python native function.
0038
0039 .. note:: It is preferred to use :meth:`pyspark.sql.GroupedData.applyInPandas` over this
0040 API. This API will be deprecated in the future releases.
0041
0042 :param udf: a grouped map user-defined function returned by
0043 :func:`pyspark.sql.functions.pandas_udf`.
0044
0045 >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
0046 >>> df = spark.createDataFrame(
0047 ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
0048 ... ("id", "v"))
0049 >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
0050 ... def normalize(pdf):
0051 ... v = pdf.v
0052 ... return pdf.assign(v=(v - v.mean()) / v.std())
0053 >>> df.groupby("id").apply(normalize).show() # doctest: +SKIP
0054 +---+-------------------+
0055 | id| v|
0056 +---+-------------------+
0057 | 1|-0.7071067811865475|
0058 | 1| 0.7071067811865475|
0059 | 2|-0.8320502943378437|
0060 | 2|-0.2773500981126146|
0061 | 2| 1.1094003924504583|
0062 +---+-------------------+
0063
0064 .. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
0065
0066 """
0067
0068 if isinstance(udf, Column) or not hasattr(udf, 'func') \
0069 or udf.evalType != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
0070 raise ValueError("Invalid udf: the udf argument must be a pandas_udf of type "
0071 "GROUPED_MAP.")
0072
0073 warnings.warn(
0074 "It is preferred to use 'applyInPandas' over this "
0075 "API. This API will be deprecated in the future releases. See SPARK-28264 for "
0076 "more details.", UserWarning)
0077
0078 return self.applyInPandas(udf.func, schema=udf.returnType)
0079
0080 @since(3.0)
0081 def applyInPandas(self, func, schema):
0082 """
0083 Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result
0084 as a `DataFrame`.
0085
0086 The function should take a `pandas.DataFrame` and return another
0087 `pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame`
0088 to the user-function and the returned `pandas.DataFrame` are combined as a
0089 :class:`DataFrame`.
0090
0091 The `schema` should be a :class:`StructType` describing the schema of the returned
0092 `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match
0093 the field names in the defined schema if specified as strings, or match the
0094 field data types by position if not strings, e.g. integer indices.
0095 The length of the returned `pandas.DataFrame` can be arbitrary.
0096
0097 :param func: a Python native function that takes a `pandas.DataFrame`, and outputs a
0098 `pandas.DataFrame`.
0099 :param schema: the return type of the `func` in PySpark. The value can be either a
0100 :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
0101
0102 >>> import pandas as pd # doctest: +SKIP
0103 >>> from pyspark.sql.functions import pandas_udf, ceil
0104 >>> df = spark.createDataFrame(
0105 ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
0106 ... ("id", "v")) # doctest: +SKIP
0107 >>> def normalize(pdf):
0108 ... v = pdf.v
0109 ... return pdf.assign(v=(v - v.mean()) / v.std())
0110 >>> df.groupby("id").applyInPandas(
0111 ... normalize, schema="id long, v double").show() # doctest: +SKIP
0112 +---+-------------------+
0113 | id| v|
0114 +---+-------------------+
0115 | 1|-0.7071067811865475|
0116 | 1| 0.7071067811865475|
0117 | 2|-0.8320502943378437|
0118 | 2|-0.2773500981126146|
0119 | 2| 1.1094003924504583|
0120 +---+-------------------+
0121
0122 Alternatively, the user can pass a function that takes two arguments.
0123 In this case, the grouping key(s) will be passed as the first argument and the data will
0124 be passed as the second argument. The grouping key(s) will be passed as a tuple of numpy
0125 data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in
0126 as a `pandas.DataFrame` containing all columns from the original Spark DataFrame.
0127 This is useful when the user does not want to hardcode grouping key(s) in the function.
0128
0129 >>> df = spark.createDataFrame(
0130 ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
0131 ... ("id", "v")) # doctest: +SKIP
0132 >>> def mean_func(key, pdf):
0133 ... # key is a tuple of one numpy.int64, which is the value
0134 ... # of 'id' for the current group
0135 ... return pd.DataFrame([key + (pdf.v.mean(),)])
0136 >>> df.groupby('id').applyInPandas(
0137 ... mean_func, schema="id long, v double").show() # doctest: +SKIP
0138 +---+---+
0139 | id| v|
0140 +---+---+
0141 | 1|1.5|
0142 | 2|6.0|
0143 +---+---+
0144 >>> def sum_func(key, pdf):
0145 ... # key is a tuple of two numpy.int64s, which is the values
0146 ... # of 'id' and 'ceil(df.v / 2)' for the current group
0147 ... return pd.DataFrame([key + (pdf.v.sum(),)])
0148 >>> df.groupby(df.id, ceil(df.v / 2)).applyInPandas(
0149 ... sum_func, schema="id long, `ceil(v / 2)` long, v double").show() # doctest: +SKIP
0150 +---+-----------+----+
0151 | id|ceil(v / 2)| v|
0152 +---+-----------+----+
0153 | 2| 5|10.0|
0154 | 1| 1| 3.0|
0155 | 2| 3| 5.0|
0156 | 2| 2| 3.0|
0157 +---+-----------+----+
0158
0159 .. note:: This function requires a full shuffle. All the data of a group will be loaded
0160 into memory, so the user should be aware of the potential OOM risk if data is skewed
0161 and certain groups are too large to fit in memory.
0162
0163 .. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is
0164 recommended to explicitly index the columns by name to ensure the positions are correct,
0165 or alternatively use an `OrderedDict`.
0166 For example, `pd.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])` or
0167 `pd.DataFrame(OrderedDict([('id', ids), ('a', data)]))`.
0168
0169 .. note:: Experimental
0170
0171 .. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
0172 """
0173 from pyspark.sql import GroupedData
0174 from pyspark.sql.functions import pandas_udf, PandasUDFType
0175
0176 assert isinstance(self, GroupedData)
0177
0178 udf = pandas_udf(
0179 func, returnType=schema, functionType=PandasUDFType.GROUPED_MAP)
0180 df = self._df
0181 udf_column = udf(*[df[col] for col in df.columns])
0182 jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
0183 return DataFrame(jdf, self.sql_ctx)
0184
0185 @since(3.0)
0186 def cogroup(self, other):
0187 """
0188 Cogroups this group with another group so that we can run cogrouped operations.
0189
0190 See :class:`CoGroupedData` for the operations that can be run.
0191 """
0192 from pyspark.sql import GroupedData
0193
0194 assert isinstance(self, GroupedData)
0195
0196 return PandasCogroupedOps(self, other)
0197
0198
0199 class PandasCogroupedOps(object):
0200 """
0201 A logical grouping of two :class:`GroupedData`,
0202 created by :func:`GroupedData.cogroup`.
0203
0204 .. note:: Experimental
0205
0206 .. versionadded:: 3.0
0207 """
0208
0209 def __init__(self, gd1, gd2):
0210 self._gd1 = gd1
0211 self._gd2 = gd2
0212 self.sql_ctx = gd1.sql_ctx
0213
0214 @since(3.0)
0215 def applyInPandas(self, func, schema):
0216 """
0217 Applies a function to each cogroup using pandas and returns the result
0218 as a `DataFrame`.
0219
0220 The function should take two `pandas.DataFrame`\\s and return another
0221 `pandas.DataFrame`. For each side of the cogroup, all columns are passed together as a
0222 `pandas.DataFrame` to the user-function and the returned `pandas.DataFrame` are combined as
0223 a :class:`DataFrame`.
0224
0225 The `schema` should be a :class:`StructType` describing the schema of the returned
0226 `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match
0227 the field names in the defined schema if specified as strings, or match the
0228 field data types by position if not strings, e.g. integer indices.
0229 The length of the returned `pandas.DataFrame` can be arbitrary.
0230
0231 :param func: a Python native function that takes two `pandas.DataFrame`\\s, and
0232 outputs a `pandas.DataFrame`, or that takes one tuple (grouping keys) and two
0233 pandas ``DataFrame``s, and outputs a pandas ``DataFrame``.
0234 :param schema: the return type of the `func` in PySpark. The value can be either a
0235 :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
0236
0237 >>> from pyspark.sql.functions import pandas_udf
0238 >>> df1 = spark.createDataFrame(
0239 ... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)],
0240 ... ("time", "id", "v1"))
0241 >>> df2 = spark.createDataFrame(
0242 ... [(20000101, 1, "x"), (20000101, 2, "y")],
0243 ... ("time", "id", "v2"))
0244 >>> def asof_join(l, r):
0245 ... return pd.merge_asof(l, r, on="time", by="id")
0246 >>> df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(
0247 ... asof_join, schema="time int, id int, v1 double, v2 string"
0248 ... ).show() # doctest: +SKIP
0249 +--------+---+---+---+
0250 | time| id| v1| v2|
0251 +--------+---+---+---+
0252 |20000101| 1|1.0| x|
0253 |20000102| 1|3.0| x|
0254 |20000101| 2|2.0| y|
0255 |20000102| 2|4.0| y|
0256 +--------+---+---+---+
0257
0258 Alternatively, the user can define a function that takes three arguments. In this case,
0259 the grouping key(s) will be passed as the first argument and the data will be passed as the
0260 second and third arguments. The grouping key(s) will be passed as a tuple of numpy data
0261 types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in as two
0262 `pandas.DataFrame` containing all columns from the original Spark DataFrames.
0263
0264 >>> def asof_join(k, l, r):
0265 ... if k == (1,):
0266 ... return pd.merge_asof(l, r, on="time", by="id")
0267 ... else:
0268 ... return pd.DataFrame(columns=['time', 'id', 'v1', 'v2'])
0269 >>> df1.groupby("id").cogroup(df2.groupby("id")).applyInPandas(
0270 ... asof_join, "time int, id int, v1 double, v2 string").show() # doctest: +SKIP
0271 +--------+---+---+---+
0272 | time| id| v1| v2|
0273 +--------+---+---+---+
0274 |20000101| 1|1.0| x|
0275 |20000102| 1|3.0| x|
0276 +--------+---+---+---+
0277
0278 .. note:: This function requires a full shuffle. All the data of a cogroup will be loaded
0279 into memory, so the user should be aware of the potential OOM risk if data is skewed
0280 and certain groups are too large to fit in memory.
0281
0282 .. note:: If returning a new `pandas.DataFrame` constructed with a dictionary, it is
0283 recommended to explicitly index the columns by name to ensure the positions are correct,
0284 or alternatively use an `OrderedDict`.
0285 For example, `pd.DataFrame({'id': ids, 'a': data}, columns=['id', 'a'])` or
0286 `pd.DataFrame(OrderedDict([('id', ids), ('a', data)]))`.
0287
0288 .. note:: Experimental
0289
0290 .. seealso:: :meth:`pyspark.sql.functions.pandas_udf`
0291
0292 """
0293 from pyspark.sql.pandas.functions import pandas_udf
0294
0295 udf = pandas_udf(
0296 func, returnType=schema, functionType=PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF)
0297 all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2)
0298 udf_column = udf(*all_cols)
0299 jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc.expr())
0300 return DataFrame(jdf, self.sql_ctx)
0301
0302 @staticmethod
0303 def _extract_cols(gd):
0304 df = gd._df
0305 return [df[col] for col in df.columns]
0306
0307
0308 def _test():
0309 import doctest
0310 from pyspark.sql import SparkSession
0311 import pyspark.sql.pandas.group_ops
0312 globs = pyspark.sql.pandas.group_ops.__dict__.copy()
0313 spark = SparkSession.builder\
0314 .master("local[4]")\
0315 .appName("sql.pandas.group tests")\
0316 .getOrCreate()
0317 globs['spark'] = spark
0318 (failure_count, test_count) = doctest.testmod(
0319 pyspark.sql.pandas.group_ops, globs=globs,
0320 optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
0321 spark.stop()
0322 if failure_count:
0323 sys.exit(-1)
0324
0325
0326 if __name__ == "__main__":
0327 _test()