0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 from pyspark.sql import Row
0019 from pyspark.testing.sqlutils import ReusedSQLTestCase
0020
0021
0022 class GroupTests(ReusedSQLTestCase):
0023
0024 def test_aggregator(self):
0025 df = self.df
0026 g = df.groupBy()
0027 self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
0028 self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
0029
0030 from pyspark.sql import functions
0031 self.assertEqual((0, u'99'),
0032 tuple(g.agg(functions.first(df.key), functions.last(df.value)).first()))
0033 self.assertTrue(95 < g.agg(functions.approx_count_distinct(df.key)).first()[0])
0034 self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
0035
0036
0037 if __name__ == "__main__":
0038 import unittest
0039 from pyspark.sql.tests.test_group import *
0040
0041 try:
0042 import xmlrunner
0043 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0044 except ImportError:
0045 testRunner = None
0046 unittest.main(testRunner=testRunner, verbosity=2)