0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import unittest
0019
0020 from pyspark.sql.utils import AnalysisException
0021 from pyspark.sql.functions import array, explode, col, lit, mean, min, max, rank, \
0022 udf, pandas_udf, PandasUDFType
0023 from pyspark.sql.window import Window
0024 from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
0025 pandas_requirement_message, pyarrow_requirement_message
0026 from pyspark.testing.utils import QuietTest
0027
0028 if have_pandas:
0029 from pandas.util.testing import assert_frame_equal
0030
0031
0032 @unittest.skipIf(
0033 not have_pandas or not have_pyarrow,
0034 pandas_requirement_message or pyarrow_requirement_message)
0035 class WindowPandasUDFTests(ReusedSQLTestCase):
0036 @property
0037 def data(self):
0038 return self.spark.range(10).toDF('id') \
0039 .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \
0040 .withColumn("v", explode(col('vs'))) \
0041 .drop('vs') \
0042 .withColumn('w', lit(1.0))
0043
0044 @property
0045 def python_plus_one(self):
0046 return udf(lambda v: v + 1, 'double')
0047
0048 @property
0049 def pandas_scalar_time_two(self):
0050 return pandas_udf(lambda v: v * 2, 'double')
0051
0052 @property
0053 def pandas_agg_count_udf(self):
0054 @pandas_udf('long', PandasUDFType.GROUPED_AGG)
0055 def count(v):
0056 return len(v)
0057 return count
0058
0059 @property
0060 def pandas_agg_mean_udf(self):
0061 @pandas_udf('double', PandasUDFType.GROUPED_AGG)
0062 def avg(v):
0063 return v.mean()
0064 return avg
0065
0066 @property
0067 def pandas_agg_max_udf(self):
0068 @pandas_udf('double', PandasUDFType.GROUPED_AGG)
0069 def max(v):
0070 return v.max()
0071 return max
0072
0073 @property
0074 def pandas_agg_min_udf(self):
0075 @pandas_udf('double', PandasUDFType.GROUPED_AGG)
0076 def min(v):
0077 return v.min()
0078 return min
0079
0080 @property
0081 def unbounded_window(self):
0082 return Window.partitionBy('id') \
0083 .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing).orderBy('v')
0084
0085 @property
0086 def ordered_window(self):
0087 return Window.partitionBy('id').orderBy('v')
0088
0089 @property
0090 def unpartitioned_window(self):
0091 return Window.partitionBy()
0092
0093 @property
0094 def sliding_row_window(self):
0095 return Window.partitionBy('id').orderBy('v').rowsBetween(-2, 1)
0096
0097 @property
0098 def sliding_range_window(self):
0099 return Window.partitionBy('id').orderBy('v').rangeBetween(-2, 4)
0100
0101 @property
0102 def growing_row_window(self):
0103 return Window.partitionBy('id').orderBy('v').rowsBetween(Window.unboundedPreceding, 3)
0104
0105 @property
0106 def growing_range_window(self):
0107 return Window.partitionBy('id').orderBy('v') \
0108 .rangeBetween(Window.unboundedPreceding, 4)
0109
0110 @property
0111 def shrinking_row_window(self):
0112 return Window.partitionBy('id').orderBy('v').rowsBetween(-2, Window.unboundedFollowing)
0113
0114 @property
0115 def shrinking_range_window(self):
0116 return Window.partitionBy('id').orderBy('v') \
0117 .rangeBetween(-3, Window.unboundedFollowing)
0118
0119 def test_simple(self):
0120 df = self.data
0121 w = self.unbounded_window
0122
0123 mean_udf = self.pandas_agg_mean_udf
0124
0125 result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w))
0126 expected1 = df.withColumn('mean_v', mean(df['v']).over(w))
0127
0128 result2 = df.select(mean_udf(df['v']).over(w))
0129 expected2 = df.select(mean(df['v']).over(w))
0130
0131 assert_frame_equal(expected1.toPandas(), result1.toPandas())
0132 assert_frame_equal(expected2.toPandas(), result2.toPandas())
0133
0134 def test_multiple_udfs(self):
0135 df = self.data
0136 w = self.unbounded_window
0137
0138 result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \
0139 .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \
0140 .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w))
0141
0142 expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \
0143 .withColumn('max_v', max(df['v']).over(w)) \
0144 .withColumn('min_w', min(df['w']).over(w))
0145
0146 assert_frame_equal(expected1.toPandas(), result1.toPandas())
0147
0148 def test_replace_existing(self):
0149 df = self.data
0150 w = self.unbounded_window
0151
0152 result1 = df.withColumn('v', self.pandas_agg_mean_udf(df['v']).over(w))
0153 expected1 = df.withColumn('v', mean(df['v']).over(w))
0154
0155 assert_frame_equal(expected1.toPandas(), result1.toPandas())
0156
0157 def test_mixed_sql(self):
0158 df = self.data
0159 w = self.unbounded_window
0160 mean_udf = self.pandas_agg_mean_udf
0161
0162 result1 = df.withColumn('v', mean_udf(df['v'] * 2).over(w) + 1)
0163 expected1 = df.withColumn('v', mean(df['v'] * 2).over(w) + 1)
0164
0165 assert_frame_equal(expected1.toPandas(), result1.toPandas())
0166
0167 def test_mixed_udf(self):
0168 df = self.data
0169 w = self.unbounded_window
0170
0171 plus_one = self.python_plus_one
0172 time_two = self.pandas_scalar_time_two
0173 mean_udf = self.pandas_agg_mean_udf
0174
0175 result1 = df.withColumn(
0176 'v2',
0177 plus_one(mean_udf(plus_one(df['v'])).over(w)))
0178 expected1 = df.withColumn(
0179 'v2',
0180 plus_one(mean(plus_one(df['v'])).over(w)))
0181
0182 result2 = df.withColumn(
0183 'v2',
0184 time_two(mean_udf(time_two(df['v'])).over(w)))
0185 expected2 = df.withColumn(
0186 'v2',
0187 time_two(mean(time_two(df['v'])).over(w)))
0188
0189 assert_frame_equal(expected1.toPandas(), result1.toPandas())
0190 assert_frame_equal(expected2.toPandas(), result2.toPandas())
0191
0192 def test_without_partitionBy(self):
0193 df = self.data
0194 w = self.unpartitioned_window
0195 mean_udf = self.pandas_agg_mean_udf
0196
0197 result1 = df.withColumn('v2', mean_udf(df['v']).over(w))
0198 expected1 = df.withColumn('v2', mean(df['v']).over(w))
0199
0200 result2 = df.select(mean_udf(df['v']).over(w))
0201 expected2 = df.select(mean(df['v']).over(w))
0202
0203 assert_frame_equal(expected1.toPandas(), result1.toPandas())
0204 assert_frame_equal(expected2.toPandas(), result2.toPandas())
0205
0206 def test_mixed_sql_and_udf(self):
0207 df = self.data
0208 w = self.unbounded_window
0209 ow = self.ordered_window
0210 max_udf = self.pandas_agg_max_udf
0211 min_udf = self.pandas_agg_min_udf
0212
0213 result1 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min_udf(df['v']).over(w))
0214 expected1 = df.withColumn('v_diff', max(df['v']).over(w) - min(df['v']).over(w))
0215
0216
0217 result2 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min(df['v']).over(w))
0218 expected2 = expected1
0219
0220
0221 result3 = df.withColumn('max_v', max_udf(df['v']).over(w)) \
0222 .withColumn('min_v', min(df['v']).over(w)) \
0223 .withColumn('v_diff', col('max_v') - col('min_v')) \
0224 .drop('max_v', 'min_v')
0225 expected3 = expected1
0226
0227
0228 result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \
0229 .withColumn('rank', rank().over(ow))
0230 expected4 = df.withColumn('max_v', max(df['v']).over(w)) \
0231 .withColumn('rank', rank().over(ow))
0232
0233 assert_frame_equal(expected1.toPandas(), result1.toPandas())
0234 assert_frame_equal(expected2.toPandas(), result2.toPandas())
0235 assert_frame_equal(expected3.toPandas(), result3.toPandas())
0236 assert_frame_equal(expected4.toPandas(), result4.toPandas())
0237
0238 def test_array_type(self):
0239 df = self.data
0240 w = self.unbounded_window
0241
0242 array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG)
0243 result1 = df.withColumn('v2', array_udf(df['v']).over(w))
0244 self.assertEquals(result1.first()['v2'], [1.0, 2.0])
0245
0246 def test_invalid_args(self):
0247 df = self.data
0248 w = self.unbounded_window
0249
0250 with QuietTest(self.sc):
0251 with self.assertRaisesRegexp(
0252 AnalysisException,
0253 '.*not supported within a window function'):
0254 foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP)
0255 df.withColumn('v2', foo_udf(df['v']).over(w))
0256
0257 def test_bounded_simple(self):
0258 from pyspark.sql.functions import mean, max, min, count
0259
0260 df = self.data
0261 w1 = self.sliding_row_window
0262 w2 = self.shrinking_range_window
0263
0264 plus_one = self.python_plus_one
0265 count_udf = self.pandas_agg_count_udf
0266 mean_udf = self.pandas_agg_mean_udf
0267 max_udf = self.pandas_agg_max_udf
0268 min_udf = self.pandas_agg_min_udf
0269
0270 result1 = df.withColumn('mean_v', mean_udf(plus_one(df['v'])).over(w1)) \
0271 .withColumn('count_v', count_udf(df['v']).over(w2)) \
0272 .withColumn('max_v', max_udf(df['v']).over(w2)) \
0273 .withColumn('min_v', min_udf(df['v']).over(w1))
0274
0275 expected1 = df.withColumn('mean_v', mean(plus_one(df['v'])).over(w1)) \
0276 .withColumn('count_v', count(df['v']).over(w2)) \
0277 .withColumn('max_v', max(df['v']).over(w2)) \
0278 .withColumn('min_v', min(df['v']).over(w1))
0279
0280 assert_frame_equal(expected1.toPandas(), result1.toPandas())
0281
0282 def test_growing_window(self):
0283 from pyspark.sql.functions import mean
0284
0285 df = self.data
0286 w1 = self.growing_row_window
0287 w2 = self.growing_range_window
0288
0289 mean_udf = self.pandas_agg_mean_udf
0290
0291 result1 = df.withColumn('m1', mean_udf(df['v']).over(w1)) \
0292 .withColumn('m2', mean_udf(df['v']).over(w2))
0293
0294 expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \
0295 .withColumn('m2', mean(df['v']).over(w2))
0296
0297 assert_frame_equal(expected1.toPandas(), result1.toPandas())
0298
0299 def test_sliding_window(self):
0300 from pyspark.sql.functions import mean
0301
0302 df = self.data
0303 w1 = self.sliding_row_window
0304 w2 = self.sliding_range_window
0305
0306 mean_udf = self.pandas_agg_mean_udf
0307
0308 result1 = df.withColumn('m1', mean_udf(df['v']).over(w1)) \
0309 .withColumn('m2', mean_udf(df['v']).over(w2))
0310
0311 expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \
0312 .withColumn('m2', mean(df['v']).over(w2))
0313
0314 assert_frame_equal(expected1.toPandas(), result1.toPandas())
0315
0316 def test_shrinking_window(self):
0317 from pyspark.sql.functions import mean
0318
0319 df = self.data
0320 w1 = self.shrinking_row_window
0321 w2 = self.shrinking_range_window
0322
0323 mean_udf = self.pandas_agg_mean_udf
0324
0325 result1 = df.withColumn('m1', mean_udf(df['v']).over(w1)) \
0326 .withColumn('m2', mean_udf(df['v']).over(w2))
0327
0328 expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \
0329 .withColumn('m2', mean(df['v']).over(w2))
0330
0331 assert_frame_equal(expected1.toPandas(), result1.toPandas())
0332
0333 def test_bounded_mixed(self):
0334 from pyspark.sql.functions import mean, max
0335
0336 df = self.data
0337 w1 = self.sliding_row_window
0338 w2 = self.unbounded_window
0339
0340 mean_udf = self.pandas_agg_mean_udf
0341 max_udf = self.pandas_agg_max_udf
0342
0343 result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w1)) \
0344 .withColumn('max_v', max_udf(df['v']).over(w2)) \
0345 .withColumn('mean_unbounded_v', mean_udf(df['v']).over(w1))
0346
0347 expected1 = df.withColumn('mean_v', mean(df['v']).over(w1)) \
0348 .withColumn('max_v', max(df['v']).over(w2)) \
0349 .withColumn('mean_unbounded_v', mean(df['v']).over(w1))
0350
0351 assert_frame_equal(expected1.toPandas(), result1.toPandas())
0352
0353
0354 if __name__ == "__main__":
0355 from pyspark.sql.tests.test_pandas_udf_window import *
0356
0357 try:
0358 import xmlrunner
0359 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0360 except ImportError:
0361 testRunner = None
0362 unittest.main(testRunner=testRunner, verbosity=2)