0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import unittest
0019 import sys
0020
0021 from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType
0022 from pyspark.sql.types import DoubleType, StructType, StructField
0023 from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
0024 pandas_requirement_message, pyarrow_requirement_message
0025 from pyspark.testing.utils import QuietTest
0026
0027 if have_pandas:
0028 import pandas as pd
0029 from pandas.util.testing import assert_frame_equal, assert_series_equal
0030
0031 if have_pyarrow:
0032 import pyarrow as pa
0033
0034
0035
0036
0037 _check_column_type = sys.version >= '3'
0038
0039
0040 @unittest.skipIf(
0041 not have_pandas or not have_pyarrow,
0042 pandas_requirement_message or pyarrow_requirement_message)
0043 class CogroupedMapInPandasTests(ReusedSQLTestCase):
0044
0045 @property
0046 def data1(self):
0047 return self.spark.range(10).toDF('id') \
0048 .withColumn("ks", array([lit(i) for i in range(20, 30)])) \
0049 .withColumn("k", explode(col('ks')))\
0050 .withColumn("v", col('k') * 10)\
0051 .drop('ks')
0052
0053 @property
0054 def data2(self):
0055 return self.spark.range(10).toDF('id') \
0056 .withColumn("ks", array([lit(i) for i in range(20, 30)])) \
0057 .withColumn("k", explode(col('ks'))) \
0058 .withColumn("v2", col('k') * 100) \
0059 .drop('ks')
0060
0061 def test_simple(self):
0062 self._test_merge(self.data1, self.data2)
0063
0064 def test_left_group_empty(self):
0065 left = self.data1.where(col("id") % 2 == 0)
0066 self._test_merge(left, self.data2)
0067
0068 def test_right_group_empty(self):
0069 right = self.data2.where(col("id") % 2 == 0)
0070 self._test_merge(self.data1, right)
0071
0072 def test_different_schemas(self):
0073 right = self.data2.withColumn('v3', lit('a'))
0074 self._test_merge(self.data1, right, 'id long, k int, v int, v2 int, v3 string')
0075
0076 def test_complex_group_by(self):
0077 left = pd.DataFrame.from_dict({
0078 'id': [1, 2, 3],
0079 'k': [5, 6, 7],
0080 'v': [9, 10, 11]
0081 })
0082
0083 right = pd.DataFrame.from_dict({
0084 'id': [11, 12, 13],
0085 'k': [5, 6, 7],
0086 'v2': [90, 100, 110]
0087 })
0088
0089 left_gdf = self.spark\
0090 .createDataFrame(left)\
0091 .groupby(col('id') % 2 == 0)
0092
0093 right_gdf = self.spark \
0094 .createDataFrame(right) \
0095 .groupby(col('id') % 2 == 0)
0096
0097 def merge_pandas(l, r):
0098 return pd.merge(l[['k', 'v']], r[['k', 'v2']], on=['k'])
0099
0100 result = left_gdf \
0101 .cogroup(right_gdf) \
0102 .applyInPandas(merge_pandas, 'k long, v long, v2 long') \
0103 .sort(['k']) \
0104 .toPandas()
0105
0106 expected = pd.DataFrame.from_dict({
0107 'k': [5, 6, 7],
0108 'v': [9, 10, 11],
0109 'v2': [90, 100, 110]
0110 })
0111
0112 assert_frame_equal(expected, result, check_column_type=_check_column_type)
0113
0114 def test_empty_group_by(self):
0115 left = self.data1
0116 right = self.data2
0117
0118 def merge_pandas(l, r):
0119 return pd.merge(l, r, on=['id', 'k'])
0120
0121 result = left.groupby().cogroup(right.groupby())\
0122 .applyInPandas(merge_pandas, 'id long, k int, v int, v2 int') \
0123 .sort(['id', 'k']) \
0124 .toPandas()
0125
0126 left = left.toPandas()
0127 right = right.toPandas()
0128
0129 expected = pd \
0130 .merge(left, right, on=['id', 'k']) \
0131 .sort_values(by=['id', 'k'])
0132
0133 assert_frame_equal(expected, result, check_column_type=_check_column_type)
0134
0135 def test_mixed_scalar_udfs_followed_by_cogrouby_apply(self):
0136 df = self.spark.range(0, 10).toDF('v1')
0137 df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \
0138 .withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1']))
0139
0140 result = df.groupby().cogroup(df.groupby()) \
0141 .applyInPandas(lambda x, y: pd.DataFrame([(x.sum().sum(), y.sum().sum())]),
0142 'sum1 int, sum2 int').collect()
0143
0144 self.assertEquals(result[0]['sum1'], 165)
0145 self.assertEquals(result[0]['sum2'], 165)
0146
0147 def test_with_key_left(self):
0148 self._test_with_key(self.data1, self.data1, isLeft=True)
0149
0150 def test_with_key_right(self):
0151 self._test_with_key(self.data1, self.data1, isLeft=False)
0152
0153 def test_with_key_left_group_empty(self):
0154 left = self.data1.where(col("id") % 2 == 0)
0155 self._test_with_key(left, self.data1, isLeft=True)
0156
0157 def test_with_key_right_group_empty(self):
0158 right = self.data1.where(col("id") % 2 == 0)
0159 self._test_with_key(self.data1, right, isLeft=False)
0160
0161 def test_with_key_complex(self):
0162
0163 def left_assign_key(key, l, _):
0164 return l.assign(key=key[0])
0165
0166 result = self.data1 \
0167 .groupby(col('id') % 2 == 0)\
0168 .cogroup(self.data2.groupby(col('id') % 2 == 0)) \
0169 .applyInPandas(left_assign_key, 'id long, k int, v int, key boolean') \
0170 .sort(['id', 'k']) \
0171 .toPandas()
0172
0173 expected = self.data1.toPandas()
0174 expected = expected.assign(key=expected.id % 2 == 0)
0175
0176 assert_frame_equal(expected, result, check_column_type=_check_column_type)
0177
0178 def test_wrong_return_type(self):
0179
0180 left = self.data1
0181 right = self.data2
0182 with QuietTest(self.sc):
0183 with self.assertRaisesRegexp(
0184 NotImplementedError,
0185 'Invalid return type.*MapType'):
0186 left.groupby('id').cogroup(right.groupby('id')).applyInPandas(
0187 lambda l, r: l, 'id long, v map<int, int>')
0188
0189 def test_wrong_args(self):
0190 left = self.data1
0191 right = self.data2
0192 with self.assertRaisesRegexp(ValueError, 'Invalid function'):
0193 left.groupby('id').cogroup(right.groupby('id')) \
0194 .applyInPandas(lambda: 1, StructType([StructField("d", DoubleType())]))
0195
0196 @staticmethod
0197 def _test_with_key(left, right, isLeft):
0198
0199 def right_assign_key(key, l, r):
0200 return l.assign(key=key[0]) if isLeft else r.assign(key=key[0])
0201
0202 result = left \
0203 .groupby('id') \
0204 .cogroup(right.groupby('id')) \
0205 .applyInPandas(right_assign_key, 'id long, k int, v int, key long') \
0206 .toPandas()
0207
0208 expected = left.toPandas() if isLeft else right.toPandas()
0209 expected = expected.assign(key=expected.id)
0210
0211 assert_frame_equal(expected, result, check_column_type=_check_column_type)
0212
0213 @staticmethod
0214 def _test_merge(left, right, output_schema='id long, k int, v int, v2 int'):
0215
0216 def merge_pandas(l, r):
0217 return pd.merge(l, r, on=['id', 'k'])
0218
0219 result = left \
0220 .groupby('id') \
0221 .cogroup(right.groupby('id')) \
0222 .applyInPandas(merge_pandas, output_schema)\
0223 .sort(['id', 'k']) \
0224 .toPandas()
0225
0226 left = left.toPandas()
0227 right = right.toPandas()
0228
0229 expected = pd \
0230 .merge(left, right, on=['id', 'k']) \
0231 .sort_values(by=['id', 'k'])
0232
0233 assert_frame_equal(expected, result, check_column_type=_check_column_type)
0234
0235
0236 if __name__ == "__main__":
0237 from pyspark.sql.tests.test_pandas_cogrouped_map import *
0238
0239 try:
0240 import xmlrunner
0241 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0242 except ImportError:
0243 testRunner = None
0244 unittest.main(testRunner=testRunner, verbosity=2)