Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import sys
0019 
0020 from pyspark import since
0021 from pyspark.rdd import ignore_unicode_prefix
0022 from pyspark.sql.column import Column, _to_seq
0023 from pyspark.sql.dataframe import DataFrame
0024 from pyspark.sql.pandas.group_ops import PandasGroupedOpsMixin
0025 from pyspark.sql.types import *
0026 
0027 __all__ = ["GroupedData"]
0028 
0029 
0030 def dfapi(f):
0031     def _api(self):
0032         name = f.__name__
0033         jdf = getattr(self._jgd, name)()
0034         return DataFrame(jdf, self.sql_ctx)
0035     _api.__name__ = f.__name__
0036     _api.__doc__ = f.__doc__
0037     return _api
0038 
0039 
0040 def df_varargs_api(f):
0041     def _api(self, *cols):
0042         name = f.__name__
0043         jdf = getattr(self._jgd, name)(_to_seq(self.sql_ctx._sc, cols))
0044         return DataFrame(jdf, self.sql_ctx)
0045     _api.__name__ = f.__name__
0046     _api.__doc__ = f.__doc__
0047     return _api
0048 
0049 
0050 class GroupedData(PandasGroupedOpsMixin):
0051     """
0052     A set of methods for aggregations on a :class:`DataFrame`,
0053     created by :func:`DataFrame.groupBy`.
0054 
0055     .. versionadded:: 1.3
0056     """
0057 
0058     def __init__(self, jgd, df):
0059         self._jgd = jgd
0060         self._df = df
0061         self.sql_ctx = df.sql_ctx
0062 
0063     @ignore_unicode_prefix
0064     @since(1.3)
0065     def agg(self, *exprs):
0066         """Compute aggregates and returns the result as a :class:`DataFrame`.
0067 
0068         The available aggregate functions can be:
0069 
0070         1. built-in aggregation functions, such as `avg`, `max`, `min`, `sum`, `count`
0071 
0072         2. group aggregate pandas UDFs, created with :func:`pyspark.sql.functions.pandas_udf`
0073 
0074            .. note:: There is no partial aggregation with group aggregate UDFs, i.e.,
0075                a full shuffle is required. Also, all the data of a group will be loaded into
0076                memory, so the user should be aware of the potential OOM risk if data is skewed
0077                and certain groups are too large to fit in memory.
0078 
0079            .. seealso:: :func:`pyspark.sql.functions.pandas_udf`
0080 
0081         If ``exprs`` is a single :class:`dict` mapping from string to string, then the key
0082         is the column to perform aggregation on, and the value is the aggregate function.
0083 
0084         Alternatively, ``exprs`` can also be a list of aggregate :class:`Column` expressions.
0085 
0086         .. note:: Built-in aggregation functions and group aggregate pandas UDFs cannot be mixed
0087             in a single call to this function.
0088 
0089         :param exprs: a dict mapping from column name (string) to aggregate functions (string),
0090             or a list of :class:`Column`.
0091 
0092         >>> gdf = df.groupBy(df.name)
0093         >>> sorted(gdf.agg({"*": "count"}).collect())
0094         [Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)]
0095 
0096         >>> from pyspark.sql import functions as F
0097         >>> sorted(gdf.agg(F.min(df.age)).collect())
0098         [Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)]
0099 
0100         >>> from pyspark.sql.functions import pandas_udf, PandasUDFType
0101         >>> @pandas_udf('int', PandasUDFType.GROUPED_AGG)  # doctest: +SKIP
0102         ... def min_udf(v):
0103         ...     return v.min()
0104         >>> sorted(gdf.agg(min_udf(df.age)).collect())  # doctest: +SKIP
0105         [Row(name=u'Alice', min_udf(age)=2), Row(name=u'Bob', min_udf(age)=5)]
0106         """
0107         assert exprs, "exprs should not be empty"
0108         if len(exprs) == 1 and isinstance(exprs[0], dict):
0109             jdf = self._jgd.agg(exprs[0])
0110         else:
0111             # Columns
0112             assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
0113             jdf = self._jgd.agg(exprs[0]._jc,
0114                                 _to_seq(self.sql_ctx._sc, [c._jc for c in exprs[1:]]))
0115         return DataFrame(jdf, self.sql_ctx)
0116 
0117     @dfapi
0118     @since(1.3)
0119     def count(self):
0120         """Counts the number of records for each group.
0121 
0122         >>> sorted(df.groupBy(df.age).count().collect())
0123         [Row(age=2, count=1), Row(age=5, count=1)]
0124         """
0125 
0126     @df_varargs_api
0127     @since(1.3)
0128     def mean(self, *cols):
0129         """Computes average values for each numeric columns for each group.
0130 
0131         :func:`mean` is an alias for :func:`avg`.
0132 
0133         :param cols: list of column names (string). Non-numeric columns are ignored.
0134 
0135         >>> df.groupBy().mean('age').collect()
0136         [Row(avg(age)=3.5)]
0137         >>> df3.groupBy().mean('age', 'height').collect()
0138         [Row(avg(age)=3.5, avg(height)=82.5)]
0139         """
0140 
0141     @df_varargs_api
0142     @since(1.3)
0143     def avg(self, *cols):
0144         """Computes average values for each numeric columns for each group.
0145 
0146         :func:`mean` is an alias for :func:`avg`.
0147 
0148         :param cols: list of column names (string). Non-numeric columns are ignored.
0149 
0150         >>> df.groupBy().avg('age').collect()
0151         [Row(avg(age)=3.5)]
0152         >>> df3.groupBy().avg('age', 'height').collect()
0153         [Row(avg(age)=3.5, avg(height)=82.5)]
0154         """
0155 
0156     @df_varargs_api
0157     @since(1.3)
0158     def max(self, *cols):
0159         """Computes the max value for each numeric columns for each group.
0160 
0161         >>> df.groupBy().max('age').collect()
0162         [Row(max(age)=5)]
0163         >>> df3.groupBy().max('age', 'height').collect()
0164         [Row(max(age)=5, max(height)=85)]
0165         """
0166 
0167     @df_varargs_api
0168     @since(1.3)
0169     def min(self, *cols):
0170         """Computes the min value for each numeric column for each group.
0171 
0172         :param cols: list of column names (string). Non-numeric columns are ignored.
0173 
0174         >>> df.groupBy().min('age').collect()
0175         [Row(min(age)=2)]
0176         >>> df3.groupBy().min('age', 'height').collect()
0177         [Row(min(age)=2, min(height)=80)]
0178         """
0179 
0180     @df_varargs_api
0181     @since(1.3)
0182     def sum(self, *cols):
0183         """Compute the sum for each numeric columns for each group.
0184 
0185         :param cols: list of column names (string). Non-numeric columns are ignored.
0186 
0187         >>> df.groupBy().sum('age').collect()
0188         [Row(sum(age)=7)]
0189         >>> df3.groupBy().sum('age', 'height').collect()
0190         [Row(sum(age)=7, sum(height)=165)]
0191         """
0192 
0193     @since(1.6)
0194     def pivot(self, pivot_col, values=None):
0195         """
0196         Pivots a column of the current :class:`DataFrame` and perform the specified aggregation.
0197         There are two versions of pivot function: one that requires the caller to specify the list
0198         of distinct values to pivot on, and one that does not. The latter is more concise but less
0199         efficient, because Spark needs to first compute the list of distinct values internally.
0200 
0201         :param pivot_col: Name of the column to pivot.
0202         :param values: List of values that will be translated to columns in the output DataFrame.
0203 
0204         # Compute the sum of earnings for each year by course with each course as a separate column
0205 
0206         >>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect()
0207         [Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)]
0208 
0209         # Or without specifying column values (less efficient)
0210 
0211         >>> df4.groupBy("year").pivot("course").sum("earnings").collect()
0212         [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
0213         >>> df5.groupBy("sales.year").pivot("sales.course").sum("sales.earnings").collect()
0214         [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
0215         """
0216         if values is None:
0217             jgd = self._jgd.pivot(pivot_col)
0218         else:
0219             jgd = self._jgd.pivot(pivot_col, values)
0220         return GroupedData(jgd, self._df)
0221 
0222 
0223 def _test():
0224     import doctest
0225     from pyspark.sql import Row, SparkSession
0226     import pyspark.sql.group
0227     globs = pyspark.sql.group.__dict__.copy()
0228     spark = SparkSession.builder\
0229         .master("local[4]")\
0230         .appName("sql.group tests")\
0231         .getOrCreate()
0232     sc = spark.sparkContext
0233     globs['sc'] = sc
0234     globs['spark'] = spark
0235     globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
0236         .toDF(StructType([StructField('age', IntegerType()),
0237                           StructField('name', StringType())]))
0238     globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
0239                                    Row(name='Bob', age=5, height=85)]).toDF()
0240     globs['df4'] = sc.parallelize([Row(course="dotNET", year=2012, earnings=10000),
0241                                    Row(course="Java",   year=2012, earnings=20000),
0242                                    Row(course="dotNET", year=2012, earnings=5000),
0243                                    Row(course="dotNET", year=2013, earnings=48000),
0244                                    Row(course="Java",   year=2013, earnings=30000)]).toDF()
0245     globs['df5'] = sc.parallelize([
0246         Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=10000)),
0247         Row(training="junior", sales=Row(course="Java",   year=2012, earnings=20000)),
0248         Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=5000)),
0249         Row(training="junior", sales=Row(course="dotNET", year=2013, earnings=48000)),
0250         Row(training="expert", sales=Row(course="Java",   year=2013, earnings=30000))]).toDF()
0251 
0252     (failure_count, test_count) = doctest.testmod(
0253         pyspark.sql.group, globs=globs,
0254         optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
0255     spark.stop()
0256     if failure_count:
0257         sys.exit(-1)
0258 
0259 
0260 if __name__ == "__main__":
0261     _test()