0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import unittest
0019
0020 from pyspark.rdd import PythonEvalType
0021 from pyspark.sql import Row
0022 from pyspark.sql.functions import array, explode, col, lit, mean, sum, \
0023 udf, pandas_udf, PandasUDFType
0024 from pyspark.sql.types import *
0025 from pyspark.sql.utils import AnalysisException
0026 from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
0027 pandas_requirement_message, pyarrow_requirement_message
0028 from pyspark.testing.utils import QuietTest
0029
0030 if have_pandas:
0031 import pandas as pd
0032 from pandas.util.testing import assert_frame_equal
0033
0034
0035 @unittest.skipIf(
0036 not have_pandas or not have_pyarrow,
0037 pandas_requirement_message or pyarrow_requirement_message)
0038 class GroupedAggPandasUDFTests(ReusedSQLTestCase):
0039
0040 @property
0041 def data(self):
0042 return self.spark.range(10).toDF('id') \
0043 .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \
0044 .withColumn("v", explode(col('vs'))) \
0045 .drop('vs') \
0046 .withColumn('w', lit(1.0))
0047
0048 @property
0049 def python_plus_one(self):
0050 @udf('double')
0051 def plus_one(v):
0052 assert isinstance(v, (int, float))
0053 return v + 1
0054 return plus_one
0055
0056 @property
0057 def pandas_scalar_plus_two(self):
0058 @pandas_udf('double', PandasUDFType.SCALAR)
0059 def plus_two(v):
0060 assert isinstance(v, pd.Series)
0061 return v + 2
0062 return plus_two
0063
0064 @property
0065 def pandas_agg_mean_udf(self):
0066 @pandas_udf('double', PandasUDFType.GROUPED_AGG)
0067 def avg(v):
0068 return v.mean()
0069 return avg
0070
0071 @property
0072 def pandas_agg_sum_udf(self):
0073 @pandas_udf('double', PandasUDFType.GROUPED_AGG)
0074 def sum(v):
0075 return v.sum()
0076 return sum
0077
0078 @property
0079 def pandas_agg_weighted_mean_udf(self):
0080 import numpy as np
0081
0082 @pandas_udf('double', PandasUDFType.GROUPED_AGG)
0083 def weighted_mean(v, w):
0084 return np.average(v, weights=w)
0085 return weighted_mean
0086
0087 def test_manual(self):
0088 df = self.data
0089 sum_udf = self.pandas_agg_sum_udf
0090 mean_udf = self.pandas_agg_mean_udf
0091 mean_arr_udf = pandas_udf(
0092 self.pandas_agg_mean_udf.func,
0093 ArrayType(self.pandas_agg_mean_udf.returnType),
0094 self.pandas_agg_mean_udf.evalType)
0095
0096 result1 = df.groupby('id').agg(
0097 sum_udf(df.v),
0098 mean_udf(df.v),
0099 mean_arr_udf(array(df.v))).sort('id')
0100 expected1 = self.spark.createDataFrame(
0101 [[0, 245.0, 24.5, [24.5]],
0102 [1, 255.0, 25.5, [25.5]],
0103 [2, 265.0, 26.5, [26.5]],
0104 [3, 275.0, 27.5, [27.5]],
0105 [4, 285.0, 28.5, [28.5]],
0106 [5, 295.0, 29.5, [29.5]],
0107 [6, 305.0, 30.5, [30.5]],
0108 [7, 315.0, 31.5, [31.5]],
0109 [8, 325.0, 32.5, [32.5]],
0110 [9, 335.0, 33.5, [33.5]]],
0111 ['id', 'sum(v)', 'avg(v)', 'avg(array(v))'])
0112
0113 assert_frame_equal(expected1.toPandas(), result1.toPandas())
0114
0115 def test_basic(self):
0116 df = self.data
0117 weighted_mean_udf = self.pandas_agg_weighted_mean_udf
0118
0119
0120 result1 = df.groupby('id').agg(weighted_mean_udf(df.v, lit(1.0))).sort('id')
0121 expected1 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort('id')
0122 assert_frame_equal(expected1.toPandas(), result1.toPandas())
0123
0124
0125 result2 = df.groupby((col('id') + 1)).agg(weighted_mean_udf(df.v, lit(1.0)))\
0126 .sort(df.id + 1)
0127 expected2 = df.groupby((col('id') + 1))\
0128 .agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort(df.id + 1)
0129 assert_frame_equal(expected2.toPandas(), result2.toPandas())
0130
0131
0132 result3 = df.groupby('id').agg(weighted_mean_udf(df.v, df.w)).sort('id')
0133 expected3 = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, w)')).sort('id')
0134 assert_frame_equal(expected3.toPandas(), result3.toPandas())
0135
0136
0137 result4 = df.groupby((col('id') + 1).alias('id'))\
0138 .agg(weighted_mean_udf(df.v, df.w))\
0139 .sort('id')
0140 expected4 = df.groupby((col('id') + 1).alias('id'))\
0141 .agg(mean(df.v).alias('weighted_mean(v, w)'))\
0142 .sort('id')
0143 assert_frame_equal(expected4.toPandas(), result4.toPandas())
0144
0145 def test_unsupported_types(self):
0146 with QuietTest(self.sc):
0147 with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
0148 pandas_udf(
0149 lambda x: x,
0150 ArrayType(ArrayType(TimestampType())),
0151 PandasUDFType.GROUPED_AGG)
0152
0153 with QuietTest(self.sc):
0154 with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
0155 @pandas_udf('mean double, std double', PandasUDFType.GROUPED_AGG)
0156 def mean_and_std_udf(v):
0157 return v.mean(), v.std()
0158
0159 with QuietTest(self.sc):
0160 with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
0161 @pandas_udf(MapType(DoubleType(), DoubleType()), PandasUDFType.GROUPED_AGG)
0162 def mean_and_std_udf(v):
0163 return {v.mean(): v.std()}
0164
0165 def test_alias(self):
0166 df = self.data
0167 mean_udf = self.pandas_agg_mean_udf
0168
0169 result1 = df.groupby('id').agg(mean_udf(df.v).alias('mean_alias'))
0170 expected1 = df.groupby('id').agg(mean(df.v).alias('mean_alias'))
0171
0172 assert_frame_equal(expected1.toPandas(), result1.toPandas())
0173
0174 def test_mixed_sql(self):
0175 """
0176 Test mixing group aggregate pandas UDF with sql expression.
0177 """
0178 df = self.data
0179 sum_udf = self.pandas_agg_sum_udf
0180
0181
0182 result1 = (df.groupby('id')
0183 .agg(sum_udf(df.v) + 1)
0184 .sort('id'))
0185 expected1 = (df.groupby('id')
0186 .agg(sum(df.v) + 1)
0187 .sort('id'))
0188
0189
0190 result2 = (df.groupby('id')
0191 .agg(sum_udf(df.v + 1))
0192 .sort('id'))
0193
0194 expected2 = (df.groupby('id')
0195 .agg(sum(df.v + 1))
0196 .sort('id'))
0197
0198
0199 result3 = (df.groupby('id')
0200 .agg(sum_udf(df.v + 1) + 2)
0201 .sort('id'))
0202 expected3 = (df.groupby('id')
0203 .agg(sum(df.v + 1) + 2)
0204 .sort('id'))
0205
0206 assert_frame_equal(expected1.toPandas(), result1.toPandas())
0207 assert_frame_equal(expected2.toPandas(), result2.toPandas())
0208 assert_frame_equal(expected3.toPandas(), result3.toPandas())
0209
0210 def test_mixed_udfs(self):
0211 """
0212 Test mixing group aggregate pandas UDF with python UDF and scalar pandas UDF.
0213 """
0214 df = self.data
0215 plus_one = self.python_plus_one
0216 plus_two = self.pandas_scalar_plus_two
0217 sum_udf = self.pandas_agg_sum_udf
0218
0219
0220 result1 = (df.groupby('id')
0221 .agg(plus_one(sum_udf(df.v)))
0222 .sort('id'))
0223 expected1 = (df.groupby('id')
0224 .agg(plus_one(sum(df.v)))
0225 .sort('id'))
0226
0227
0228 result2 = (df.groupby('id')
0229 .agg(sum_udf(plus_one(df.v)))
0230 .sort('id'))
0231 expected2 = (df.groupby('id')
0232 .agg(sum(plus_one(df.v)))
0233 .sort('id'))
0234
0235
0236 result3 = (df.groupby('id')
0237 .agg(sum_udf(plus_two(df.v)))
0238 .sort('id'))
0239 expected3 = (df.groupby('id')
0240 .agg(sum(plus_two(df.v)))
0241 .sort('id'))
0242
0243
0244 result4 = (df.groupby('id')
0245 .agg(plus_two(sum_udf(df.v)))
0246 .sort('id'))
0247 expected4 = (df.groupby('id')
0248 .agg(plus_two(sum(df.v)))
0249 .sort('id'))
0250
0251
0252 result5 = (df.groupby(plus_one(df.id))
0253 .agg(plus_one(sum_udf(plus_one(df.v))))
0254 .sort('plus_one(id)'))
0255 expected5 = (df.groupby(plus_one(df.id))
0256 .agg(plus_one(sum(plus_one(df.v))))
0257 .sort('plus_one(id)'))
0258
0259
0260
0261 result6 = (df.groupby(plus_two(df.id))
0262 .agg(plus_two(sum_udf(plus_two(df.v))))
0263 .sort('plus_two(id)'))
0264 expected6 = (df.groupby(plus_two(df.id))
0265 .agg(plus_two(sum(plus_two(df.v))))
0266 .sort('plus_two(id)'))
0267
0268 assert_frame_equal(expected1.toPandas(), result1.toPandas())
0269 assert_frame_equal(expected2.toPandas(), result2.toPandas())
0270 assert_frame_equal(expected3.toPandas(), result3.toPandas())
0271 assert_frame_equal(expected4.toPandas(), result4.toPandas())
0272 assert_frame_equal(expected5.toPandas(), result5.toPandas())
0273 assert_frame_equal(expected6.toPandas(), result6.toPandas())
0274
0275 def test_multiple_udfs(self):
0276 """
0277 Test multiple group aggregate pandas UDFs in one agg function.
0278 """
0279 df = self.data
0280 mean_udf = self.pandas_agg_mean_udf
0281 sum_udf = self.pandas_agg_sum_udf
0282 weighted_mean_udf = self.pandas_agg_weighted_mean_udf
0283
0284 result1 = (df.groupBy('id')
0285 .agg(mean_udf(df.v),
0286 sum_udf(df.v),
0287 weighted_mean_udf(df.v, df.w))
0288 .sort('id')
0289 .toPandas())
0290 expected1 = (df.groupBy('id')
0291 .agg(mean(df.v),
0292 sum(df.v),
0293 mean(df.v).alias('weighted_mean(v, w)'))
0294 .sort('id')
0295 .toPandas())
0296
0297 assert_frame_equal(expected1, result1)
0298
0299 def test_complex_groupby(self):
0300 df = self.data
0301 sum_udf = self.pandas_agg_sum_udf
0302 plus_one = self.python_plus_one
0303 plus_two = self.pandas_scalar_plus_two
0304
0305
0306 result1 = df.groupby(df.v % 2).agg(sum_udf(df.v))
0307 expected1 = df.groupby(df.v % 2).agg(sum(df.v))
0308
0309
0310 result2 = df.groupby().agg(sum_udf(df.v))
0311 expected2 = df.groupby().agg(sum(df.v))
0312
0313
0314 result3 = df.groupby(df.id, df.v % 2).agg(sum_udf(df.v)).orderBy(df.id, df.v % 2)
0315 expected3 = df.groupby(df.id, df.v % 2).agg(sum(df.v)).orderBy(df.id, df.v % 2)
0316
0317
0318 result4 = df.groupby(plus_one(df.id)).agg(sum_udf(df.v))
0319 expected4 = df.groupby(plus_one(df.id)).agg(sum(df.v))
0320
0321
0322 result5 = df.groupby(plus_two(df.id)).agg(sum_udf(df.v)).sort('sum(v)')
0323 expected5 = df.groupby(plus_two(df.id)).agg(sum(df.v)).sort('sum(v)')
0324
0325
0326 result6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum_udf(df.v))
0327 expected6 = df.groupby(df.v % 2, plus_one(df.id)).agg(sum(df.v))
0328
0329
0330 result7 = (df.groupby(df.v % 2, plus_two(df.id))
0331 .agg(sum_udf(df.v)).sort(['sum(v)', 'plus_two(id)']))
0332 expected7 = (df.groupby(df.v % 2, plus_two(df.id))
0333 .agg(sum(df.v)).sort(['sum(v)', 'plus_two(id)']))
0334
0335 assert_frame_equal(expected1.toPandas(), result1.toPandas())
0336 assert_frame_equal(expected2.toPandas(), result2.toPandas())
0337 assert_frame_equal(expected3.toPandas(), result3.toPandas())
0338 assert_frame_equal(expected4.toPandas(), result4.toPandas())
0339 assert_frame_equal(expected5.toPandas(), result5.toPandas())
0340 assert_frame_equal(expected6.toPandas(), result6.toPandas())
0341 assert_frame_equal(expected7.toPandas(), result7.toPandas())
0342
0343 def test_complex_expressions(self):
0344 df = self.data
0345 plus_one = self.python_plus_one
0346 plus_two = self.pandas_scalar_plus_two
0347 sum_udf = self.pandas_agg_sum_udf
0348
0349
0350
0351 result1 = (df.withColumn('v1', plus_one(df.v))
0352 .withColumn('v2', df.v + 2)
0353 .groupby(df.id, df.v % 2)
0354 .agg(sum_udf(col('v')),
0355 sum_udf(col('v1') + 3),
0356 sum_udf(col('v2')) + 5,
0357 plus_one(sum_udf(col('v1'))),
0358 sum_udf(plus_one(col('v2'))))
0359 .sort(['id', '(v % 2)'])
0360 .toPandas().sort_values(by=['id', '(v % 2)']))
0361
0362 expected1 = (df.withColumn('v1', df.v + 1)
0363 .withColumn('v2', df.v + 2)
0364 .groupby(df.id, df.v % 2)
0365 .agg(sum(col('v')),
0366 sum(col('v1') + 3),
0367 sum(col('v2')) + 5,
0368 plus_one(sum(col('v1'))),
0369 sum(plus_one(col('v2'))))
0370 .sort(['id', '(v % 2)'])
0371 .toPandas().sort_values(by=['id', '(v % 2)']))
0372
0373
0374
0375 result2 = (df.withColumn('v1', plus_one(df.v))
0376 .withColumn('v2', df.v + 2)
0377 .groupby(df.id, df.v % 2)
0378 .agg(sum_udf(col('v')),
0379 sum_udf(col('v1') + 3),
0380 sum_udf(col('v2')) + 5,
0381 plus_two(sum_udf(col('v1'))),
0382 sum_udf(plus_two(col('v2'))))
0383 .sort(['id', '(v % 2)'])
0384 .toPandas().sort_values(by=['id', '(v % 2)']))
0385
0386 expected2 = (df.withColumn('v1', df.v + 1)
0387 .withColumn('v2', df.v + 2)
0388 .groupby(df.id, df.v % 2)
0389 .agg(sum(col('v')),
0390 sum(col('v1') + 3),
0391 sum(col('v2')) + 5,
0392 plus_two(sum(col('v1'))),
0393 sum(plus_two(col('v2'))))
0394 .sort(['id', '(v % 2)'])
0395 .toPandas().sort_values(by=['id', '(v % 2)']))
0396
0397
0398 result3 = (df.groupby('id')
0399 .agg(sum_udf(df.v).alias('v'))
0400 .groupby('id')
0401 .agg(sum_udf(col('v')))
0402 .sort('id')
0403 .toPandas())
0404
0405 expected3 = (df.groupby('id')
0406 .agg(sum(df.v).alias('v'))
0407 .groupby('id')
0408 .agg(sum(col('v')))
0409 .sort('id')
0410 .toPandas())
0411
0412 assert_frame_equal(expected1, result1)
0413 assert_frame_equal(expected2, result2)
0414 assert_frame_equal(expected3, result3)
0415
0416 def test_retain_group_columns(self):
0417 with self.sql_conf({"spark.sql.retainGroupColumns": False}):
0418 df = self.data
0419 sum_udf = self.pandas_agg_sum_udf
0420
0421 result1 = df.groupby(df.id).agg(sum_udf(df.v))
0422 expected1 = df.groupby(df.id).agg(sum(df.v))
0423 assert_frame_equal(expected1.toPandas(), result1.toPandas())
0424
0425 def test_array_type(self):
0426 df = self.data
0427
0428 array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG)
0429 result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2'))
0430 self.assertEquals(result1.first()['v2'], [1.0, 2.0])
0431
0432 def test_invalid_args(self):
0433 df = self.data
0434 plus_one = self.python_plus_one
0435 mean_udf = self.pandas_agg_mean_udf
0436
0437 with QuietTest(self.sc):
0438 with self.assertRaisesRegexp(
0439 AnalysisException,
0440 'nor.*aggregate function'):
0441 df.groupby(df.id).agg(plus_one(df.v)).collect()
0442
0443 with QuietTest(self.sc):
0444 with self.assertRaisesRegexp(
0445 AnalysisException,
0446 'aggregate function.*argument.*aggregate function'):
0447 df.groupby(df.id).agg(mean_udf(mean_udf(df.v))).collect()
0448
0449 with QuietTest(self.sc):
0450 with self.assertRaisesRegexp(
0451 AnalysisException,
0452 'mixture.*aggregate function.*group aggregate pandas UDF'):
0453 df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
0454
0455 def test_register_vectorized_udf_basic(self):
0456 sum_pandas_udf = pandas_udf(
0457 lambda v: v.sum(), "integer", PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
0458
0459 self.assertEqual(sum_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
0460 group_agg_pandas_udf = self.spark.udf.register("sum_pandas_udf", sum_pandas_udf)
0461 self.assertEqual(group_agg_pandas_udf.evalType, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
0462 q = "SELECT sum_pandas_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2"
0463 actual = sorted(map(lambda r: r[0], self.spark.sql(q).collect()))
0464 expected = [1, 5]
0465 self.assertEqual(actual, expected)
0466
0467 def test_grouped_with_empty_partition(self):
0468 data = [Row(id=1, x=2), Row(id=1, x=3), Row(id=2, x=4)]
0469 expected = [Row(id=1, sum=5), Row(id=2, x=4)]
0470 num_parts = len(data) + 1
0471 df = self.spark.createDataFrame(self.sc.parallelize(data, numSlices=num_parts))
0472
0473 f = pandas_udf(lambda x: x.sum(),
0474 'int', PandasUDFType.GROUPED_AGG)
0475
0476 result = df.groupBy('id').agg(f(df['x']).alias('sum')).collect()
0477 self.assertEqual(result, expected)
0478
0479 def test_grouped_without_group_by_clause(self):
0480 @pandas_udf('double', PandasUDFType.GROUPED_AGG)
0481 def max_udf(v):
0482 return v.max()
0483
0484 df = self.spark.range(0, 100)
0485 self.spark.udf.register('max_udf', max_udf)
0486
0487 with self.tempView("table"):
0488 df.createTempView('table')
0489
0490 agg1 = df.agg(max_udf(df['id']))
0491 agg2 = self.spark.sql("select max_udf(id) from table")
0492 assert_frame_equal(agg1.toPandas(), agg2.toPandas())
0493
0494 def test_no_predicate_pushdown_through(self):
0495
0496 import numpy as np
0497
0498 @pandas_udf('float', PandasUDFType.GROUPED_AGG)
0499 def mean(x):
0500 return np.mean(x)
0501
0502 df = self.spark.createDataFrame([
0503 Row(id=1, foo=42), Row(id=2, foo=1), Row(id=2, foo=2)
0504 ])
0505
0506 agg = df.groupBy('id').agg(mean('foo').alias("mean"))
0507 filtered = agg.filter(agg['mean'] > 40.0)
0508
0509 assert(filtered.collect()[0]["mean"] == 42.0)
0510
0511
0512 if __name__ == "__main__":
0513 from pyspark.sql.tests.test_pandas_udf_grouped_agg import *
0514
0515 try:
0516 import xmlrunner
0517 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0518 except ImportError:
0519 testRunner = None
0520 unittest.main(testRunner=testRunner, verbosity=2)