0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import sys
0019
0020 from pyspark import since
0021 from pyspark.rdd import ignore_unicode_prefix
0022 from pyspark.sql.column import Column, _to_seq
0023 from pyspark.sql.dataframe import DataFrame
0024 from pyspark.sql.pandas.group_ops import PandasGroupedOpsMixin
0025 from pyspark.sql.types import *
0026
0027 __all__ = ["GroupedData"]
0028
0029
0030 def dfapi(f):
0031 def _api(self):
0032 name = f.__name__
0033 jdf = getattr(self._jgd, name)()
0034 return DataFrame(jdf, self.sql_ctx)
0035 _api.__name__ = f.__name__
0036 _api.__doc__ = f.__doc__
0037 return _api
0038
0039
0040 def df_varargs_api(f):
0041 def _api(self, *cols):
0042 name = f.__name__
0043 jdf = getattr(self._jgd, name)(_to_seq(self.sql_ctx._sc, cols))
0044 return DataFrame(jdf, self.sql_ctx)
0045 _api.__name__ = f.__name__
0046 _api.__doc__ = f.__doc__
0047 return _api
0048
0049
0050 class GroupedData(PandasGroupedOpsMixin):
0051 """
0052 A set of methods for aggregations on a :class:`DataFrame`,
0053 created by :func:`DataFrame.groupBy`.
0054
0055 .. versionadded:: 1.3
0056 """
0057
0058 def __init__(self, jgd, df):
0059 self._jgd = jgd
0060 self._df = df
0061 self.sql_ctx = df.sql_ctx
0062
0063 @ignore_unicode_prefix
0064 @since(1.3)
0065 def agg(self, *exprs):
0066 """Compute aggregates and returns the result as a :class:`DataFrame`.
0067
0068 The available aggregate functions can be:
0069
0070 1. built-in aggregation functions, such as `avg`, `max`, `min`, `sum`, `count`
0071
0072 2. group aggregate pandas UDFs, created with :func:`pyspark.sql.functions.pandas_udf`
0073
0074 .. note:: There is no partial aggregation with group aggregate UDFs, i.e.,
0075 a full shuffle is required. Also, all the data of a group will be loaded into
0076 memory, so the user should be aware of the potential OOM risk if data is skewed
0077 and certain groups are too large to fit in memory.
0078
0079 .. seealso:: :func:`pyspark.sql.functions.pandas_udf`
0080
0081 If ``exprs`` is a single :class:`dict` mapping from string to string, then the key
0082 is the column to perform aggregation on, and the value is the aggregate function.
0083
0084 Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions.
0085
0086 .. note:: Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed
0087 in a single call to this function.
0088
0089 :param exprs: a dict mapping from column name (string) to aggregate functions (string),
0090 or a list of :class:`Column`.
0091
0092 >>> gdf = df.groupBy(df.name)
0093 >>> sorted(gdf.agg({"*": "count"}).collect())
0094 [Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)]
0095
0096 >>> from pyspark.sql import functions as F
0097 >>> sorted(gdf.agg(F.min(df.age)).collect())
0098 [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)]
0099
0100 >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
0101 >>> @pandas_udf('int', PandasUDFType.GROUPED_AGG) # doctest: +SKIP
0102 ... def min_udf(v):
0103 ... return v.min()
0104 >>> sorted(gdf.agg(min_udf(df.age)).collect()) # doctest: +SKIP
0105 [Row(name=u'Alice', min_udf(age)=2), Row(name=u'Bob', min_udf(age)=5)]
0106 """
0107 assert exprs, "exprs should not be empty"
0108 if len(exprs) == 1 and isinstance(exprs[0], dict):
0109 jdf = self._jgd.agg(exprs[0])
0110 else:
0111
0112 assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
0113 jdf = self._jgd.agg(exprs[0]._jc,
0114 _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
0115 return DataFrame(jdf, self.sql_ctx)
0116
0117 @dfapi
0118 @since(1.3)
0119 def count(self):
0120 """Counts the number of records for each group.
0121
0122 >>> sorted(df.groupBy(df.age).count().collect())
0123 [Row(age=2, count=1), Row(age=5, count=1)]
0124 """
0125
0126 @df_varargs_api
0127 @since(1.3)
0128 def mean(self, *cols):
0129 """Computes average values for each numeric columns for each group.
0130
0131 :func:`mean` is an alias for :func:`avg`.
0132
0133 :param cols: list of column names (string). Non-numeric columns are ignored.
0134
0135 >>> df.groupBy().mean('age').collect()
0136 [Row(avg(age)=3.5)]
0137 >>> df3.groupBy().mean('age', 'height').collect()
0138 [Row(avg(age)=3.5, avg(height)=82.5)]
0139 """
0140
0141 @df_varargs_api
0142 @since(1.3)
0143 def avg(self, *cols):
0144 """Computes average values for each numeric columns for each group.
0145
0146 :func:`mean` is an alias for :func:`avg`.
0147
0148 :param cols: list of column names (string). Non-numeric columns are ignored.
0149
0150 >>> df.groupBy().avg('age').collect()
0151 [Row(avg(age)=3.5)]
0152 >>> df3.groupBy().avg('age', 'height').collect()
0153 [Row(avg(age)=3.5, avg(height)=82.5)]
0154 """
0155
0156 @df_varargs_api
0157 @since(1.3)
0158 def max(self, *cols):
0159 """Computes the max value for each numeric columns for each group.
0160
0161 >>> df.groupBy().max('age').collect()
0162 [Row(max(age)=5)]
0163 >>> df3.groupBy().max('age', 'height').collect()
0164 [Row(max(age)=5, max(height)=85)]
0165 """
0166
0167 @df_varargs_api
0168 @since(1.3)
0169 def min(self, *cols):
0170 """Computes the min value for each numeric column for each group.
0171
0172 :param cols: list of column names (string). Non-numeric columns are ignored.
0173
0174 >>> df.groupBy().min('age').collect()
0175 [Row(min(age)=2)]
0176 >>> df3.groupBy().min('age', 'height').collect()
0177 [Row(min(age)=2, min(height)=80)]
0178 """
0179
0180 @df_varargs_api
0181 @since(1.3)
0182 def sum(self, *cols):
0183 """Compute the sum for each numeric columns for each group.
0184
0185 :param cols: list of column names (string). Non-numeric columns are ignored.
0186
0187 >>> df.groupBy().sum('age').collect()
0188 [Row(sum(age)=7)]
0189 >>> df3.groupBy().sum('age', 'height').collect()
0190 [Row(sum(age)=7, sum(height)=165)]
0191 """
0192
0193 @since(1.6)
0194 def pivot(self, pivot_col, values=None):
0195 """
0196 Pivots a column of the current :class:`DataFrame` and perform the specified aggregation.
0197 There are two versions of pivot function: one that requires the caller to specify the list
0198 of distinct values to pivot on, and one that does not. The latter is more concise but less
0199 efficient, because Spark needs to first compute the list of distinct values internally.
0200
0201 :param pivot_col: Name of the column to pivot.
0202 :param values: List of values that will be translated to columns in the output DataFrame.
0203
0204 # Compute the sum of earnings for each year by course with each course as a separate column
0205
0206 >>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect()
0207 [Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)]
0208
0209 # Or without specifying column values (less efficient)
0210
0211 >>> df4.groupBy("year").pivot("course").sum("earnings").collect()
0212 [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
0213 >>> df5.groupBy("sales.year").pivot("sales.course").sum("sales.earnings").collect()
0214 [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
0215 """
0216 if values is None:
0217 jgd = self._jgd.pivot(pivot_col)
0218 else:
0219 jgd = self._jgd.pivot(pivot_col, values)
0220 return GroupedData(jgd, self._df)
0221
0222
0223 def _test():
0224 import doctest
0225 from pyspark.sql import Row, SparkSession
0226 import pyspark.sql.group
0227 globs = pyspark.sql.group.__dict__.copy()
0228 spark = SparkSession.builder\
0229 .master("local[4]")\
0230 .appName("sql.group tests")\
0231 .getOrCreate()
0232 sc = spark.sparkContext
0233 globs['sc'] = sc
0234 globs['spark'] = spark
0235 globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
0236 .toDF(StructType([StructField('age', IntegerType()),
0237 StructField('name', StringType())]))
0238 globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
0239 Row(name='Bob', age=5, height=85)]).toDF()
0240 globs['df4'] = sc.parallelize([Row(course="dotNET", year=2012, earnings=10000),
0241 Row(course="Java", year=2012, earnings=20000),
0242 Row(course="dotNET", year=2012, earnings=5000),
0243 Row(course="dotNET", year=2013, earnings=48000),
0244 Row(course="Java", year=2013, earnings=30000)]).toDF()
0245 globs['df5'] = sc.parallelize([
0246 Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=10000)),
0247 Row(training="junior", sales=Row(course="Java", year=2012, earnings=20000)),
0248 Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=5000)),
0249 Row(training="junior", sales=Row(course="dotNET", year=2013, earnings=48000)),
0250 Row(training="expert", sales=Row(course="Java", year=2013, earnings=30000))]).toDF()
0251
0252 (failure_count, test_count) = doctest.testmod(
0253 pyspark.sql.group, globs=globs,
0254 optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
0255 spark.stop()
0256 if failure_count:
0257 sys.exit(-1)
0258
0259
0260 if __name__ == "__main__":
0261 _test()