0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 import sys
0018 import unittest
0019 import inspect
0020
0021 from pyspark.sql.functions import mean, lit
0022 from pyspark.testing.sqlutils import ReusedSQLTestCase, \
0023 have_pandas, have_pyarrow, pandas_requirement_message, \
0024 pyarrow_requirement_message
0025 from pyspark.sql.pandas.typehints import infer_eval_type
0026 from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType
0027
0028 if have_pandas:
0029 import pandas as pd
0030 from pandas.util.testing import assert_frame_equal
0031
0032 python_requirement_message = "pandas UDF with type hints are supported with Python 3.6+."
0033
0034
0035 @unittest.skipIf(
0036 not have_pandas or not have_pyarrow or sys.version_info[:2] < (3, 6),
0037 pandas_requirement_message or pyarrow_requirement_message or python_requirement_message)
0038 class PandasUDFTypeHintsTests(ReusedSQLTestCase):
0039
0040
0041 def setUp(self):
0042 self.local = {'pd': pd}
0043
0044 def test_type_annotation_scalar(self):
0045 exec(
0046 "def func(col: pd.Series) -> pd.Series: pass",
0047 self.local)
0048 self.assertEqual(
0049 infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR)
0050
0051 exec(
0052 "def func(col: pd.DataFrame, col1: pd.Series) -> pd.DataFrame: pass",
0053 self.local)
0054 self.assertEqual(
0055 infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR)
0056
0057 exec(
0058 "def func(col: pd.DataFrame, *args: pd.Series) -> pd.Series: pass",
0059 self.local)
0060 self.assertEqual(
0061 infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR)
0062
0063 exec(
0064 "def func(col: pd.Series, *args: pd.Series, **kwargs: pd.DataFrame) -> pd.Series:\n"
0065 " pass",
0066 self.local)
0067 self.assertEqual(
0068 infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR)
0069
0070 exec(
0071 "def func(col: pd.Series, *, col2: pd.DataFrame) -> pd.DataFrame:\n"
0072 " pass",
0073 self.local)
0074 self.assertEqual(
0075 infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR)
0076
0077 exec(
0078 "from typing import Union\n"
0079 "def func(col: Union[pd.Series, pd.DataFrame], *, col2: pd.DataFrame) -> pd.Series:\n"
0080 " pass",
0081 self.local)
0082 self.assertEqual(
0083 infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR)
0084
0085 def test_type_annotation_scalar_iter(self):
0086 exec(
0087 "from typing import Iterator\n"
0088 "def func(iter: Iterator[pd.Series]) -> Iterator[pd.Series]: pass",
0089 self.local)
0090 self.assertEqual(
0091 infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR_ITER)
0092
0093 exec(
0094 "from typing import Iterator, Tuple\n"
0095 "def func(iter: Iterator[Tuple[pd.DataFrame, pd.Series]]) -> Iterator[pd.DataFrame]:\n"
0096 " pass",
0097 self.local)
0098 self.assertEqual(
0099 infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR_ITER)
0100
0101 exec(
0102 "from typing import Iterator, Tuple\n"
0103 "def func(iter: Iterator[Tuple[pd.DataFrame, ...]]) -> Iterator[pd.Series]: pass",
0104 self.local)
0105 self.assertEqual(
0106 infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR_ITER)
0107
0108 exec(
0109 "from typing import Iterator, Tuple, Union\n"
0110 "def func(iter: Iterator[Tuple[Union[pd.DataFrame, pd.Series], ...]])"
0111 " -> Iterator[pd.Series]: pass",
0112 self.local)
0113 self.assertEqual(
0114 infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR_ITER)
0115
0116 def test_type_annotation_group_agg(self):
0117 exec(
0118 "def func(col: pd.Series) -> str: pass",
0119 self.local)
0120 self.assertEqual(
0121 infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.GROUPED_AGG)
0122
0123 exec(
0124 "def func(col: pd.DataFrame, col1: pd.Series) -> int: pass",
0125 self.local)
0126 self.assertEqual(
0127 infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.GROUPED_AGG)
0128
0129 exec(
0130 "from pyspark.sql import Row\n"
0131 "def func(col: pd.DataFrame, *args: pd.Series) -> Row: pass",
0132 self.local)
0133 self.assertEqual(
0134 infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.GROUPED_AGG)
0135
0136 exec(
0137 "def func(col: pd.Series, *args: pd.Series, **kwargs: pd.DataFrame) -> str:\n"
0138 " pass",
0139 self.local)
0140 self.assertEqual(
0141 infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.GROUPED_AGG)
0142
0143 exec(
0144 "def func(col: pd.Series, *, col2: pd.DataFrame) -> float:\n"
0145 " pass",
0146 self.local)
0147 self.assertEqual(
0148 infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.GROUPED_AGG)
0149
0150 exec(
0151 "from typing import Union\n"
0152 "def func(col: Union[pd.Series, pd.DataFrame], *, col2: pd.DataFrame) -> float:\n"
0153 " pass",
0154 self.local)
0155 self.assertEqual(
0156 infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.GROUPED_AGG)
0157
0158 def test_type_annotation_negative(self):
0159 exec(
0160 "def func(col: str) -> pd.Series: pass",
0161 self.local)
0162 self.assertRaisesRegex(
0163 NotImplementedError,
0164 "Unsupported signature.*str",
0165 infer_eval_type, inspect.signature(self.local['func']))
0166
0167 exec(
0168 "def func(col: pd.DataFrame, col1: int) -> pd.DataFrame: pass",
0169 self.local)
0170 self.assertRaisesRegex(
0171 NotImplementedError,
0172 "Unsupported signature.*int",
0173 infer_eval_type, inspect.signature(self.local['func']))
0174
0175 exec(
0176 "from typing import Union\n"
0177 "def func(col: Union[pd.DataFrame, str], col1: int) -> pd.DataFrame: pass",
0178 self.local)
0179 self.assertRaisesRegex(
0180 NotImplementedError,
0181 "Unsupported signature.*str",
0182 infer_eval_type, inspect.signature(self.local['func']))
0183
0184 exec(
0185 "from typing import Tuple\n"
0186 "def func(col: pd.Series) -> Tuple[pd.DataFrame]: pass",
0187 self.local)
0188 self.assertRaisesRegex(
0189 NotImplementedError,
0190 "Unsupported signature.*Tuple",
0191 infer_eval_type, inspect.signature(self.local['func']))
0192
0193 exec(
0194 "def func(col, *args: pd.Series) -> pd.Series: pass",
0195 self.local)
0196 self.assertRaisesRegex(
0197 ValueError,
0198 "should be specified.*Series",
0199 infer_eval_type, inspect.signature(self.local['func']))
0200
0201 exec(
0202 "def func(col: pd.Series, *args: pd.Series, **kwargs: pd.DataFrame):\n"
0203 " pass",
0204 self.local)
0205 self.assertRaisesRegex(
0206 ValueError,
0207 "should be specified.*Series",
0208 infer_eval_type, inspect.signature(self.local['func']))
0209
0210 exec(
0211 "def func(col: pd.Series, *, col2) -> pd.DataFrame:\n"
0212 " pass",
0213 self.local)
0214 self.assertRaisesRegex(
0215 ValueError,
0216 "should be specified.*Series",
0217 infer_eval_type, inspect.signature(self.local['func']))
0218
0219 def test_scalar_udf_type_hint(self):
0220 df = self.spark.range(10).selectExpr("id", "id as v")
0221
0222 exec(
0223 "import typing\n"
0224 "def plus_one(v: typing.Union[pd.Series, pd.DataFrame]) -> pd.Series:\n"
0225 " return v + 1",
0226 self.local)
0227
0228 plus_one = pandas_udf("long")(self.local["plus_one"])
0229
0230 actual = df.select(plus_one(df.v).alias("plus_one"))
0231 expected = df.selectExpr("(v + 1) as plus_one")
0232 assert_frame_equal(expected.toPandas(), actual.toPandas())
0233
0234 def test_scalar_iter_udf_type_hint(self):
0235 df = self.spark.range(10).selectExpr("id", "id as v")
0236
0237 exec(
0238 "import typing\n"
0239 "def plus_one(itr: typing.Iterator[pd.Series]) -> typing.Iterator[pd.Series]:\n"
0240 " for s in itr:\n"
0241 " yield s + 1",
0242 self.local)
0243
0244 plus_one = pandas_udf("long")(self.local["plus_one"])
0245
0246 actual = df.select(plus_one(df.v).alias("plus_one"))
0247 expected = df.selectExpr("(v + 1) as plus_one")
0248 assert_frame_equal(expected.toPandas(), actual.toPandas())
0249
0250 def test_group_agg_udf_type_hint(self):
0251 df = self.spark.range(10).selectExpr("id", "id as v")
0252 exec(
0253 "import numpy as np\n"
0254 "def weighted_mean(v: pd.Series, w: pd.Series) -> float:\n"
0255 " return np.average(v, weights=w)",
0256 self.local)
0257
0258 weighted_mean = pandas_udf("double")(self.local["weighted_mean"])
0259
0260 actual = df.groupby('id').agg(weighted_mean(df.v, lit(1.0))).sort('id')
0261 expected = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort('id')
0262 assert_frame_equal(expected.toPandas(), actual.toPandas())
0263
0264 def test_ignore_type_hint_in_group_apply_in_pandas(self):
0265 df = self.spark.range(10)
0266 exec(
0267 "def pandas_plus_one(v: pd.DataFrame) -> pd.DataFrame:\n"
0268 " return v + 1",
0269 self.local)
0270
0271 pandas_plus_one = self.local["pandas_plus_one"]
0272
0273 actual = df.groupby('id').applyInPandas(pandas_plus_one, schema=df.schema).sort('id')
0274 expected = df.selectExpr("id + 1 as id")
0275 assert_frame_equal(expected.toPandas(), actual.toPandas())
0276
0277 def test_ignore_type_hint_in_cogroup_apply_in_pandas(self):
0278 df = self.spark.range(10)
0279 exec(
0280 "def pandas_plus_one(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:\n"
0281 " return left + 1",
0282 self.local)
0283
0284 pandas_plus_one = self.local["pandas_plus_one"]
0285
0286 actual = df.groupby('id').cogroup(
0287 self.spark.range(10).groupby("id")
0288 ).applyInPandas(pandas_plus_one, schema=df.schema).sort('id')
0289 expected = df.selectExpr("id + 1 as id")
0290 assert_frame_equal(expected.toPandas(), actual.toPandas())
0291
0292 def test_ignore_type_hint_in_map_in_pandas(self):
0293 df = self.spark.range(10)
0294 exec(
0295 "from typing import Iterator\n"
0296 "def pandas_plus_one(iter: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:\n"
0297 " return map(lambda v: v + 1, iter)",
0298 self.local)
0299
0300 pandas_plus_one = self.local["pandas_plus_one"]
0301
0302 actual = df.mapInPandas(pandas_plus_one, schema=df.schema)
0303 expected = df.selectExpr("id + 1 as id")
0304 assert_frame_equal(expected.toPandas(), actual.toPandas())
0305
0306
0307 if __name__ == "__main__":
0308 from pyspark.sql.tests.test_pandas_udf_typehints import *
0309
0310 try:
0311 import xmlrunner
0312 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0313 except ImportError:
0314 testRunner = None
0315 unittest.main(testRunner=testRunner, verbosity=2)