Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import unittest
0019 
0020 from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
0021 from pyspark.sql.types import *
0022 from pyspark.sql.utils import ParseException, PythonException
0023 from pyspark.rdd import PythonEvalType
0024 from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
0025     pandas_requirement_message, pyarrow_requirement_message
0026 from pyspark.testing.utils import QuietTest
0027 
0028 
0029 @unittest.skipIf(
0030     not have_pandas or not have_pyarrow,
0031     pandas_requirement_message or pyarrow_requirement_message)
0032 class PandasUDFTests(ReusedSQLTestCase):
0033 
0034     def test_pandas_udf_basic(self):
0035         udf = pandas_udf(lambda x: x, DoubleType())
0036         self.assertEqual(udf.returnType, DoubleType())
0037         self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
0038 
0039         udf = pandas_udf(lambda x: x, DoubleType(), PandasUDFType.SCALAR)
0040         self.assertEqual(udf.returnType, DoubleType())
0041         self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
0042 
0043         udf = pandas_udf(lambda x: x, 'double', PandasUDFType.SCALAR)
0044         self.assertEqual(udf.returnType, DoubleType())
0045         self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
0046 
0047         udf = pandas_udf(lambda x: x, StructType([StructField("v", DoubleType())]),
0048                          PandasUDFType.GROUPED_MAP)
0049         self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
0050         self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
0051 
0052         udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP)
0053         self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
0054         self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
0055 
0056         udf = pandas_udf(lambda x: x, 'v double',
0057                          functionType=PandasUDFType.GROUPED_MAP)
0058         self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
0059         self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
0060 
0061         udf = pandas_udf(lambda x: x, returnType='v double',
0062                          functionType=PandasUDFType.GROUPED_MAP)
0063         self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
0064         self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
0065 
0066     def test_pandas_udf_decorator(self):
0067         @pandas_udf(DoubleType())
0068         def foo(x):
0069             return x
0070         self.assertEqual(foo.returnType, DoubleType())
0071         self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
0072 
0073         @pandas_udf(returnType=DoubleType())
0074         def foo(x):
0075             return x
0076         self.assertEqual(foo.returnType, DoubleType())
0077         self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
0078 
0079         schema = StructType([StructField("v", DoubleType())])
0080 
0081         @pandas_udf(schema, PandasUDFType.GROUPED_MAP)
0082         def foo(x):
0083             return x
0084         self.assertEqual(foo.returnType, schema)
0085         self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
0086 
0087         @pandas_udf('v double', PandasUDFType.GROUPED_MAP)
0088         def foo(x):
0089             return x
0090         self.assertEqual(foo.returnType, schema)
0091         self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
0092 
0093         @pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP)
0094         def foo(x):
0095             return x
0096         self.assertEqual(foo.returnType, schema)
0097         self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
0098 
0099         @pandas_udf(returnType='double', functionType=PandasUDFType.SCALAR)
0100         def foo(x):
0101             return x
0102         self.assertEqual(foo.returnType, DoubleType())
0103         self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
0104 
0105         @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP)
0106         def foo(x):
0107             return x
0108         self.assertEqual(foo.returnType, schema)
0109         self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
0110 
0111     def test_udf_wrong_arg(self):
0112         with QuietTest(self.sc):
0113             with self.assertRaises(ParseException):
0114                 @pandas_udf('blah')
0115                 def foo(x):
0116                     return x
0117             with self.assertRaisesRegexp(ValueError, 'Invalid return type.*None'):
0118                 @pandas_udf(functionType=PandasUDFType.SCALAR)
0119                 def foo(x):
0120                     return x
0121             with self.assertRaisesRegexp(ValueError, 'Invalid function'):
0122                 @pandas_udf('double', 100)
0123                 def foo(x):
0124                     return x
0125 
0126             with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'):
0127                 pandas_udf(lambda: 1, LongType(), PandasUDFType.SCALAR)
0128             with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'):
0129                 @pandas_udf(LongType(), PandasUDFType.SCALAR)
0130                 def zero_with_type():
0131                     return 1
0132 
0133             with self.assertRaisesRegexp(TypeError, 'Invalid return type'):
0134                 @pandas_udf(returnType=PandasUDFType.GROUPED_MAP)
0135                 def foo(df):
0136                     return df
0137             with self.assertRaisesRegexp(TypeError, 'Invalid return type'):
0138                 @pandas_udf(returnType='double', functionType=PandasUDFType.GROUPED_MAP)
0139                 def foo(df):
0140                     return df
0141             with self.assertRaisesRegexp(ValueError, 'Invalid function'):
0142                 @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP)
0143                 def foo(k, v, w):
0144                     return k
0145 
0146     def test_stopiteration_in_udf(self):
0147         def foo(x):
0148             raise StopIteration()
0149 
0150         def foofoo(x, y):
0151             raise StopIteration()
0152 
0153         exc_message = "Caught StopIteration thrown from user's code; failing the task"
0154         df = self.spark.range(0, 100)
0155 
0156         # plain udf (test for SPARK-23754)
0157         self.assertRaisesRegexp(
0158             PythonException,
0159             exc_message,
0160             df.withColumn('v', udf(foo)('id')).collect
0161         )
0162 
0163         # pandas scalar udf
0164         self.assertRaisesRegexp(
0165             PythonException,
0166             exc_message,
0167             df.withColumn(
0168                 'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id')
0169             ).collect
0170         )
0171 
0172         # pandas grouped map
0173         self.assertRaisesRegexp(
0174             PythonException,
0175             exc_message,
0176             df.groupBy('id').apply(
0177                 pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP)
0178             ).collect
0179         )
0180 
0181         self.assertRaisesRegexp(
0182             PythonException,
0183             exc_message,
0184             df.groupBy('id').apply(
0185                 pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP)
0186             ).collect
0187         )
0188 
0189         # pandas grouped agg
0190         self.assertRaisesRegexp(
0191             PythonException,
0192             exc_message,
0193             df.groupBy('id').agg(
0194                 pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id')
0195             ).collect
0196         )
0197 
0198     def test_pandas_udf_detect_unsafe_type_conversion(self):
0199         import pandas as pd
0200         import numpy as np
0201 
0202         values = [1.0] * 3
0203         pdf = pd.DataFrame({'A': values})
0204         df = self.spark.createDataFrame(pdf).repartition(1)
0205 
0206         @pandas_udf(returnType="int")
0207         def udf(column):
0208             return pd.Series(np.linspace(0, 1, len(column)))
0209 
0210         # Since 0.11.0, PyArrow supports the feature to raise an error for unsafe cast.
0211         with self.sql_conf({
0212                 "spark.sql.execution.pandas.convertToArrowArraySafely": True}):
0213             with self.assertRaisesRegexp(Exception,
0214                                          "Exception thrown when converting pandas.Series"):
0215                 df.select(['A']).withColumn('udf', udf('A')).collect()
0216 
0217         # Disabling Arrow safe type check.
0218         with self.sql_conf({
0219                 "spark.sql.execution.pandas.convertToArrowArraySafely": False}):
0220             df.select(['A']).withColumn('udf', udf('A')).collect()
0221 
0222     def test_pandas_udf_arrow_overflow(self):
0223         import pandas as pd
0224 
0225         df = self.spark.range(0, 1)
0226 
0227         @pandas_udf(returnType="byte")
0228         def udf(column):
0229             return pd.Series([128] * len(column))
0230 
0231         # When enabling safe type check, Arrow 0.11.0+ disallows overflow cast.
0232         with self.sql_conf({
0233                 "spark.sql.execution.pandas.convertToArrowArraySafely": True}):
0234             with self.assertRaisesRegexp(Exception,
0235                                          "Exception thrown when converting pandas.Series"):
0236                 df.withColumn('udf', udf('id')).collect()
0237 
0238         # Disabling safe type check, let Arrow do the cast anyway.
0239         with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}):
0240             df.withColumn('udf', udf('id')).collect()
0241 
0242 
0243 if __name__ == "__main__":
0244     from pyspark.sql.tests.test_pandas_udf import *
0245 
0246     try:
0247         import xmlrunner
0248         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0249     except ImportError:
0250         testRunner = None
0251     unittest.main(testRunner=testRunner, verbosity=2)