0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import unittest
0019
0020 from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
0021 from pyspark.sql.types import *
0022 from pyspark.sql.utils import ParseException, PythonException
0023 from pyspark.rdd import PythonEvalType
0024 from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
0025 pandas_requirement_message, pyarrow_requirement_message
0026 from pyspark.testing.utils import QuietTest
0027
0028
0029 @unittest.skipIf(
0030 not have_pandas or not have_pyarrow,
0031 pandas_requirement_message or pyarrow_requirement_message)
0032 class PandasUDFTests(ReusedSQLTestCase):
0033
0034 def test_pandas_udf_basic(self):
0035 udf = pandas_udf(lambda x: x, DoubleType())
0036 self.assertEqual(udf.returnType, DoubleType())
0037 self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
0038
0039 udf = pandas_udf(lambda x: x, DoubleType(), PandasUDFType.SCALAR)
0040 self.assertEqual(udf.returnType, DoubleType())
0041 self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
0042
0043 udf = pandas_udf(lambda x: x, 'double', PandasUDFType.SCALAR)
0044 self.assertEqual(udf.returnType, DoubleType())
0045 self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
0046
0047 udf = pandas_udf(lambda x: x, StructType([StructField("v", DoubleType())]),
0048 PandasUDFType.GROUPED_MAP)
0049 self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
0050 self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
0051
0052 udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP)
0053 self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
0054 self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
0055
0056 udf = pandas_udf(lambda x: x, 'v double',
0057 functionType=PandasUDFType.GROUPED_MAP)
0058 self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
0059 self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
0060
0061 udf = pandas_udf(lambda x: x, returnType='v double',
0062 functionType=PandasUDFType.GROUPED_MAP)
0063 self.assertEqual(udf.returnType, StructType([StructField("v", DoubleType())]))
0064 self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
0065
0066 def test_pandas_udf_decorator(self):
0067 @pandas_udf(DoubleType())
0068 def foo(x):
0069 return x
0070 self.assertEqual(foo.returnType, DoubleType())
0071 self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
0072
0073 @pandas_udf(returnType=DoubleType())
0074 def foo(x):
0075 return x
0076 self.assertEqual(foo.returnType, DoubleType())
0077 self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
0078
0079 schema = StructType([StructField("v", DoubleType())])
0080
0081 @pandas_udf(schema, PandasUDFType.GROUPED_MAP)
0082 def foo(x):
0083 return x
0084 self.assertEqual(foo.returnType, schema)
0085 self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
0086
0087 @pandas_udf('v double', PandasUDFType.GROUPED_MAP)
0088 def foo(x):
0089 return x
0090 self.assertEqual(foo.returnType, schema)
0091 self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
0092
0093 @pandas_udf(schema, functionType=PandasUDFType.GROUPED_MAP)
0094 def foo(x):
0095 return x
0096 self.assertEqual(foo.returnType, schema)
0097 self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
0098
0099 @pandas_udf(returnType='double', functionType=PandasUDFType.SCALAR)
0100 def foo(x):
0101 return x
0102 self.assertEqual(foo.returnType, DoubleType())
0103 self.assertEqual(foo.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
0104
0105 @pandas_udf(returnType=schema, functionType=PandasUDFType.GROUPED_MAP)
0106 def foo(x):
0107 return x
0108 self.assertEqual(foo.returnType, schema)
0109 self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
0110
0111 def test_udf_wrong_arg(self):
0112 with QuietTest(self.sc):
0113 with self.assertRaises(ParseException):
0114 @pandas_udf('blah')
0115 def foo(x):
0116 return x
0117 with self.assertRaisesRegexp(ValueError, 'Invalid return type.*None'):
0118 @pandas_udf(functionType=PandasUDFType.SCALAR)
0119 def foo(x):
0120 return x
0121 with self.assertRaisesRegexp(ValueError, 'Invalid function'):
0122 @pandas_udf('double', 100)
0123 def foo(x):
0124 return x
0125
0126 with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'):
0127 pandas_udf(lambda: 1, LongType(), PandasUDFType.SCALAR)
0128 with self.assertRaisesRegexp(ValueError, '0-arg pandas_udfs.*not.*supported'):
0129 @pandas_udf(LongType(), PandasUDFType.SCALAR)
0130 def zero_with_type():
0131 return 1
0132
0133 with self.assertRaisesRegexp(TypeError, 'Invalid return type'):
0134 @pandas_udf(returnType=PandasUDFType.GROUPED_MAP)
0135 def foo(df):
0136 return df
0137 with self.assertRaisesRegexp(TypeError, 'Invalid return type'):
0138 @pandas_udf(returnType='double', functionType=PandasUDFType.GROUPED_MAP)
0139 def foo(df):
0140 return df
0141 with self.assertRaisesRegexp(ValueError, 'Invalid function'):
0142 @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP)
0143 def foo(k, v, w):
0144 return k
0145
0146 def test_stopiteration_in_udf(self):
0147 def foo(x):
0148 raise StopIteration()
0149
0150 def foofoo(x, y):
0151 raise StopIteration()
0152
0153 exc_message = "Caught StopIteration thrown from user's code; failing the task"
0154 df = self.spark.range(0, 100)
0155
0156
0157 self.assertRaisesRegexp(
0158 PythonException,
0159 exc_message,
0160 df.withColumn('v', udf(foo)('id')).collect
0161 )
0162
0163
0164 self.assertRaisesRegexp(
0165 PythonException,
0166 exc_message,
0167 df.withColumn(
0168 'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id')
0169 ).collect
0170 )
0171
0172
0173 self.assertRaisesRegexp(
0174 PythonException,
0175 exc_message,
0176 df.groupBy('id').apply(
0177 pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP)
0178 ).collect
0179 )
0180
0181 self.assertRaisesRegexp(
0182 PythonException,
0183 exc_message,
0184 df.groupBy('id').apply(
0185 pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP)
0186 ).collect
0187 )
0188
0189
0190 self.assertRaisesRegexp(
0191 PythonException,
0192 exc_message,
0193 df.groupBy('id').agg(
0194 pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id')
0195 ).collect
0196 )
0197
0198 def test_pandas_udf_detect_unsafe_type_conversion(self):
0199 import pandas as pd
0200 import numpy as np
0201
0202 values = [1.0] * 3
0203 pdf = pd.DataFrame({'A': values})
0204 df = self.spark.createDataFrame(pdf).repartition(1)
0205
0206 @pandas_udf(returnType="int")
0207 def udf(column):
0208 return pd.Series(np.linspace(0, 1, len(column)))
0209
0210
0211 with self.sql_conf({
0212 "spark.sql.execution.pandas.convertToArrowArraySafely": True}):
0213 with self.assertRaisesRegexp(Exception,
0214 "Exception thrown when converting pandas.Series"):
0215 df.select(['A']).withColumn('udf', udf('A')).collect()
0216
0217
0218 with self.sql_conf({
0219 "spark.sql.execution.pandas.convertToArrowArraySafely": False}):
0220 df.select(['A']).withColumn('udf', udf('A')).collect()
0221
0222 def test_pandas_udf_arrow_overflow(self):
0223 import pandas as pd
0224
0225 df = self.spark.range(0, 1)
0226
0227 @pandas_udf(returnType="byte")
0228 def udf(column):
0229 return pd.Series([128] * len(column))
0230
0231
0232 with self.sql_conf({
0233 "spark.sql.execution.pandas.convertToArrowArraySafely": True}):
0234 with self.assertRaisesRegexp(Exception,
0235 "Exception thrown when converting pandas.Series"):
0236 df.withColumn('udf', udf('id')).collect()
0237
0238
0239 with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}):
0240 df.withColumn('udf', udf('id')).collect()
0241
0242
0243 if __name__ == "__main__":
0244 from pyspark.sql.tests.test_pandas_udf import *
0245
0246 try:
0247 import xmlrunner
0248 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0249 except ImportError:
0250 testRunner = None
0251 unittest.main(testRunner=testRunner, verbosity=2)