Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import datetime
0019 import unittest
0020 import sys
0021 
0022 from collections import OrderedDict
0023 from decimal import Decimal
0024 
0025 from pyspark.sql import Row
0026 from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType, \
0027     window
0028 from pyspark.sql.types import *
0029 from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
0030     pandas_requirement_message, pyarrow_requirement_message
0031 from pyspark.testing.utils import QuietTest
0032 
0033 if have_pandas:
0034     import pandas as pd
0035     from pandas.util.testing import assert_frame_equal
0036 
0037 if have_pyarrow:
0038     import pyarrow as pa
0039 
0040 
0041 # Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
0042 # from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
0043 _check_column_type = sys.version >= '3'
0044 
0045 
0046 @unittest.skipIf(
0047     not have_pandas or not have_pyarrow,
0048     pandas_requirement_message or pyarrow_requirement_message)
0049 class GroupedMapInPandasTests(ReusedSQLTestCase):
0050 
0051     @property
0052     def data(self):
0053         return self.spark.range(10).toDF('id') \
0054             .withColumn("vs", array([lit(i) for i in range(20, 30)])) \
0055             .withColumn("v", explode(col('vs'))).drop('vs')
0056 
0057     def test_supported_types(self):
0058 
0059         values = [
0060             1, 2, 3,
0061             4, 5, 1.1,
0062             2.2, Decimal(1.123),
0063             [1, 2, 2], True, 'hello',
0064             bytearray([0x01, 0x02])
0065         ]
0066         output_fields = [
0067             ('id', IntegerType()), ('byte', ByteType()), ('short', ShortType()),
0068             ('int', IntegerType()), ('long', LongType()), ('float', FloatType()),
0069             ('double', DoubleType()), ('decim', DecimalType(10, 3)),
0070             ('array', ArrayType(IntegerType())), ('bool', BooleanType()), ('str', StringType()),
0071             ('bin', BinaryType())
0072         ]
0073 
0074         output_schema = StructType([StructField(*x) for x in output_fields])
0075         df = self.spark.createDataFrame([values], schema=output_schema)
0076 
0077         # Different forms of group map pandas UDF, results of these are the same
0078         udf1 = pandas_udf(
0079             lambda pdf: pdf.assign(
0080                 byte=pdf.byte * 2,
0081                 short=pdf.short * 2,
0082                 int=pdf.int * 2,
0083                 long=pdf.long * 2,
0084                 float=pdf.float * 2,
0085                 double=pdf.double * 2,
0086                 decim=pdf.decim * 2,
0087                 bool=False if pdf.bool else True,
0088                 str=pdf.str + 'there',
0089                 array=pdf.array,
0090                 bin=pdf.bin
0091             ),
0092             output_schema,
0093             PandasUDFType.GROUPED_MAP
0094         )
0095 
0096         udf2 = pandas_udf(
0097             lambda _, pdf: pdf.assign(
0098                 byte=pdf.byte * 2,
0099                 short=pdf.short * 2,
0100                 int=pdf.int * 2,
0101                 long=pdf.long * 2,
0102                 float=pdf.float * 2,
0103                 double=pdf.double * 2,
0104                 decim=pdf.decim * 2,
0105                 bool=False if pdf.bool else True,
0106                 str=pdf.str + 'there',
0107                 array=pdf.array,
0108                 bin=pdf.bin
0109             ),
0110             output_schema,
0111             PandasUDFType.GROUPED_MAP
0112         )
0113 
0114         udf3 = pandas_udf(
0115             lambda key, pdf: pdf.assign(
0116                 id=key[0],
0117                 byte=pdf.byte * 2,
0118                 short=pdf.short * 2,
0119                 int=pdf.int * 2,
0120                 long=pdf.long * 2,
0121                 float=pdf.float * 2,
0122                 double=pdf.double * 2,
0123                 decim=pdf.decim * 2,
0124                 bool=False if pdf.bool else True,
0125                 str=pdf.str + 'there',
0126                 array=pdf.array,
0127                 bin=pdf.bin
0128             ),
0129             output_schema,
0130             PandasUDFType.GROUPED_MAP
0131         )
0132 
0133         result1 = df.groupby('id').apply(udf1).sort('id').toPandas()
0134         expected1 = df.toPandas().groupby('id').apply(udf1.func).reset_index(drop=True)
0135 
0136         result2 = df.groupby('id').apply(udf2).sort('id').toPandas()
0137         expected2 = expected1
0138 
0139         result3 = df.groupby('id').apply(udf3).sort('id').toPandas()
0140         expected3 = expected1
0141 
0142         assert_frame_equal(expected1, result1, check_column_type=_check_column_type)
0143         assert_frame_equal(expected2, result2, check_column_type=_check_column_type)
0144         assert_frame_equal(expected3, result3, check_column_type=_check_column_type)
0145 
0146     def test_array_type_correct(self):
0147         df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id")
0148 
0149         output_schema = StructType(
0150             [StructField('id', LongType()),
0151              StructField('v', IntegerType()),
0152              StructField('arr', ArrayType(LongType()))])
0153 
0154         udf = pandas_udf(
0155             lambda pdf: pdf,
0156             output_schema,
0157             PandasUDFType.GROUPED_MAP
0158         )
0159 
0160         result = df.groupby('id').apply(udf).sort('id').toPandas()
0161         expected = df.toPandas().groupby('id').apply(udf.func).reset_index(drop=True)
0162         assert_frame_equal(expected, result, check_column_type=_check_column_type)
0163 
0164     def test_register_grouped_map_udf(self):
0165         foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP)
0166         with QuietTest(self.sc):
0167             with self.assertRaisesRegexp(
0168                     ValueError,
0169                     'f.*SQL_BATCHED_UDF.*SQL_SCALAR_PANDAS_UDF.*SQL_GROUPED_AGG_PANDAS_UDF.*'):
0170                 self.spark.catalog.registerFunction("foo_udf", foo_udf)
0171 
0172     def test_decorator(self):
0173         df = self.data
0174 
0175         @pandas_udf(
0176             'id long, v int, v1 double, v2 long',
0177             PandasUDFType.GROUPED_MAP
0178         )
0179         def foo(pdf):
0180             return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id)
0181 
0182         result = df.groupby('id').apply(foo).sort('id').toPandas()
0183         expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
0184         assert_frame_equal(expected, result, check_column_type=_check_column_type)
0185 
0186     def test_coerce(self):
0187         df = self.data
0188 
0189         foo = pandas_udf(
0190             lambda pdf: pdf,
0191             'id long, v double',
0192             PandasUDFType.GROUPED_MAP
0193         )
0194 
0195         result = df.groupby('id').apply(foo).sort('id').toPandas()
0196         expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
0197         expected = expected.assign(v=expected.v.astype('float64'))
0198         assert_frame_equal(expected, result, check_column_type=_check_column_type)
0199 
0200     def test_complex_groupby(self):
0201         df = self.data
0202 
0203         @pandas_udf(
0204             'id long, v int, norm double',
0205             PandasUDFType.GROUPED_MAP
0206         )
0207         def normalize(pdf):
0208             v = pdf.v
0209             return pdf.assign(norm=(v - v.mean()) / v.std())
0210 
0211         result = df.groupby(col('id') % 2 == 0).apply(normalize).sort('id', 'v').toPandas()
0212         pdf = df.toPandas()
0213         expected = pdf.groupby(pdf['id'] % 2 == 0, as_index=False).apply(normalize.func)
0214         expected = expected.sort_values(['id', 'v']).reset_index(drop=True)
0215         expected = expected.assign(norm=expected.norm.astype('float64'))
0216         assert_frame_equal(expected, result, check_column_type=_check_column_type)
0217 
0218     def test_empty_groupby(self):
0219         df = self.data
0220 
0221         @pandas_udf(
0222             'id long, v int, norm double',
0223             PandasUDFType.GROUPED_MAP
0224         )
0225         def normalize(pdf):
0226             v = pdf.v
0227             return pdf.assign(norm=(v - v.mean()) / v.std())
0228 
0229         result = df.groupby().apply(normalize).sort('id', 'v').toPandas()
0230         pdf = df.toPandas()
0231         expected = normalize.func(pdf)
0232         expected = expected.sort_values(['id', 'v']).reset_index(drop=True)
0233         expected = expected.assign(norm=expected.norm.astype('float64'))
0234         assert_frame_equal(expected, result, check_column_type=_check_column_type)
0235 
0236     def test_datatype_string(self):
0237         df = self.data
0238 
0239         foo_udf = pandas_udf(
0240             lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
0241             'id long, v int, v1 double, v2 long',
0242             PandasUDFType.GROUPED_MAP
0243         )
0244 
0245         result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
0246         expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
0247         assert_frame_equal(expected, result, check_column_type=_check_column_type)
0248 
0249     def test_wrong_return_type(self):
0250         with QuietTest(self.sc):
0251             with self.assertRaisesRegexp(
0252                     NotImplementedError,
0253                     'Invalid return type.*grouped map Pandas UDF.*MapType'):
0254                 pandas_udf(
0255                     lambda pdf: pdf,
0256                     'id long, v map<int, int>',
0257                     PandasUDFType.GROUPED_MAP)
0258 
0259     def test_wrong_args(self):
0260         df = self.data
0261 
0262         with QuietTest(self.sc):
0263             with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
0264                 df.groupby('id').apply(lambda x: x)
0265             with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
0266                 df.groupby('id').apply(udf(lambda x: x, DoubleType()))
0267             with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
0268                 df.groupby('id').apply(sum(df.v))
0269             with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
0270                 df.groupby('id').apply(df.v + 1)
0271             with self.assertRaisesRegexp(ValueError, 'Invalid function'):
0272                 df.groupby('id').apply(
0273                     pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())])))
0274             with self.assertRaisesRegexp(ValueError, 'Invalid udf'):
0275                 df.groupby('id').apply(pandas_udf(lambda x, y: x, DoubleType()))
0276             with self.assertRaisesRegexp(ValueError, 'Invalid udf.*GROUPED_MAP'):
0277                 df.groupby('id').apply(
0278                     pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR))
0279 
0280     def test_unsupported_types(self):
0281         common_err_msg = 'Invalid return type.*grouped map Pandas UDF.*'
0282         unsupported_types = [
0283             StructField('map', MapType(StringType(), IntegerType())),
0284             StructField('arr_ts', ArrayType(TimestampType())),
0285             StructField('null', NullType()),
0286             StructField('struct', StructType([StructField('l', LongType())])),
0287         ]
0288 
0289         for unsupported_type in unsupported_types:
0290             schema = StructType([StructField('id', LongType(), True), unsupported_type])
0291             with QuietTest(self.sc):
0292                 with self.assertRaisesRegexp(NotImplementedError, common_err_msg):
0293                     pandas_udf(lambda x: x, schema, PandasUDFType.GROUPED_MAP)
0294 
0295     # Regression test for SPARK-23314
0296     def test_timestamp_dst(self):
0297         # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am
0298         dt = [datetime.datetime(2015, 11, 1, 0, 30),
0299               datetime.datetime(2015, 11, 1, 1, 30),
0300               datetime.datetime(2015, 11, 1, 2, 30)]
0301         df = self.spark.createDataFrame(dt, 'timestamp').toDF('time')
0302         foo_udf = pandas_udf(lambda pdf: pdf, 'time timestamp', PandasUDFType.GROUPED_MAP)
0303         result = df.groupby('time').apply(foo_udf).sort('time')
0304         assert_frame_equal(df.toPandas(), result.toPandas(), check_column_type=_check_column_type)
0305 
0306     def test_udf_with_key(self):
0307         import numpy as np
0308 
0309         df = self.data
0310         pdf = df.toPandas()
0311 
0312         def foo1(key, pdf):
0313             assert type(key) == tuple
0314             assert type(key[0]) == np.int64
0315 
0316             return pdf.assign(v1=key[0],
0317                               v2=pdf.v * key[0],
0318                               v3=pdf.v * pdf.id,
0319                               v4=pdf.v * pdf.id.mean())
0320 
0321         def foo2(key, pdf):
0322             assert type(key) == tuple
0323             assert type(key[0]) == np.int64
0324             assert type(key[1]) == np.int32
0325 
0326             return pdf.assign(v1=key[0],
0327                               v2=key[1],
0328                               v3=pdf.v * key[0],
0329                               v4=pdf.v + key[1])
0330 
0331         def foo3(key, pdf):
0332             assert type(key) == tuple
0333             assert len(key) == 0
0334             return pdf.assign(v1=pdf.v * pdf.id)
0335 
0336         # v2 is int because numpy.int64 * pd.Series<int32> results in pd.Series<int32>
0337         # v3 is long because pd.Series<int64> * pd.Series<int32> results in pd.Series<int64>
0338         udf1 = pandas_udf(
0339             foo1,
0340             'id long, v int, v1 long, v2 int, v3 long, v4 double',
0341             PandasUDFType.GROUPED_MAP)
0342 
0343         udf2 = pandas_udf(
0344             foo2,
0345             'id long, v int, v1 long, v2 int, v3 int, v4 int',
0346             PandasUDFType.GROUPED_MAP)
0347 
0348         udf3 = pandas_udf(
0349             foo3,
0350             'id long, v int, v1 long',
0351             PandasUDFType.GROUPED_MAP)
0352 
0353         # Test groupby column
0354         result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas()
0355         expected1 = pdf.groupby('id', as_index=False)\
0356             .apply(lambda x: udf1.func((x.id.iloc[0],), x))\
0357             .sort_values(['id', 'v']).reset_index(drop=True)
0358         assert_frame_equal(expected1, result1, check_column_type=_check_column_type)
0359 
0360         # Test groupby expression
0361         result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas()
0362         expected2 = pdf.groupby(pdf.id % 2, as_index=False)\
0363             .apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\
0364             .sort_values(['id', 'v']).reset_index(drop=True)
0365         assert_frame_equal(expected2, result2, check_column_type=_check_column_type)
0366 
0367         # Test complex groupby
0368         result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas()
0369         expected3 = pdf.groupby([pdf.id, pdf.v % 2], as_index=False)\
0370             .apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\
0371             .sort_values(['id', 'v']).reset_index(drop=True)
0372         assert_frame_equal(expected3, result3, check_column_type=_check_column_type)
0373 
0374         # Test empty groupby
0375         result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas()
0376         expected4 = udf3.func((), pdf)
0377         assert_frame_equal(expected4, result4, check_column_type=_check_column_type)
0378 
0379     def test_column_order(self):
0380 
0381         # Helper function to set column names from a list
0382         def rename_pdf(pdf, names):
0383             pdf.rename(columns={old: new for old, new in
0384                                 zip(pd_result.columns, names)}, inplace=True)
0385 
0386         df = self.data
0387         grouped_df = df.groupby('id')
0388         grouped_pdf = df.toPandas().groupby('id', as_index=False)
0389 
0390         # Function returns a pdf with required column names, but order could be arbitrary using dict
0391         def change_col_order(pdf):
0392             # Constructing a DataFrame from a dict should result in the same order,
0393             # but use OrderedDict to ensure the pdf column order is different than schema
0394             return pd.DataFrame.from_dict(OrderedDict([
0395                 ('id', pdf.id),
0396                 ('u', pdf.v * 2),
0397                 ('v', pdf.v)]))
0398 
0399         ordered_udf = pandas_udf(
0400             change_col_order,
0401             'id long, v int, u int',
0402             PandasUDFType.GROUPED_MAP
0403         )
0404 
0405         # The UDF result should assign columns by name from the pdf
0406         result = grouped_df.apply(ordered_udf).sort('id', 'v')\
0407             .select('id', 'u', 'v').toPandas()
0408         pd_result = grouped_pdf.apply(change_col_order)
0409         expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True)
0410         assert_frame_equal(expected, result, check_column_type=_check_column_type)
0411 
0412         # Function returns a pdf with positional columns, indexed by range
0413         def range_col_order(pdf):
0414             # Create a DataFrame with positional columns, fix types to long
0415             return pd.DataFrame(list(zip(pdf.id, pdf.v * 3, pdf.v)), dtype='int64')
0416 
0417         range_udf = pandas_udf(
0418             range_col_order,
0419             'id long, u long, v long',
0420             PandasUDFType.GROUPED_MAP
0421         )
0422 
0423         # The UDF result uses positional columns from the pdf
0424         result = grouped_df.apply(range_udf).sort('id', 'v') \
0425             .select('id', 'u', 'v').toPandas()
0426         pd_result = grouped_pdf.apply(range_col_order)
0427         rename_pdf(pd_result, ['id', 'u', 'v'])
0428         expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True)
0429         assert_frame_equal(expected, result, check_column_type=_check_column_type)
0430 
0431         # Function returns a pdf with columns indexed with integers
0432         def int_index(pdf):
0433             return pd.DataFrame(OrderedDict([(0, pdf.id), (1, pdf.v * 4), (2, pdf.v)]))
0434 
0435         int_index_udf = pandas_udf(
0436             int_index,
0437             'id long, u int, v int',
0438             PandasUDFType.GROUPED_MAP
0439         )
0440 
0441         # The UDF result should assign columns by position of integer index
0442         result = grouped_df.apply(int_index_udf).sort('id', 'v') \
0443             .select('id', 'u', 'v').toPandas()
0444         pd_result = grouped_pdf.apply(int_index)
0445         rename_pdf(pd_result, ['id', 'u', 'v'])
0446         expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True)
0447         assert_frame_equal(expected, result, check_column_type=_check_column_type)
0448 
0449         @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP)
0450         def column_name_typo(pdf):
0451             return pd.DataFrame({'iid': pdf.id, 'v': pdf.v})
0452 
0453         @pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP)
0454         def invalid_positional_types(pdf):
0455             return pd.DataFrame([(u'a', 1.2)])
0456 
0457         with QuietTest(self.sc):
0458             with self.assertRaisesRegexp(Exception, "KeyError: 'id'"):
0459                 grouped_df.apply(column_name_typo).collect()
0460             with self.assertRaisesRegexp(Exception, "an integer is required"):
0461                 grouped_df.apply(invalid_positional_types).collect()
0462 
0463     def test_positional_assignment_conf(self):
0464         with self.sql_conf({
0465                 "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}):
0466 
0467             @pandas_udf("a string, b float", PandasUDFType.GROUPED_MAP)
0468             def foo(_):
0469                 return pd.DataFrame([('hi', 1)], columns=['x', 'y'])
0470 
0471             df = self.data
0472             result = df.groupBy('id').apply(foo).select('a', 'b').collect()
0473             for r in result:
0474                 self.assertEqual(r.a, 'hi')
0475                 self.assertEqual(r.b, 1)
0476 
0477     def test_self_join_with_pandas(self):
0478         @pandas_udf('key long, col string', PandasUDFType.GROUPED_MAP)
0479         def dummy_pandas_udf(df):
0480             return df[['key', 'col']]
0481 
0482         df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'),
0483                                          Row(key=2, col='C')])
0484         df_with_pandas = df.groupBy('key').apply(dummy_pandas_udf)
0485 
0486         # this was throwing an AnalysisException before SPARK-24208
0487         res = df_with_pandas.alias('temp0').join(df_with_pandas.alias('temp1'),
0488                                                  col('temp0.key') == col('temp1.key'))
0489         self.assertEquals(res.count(), 5)
0490 
0491     def test_mixed_scalar_udfs_followed_by_grouby_apply(self):
0492         df = self.spark.range(0, 10).toDF('v1')
0493         df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \
0494             .withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1']))
0495 
0496         result = df.groupby() \
0497             .apply(pandas_udf(lambda x: pd.DataFrame([x.sum().sum()]),
0498                               'sum int',
0499                               PandasUDFType.GROUPED_MAP))
0500 
0501         self.assertEquals(result.collect()[0]['sum'], 165)
0502 
0503     def test_grouped_with_empty_partition(self):
0504         data = [Row(id=1, x=2), Row(id=1, x=3), Row(id=2, x=4)]
0505         expected = [Row(id=1, x=5), Row(id=1, x=5), Row(id=2, x=4)]
0506         num_parts = len(data) + 1
0507         df = self.spark.createDataFrame(self.sc.parallelize(data, numSlices=num_parts))
0508 
0509         f = pandas_udf(lambda pdf: pdf.assign(x=pdf['x'].sum()),
0510                        'id long, x int', PandasUDFType.GROUPED_MAP)
0511 
0512         result = df.groupBy('id').apply(f).collect()
0513         self.assertEqual(result, expected)
0514 
0515     def test_grouped_over_window(self):
0516 
0517         data = [(0, 1, "2018-03-10T00:00:00+00:00", [0]),
0518                 (1, 2, "2018-03-11T00:00:00+00:00", [0]),
0519                 (2, 2, "2018-03-12T00:00:00+00:00", [0]),
0520                 (3, 3, "2018-03-15T00:00:00+00:00", [0]),
0521                 (4, 3, "2018-03-16T00:00:00+00:00", [0]),
0522                 (5, 3, "2018-03-17T00:00:00+00:00", [0]),
0523                 (6, 3, "2018-03-21T00:00:00+00:00", [0])]
0524 
0525         expected = {0: [0],
0526                     1: [1, 2],
0527                     2: [1, 2],
0528                     3: [3, 4, 5],
0529                     4: [3, 4, 5],
0530                     5: [3, 4, 5],
0531                     6: [6]}
0532 
0533         df = self.spark.createDataFrame(data, ['id', 'group', 'ts', 'result'])
0534         df = df.select(col('id'), col('group'), col('ts').cast('timestamp'), col('result'))
0535 
0536         def f(pdf):
0537             # Assign each result element the ids of the windowed group
0538             pdf['result'] = [pdf['id']] * len(pdf)
0539             return pdf
0540 
0541         result = df.groupby('group', window('ts', '5 days')).applyInPandas(f, df.schema)\
0542             .select('id', 'result').collect()
0543         for r in result:
0544             self.assertListEqual(expected[r[0]], r[1])
0545 
0546     def test_grouped_over_window_with_key(self):
0547 
0548         data = [(0, 1, "2018-03-10T00:00:00+00:00", False),
0549                 (1, 2, "2018-03-11T00:00:00+00:00", False),
0550                 (2, 2, "2018-03-12T00:00:00+00:00", False),
0551                 (3, 3, "2018-03-15T00:00:00+00:00", False),
0552                 (4, 3, "2018-03-16T00:00:00+00:00", False),
0553                 (5, 3, "2018-03-17T00:00:00+00:00", False),
0554                 (6, 3, "2018-03-21T00:00:00+00:00", False)]
0555 
0556         expected_window = [
0557             {'start': datetime.datetime(2018, 3, 10, 0, 0),
0558              'end': datetime.datetime(2018, 3, 15, 0, 0)},
0559             {'start': datetime.datetime(2018, 3, 15, 0, 0),
0560              'end': datetime.datetime(2018, 3, 20, 0, 0)},
0561             {'start': datetime.datetime(2018, 3, 20, 0, 0),
0562              'end': datetime.datetime(2018, 3, 25, 0, 0)},
0563         ]
0564 
0565         expected = {0: (1, expected_window[0]),
0566                     1: (2, expected_window[0]),
0567                     2: (2, expected_window[0]),
0568                     3: (3, expected_window[1]),
0569                     4: (3, expected_window[1]),
0570                     5: (3, expected_window[1]),
0571                     6: (3, expected_window[2])}
0572 
0573         df = self.spark.createDataFrame(data, ['id', 'group', 'ts', 'result'])
0574         df = df.select(col('id'), col('group'), col('ts').cast('timestamp'), col('result'))
0575 
0576         @pandas_udf(df.schema, PandasUDFType.GROUPED_MAP)
0577         def f(key, pdf):
0578             group = key[0]
0579             window_range = key[1]
0580             # Result will be True if group and window range equal to expected
0581             is_expected = pdf.id.apply(lambda id: (expected[id][0] == group and
0582                                                    expected[id][1] == window_range))
0583             return pdf.assign(result=is_expected)
0584 
0585         result = df.groupby('group', window('ts', '5 days')).apply(f).select('result').collect()
0586 
0587         # Check that all group and window_range values from udf matched expected
0588         self.assertTrue(all([r[0] for r in result]))
0589 
0590 
0591 if __name__ == "__main__":
0592     from pyspark.sql.tests.test_pandas_grouped_map import *
0593 
0594     try:
0595         import xmlrunner
0596         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0597     except ImportError:
0598         testRunner = None
0599     unittest.main(testRunner=testRunner, verbosity=2)