Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import unittest
0019 import sys
0020 
0021 from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType
0022 from pyspark.sql.types import DoubleType, StructType, StructField
0023 from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
0024     pandas_requirement_message, pyarrow_requirement_message
0025 from pyspark.testing.utils import QuietTest
0026 
0027 if have_pandas:
0028     import pandas as pd
0029     from pandas.util.testing import assert_frame_equal, assert_series_equal
0030 
0031 if have_pyarrow:
0032     import pyarrow as pa
0033 
0034 
0035 # Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
0036 # From kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
0037 _check_column_type = sys.version >= '3'
0038 
0039 
0040 @unittest.skipIf(
0041     not have_pandas or not have_pyarrow,
0042     pandas_requirement_message or pyarrow_requirement_message)
0043 class CogroupedMapInPandasTests(ReusedSQLTestCase):
0044 
0045     @property
0046     def data1(self):
0047         return self.spark.range(10).toDF('id') \
0048             .withColumn("ks", array([lit(i) for i in range(20, 30)])) \
0049             .withColumn("k", explode(col('ks')))\
0050             .withColumn("v", col('k') * 10)\
0051             .drop('ks')
0052 
0053     @property
0054     def data2(self):
0055         return self.spark.range(10).toDF('id') \
0056             .withColumn("ks", array([lit(i) for i in range(20, 30)])) \
0057             .withColumn("k", explode(col('ks'))) \
0058             .withColumn("v2", col('k') * 100) \
0059             .drop('ks')
0060 
0061     def test_simple(self):
0062         self._test_merge(self.data1, self.data2)
0063 
0064     def test_left_group_empty(self):
0065         left = self.data1.where(col("id") % 2 == 0)
0066         self._test_merge(left, self.data2)
0067 
0068     def test_right_group_empty(self):
0069         right = self.data2.where(col("id") % 2 == 0)
0070         self._test_merge(self.data1, right)
0071 
0072     def test_different_schemas(self):
0073         right = self.data2.withColumn('v3', lit('a'))
0074         self._test_merge(self.data1, right, 'id long, k int, v int, v2 int, v3 string')
0075 
0076     def test_complex_group_by(self):
0077         left = pd.DataFrame.from_dict({
0078             'id': [1, 2, 3],
0079             'k':  [5, 6, 7],
0080             'v': [9, 10, 11]
0081         })
0082 
0083         right = pd.DataFrame.from_dict({
0084             'id': [11, 12, 13],
0085             'k': [5, 6, 7],
0086             'v2': [90, 100, 110]
0087         })
0088 
0089         left_gdf = self.spark\
0090             .createDataFrame(left)\
0091             .groupby(col('id') % 2 == 0)
0092 
0093         right_gdf = self.spark \
0094             .createDataFrame(right) \
0095             .groupby(col('id') % 2 == 0)
0096 
0097         def merge_pandas(l, r):
0098             return pd.merge(l[['k', 'v']], r[['k', 'v2']], on=['k'])
0099 
0100         result = left_gdf \
0101             .cogroup(right_gdf) \
0102             .applyInPandas(merge_pandas, 'k long, v long, v2 long') \
0103             .sort(['k']) \
0104             .toPandas()
0105 
0106         expected = pd.DataFrame.from_dict({
0107             'k': [5, 6, 7],
0108             'v': [9, 10, 11],
0109             'v2': [90, 100, 110]
0110         })
0111 
0112         assert_frame_equal(expected, result, check_column_type=_check_column_type)
0113 
0114     def test_empty_group_by(self):
0115         left = self.data1
0116         right = self.data2
0117 
0118         def merge_pandas(l, r):
0119             return pd.merge(l, r, on=['id', 'k'])
0120 
0121         result = left.groupby().cogroup(right.groupby())\
0122             .applyInPandas(merge_pandas, 'id long, k int, v int, v2 int') \
0123             .sort(['id', 'k']) \
0124             .toPandas()
0125 
0126         left = left.toPandas()
0127         right = right.toPandas()
0128 
0129         expected = pd \
0130             .merge(left, right, on=['id', 'k']) \
0131             .sort_values(by=['id', 'k'])
0132 
0133         assert_frame_equal(expected, result, check_column_type=_check_column_type)
0134 
0135     def test_mixed_scalar_udfs_followed_by_cogrouby_apply(self):
0136         df = self.spark.range(0, 10).toDF('v1')
0137         df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \
0138             .withColumn('v3', pandas_udf(lambda x: x + 2, 'int')(df['v1']))
0139 
0140         result = df.groupby().cogroup(df.groupby()) \
0141             .applyInPandas(lambda x, y: pd.DataFrame([(x.sum().sum(), y.sum().sum())]),
0142                            'sum1 int, sum2 int').collect()
0143 
0144         self.assertEquals(result[0]['sum1'], 165)
0145         self.assertEquals(result[0]['sum2'], 165)
0146 
0147     def test_with_key_left(self):
0148         self._test_with_key(self.data1, self.data1, isLeft=True)
0149 
0150     def test_with_key_right(self):
0151         self._test_with_key(self.data1, self.data1, isLeft=False)
0152 
0153     def test_with_key_left_group_empty(self):
0154         left = self.data1.where(col("id") % 2 == 0)
0155         self._test_with_key(left, self.data1, isLeft=True)
0156 
0157     def test_with_key_right_group_empty(self):
0158         right = self.data1.where(col("id") % 2 == 0)
0159         self._test_with_key(self.data1, right, isLeft=False)
0160 
0161     def test_with_key_complex(self):
0162 
0163         def left_assign_key(key, l, _):
0164             return l.assign(key=key[0])
0165 
0166         result = self.data1 \
0167             .groupby(col('id') % 2 == 0)\
0168             .cogroup(self.data2.groupby(col('id') % 2 == 0)) \
0169             .applyInPandas(left_assign_key, 'id long, k int, v int, key boolean') \
0170             .sort(['id', 'k']) \
0171             .toPandas()
0172 
0173         expected = self.data1.toPandas()
0174         expected = expected.assign(key=expected.id % 2 == 0)
0175 
0176         assert_frame_equal(expected, result, check_column_type=_check_column_type)
0177 
0178     def test_wrong_return_type(self):
0179         # Test that we get a sensible exception invalid values passed to apply
0180         left = self.data1
0181         right = self.data2
0182         with QuietTest(self.sc):
0183             with self.assertRaisesRegexp(
0184                     NotImplementedError,
0185                     'Invalid return type.*MapType'):
0186                 left.groupby('id').cogroup(right.groupby('id')).applyInPandas(
0187                     lambda l, r: l, 'id long, v map<int, int>')
0188 
0189     def test_wrong_args(self):
0190         left = self.data1
0191         right = self.data2
0192         with self.assertRaisesRegexp(ValueError, 'Invalid function'):
0193             left.groupby('id').cogroup(right.groupby('id')) \
0194                 .applyInPandas(lambda: 1, StructType([StructField("d", DoubleType())]))
0195 
0196     @staticmethod
0197     def _test_with_key(left, right, isLeft):
0198 
0199         def right_assign_key(key, l, r):
0200             return l.assign(key=key[0]) if isLeft else r.assign(key=key[0])
0201 
0202         result = left \
0203             .groupby('id') \
0204             .cogroup(right.groupby('id')) \
0205             .applyInPandas(right_assign_key, 'id long, k int, v int, key long') \
0206             .toPandas()
0207 
0208         expected = left.toPandas() if isLeft else right.toPandas()
0209         expected = expected.assign(key=expected.id)
0210 
0211         assert_frame_equal(expected, result, check_column_type=_check_column_type)
0212 
0213     @staticmethod
0214     def _test_merge(left, right, output_schema='id long, k int, v int, v2 int'):
0215 
0216         def merge_pandas(l, r):
0217             return pd.merge(l, r, on=['id', 'k'])
0218 
0219         result = left \
0220             .groupby('id') \
0221             .cogroup(right.groupby('id')) \
0222             .applyInPandas(merge_pandas, output_schema)\
0223             .sort(['id', 'k']) \
0224             .toPandas()
0225 
0226         left = left.toPandas()
0227         right = right.toPandas()
0228 
0229         expected = pd \
0230             .merge(left, right, on=['id', 'k']) \
0231             .sort_values(by=['id', 'k'])
0232 
0233         assert_frame_equal(expected, result, check_column_type=_check_column_type)
0234 
0235 
0236 if __name__ == "__main__":
0237     from pyspark.sql.tests.test_pandas_cogrouped_map import *
0238 
0239     try:
0240         import xmlrunner
0241         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0242     except ImportError:
0243         testRunner = None
0244     unittest.main(testRunner=testRunner, verbosity=2)