Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 import sys
0018 import unittest
0019 import inspect
0020 
0021 from pyspark.sql.functions import mean, lit
0022 from pyspark.testing.sqlutils import ReusedSQLTestCase, \
0023     have_pandas, have_pyarrow, pandas_requirement_message, \
0024     pyarrow_requirement_message
0025 from pyspark.sql.pandas.typehints import infer_eval_type
0026 from pyspark.sql.pandas.functions import pandas_udf, PandasUDFType
0027 
0028 if have_pandas:
0029     import pandas as pd
0030     from pandas.util.testing import assert_frame_equal
0031 
0032 python_requirement_message = "pandas UDF with type hints are supported with Python 3.6+."
0033 
0034 
0035 @unittest.skipIf(
0036     not have_pandas or not have_pyarrow or sys.version_info[:2] < (3, 6),
0037     pandas_requirement_message or pyarrow_requirement_message or python_requirement_message)
0038 class PandasUDFTypeHintsTests(ReusedSQLTestCase):
0039     # Note that, we should remove `exec` once we drop Python 2 in this class.
0040 
0041     def setUp(self):
0042         self.local = {'pd': pd}
0043 
0044     def test_type_annotation_scalar(self):
0045         exec(
0046             "def func(col: pd.Series) -> pd.Series: pass",
0047             self.local)
0048         self.assertEqual(
0049             infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR)
0050 
0051         exec(
0052             "def func(col: pd.DataFrame, col1: pd.Series) -> pd.DataFrame: pass",
0053             self.local)
0054         self.assertEqual(
0055             infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR)
0056 
0057         exec(
0058             "def func(col: pd.DataFrame, *args: pd.Series) -> pd.Series: pass",
0059             self.local)
0060         self.assertEqual(
0061             infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR)
0062 
0063         exec(
0064             "def func(col: pd.Series, *args: pd.Series, **kwargs: pd.DataFrame) -> pd.Series:\n"
0065             "    pass",
0066             self.local)
0067         self.assertEqual(
0068             infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR)
0069 
0070         exec(
0071             "def func(col: pd.Series, *, col2: pd.DataFrame) -> pd.DataFrame:\n"
0072             "    pass",
0073             self.local)
0074         self.assertEqual(
0075             infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR)
0076 
0077         exec(
0078             "from typing import Union\n"
0079             "def func(col: Union[pd.Series, pd.DataFrame], *, col2: pd.DataFrame) -> pd.Series:\n"
0080             "    pass",
0081             self.local)
0082         self.assertEqual(
0083             infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR)
0084 
0085     def test_type_annotation_scalar_iter(self):
0086         exec(
0087             "from typing import Iterator\n"
0088             "def func(iter: Iterator[pd.Series]) -> Iterator[pd.Series]: pass",
0089             self.local)
0090         self.assertEqual(
0091             infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR_ITER)
0092 
0093         exec(
0094             "from typing import Iterator, Tuple\n"
0095             "def func(iter: Iterator[Tuple[pd.DataFrame, pd.Series]]) -> Iterator[pd.DataFrame]:\n"
0096             "    pass",
0097             self.local)
0098         self.assertEqual(
0099             infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR_ITER)
0100 
0101         exec(
0102             "from typing import Iterator, Tuple\n"
0103             "def func(iter: Iterator[Tuple[pd.DataFrame, ...]]) -> Iterator[pd.Series]: pass",
0104             self.local)
0105         self.assertEqual(
0106             infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR_ITER)
0107 
0108         exec(
0109             "from typing import Iterator, Tuple, Union\n"
0110             "def func(iter: Iterator[Tuple[Union[pd.DataFrame, pd.Series], ...]])"
0111             " -> Iterator[pd.Series]: pass",
0112             self.local)
0113         self.assertEqual(
0114             infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.SCALAR_ITER)
0115 
0116     def test_type_annotation_group_agg(self):
0117         exec(
0118             "def func(col: pd.Series) -> str: pass",
0119             self.local)
0120         self.assertEqual(
0121             infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.GROUPED_AGG)
0122 
0123         exec(
0124             "def func(col: pd.DataFrame, col1: pd.Series) -> int: pass",
0125             self.local)
0126         self.assertEqual(
0127             infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.GROUPED_AGG)
0128 
0129         exec(
0130             "from pyspark.sql import Row\n"
0131             "def func(col: pd.DataFrame, *args: pd.Series) -> Row: pass",
0132             self.local)
0133         self.assertEqual(
0134             infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.GROUPED_AGG)
0135 
0136         exec(
0137             "def func(col: pd.Series, *args: pd.Series, **kwargs: pd.DataFrame) -> str:\n"
0138             "    pass",
0139             self.local)
0140         self.assertEqual(
0141             infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.GROUPED_AGG)
0142 
0143         exec(
0144             "def func(col: pd.Series, *, col2: pd.DataFrame) -> float:\n"
0145             "    pass",
0146             self.local)
0147         self.assertEqual(
0148             infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.GROUPED_AGG)
0149 
0150         exec(
0151             "from typing import Union\n"
0152             "def func(col: Union[pd.Series, pd.DataFrame], *, col2: pd.DataFrame) -> float:\n"
0153             "    pass",
0154             self.local)
0155         self.assertEqual(
0156             infer_eval_type(inspect.signature(self.local['func'])), PandasUDFType.GROUPED_AGG)
0157 
0158     def test_type_annotation_negative(self):
0159         exec(
0160             "def func(col: str) -> pd.Series: pass",
0161             self.local)
0162         self.assertRaisesRegex(
0163             NotImplementedError,
0164             "Unsupported signature.*str",
0165             infer_eval_type, inspect.signature(self.local['func']))
0166 
0167         exec(
0168             "def func(col: pd.DataFrame, col1: int) -> pd.DataFrame: pass",
0169             self.local)
0170         self.assertRaisesRegex(
0171             NotImplementedError,
0172             "Unsupported signature.*int",
0173             infer_eval_type, inspect.signature(self.local['func']))
0174 
0175         exec(
0176             "from typing import Union\n"
0177             "def func(col: Union[pd.DataFrame, str], col1: int) -> pd.DataFrame: pass",
0178             self.local)
0179         self.assertRaisesRegex(
0180             NotImplementedError,
0181             "Unsupported signature.*str",
0182             infer_eval_type, inspect.signature(self.local['func']))
0183 
0184         exec(
0185             "from typing import Tuple\n"
0186             "def func(col: pd.Series) -> Tuple[pd.DataFrame]: pass",
0187             self.local)
0188         self.assertRaisesRegex(
0189             NotImplementedError,
0190             "Unsupported signature.*Tuple",
0191             infer_eval_type, inspect.signature(self.local['func']))
0192 
0193         exec(
0194             "def func(col, *args: pd.Series) -> pd.Series: pass",
0195             self.local)
0196         self.assertRaisesRegex(
0197             ValueError,
0198             "should be specified.*Series",
0199             infer_eval_type, inspect.signature(self.local['func']))
0200 
0201         exec(
0202             "def func(col: pd.Series, *args: pd.Series, **kwargs: pd.DataFrame):\n"
0203             "    pass",
0204             self.local)
0205         self.assertRaisesRegex(
0206             ValueError,
0207             "should be specified.*Series",
0208             infer_eval_type, inspect.signature(self.local['func']))
0209 
0210         exec(
0211             "def func(col: pd.Series, *, col2) -> pd.DataFrame:\n"
0212             "    pass",
0213             self.local)
0214         self.assertRaisesRegex(
0215             ValueError,
0216             "should be specified.*Series",
0217             infer_eval_type, inspect.signature(self.local['func']))
0218 
0219     def test_scalar_udf_type_hint(self):
0220         df = self.spark.range(10).selectExpr("id", "id as v")
0221 
0222         exec(
0223             "import typing\n"
0224             "def plus_one(v: typing.Union[pd.Series, pd.DataFrame]) -> pd.Series:\n"
0225             "    return v + 1",
0226             self.local)
0227 
0228         plus_one = pandas_udf("long")(self.local["plus_one"])
0229 
0230         actual = df.select(plus_one(df.v).alias("plus_one"))
0231         expected = df.selectExpr("(v + 1) as plus_one")
0232         assert_frame_equal(expected.toPandas(), actual.toPandas())
0233 
0234     def test_scalar_iter_udf_type_hint(self):
0235         df = self.spark.range(10).selectExpr("id", "id as v")
0236 
0237         exec(
0238             "import typing\n"
0239             "def plus_one(itr: typing.Iterator[pd.Series]) -> typing.Iterator[pd.Series]:\n"
0240             "    for s in itr:\n"
0241             "        yield s + 1",
0242             self.local)
0243 
0244         plus_one = pandas_udf("long")(self.local["plus_one"])
0245 
0246         actual = df.select(plus_one(df.v).alias("plus_one"))
0247         expected = df.selectExpr("(v + 1) as plus_one")
0248         assert_frame_equal(expected.toPandas(), actual.toPandas())
0249 
0250     def test_group_agg_udf_type_hint(self):
0251         df = self.spark.range(10).selectExpr("id", "id as v")
0252         exec(
0253             "import numpy as np\n"
0254             "def weighted_mean(v: pd.Series, w: pd.Series) -> float:\n"
0255             "    return np.average(v, weights=w)",
0256             self.local)
0257 
0258         weighted_mean = pandas_udf("double")(self.local["weighted_mean"])
0259 
0260         actual = df.groupby('id').agg(weighted_mean(df.v, lit(1.0))).sort('id')
0261         expected = df.groupby('id').agg(mean(df.v).alias('weighted_mean(v, 1.0)')).sort('id')
0262         assert_frame_equal(expected.toPandas(), actual.toPandas())
0263 
0264     def test_ignore_type_hint_in_group_apply_in_pandas(self):
0265         df = self.spark.range(10)
0266         exec(
0267             "def pandas_plus_one(v: pd.DataFrame) -> pd.DataFrame:\n"
0268             "    return v + 1",
0269             self.local)
0270 
0271         pandas_plus_one = self.local["pandas_plus_one"]
0272 
0273         actual = df.groupby('id').applyInPandas(pandas_plus_one, schema=df.schema).sort('id')
0274         expected = df.selectExpr("id + 1 as id")
0275         assert_frame_equal(expected.toPandas(), actual.toPandas())
0276 
0277     def test_ignore_type_hint_in_cogroup_apply_in_pandas(self):
0278         df = self.spark.range(10)
0279         exec(
0280             "def pandas_plus_one(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:\n"
0281             "    return left + 1",
0282             self.local)
0283 
0284         pandas_plus_one = self.local["pandas_plus_one"]
0285 
0286         actual = df.groupby('id').cogroup(
0287             self.spark.range(10).groupby("id")
0288         ).applyInPandas(pandas_plus_one, schema=df.schema).sort('id')
0289         expected = df.selectExpr("id + 1 as id")
0290         assert_frame_equal(expected.toPandas(), actual.toPandas())
0291 
0292     def test_ignore_type_hint_in_map_in_pandas(self):
0293         df = self.spark.range(10)
0294         exec(
0295             "from typing import Iterator\n"
0296             "def pandas_plus_one(iter: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:\n"
0297             "    return map(lambda v: v + 1, iter)",
0298             self.local)
0299 
0300         pandas_plus_one = self.local["pandas_plus_one"]
0301 
0302         actual = df.mapInPandas(pandas_plus_one, schema=df.schema)
0303         expected = df.selectExpr("id + 1 as id")
0304         assert_frame_equal(expected.toPandas(), actual.toPandas())
0305 
0306 
0307 if __name__ == "__main__":
0308     from pyspark.sql.tests.test_pandas_udf_typehints import *
0309 
0310     try:
0311         import xmlrunner
0312         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0313     except ImportError:
0314         testRunner = None
0315     unittest.main(testRunner=testRunner, verbosity=2)