Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import unittest
0019 
0020 from pyspark.sql.utils import AnalysisException
0021 from pyspark.sql.functions import array, explode, col, lit, mean, min, max, rank, \
0022     udf, pandas_udf, PandasUDFType
0023 from pyspark.sql.window import Window
0024 from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
0025     pandas_requirement_message, pyarrow_requirement_message
0026 from pyspark.testing.utils import QuietTest
0027 
0028 if have_pandas:
0029     from pandas.util.testing import assert_frame_equal
0030 
0031 
0032 @unittest.skipIf(
0033     not have_pandas or not have_pyarrow,
0034     pandas_requirement_message or pyarrow_requirement_message)
0035 class WindowPandasUDFTests(ReusedSQLTestCase):
0036     @property
0037     def data(self):
0038         return self.spark.range(10).toDF('id') \
0039             .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \
0040             .withColumn("v", explode(col('vs'))) \
0041             .drop('vs') \
0042             .withColumn('w', lit(1.0))
0043 
0044     @property
0045     def python_plus_one(self):
0046         return udf(lambda v: v + 1, 'double')
0047 
0048     @property
0049     def pandas_scalar_time_two(self):
0050         return pandas_udf(lambda v: v * 2, 'double')
0051 
0052     @property
0053     def pandas_agg_count_udf(self):
0054         @pandas_udf('long', PandasUDFType.GROUPED_AGG)
0055         def count(v):
0056             return len(v)
0057         return count
0058 
0059     @property
0060     def pandas_agg_mean_udf(self):
0061         @pandas_udf('double', PandasUDFType.GROUPED_AGG)
0062         def avg(v):
0063             return v.mean()
0064         return avg
0065 
0066     @property
0067     def pandas_agg_max_udf(self):
0068         @pandas_udf('double', PandasUDFType.GROUPED_AGG)
0069         def max(v):
0070             return v.max()
0071         return max
0072 
0073     @property
0074     def pandas_agg_min_udf(self):
0075         @pandas_udf('double', PandasUDFType.GROUPED_AGG)
0076         def min(v):
0077             return v.min()
0078         return min
0079 
0080     @property
0081     def unbounded_window(self):
0082         return Window.partitionBy('id') \
0083             .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing).orderBy('v')
0084 
0085     @property
0086     def ordered_window(self):
0087         return Window.partitionBy('id').orderBy('v')
0088 
0089     @property
0090     def unpartitioned_window(self):
0091         return Window.partitionBy()
0092 
0093     @property
0094     def sliding_row_window(self):
0095         return Window.partitionBy('id').orderBy('v').rowsBetween(-2, 1)
0096 
0097     @property
0098     def sliding_range_window(self):
0099         return Window.partitionBy('id').orderBy('v').rangeBetween(-2, 4)
0100 
0101     @property
0102     def growing_row_window(self):
0103         return Window.partitionBy('id').orderBy('v').rowsBetween(Window.unboundedPreceding, 3)
0104 
0105     @property
0106     def growing_range_window(self):
0107         return Window.partitionBy('id').orderBy('v') \
0108             .rangeBetween(Window.unboundedPreceding, 4)
0109 
0110     @property
0111     def shrinking_row_window(self):
0112         return Window.partitionBy('id').orderBy('v').rowsBetween(-2, Window.unboundedFollowing)
0113 
0114     @property
0115     def shrinking_range_window(self):
0116         return Window.partitionBy('id').orderBy('v') \
0117             .rangeBetween(-3, Window.unboundedFollowing)
0118 
0119     def test_simple(self):
0120         df = self.data
0121         w = self.unbounded_window
0122 
0123         mean_udf = self.pandas_agg_mean_udf
0124 
0125         result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w))
0126         expected1 = df.withColumn('mean_v', mean(df['v']).over(w))
0127 
0128         result2 = df.select(mean_udf(df['v']).over(w))
0129         expected2 = df.select(mean(df['v']).over(w))
0130 
0131         assert_frame_equal(expected1.toPandas(), result1.toPandas())
0132         assert_frame_equal(expected2.toPandas(), result2.toPandas())
0133 
0134     def test_multiple_udfs(self):
0135         df = self.data
0136         w = self.unbounded_window
0137 
0138         result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \
0139             .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \
0140             .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w))
0141 
0142         expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \
0143             .withColumn('max_v', max(df['v']).over(w)) \
0144             .withColumn('min_w', min(df['w']).over(w))
0145 
0146         assert_frame_equal(expected1.toPandas(), result1.toPandas())
0147 
0148     def test_replace_existing(self):
0149         df = self.data
0150         w = self.unbounded_window
0151 
0152         result1 = df.withColumn('v', self.pandas_agg_mean_udf(df['v']).over(w))
0153         expected1 = df.withColumn('v', mean(df['v']).over(w))
0154 
0155         assert_frame_equal(expected1.toPandas(), result1.toPandas())
0156 
0157     def test_mixed_sql(self):
0158         df = self.data
0159         w = self.unbounded_window
0160         mean_udf = self.pandas_agg_mean_udf
0161 
0162         result1 = df.withColumn('v', mean_udf(df['v'] * 2).over(w) + 1)
0163         expected1 = df.withColumn('v', mean(df['v'] * 2).over(w) + 1)
0164 
0165         assert_frame_equal(expected1.toPandas(), result1.toPandas())
0166 
0167     def test_mixed_udf(self):
0168         df = self.data
0169         w = self.unbounded_window
0170 
0171         plus_one = self.python_plus_one
0172         time_two = self.pandas_scalar_time_two
0173         mean_udf = self.pandas_agg_mean_udf
0174 
0175         result1 = df.withColumn(
0176             'v2',
0177             plus_one(mean_udf(plus_one(df['v'])).over(w)))
0178         expected1 = df.withColumn(
0179             'v2',
0180             plus_one(mean(plus_one(df['v'])).over(w)))
0181 
0182         result2 = df.withColumn(
0183             'v2',
0184             time_two(mean_udf(time_two(df['v'])).over(w)))
0185         expected2 = df.withColumn(
0186             'v2',
0187             time_two(mean(time_two(df['v'])).over(w)))
0188 
0189         assert_frame_equal(expected1.toPandas(), result1.toPandas())
0190         assert_frame_equal(expected2.toPandas(), result2.toPandas())
0191 
0192     def test_without_partitionBy(self):
0193         df = self.data
0194         w = self.unpartitioned_window
0195         mean_udf = self.pandas_agg_mean_udf
0196 
0197         result1 = df.withColumn('v2', mean_udf(df['v']).over(w))
0198         expected1 = df.withColumn('v2', mean(df['v']).over(w))
0199 
0200         result2 = df.select(mean_udf(df['v']).over(w))
0201         expected2 = df.select(mean(df['v']).over(w))
0202 
0203         assert_frame_equal(expected1.toPandas(), result1.toPandas())
0204         assert_frame_equal(expected2.toPandas(), result2.toPandas())
0205 
0206     def test_mixed_sql_and_udf(self):
0207         df = self.data
0208         w = self.unbounded_window
0209         ow = self.ordered_window
0210         max_udf = self.pandas_agg_max_udf
0211         min_udf = self.pandas_agg_min_udf
0212 
0213         result1 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min_udf(df['v']).over(w))
0214         expected1 = df.withColumn('v_diff', max(df['v']).over(w) - min(df['v']).over(w))
0215 
0216         # Test mixing sql window function and window udf in the same expression
0217         result2 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min(df['v']).over(w))
0218         expected2 = expected1
0219 
0220         # Test chaining sql aggregate function and udf
0221         result3 = df.withColumn('max_v', max_udf(df['v']).over(w)) \
0222             .withColumn('min_v', min(df['v']).over(w)) \
0223             .withColumn('v_diff', col('max_v') - col('min_v')) \
0224             .drop('max_v', 'min_v')
0225         expected3 = expected1
0226 
0227         # Test mixing sql window function and udf
0228         result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \
0229             .withColumn('rank', rank().over(ow))
0230         expected4 = df.withColumn('max_v', max(df['v']).over(w)) \
0231             .withColumn('rank', rank().over(ow))
0232 
0233         assert_frame_equal(expected1.toPandas(), result1.toPandas())
0234         assert_frame_equal(expected2.toPandas(), result2.toPandas())
0235         assert_frame_equal(expected3.toPandas(), result3.toPandas())
0236         assert_frame_equal(expected4.toPandas(), result4.toPandas())
0237 
0238     def test_array_type(self):
0239         df = self.data
0240         w = self.unbounded_window
0241 
0242         array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG)
0243         result1 = df.withColumn('v2', array_udf(df['v']).over(w))
0244         self.assertEquals(result1.first()['v2'], [1.0, 2.0])
0245 
0246     def test_invalid_args(self):
0247         df = self.data
0248         w = self.unbounded_window
0249 
0250         with QuietTest(self.sc):
0251             with self.assertRaisesRegexp(
0252                     AnalysisException,
0253                     '.*not supported within a window function'):
0254                 foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP)
0255                 df.withColumn('v2', foo_udf(df['v']).over(w))
0256 
0257     def test_bounded_simple(self):
0258         from pyspark.sql.functions import mean, max, min, count
0259 
0260         df = self.data
0261         w1 = self.sliding_row_window
0262         w2 = self.shrinking_range_window
0263 
0264         plus_one = self.python_plus_one
0265         count_udf = self.pandas_agg_count_udf
0266         mean_udf = self.pandas_agg_mean_udf
0267         max_udf = self.pandas_agg_max_udf
0268         min_udf = self.pandas_agg_min_udf
0269 
0270         result1 = df.withColumn('mean_v', mean_udf(plus_one(df['v'])).over(w1)) \
0271             .withColumn('count_v', count_udf(df['v']).over(w2)) \
0272             .withColumn('max_v',  max_udf(df['v']).over(w2)) \
0273             .withColumn('min_v', min_udf(df['v']).over(w1))
0274 
0275         expected1 = df.withColumn('mean_v', mean(plus_one(df['v'])).over(w1)) \
0276             .withColumn('count_v', count(df['v']).over(w2)) \
0277             .withColumn('max_v', max(df['v']).over(w2)) \
0278             .withColumn('min_v', min(df['v']).over(w1))
0279 
0280         assert_frame_equal(expected1.toPandas(), result1.toPandas())
0281 
0282     def test_growing_window(self):
0283         from pyspark.sql.functions import mean
0284 
0285         df = self.data
0286         w1 = self.growing_row_window
0287         w2 = self.growing_range_window
0288 
0289         mean_udf = self.pandas_agg_mean_udf
0290 
0291         result1 = df.withColumn('m1', mean_udf(df['v']).over(w1)) \
0292             .withColumn('m2', mean_udf(df['v']).over(w2))
0293 
0294         expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \
0295             .withColumn('m2', mean(df['v']).over(w2))
0296 
0297         assert_frame_equal(expected1.toPandas(), result1.toPandas())
0298 
0299     def test_sliding_window(self):
0300         from pyspark.sql.functions import mean
0301 
0302         df = self.data
0303         w1 = self.sliding_row_window
0304         w2 = self.sliding_range_window
0305 
0306         mean_udf = self.pandas_agg_mean_udf
0307 
0308         result1 = df.withColumn('m1', mean_udf(df['v']).over(w1)) \
0309             .withColumn('m2', mean_udf(df['v']).over(w2))
0310 
0311         expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \
0312             .withColumn('m2', mean(df['v']).over(w2))
0313 
0314         assert_frame_equal(expected1.toPandas(), result1.toPandas())
0315 
0316     def test_shrinking_window(self):
0317         from pyspark.sql.functions import mean
0318 
0319         df = self.data
0320         w1 = self.shrinking_row_window
0321         w2 = self.shrinking_range_window
0322 
0323         mean_udf = self.pandas_agg_mean_udf
0324 
0325         result1 = df.withColumn('m1', mean_udf(df['v']).over(w1)) \
0326             .withColumn('m2', mean_udf(df['v']).over(w2))
0327 
0328         expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \
0329             .withColumn('m2', mean(df['v']).over(w2))
0330 
0331         assert_frame_equal(expected1.toPandas(), result1.toPandas())
0332 
0333     def test_bounded_mixed(self):
0334         from pyspark.sql.functions import mean, max
0335 
0336         df = self.data
0337         w1 = self.sliding_row_window
0338         w2 = self.unbounded_window
0339 
0340         mean_udf = self.pandas_agg_mean_udf
0341         max_udf = self.pandas_agg_max_udf
0342 
0343         result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w1)) \
0344             .withColumn('max_v', max_udf(df['v']).over(w2)) \
0345             .withColumn('mean_unbounded_v', mean_udf(df['v']).over(w1))
0346 
0347         expected1 = df.withColumn('mean_v', mean(df['v']).over(w1)) \
0348             .withColumn('max_v', max(df['v']).over(w2)) \
0349             .withColumn('mean_unbounded_v', mean(df['v']).over(w1))
0350 
0351         assert_frame_equal(expected1.toPandas(), result1.toPandas())
0352 
0353 
0354 if __name__ == "__main__":
0355     from pyspark.sql.tests.test_pandas_udf_window import *
0356 
0357     try:
0358         import xmlrunner
0359         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0360     except ImportError:
0361         testRunner = None
0362     unittest.main(testRunner=testRunner, verbosity=2)