0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 import datetime
0018 import os
0019 import random
0020 import shutil
0021 import sys
0022 import tempfile
0023 import time
0024 import unittest
0025
0026 if sys.version >= '3':
0027 unicode = str
0028
0029 from datetime import date, datetime
0030 from decimal import Decimal
0031
0032 from pyspark import TaskContext
0033 from pyspark.rdd import PythonEvalType
0034 from pyspark.sql import Column
0035 from pyspark.sql.functions import array, col, expr, lit, sum, struct, udf, pandas_udf, \
0036 PandasUDFType
0037 from pyspark.sql.types import Row
0038 from pyspark.sql.types import *
0039 from pyspark.sql.utils import AnalysisException
0040 from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled,\
0041 test_not_compiled_message, have_pandas, have_pyarrow, pandas_requirement_message, \
0042 pyarrow_requirement_message
0043 from pyspark.testing.utils import QuietTest
0044
0045 if have_pandas:
0046 import pandas as pd
0047
0048 if have_pyarrow:
0049 import pyarrow as pa
0050
0051
0052 @unittest.skipIf(
0053 not have_pandas or not have_pyarrow,
0054 pandas_requirement_message or pyarrow_requirement_message)
0055 class ScalarPandasUDFTests(ReusedSQLTestCase):
0056
0057 @classmethod
0058 def setUpClass(cls):
0059 ReusedSQLTestCase.setUpClass()
0060
0061
0062 cls.tz_prev = os.environ.get("TZ", None)
0063 tz = "America/Los_Angeles"
0064 os.environ["TZ"] = tz
0065 time.tzset()
0066
0067 cls.sc.environment["TZ"] = tz
0068 cls.spark.conf.set("spark.sql.session.timeZone", tz)
0069
0070 @classmethod
0071 def tearDownClass(cls):
0072 del os.environ["TZ"]
0073 if cls.tz_prev is not None:
0074 os.environ["TZ"] = cls.tz_prev
0075 time.tzset()
0076 ReusedSQLTestCase.tearDownClass()
0077
0078 @property
0079 def nondeterministic_vectorized_udf(self):
0080 import numpy as np
0081
0082 @pandas_udf('double')
0083 def random_udf(v):
0084 return pd.Series(np.random.random(len(v)))
0085 random_udf = random_udf.asNondeterministic()
0086 return random_udf
0087
0088 @property
0089 def nondeterministic_vectorized_iter_udf(self):
0090 import numpy as np
0091
0092 @pandas_udf('double', PandasUDFType.SCALAR_ITER)
0093 def random_udf(it):
0094 for v in it:
0095 yield pd.Series(np.random.random(len(v)))
0096
0097 random_udf = random_udf.asNondeterministic()
0098 return random_udf
0099
0100 def test_pandas_udf_tokenize(self):
0101 tokenize = pandas_udf(lambda s: s.apply(lambda str: str.split(' ')),
0102 ArrayType(StringType()))
0103 self.assertEqual(tokenize.returnType, ArrayType(StringType()))
0104 df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"])
0105 result = df.select(tokenize("vals").alias("hi"))
0106 self.assertEqual([Row(hi=[u'hi', u'boo']), Row(hi=[u'bye', u'boo'])], result.collect())
0107
0108 def test_pandas_udf_nested_arrays(self):
0109 tokenize = pandas_udf(lambda s: s.apply(lambda str: [str.split(' ')]),
0110 ArrayType(ArrayType(StringType())))
0111 self.assertEqual(tokenize.returnType, ArrayType(ArrayType(StringType())))
0112 df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"])
0113 result = df.select(tokenize("vals").alias("hi"))
0114 self.assertEqual([Row(hi=[[u'hi', u'boo']]), Row(hi=[[u'bye', u'boo']])], result.collect())
0115
0116 def test_vectorized_udf_basic(self):
0117 df = self.spark.range(10).select(
0118 col('id').cast('string').alias('str'),
0119 col('id').cast('int').alias('int'),
0120 col('id').alias('long'),
0121 col('id').cast('float').alias('float'),
0122 col('id').cast('double').alias('double'),
0123 col('id').cast('decimal').alias('decimal'),
0124 col('id').cast('boolean').alias('bool'),
0125 array(col('id')).alias('array_long'))
0126 f = lambda x: x
0127 for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0128 str_f = pandas_udf(f, StringType(), udf_type)
0129 int_f = pandas_udf(f, IntegerType(), udf_type)
0130 long_f = pandas_udf(f, LongType(), udf_type)
0131 float_f = pandas_udf(f, FloatType(), udf_type)
0132 double_f = pandas_udf(f, DoubleType(), udf_type)
0133 decimal_f = pandas_udf(f, DecimalType(), udf_type)
0134 bool_f = pandas_udf(f, BooleanType(), udf_type)
0135 array_long_f = pandas_udf(f, ArrayType(LongType()), udf_type)
0136 res = df.select(str_f(col('str')), int_f(col('int')),
0137 long_f(col('long')), float_f(col('float')),
0138 double_f(col('double')), decimal_f('decimal'),
0139 bool_f(col('bool')), array_long_f('array_long'))
0140 self.assertEquals(df.collect(), res.collect())
0141
0142 def test_register_nondeterministic_vectorized_udf_basic(self):
0143 random_pandas_udf = pandas_udf(
0144 lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic()
0145 self.assertEqual(random_pandas_udf.deterministic, False)
0146 self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
0147 nondeterministic_pandas_udf = self.spark.catalog.registerFunction(
0148 "randomPandasUDF", random_pandas_udf)
0149 self.assertEqual(nondeterministic_pandas_udf.deterministic, False)
0150 self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
0151 [row] = self.spark.sql("SELECT randomPandasUDF(1)").collect()
0152 self.assertEqual(row[0], 7)
0153
0154 def random_iter_udf(it):
0155 for i in it:
0156 yield random.randint(6, 6) + i
0157 random_pandas_iter_udf = pandas_udf(
0158 random_iter_udf, IntegerType(), PandasUDFType.SCALAR_ITER).asNondeterministic()
0159 self.assertEqual(random_pandas_iter_udf.deterministic, False)
0160 self.assertEqual(random_pandas_iter_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF)
0161 nondeterministic_pandas_iter_udf = self.spark.catalog.registerFunction(
0162 "randomPandasIterUDF", random_pandas_iter_udf)
0163 self.assertEqual(nondeterministic_pandas_iter_udf.deterministic, False)
0164 self.assertEqual(nondeterministic_pandas_iter_udf.evalType,
0165 PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF)
0166 [row] = self.spark.sql("SELECT randomPandasIterUDF(1)").collect()
0167 self.assertEqual(row[0], 7)
0168
0169 def test_vectorized_udf_null_boolean(self):
0170 data = [(True,), (True,), (None,), (False,)]
0171 schema = StructType().add("bool", BooleanType())
0172 df = self.spark.createDataFrame(data, schema)
0173 for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0174 bool_f = pandas_udf(lambda x: x, BooleanType(), udf_type)
0175 res = df.select(bool_f(col('bool')))
0176 self.assertEquals(df.collect(), res.collect())
0177
0178 def test_vectorized_udf_null_byte(self):
0179 data = [(None,), (2,), (3,), (4,)]
0180 schema = StructType().add("byte", ByteType())
0181 df = self.spark.createDataFrame(data, schema)
0182 for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0183 byte_f = pandas_udf(lambda x: x, ByteType(), udf_type)
0184 res = df.select(byte_f(col('byte')))
0185 self.assertEquals(df.collect(), res.collect())
0186
0187 def test_vectorized_udf_null_short(self):
0188 data = [(None,), (2,), (3,), (4,)]
0189 schema = StructType().add("short", ShortType())
0190 df = self.spark.createDataFrame(data, schema)
0191 for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0192 short_f = pandas_udf(lambda x: x, ShortType(), udf_type)
0193 res = df.select(short_f(col('short')))
0194 self.assertEquals(df.collect(), res.collect())
0195
0196 def test_vectorized_udf_null_int(self):
0197 data = [(None,), (2,), (3,), (4,)]
0198 schema = StructType().add("int", IntegerType())
0199 df = self.spark.createDataFrame(data, schema)
0200 for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0201 int_f = pandas_udf(lambda x: x, IntegerType(), udf_type)
0202 res = df.select(int_f(col('int')))
0203 self.assertEquals(df.collect(), res.collect())
0204
0205 def test_vectorized_udf_null_long(self):
0206 data = [(None,), (2,), (3,), (4,)]
0207 schema = StructType().add("long", LongType())
0208 df = self.spark.createDataFrame(data, schema)
0209 for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0210 long_f = pandas_udf(lambda x: x, LongType(), udf_type)
0211 res = df.select(long_f(col('long')))
0212 self.assertEquals(df.collect(), res.collect())
0213
0214 def test_vectorized_udf_null_float(self):
0215 data = [(3.0,), (5.0,), (-1.0,), (None,)]
0216 schema = StructType().add("float", FloatType())
0217 df = self.spark.createDataFrame(data, schema)
0218 for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0219 float_f = pandas_udf(lambda x: x, FloatType(), udf_type)
0220 res = df.select(float_f(col('float')))
0221 self.assertEquals(df.collect(), res.collect())
0222
0223 def test_vectorized_udf_null_double(self):
0224 data = [(3.0,), (5.0,), (-1.0,), (None,)]
0225 schema = StructType().add("double", DoubleType())
0226 df = self.spark.createDataFrame(data, schema)
0227 for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0228 double_f = pandas_udf(lambda x: x, DoubleType(), udf_type)
0229 res = df.select(double_f(col('double')))
0230 self.assertEquals(df.collect(), res.collect())
0231
0232 def test_vectorized_udf_null_decimal(self):
0233 data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)]
0234 schema = StructType().add("decimal", DecimalType(38, 18))
0235 df = self.spark.createDataFrame(data, schema)
0236 for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0237 decimal_f = pandas_udf(lambda x: x, DecimalType(38, 18), udf_type)
0238 res = df.select(decimal_f(col('decimal')))
0239 self.assertEquals(df.collect(), res.collect())
0240
0241 def test_vectorized_udf_null_string(self):
0242 data = [("foo",), (None,), ("bar",), ("bar",)]
0243 schema = StructType().add("str", StringType())
0244 df = self.spark.createDataFrame(data, schema)
0245 for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0246 str_f = pandas_udf(lambda x: x, StringType(), udf_type)
0247 res = df.select(str_f(col('str')))
0248 self.assertEquals(df.collect(), res.collect())
0249
0250 def test_vectorized_udf_string_in_udf(self):
0251 df = self.spark.range(10)
0252 scalar_f = lambda x: pd.Series(map(str, x))
0253
0254 def iter_f(it):
0255 for i in it:
0256 yield scalar_f(i)
0257
0258 for f, udf_type in [(scalar_f, PandasUDFType.SCALAR), (iter_f, PandasUDFType.SCALAR_ITER)]:
0259 str_f = pandas_udf(f, StringType(), udf_type)
0260 actual = df.select(str_f(col('id')))
0261 expected = df.select(col('id').cast('string'))
0262 self.assertEquals(expected.collect(), actual.collect())
0263
0264 def test_vectorized_udf_datatype_string(self):
0265 df = self.spark.range(10).select(
0266 col('id').cast('string').alias('str'),
0267 col('id').cast('int').alias('int'),
0268 col('id').alias('long'),
0269 col('id').cast('float').alias('float'),
0270 col('id').cast('double').alias('double'),
0271 col('id').cast('decimal').alias('decimal'),
0272 col('id').cast('boolean').alias('bool'))
0273 f = lambda x: x
0274 for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0275 str_f = pandas_udf(f, 'string', udf_type)
0276 int_f = pandas_udf(f, 'integer', udf_type)
0277 long_f = pandas_udf(f, 'long', udf_type)
0278 float_f = pandas_udf(f, 'float', udf_type)
0279 double_f = pandas_udf(f, 'double', udf_type)
0280 decimal_f = pandas_udf(f, 'decimal(38, 18)', udf_type)
0281 bool_f = pandas_udf(f, 'boolean', udf_type)
0282 res = df.select(str_f(col('str')), int_f(col('int')),
0283 long_f(col('long')), float_f(col('float')),
0284 double_f(col('double')), decimal_f('decimal'),
0285 bool_f(col('bool')))
0286 self.assertEquals(df.collect(), res.collect())
0287
0288 def test_vectorized_udf_null_binary(self):
0289 data = [(bytearray(b"a"),), (None,), (bytearray(b"bb"),), (bytearray(b"ccc"),)]
0290 schema = StructType().add("binary", BinaryType())
0291 df = self.spark.createDataFrame(data, schema)
0292 for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0293 str_f = pandas_udf(lambda x: x, BinaryType(), udf_type)
0294 res = df.select(str_f(col('binary')))
0295 self.assertEquals(df.collect(), res.collect())
0296
0297 def test_vectorized_udf_array_type(self):
0298 data = [([1, 2],), ([3, 4],)]
0299 array_schema = StructType([StructField("array", ArrayType(IntegerType()))])
0300 df = self.spark.createDataFrame(data, schema=array_schema)
0301 for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0302 array_f = pandas_udf(lambda x: x, ArrayType(IntegerType()), udf_type)
0303 result = df.select(array_f(col('array')))
0304 self.assertEquals(df.collect(), result.collect())
0305
0306 def test_vectorized_udf_null_array(self):
0307 data = [([1, 2],), (None,), (None,), ([3, 4],), (None,)]
0308 array_schema = StructType([StructField("array", ArrayType(IntegerType()))])
0309 df = self.spark.createDataFrame(data, schema=array_schema)
0310 for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0311 array_f = pandas_udf(lambda x: x, ArrayType(IntegerType()), udf_type)
0312 result = df.select(array_f(col('array')))
0313 self.assertEquals(df.collect(), result.collect())
0314
0315 def test_vectorized_udf_struct_type(self):
0316 df = self.spark.range(10)
0317 return_type = StructType([
0318 StructField('id', LongType()),
0319 StructField('str', StringType())])
0320
0321 def scalar_func(id):
0322 return pd.DataFrame({'id': id, 'str': id.apply(unicode)})
0323
0324 def iter_func(it):
0325 for id in it:
0326 yield scalar_func(id)
0327
0328 for func, udf_type in [(scalar_func, PandasUDFType.SCALAR),
0329 (iter_func, PandasUDFType.SCALAR_ITER)]:
0330 f = pandas_udf(func, returnType=return_type, functionType=udf_type)
0331
0332 expected = df.select(struct(col('id'), col('id').cast('string').alias('str'))
0333 .alias('struct')).collect()
0334
0335 actual = df.select(f(col('id')).alias('struct')).collect()
0336 self.assertEqual(expected, actual)
0337
0338 g = pandas_udf(func, 'id: long, str: string', functionType=udf_type)
0339 actual = df.select(g(col('id')).alias('struct')).collect()
0340 self.assertEqual(expected, actual)
0341
0342 struct_f = pandas_udf(lambda x: x, return_type, functionType=udf_type)
0343 actual = df.select(struct_f(struct(col('id'), col('id').cast('string').alias('str'))))
0344 self.assertEqual(expected, actual.collect())
0345
0346 def test_vectorized_udf_struct_complex(self):
0347 df = self.spark.range(10)
0348 return_type = StructType([
0349 StructField('ts', TimestampType()),
0350 StructField('arr', ArrayType(LongType()))])
0351
0352 def _scalar_f(id):
0353 return pd.DataFrame({'ts': id.apply(lambda i: pd.Timestamp(i)),
0354 'arr': id.apply(lambda i: [i, i + 1])})
0355
0356 scalar_f = pandas_udf(_scalar_f, returnType=return_type)
0357
0358 @pandas_udf(returnType=return_type, functionType=PandasUDFType.SCALAR_ITER)
0359 def iter_f(it):
0360 for id in it:
0361 yield _scalar_f(id)
0362
0363 for f, udf_type in [(scalar_f, PandasUDFType.SCALAR), (iter_f, PandasUDFType.SCALAR_ITER)]:
0364 actual = df.withColumn('f', f(col('id'))).collect()
0365 for i, row in enumerate(actual):
0366 id, f = row
0367 self.assertEqual(i, id)
0368 self.assertEqual(pd.Timestamp(i).to_pydatetime(), f[0])
0369 self.assertListEqual([i, i + 1], f[1])
0370
0371 def test_vectorized_udf_nested_struct(self):
0372 nested_type = StructType([
0373 StructField('id', IntegerType()),
0374 StructField('nested', StructType([
0375 StructField('foo', StringType()),
0376 StructField('bar', FloatType())
0377 ]))
0378 ])
0379
0380 for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0381 with QuietTest(self.sc):
0382 with self.assertRaisesRegexp(
0383 Exception,
0384 'Invalid return type with scalar Pandas UDFs'):
0385 pandas_udf(lambda x: x, returnType=nested_type, functionType=udf_type)
0386
0387 def test_vectorized_udf_complex(self):
0388 df = self.spark.range(10).select(
0389 col('id').cast('int').alias('a'),
0390 col('id').cast('int').alias('b'),
0391 col('id').cast('double').alias('c'))
0392 scalar_add = pandas_udf(lambda x, y: x + y, IntegerType())
0393 scalar_power2 = pandas_udf(lambda x: 2 ** x, IntegerType())
0394 scalar_mul = pandas_udf(lambda x, y: x * y, DoubleType())
0395
0396 @pandas_udf(IntegerType(), PandasUDFType.SCALAR_ITER)
0397 def iter_add(it):
0398 for x, y in it:
0399 yield x + y
0400
0401 @pandas_udf(IntegerType(), PandasUDFType.SCALAR_ITER)
0402 def iter_power2(it):
0403 for x in it:
0404 yield 2 ** x
0405
0406 @pandas_udf(DoubleType(), PandasUDFType.SCALAR_ITER)
0407 def iter_mul(it):
0408 for x, y in it:
0409 yield x * y
0410
0411 for add, power2, mul in [(scalar_add, scalar_power2, scalar_mul),
0412 (iter_add, iter_power2, iter_mul)]:
0413 res = df.select(add(col('a'), col('b')), power2(col('a')), mul(col('b'), col('c')))
0414 expected = df.select(expr('a + b'), expr('power(2, a)'), expr('b * c'))
0415 self.assertEquals(expected.collect(), res.collect())
0416
0417 def test_vectorized_udf_exception(self):
0418 df = self.spark.range(10)
0419 scalar_raise_exception = pandas_udf(lambda x: x * (1 / 0), LongType())
0420
0421 @pandas_udf(LongType(), PandasUDFType.SCALAR_ITER)
0422 def iter_raise_exception(it):
0423 for x in it:
0424 yield x * (1 / 0)
0425
0426 for raise_exception in [scalar_raise_exception, iter_raise_exception]:
0427 with QuietTest(self.sc):
0428 with self.assertRaisesRegexp(Exception, 'division( or modulo)? by zero'):
0429 df.select(raise_exception(col('id'))).collect()
0430
0431 def test_vectorized_udf_invalid_length(self):
0432 df = self.spark.range(10)
0433 raise_exception = pandas_udf(lambda _: pd.Series(1), LongType())
0434 with QuietTest(self.sc):
0435 with self.assertRaisesRegexp(
0436 Exception,
0437 'Result vector from pandas_udf was not the required length'):
0438 df.select(raise_exception(col('id'))).collect()
0439
0440 @pandas_udf(LongType(), PandasUDFType.SCALAR_ITER)
0441 def iter_udf_wong_output_size(it):
0442 for _ in it:
0443 yield pd.Series(1)
0444
0445 with QuietTest(self.sc):
0446 with self.assertRaisesRegexp(
0447 Exception,
0448 "The length of output in Scalar iterator.*"
0449 "the length of output was 1"):
0450 df.select(iter_udf_wong_output_size(col('id'))).collect()
0451
0452 @pandas_udf(LongType(), PandasUDFType.SCALAR_ITER)
0453 def iter_udf_not_reading_all_input(it):
0454 for batch in it:
0455 batch_len = len(batch)
0456 yield pd.Series([1] * batch_len)
0457 break
0458
0459 with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}):
0460 df1 = self.spark.range(10).repartition(1)
0461 with QuietTest(self.sc):
0462 with self.assertRaisesRegexp(
0463 Exception,
0464 "pandas iterator UDF should exhaust"):
0465 df1.select(iter_udf_not_reading_all_input(col('id'))).collect()
0466
0467 def test_vectorized_udf_chained(self):
0468 df = self.spark.range(10)
0469 scalar_f = pandas_udf(lambda x: x + 1, LongType())
0470 scalar_g = pandas_udf(lambda x: x - 1, LongType())
0471
0472 iter_f = pandas_udf(lambda it: map(lambda x: x + 1, it), LongType(),
0473 PandasUDFType.SCALAR_ITER)
0474 iter_g = pandas_udf(lambda it: map(lambda x: x - 1, it), LongType(),
0475 PandasUDFType.SCALAR_ITER)
0476
0477 for f, g in [(scalar_f, scalar_g), (iter_f, iter_g)]:
0478 res = df.select(g(f(col('id'))))
0479 self.assertEquals(df.collect(), res.collect())
0480
0481 def test_vectorized_udf_chained_struct_type(self):
0482 df = self.spark.range(10)
0483 return_type = StructType([
0484 StructField('id', LongType()),
0485 StructField('str', StringType())])
0486
0487 @pandas_udf(return_type)
0488 def scalar_f(id):
0489 return pd.DataFrame({'id': id, 'str': id.apply(unicode)})
0490
0491 scalar_g = pandas_udf(lambda x: x, return_type)
0492
0493 @pandas_udf(return_type, PandasUDFType.SCALAR_ITER)
0494 def iter_f(it):
0495 for id in it:
0496 yield pd.DataFrame({'id': id, 'str': id.apply(unicode)})
0497
0498 iter_g = pandas_udf(lambda x: x, return_type, PandasUDFType.SCALAR_ITER)
0499
0500 expected = df.select(struct(col('id'), col('id').cast('string').alias('str'))
0501 .alias('struct')).collect()
0502
0503 for f, g in [(scalar_f, scalar_g), (iter_f, iter_g)]:
0504 actual = df.select(g(f(col('id'))).alias('struct')).collect()
0505 self.assertEqual(expected, actual)
0506
0507 def test_vectorized_udf_wrong_return_type(self):
0508 with QuietTest(self.sc):
0509 for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0510 with self.assertRaisesRegexp(
0511 NotImplementedError,
0512 'Invalid return type.*scalar Pandas UDF.*MapType'):
0513 pandas_udf(lambda x: x, MapType(LongType(), LongType()), udf_type)
0514
0515 def test_vectorized_udf_return_scalar(self):
0516 df = self.spark.range(10)
0517 scalar_f = pandas_udf(lambda x: 1.0, DoubleType())
0518 iter_f = pandas_udf(lambda it: map(lambda x: 1.0, it), DoubleType(),
0519 PandasUDFType.SCALAR_ITER)
0520 for f in [scalar_f, iter_f]:
0521 with QuietTest(self.sc):
0522 with self.assertRaisesRegexp(Exception, 'Return.*type.*Series'):
0523 df.select(f(col('id'))).collect()
0524
0525 def test_vectorized_udf_decorator(self):
0526 df = self.spark.range(10)
0527
0528 @pandas_udf(returnType=LongType())
0529 def scalar_identity(x):
0530 return x
0531
0532 @pandas_udf(returnType=LongType(), functionType=PandasUDFType.SCALAR_ITER)
0533 def iter_identity(x):
0534 return x
0535
0536 for identity in [scalar_identity, iter_identity]:
0537 res = df.select(identity(col('id')))
0538 self.assertEquals(df.collect(), res.collect())
0539
0540 def test_vectorized_udf_empty_partition(self):
0541 df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))
0542 for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0543 f = pandas_udf(lambda x: x, LongType(), udf_type)
0544 res = df.select(f(col('id')))
0545 self.assertEquals(df.collect(), res.collect())
0546
0547 def test_vectorized_udf_struct_with_empty_partition(self):
0548 df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))\
0549 .withColumn('name', lit('John Doe'))
0550
0551 @pandas_udf("first string, last string")
0552 def scalar_split_expand(n):
0553 return n.str.split(expand=True)
0554
0555 @pandas_udf("first string, last string", PandasUDFType.SCALAR_ITER)
0556 def iter_split_expand(it):
0557 for n in it:
0558 yield n.str.split(expand=True)
0559
0560 for split_expand in [scalar_split_expand, iter_split_expand]:
0561 result = df.select(split_expand('name')).collect()
0562 self.assertEqual(1, len(result))
0563 row = result[0]
0564 self.assertEqual('John', row[0]['first'])
0565 self.assertEqual('Doe', row[0]['last'])
0566
0567 def test_vectorized_udf_varargs(self):
0568 df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))
0569 scalar_f = pandas_udf(lambda *v: v[0], LongType())
0570
0571 @pandas_udf(LongType(), PandasUDFType.SCALAR_ITER)
0572 def iter_f(it):
0573 for v in it:
0574 yield v[0]
0575
0576 for f in [scalar_f, iter_f]:
0577 res = df.select(f(col('id'), col('id')))
0578 self.assertEquals(df.collect(), res.collect())
0579
0580 def test_vectorized_udf_unsupported_types(self):
0581 with QuietTest(self.sc):
0582 for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0583 with self.assertRaisesRegexp(
0584 NotImplementedError,
0585 'Invalid return type.*scalar Pandas UDF.*MapType'):
0586 pandas_udf(lambda x: x, MapType(StringType(), IntegerType()), udf_type)
0587 with self.assertRaisesRegexp(
0588 NotImplementedError,
0589 'Invalid return type.*scalar Pandas UDF.*ArrayType.StructType'):
0590 pandas_udf(lambda x: x,
0591 ArrayType(StructType([StructField('a', IntegerType())])), udf_type)
0592
0593 def test_vectorized_udf_dates(self):
0594 schema = StructType().add("idx", LongType()).add("date", DateType())
0595 data = [(0, date(1969, 1, 1),),
0596 (1, date(2012, 2, 2),),
0597 (2, None,),
0598 (3, date(2100, 4, 4),),
0599 (4, date(2262, 4, 12),)]
0600 df = self.spark.createDataFrame(data, schema=schema)
0601
0602 def scalar_check_data(idx, date, date_copy):
0603 msgs = []
0604 is_equal = date.isnull()
0605 for i in range(len(idx)):
0606 if (is_equal[i] and data[idx[i]][1] is None) or \
0607 date[i] == data[idx[i]][1]:
0608 msgs.append(None)
0609 else:
0610 msgs.append(
0611 "date values are not equal (date='%s': data[%d][1]='%s')"
0612 % (date[i], idx[i], data[idx[i]][1]))
0613 return pd.Series(msgs)
0614
0615 def iter_check_data(it):
0616 for idx, date, date_copy in it:
0617 yield scalar_check_data(idx, date, date_copy)
0618
0619 pandas_scalar_check_data = pandas_udf(scalar_check_data, StringType())
0620 pandas_iter_check_data = pandas_udf(iter_check_data, StringType(),
0621 PandasUDFType.SCALAR_ITER)
0622
0623 for check_data, udf_type in [(pandas_scalar_check_data, PandasUDFType.SCALAR),
0624 (pandas_iter_check_data, PandasUDFType.SCALAR_ITER)]:
0625 date_copy = pandas_udf(lambda t: t, returnType=DateType(), functionType=udf_type)
0626 df = df.withColumn("date_copy", date_copy(col("date")))
0627 result = df.withColumn("check_data",
0628 check_data(col("idx"), col("date"), col("date_copy"))).collect()
0629
0630 self.assertEquals(len(data), len(result))
0631 for i in range(len(result)):
0632 self.assertEquals(data[i][1], result[i][1])
0633 self.assertEquals(data[i][1], result[i][2])
0634 self.assertIsNone(result[i][3])
0635
0636 def test_vectorized_udf_timestamps(self):
0637 schema = StructType([
0638 StructField("idx", LongType(), True),
0639 StructField("timestamp", TimestampType(), True)])
0640 data = [(0, datetime(1969, 1, 1, 1, 1, 1)),
0641 (1, datetime(2012, 2, 2, 2, 2, 2)),
0642 (2, None),
0643 (3, datetime(2100, 3, 3, 3, 3, 3))]
0644
0645 df = self.spark.createDataFrame(data, schema=schema)
0646
0647 def scalar_check_data(idx, timestamp, timestamp_copy):
0648 msgs = []
0649 is_equal = timestamp.isnull()
0650 for i in range(len(idx)):
0651
0652 if (is_equal[i] and data[idx[i]][1] is None) or \
0653 timestamp[i].to_pydatetime() == data[idx[i]][1]:
0654 msgs.append(None)
0655 else:
0656 msgs.append(
0657 "timestamp values are not equal (timestamp='%s': data[%d][1]='%s')"
0658 % (timestamp[i], idx[i], data[idx[i]][1]))
0659 return pd.Series(msgs)
0660
0661 def iter_check_data(it):
0662 for idx, timestamp, timestamp_copy in it:
0663 yield scalar_check_data(idx, timestamp, timestamp_copy)
0664
0665 pandas_scalar_check_data = pandas_udf(scalar_check_data, StringType())
0666 pandas_iter_check_data = pandas_udf(iter_check_data, StringType(),
0667 PandasUDFType.SCALAR_ITER)
0668
0669 for check_data, udf_type in [(pandas_scalar_check_data, PandasUDFType.SCALAR),
0670 (pandas_iter_check_data, PandasUDFType.SCALAR_ITER)]:
0671
0672
0673 f_timestamp_copy = pandas_udf(lambda t: t,
0674 returnType=TimestampType(), functionType=udf_type)
0675 df = df.withColumn("timestamp_copy", f_timestamp_copy(col("timestamp")))
0676 result = df.withColumn("check_data", check_data(col("idx"), col("timestamp"),
0677 col("timestamp_copy"))).collect()
0678
0679 self.assertEquals(len(data), len(result))
0680 for i in range(len(result)):
0681 self.assertEquals(data[i][1], result[i][1])
0682 self.assertEquals(data[i][1], result[i][2])
0683 self.assertIsNone(result[i][3])
0684
0685 def test_vectorized_udf_return_timestamp_tz(self):
0686 df = self.spark.range(10)
0687
0688 @pandas_udf(returnType=TimestampType())
0689 def scalar_gen_timestamps(id):
0690 ts = [pd.Timestamp(i, unit='D', tz='America/Los_Angeles') for i in id]
0691 return pd.Series(ts)
0692
0693 @pandas_udf(returnType=TimestampType(), functionType=PandasUDFType.SCALAR_ITER)
0694 def iter_gen_timestamps(it):
0695 for id in it:
0696 ts = [pd.Timestamp(i, unit='D', tz='America/Los_Angeles') for i in id]
0697 yield pd.Series(ts)
0698
0699 for gen_timestamps in [scalar_gen_timestamps, iter_gen_timestamps]:
0700 result = df.withColumn("ts", gen_timestamps(col("id"))).collect()
0701 spark_ts_t = TimestampType()
0702 for r in result:
0703 i, ts = r
0704 ts_tz = pd.Timestamp(i, unit='D', tz='America/Los_Angeles').to_pydatetime()
0705 expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz))
0706 self.assertEquals(expected, ts)
0707
0708 def test_vectorized_udf_check_config(self):
0709 with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}):
0710 df = self.spark.range(10, numPartitions=1)
0711
0712 @pandas_udf(returnType=LongType())
0713 def scalar_check_records_per_batch(x):
0714 return pd.Series(x.size).repeat(x.size)
0715
0716 @pandas_udf(returnType=LongType(), functionType=PandasUDFType.SCALAR_ITER)
0717 def iter_check_records_per_batch(it):
0718 for x in it:
0719 yield pd.Series(x.size).repeat(x.size)
0720
0721 for check_records_per_batch in [scalar_check_records_per_batch,
0722 iter_check_records_per_batch]:
0723 result = df.select(check_records_per_batch(col("id"))).collect()
0724 for (r,) in result:
0725 self.assertTrue(r <= 3)
0726
0727 def test_vectorized_udf_timestamps_respect_session_timezone(self):
0728 schema = StructType([
0729 StructField("idx", LongType(), True),
0730 StructField("timestamp", TimestampType(), True)])
0731 data = [(1, datetime(1969, 1, 1, 1, 1, 1)),
0732 (2, datetime(2012, 2, 2, 2, 2, 2)),
0733 (3, None),
0734 (4, datetime(2100, 3, 3, 3, 3, 3))]
0735 df = self.spark.createDataFrame(data, schema=schema)
0736
0737 scalar_internal_value = pandas_udf(
0738 lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType())
0739
0740 @pandas_udf(LongType(), PandasUDFType.SCALAR_ITER)
0741 def iter_internal_value(it):
0742 for ts in it:
0743 yield ts.apply(lambda ts: ts.value if ts is not pd.NaT else None)
0744
0745 for internal_value, udf_type in [(scalar_internal_value, PandasUDFType.SCALAR),
0746 (iter_internal_value, PandasUDFType.SCALAR_ITER)]:
0747 f_timestamp_copy = pandas_udf(lambda ts: ts, TimestampType(), udf_type)
0748 timezone = "America/Los_Angeles"
0749 with self.sql_conf({"spark.sql.session.timeZone": timezone}):
0750 df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
0751 .withColumn("internal_value", internal_value(col("timestamp")))
0752 result_la = df_la.select(col("idx"), col("internal_value")).collect()
0753
0754 diff = 3 * 60 * 60 * 1000 * 1000 * 1000
0755 result_la_corrected = \
0756 df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect()
0757
0758 timezone = "America/New_York"
0759 with self.sql_conf({"spark.sql.session.timeZone": timezone}):
0760 df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
0761 .withColumn("internal_value", internal_value(col("timestamp")))
0762 result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect()
0763
0764 self.assertNotEqual(result_ny, result_la)
0765 self.assertEqual(result_ny, result_la_corrected)
0766
0767 def test_nondeterministic_vectorized_udf(self):
0768
0769 @pandas_udf('double')
0770 def scalar_plus_ten(v):
0771 return v + 10
0772
0773 @pandas_udf('double', PandasUDFType.SCALAR_ITER)
0774 def iter_plus_ten(it):
0775 for v in it:
0776 yield v + 10
0777
0778 for plus_ten in [scalar_plus_ten, iter_plus_ten]:
0779 random_udf = self.nondeterministic_vectorized_udf
0780
0781 df = self.spark.range(10).withColumn('rand', random_udf(col('id')))
0782 result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas()
0783
0784 self.assertEqual(random_udf.deterministic, False)
0785 self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10))
0786
0787 def test_nondeterministic_vectorized_udf_in_aggregate(self):
0788 df = self.spark.range(10)
0789 for random_udf in [self.nondeterministic_vectorized_udf,
0790 self.nondeterministic_vectorized_iter_udf]:
0791 with QuietTest(self.sc):
0792 with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'):
0793 df.groupby(df.id).agg(sum(random_udf(df.id))).collect()
0794 with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'):
0795 df.agg(sum(random_udf(df.id))).collect()
0796
0797 def test_register_vectorized_udf_basic(self):
0798 df = self.spark.range(10).select(
0799 col('id').cast('int').alias('a'),
0800 col('id').cast('int').alias('b'))
0801 scalar_original_add = pandas_udf(lambda x, y: x + y, IntegerType())
0802 self.assertEqual(scalar_original_add.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
0803
0804 @pandas_udf(IntegerType(), PandasUDFType.SCALAR_ITER)
0805 def iter_original_add(it):
0806 for x, y in it:
0807 yield x + y
0808
0809 self.assertEqual(iter_original_add.evalType, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF)
0810
0811 for original_add in [scalar_original_add, iter_original_add]:
0812 self.assertEqual(original_add.deterministic, True)
0813 new_add = self.spark.catalog.registerFunction("add1", original_add)
0814 res1 = df.select(new_add(col('a'), col('b')))
0815 res2 = self.spark.sql(
0816 "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t")
0817 expected = df.select(expr('a + b'))
0818 self.assertEquals(expected.collect(), res1.collect())
0819 self.assertEquals(expected.collect(), res2.collect())
0820
0821 def test_scalar_iter_udf_init(self):
0822 import numpy as np
0823
0824 @pandas_udf('int', PandasUDFType.SCALAR_ITER)
0825 def rng(batch_iter):
0826 context = TaskContext.get()
0827 part = context.partitionId()
0828 np.random.seed(part)
0829 for batch in batch_iter:
0830 yield pd.Series(np.random.randint(100, size=len(batch)))
0831 with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 2}):
0832 df = self.spark.range(10, numPartitions=2).select(rng(col("id").alias("v")))
0833 result1 = df.collect()
0834 result2 = df.collect()
0835 self.assertEqual(result1, result2,
0836 "SCALAR ITER UDF can initialize state and produce deterministic RNG")
0837
0838 def test_scalar_iter_udf_close(self):
0839 @pandas_udf('int', PandasUDFType.SCALAR_ITER)
0840 def test_close(batch_iter):
0841 try:
0842 for batch in batch_iter:
0843 yield batch
0844 finally:
0845 raise RuntimeError("reached finally block")
0846 with QuietTest(self.sc):
0847 with self.assertRaisesRegexp(Exception, "reached finally block"):
0848 self.spark.range(1).select(test_close(col("id"))).collect()
0849
0850 def test_scalar_iter_udf_close_early(self):
0851 tmp_dir = tempfile.mkdtemp()
0852 try:
0853 tmp_file = tmp_dir + '/reach_finally_block'
0854
0855 @pandas_udf('int', PandasUDFType.SCALAR_ITER)
0856 def test_close(batch_iter):
0857 generator_exit_caught = False
0858 try:
0859 for batch in batch_iter:
0860 yield batch
0861 time.sleep(1.0)
0862 except GeneratorExit as ge:
0863 generator_exit_caught = True
0864 raise ge
0865 finally:
0866 assert generator_exit_caught, "Generator exit exception was not caught."
0867 open(tmp_file, 'a').close()
0868
0869 with QuietTest(self.sc):
0870 with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 1,
0871 "spark.sql.execution.pandas.udf.buffer.size": 4}):
0872 self.spark.range(10).repartition(1) \
0873 .select(test_close(col("id"))).limit(2).collect()
0874
0875
0876
0877 for i in range(100):
0878 time.sleep(0.1)
0879 if os.path.exists(tmp_file):
0880 break
0881
0882 assert os.path.exists(tmp_file), "finally block not reached."
0883
0884 finally:
0885 shutil.rmtree(tmp_dir)
0886
0887
0888 def test_timestamp_dst(self):
0889
0890 dt = [datetime(2015, 11, 1, 0, 30),
0891 datetime(2015, 11, 1, 1, 30),
0892 datetime(2015, 11, 1, 2, 30)]
0893 df = self.spark.createDataFrame(dt, 'timestamp').toDF('time')
0894
0895 for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0896 foo_udf = pandas_udf(lambda x: x, 'timestamp', udf_type)
0897 result = df.withColumn('time', foo_udf(df.time))
0898 self.assertEquals(df.collect(), result.collect())
0899
0900 @unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.")
0901 def test_type_annotation(self):
0902
0903
0904
0905
0906
0907
0908
0909
0910 _locals = {}
0911 exec(
0912 "import pandas as pd\ndef noop(col: pd.Series) -> pd.Series: return col",
0913 _locals)
0914 df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id'))
0915 self.assertEqual(df.first()[0], 0)
0916
0917 def test_mixed_udf(self):
0918 df = self.spark.range(0, 1).toDF('v')
0919
0920
0921
0922 @udf('int')
0923 def f1(x):
0924 assert type(x) == int
0925 return x + 1
0926
0927 @pandas_udf('int')
0928 def f2_scalar(x):
0929 assert type(x) == pd.Series
0930 return x + 10
0931
0932 @pandas_udf('int', PandasUDFType.SCALAR_ITER)
0933 def f2_iter(it):
0934 for x in it:
0935 assert type(x) == pd.Series
0936 yield x + 10
0937
0938 @udf('int')
0939 def f3(x):
0940 assert type(x) == int
0941 return x + 100
0942
0943 @pandas_udf('int')
0944 def f4_scalar(x):
0945 assert type(x) == pd.Series
0946 return x + 1000
0947
0948 @pandas_udf('int', PandasUDFType.SCALAR_ITER)
0949 def f4_iter(it):
0950 for x in it:
0951 assert type(x) == pd.Series
0952 yield x + 1000
0953
0954 expected_chained_1 = df.withColumn('f2_f1', df['v'] + 11).collect()
0955 expected_chained_2 = df.withColumn('f3_f2_f1', df['v'] + 111).collect()
0956 expected_chained_3 = df.withColumn('f4_f3_f2_f1', df['v'] + 1111).collect()
0957 expected_chained_4 = df.withColumn('f4_f2_f1', df['v'] + 1011).collect()
0958 expected_chained_5 = df.withColumn('f4_f3_f1', df['v'] + 1101).collect()
0959
0960 expected_multi = df \
0961 .withColumn('f1', df['v'] + 1) \
0962 .withColumn('f2', df['v'] + 10) \
0963 .withColumn('f3', df['v'] + 100) \
0964 .withColumn('f4', df['v'] + 1000) \
0965 .withColumn('f2_f1', df['v'] + 11) \
0966 .withColumn('f3_f1', df['v'] + 101) \
0967 .withColumn('f4_f1', df['v'] + 1001) \
0968 .withColumn('f3_f2', df['v'] + 110) \
0969 .withColumn('f4_f2', df['v'] + 1010) \
0970 .withColumn('f4_f3', df['v'] + 1100) \
0971 .withColumn('f3_f2_f1', df['v'] + 111) \
0972 .withColumn('f4_f2_f1', df['v'] + 1011) \
0973 .withColumn('f4_f3_f1', df['v'] + 1101) \
0974 .withColumn('f4_f3_f2', df['v'] + 1110) \
0975 .withColumn('f4_f3_f2_f1', df['v'] + 1111) \
0976 .collect()
0977
0978 for f2, f4 in [(f2_scalar, f4_scalar), (f2_scalar, f4_iter),
0979 (f2_iter, f4_scalar), (f2_iter, f4_iter)]:
0980
0981 df_chained_1 = df.withColumn('f2_f1', f2(f1(df['v'])))
0982 df_chained_2 = df.withColumn('f3_f2_f1', f3(f2(f1(df['v']))))
0983 df_chained_3 = df.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(df['v'])))))
0984 df_chained_4 = df.withColumn('f4_f2_f1', f4(f2(f1(df['v']))))
0985 df_chained_5 = df.withColumn('f4_f3_f1', f4(f3(f1(df['v']))))
0986
0987 self.assertEquals(expected_chained_1, df_chained_1.collect())
0988 self.assertEquals(expected_chained_2, df_chained_2.collect())
0989 self.assertEquals(expected_chained_3, df_chained_3.collect())
0990 self.assertEquals(expected_chained_4, df_chained_4.collect())
0991 self.assertEquals(expected_chained_5, df_chained_5.collect())
0992
0993
0994 df_multi_1 = df \
0995 .withColumn('f1', f1(col('v'))) \
0996 .withColumn('f2', f2(col('v'))) \
0997 .withColumn('f3', f3(col('v'))) \
0998 .withColumn('f4', f4(col('v'))) \
0999 .withColumn('f2_f1', f2(col('f1'))) \
1000 .withColumn('f3_f1', f3(col('f1'))) \
1001 .withColumn('f4_f1', f4(col('f1'))) \
1002 .withColumn('f3_f2', f3(col('f2'))) \
1003 .withColumn('f4_f2', f4(col('f2'))) \
1004 .withColumn('f4_f3', f4(col('f3'))) \
1005 .withColumn('f3_f2_f1', f3(col('f2_f1'))) \
1006 .withColumn('f4_f2_f1', f4(col('f2_f1'))) \
1007 .withColumn('f4_f3_f1', f4(col('f3_f1'))) \
1008 .withColumn('f4_f3_f2', f4(col('f3_f2'))) \
1009 .withColumn('f4_f3_f2_f1', f4(col('f3_f2_f1')))
1010
1011
1012 df_multi_2 = df \
1013 .withColumn('f1', f1(col('v'))) \
1014 .withColumn('f2', f2(col('v'))) \
1015 .withColumn('f3', f3(col('v'))) \
1016 .withColumn('f4', f4(col('v'))) \
1017 .withColumn('f2_f1', f2(f1(col('v')))) \
1018 .withColumn('f3_f1', f3(f1(col('v')))) \
1019 .withColumn('f4_f1', f4(f1(col('v')))) \
1020 .withColumn('f3_f2', f3(f2(col('v')))) \
1021 .withColumn('f4_f2', f4(f2(col('v')))) \
1022 .withColumn('f4_f3', f4(f3(col('v')))) \
1023 .withColumn('f3_f2_f1', f3(f2(f1(col('v'))))) \
1024 .withColumn('f4_f2_f1', f4(f2(f1(col('v'))))) \
1025 .withColumn('f4_f3_f1', f4(f3(f1(col('v'))))) \
1026 .withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \
1027 .withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v'))))))
1028
1029 self.assertEquals(expected_multi, df_multi_1.collect())
1030 self.assertEquals(expected_multi, df_multi_2.collect())
1031
1032 def test_mixed_udf_and_sql(self):
1033 df = self.spark.range(0, 1).toDF('v')
1034
1035
1036
1037 @udf('int')
1038 def f1(x):
1039 assert type(x) == int
1040 return x + 1
1041
1042 def f2(x):
1043 assert type(x) == Column
1044 return x + 10
1045
1046 @pandas_udf('int')
1047 def f3s(x):
1048 assert type(x) == pd.Series
1049 return x + 100
1050
1051 @pandas_udf('int', PandasUDFType.SCALAR_ITER)
1052 def f3i(it):
1053 for x in it:
1054 assert type(x) == pd.Series
1055 yield x + 100
1056
1057 expected = df.withColumn('f1', df['v'] + 1) \
1058 .withColumn('f2', df['v'] + 10) \
1059 .withColumn('f3', df['v'] + 100) \
1060 .withColumn('f1_f2', df['v'] + 11) \
1061 .withColumn('f1_f3', df['v'] + 101) \
1062 .withColumn('f2_f1', df['v'] + 11) \
1063 .withColumn('f2_f3', df['v'] + 110) \
1064 .withColumn('f3_f1', df['v'] + 101) \
1065 .withColumn('f3_f2', df['v'] + 110) \
1066 .withColumn('f1_f2_f3', df['v'] + 111) \
1067 .withColumn('f1_f3_f2', df['v'] + 111) \
1068 .withColumn('f2_f1_f3', df['v'] + 111) \
1069 .withColumn('f2_f3_f1', df['v'] + 111) \
1070 .withColumn('f3_f1_f2', df['v'] + 111) \
1071 .withColumn('f3_f2_f1', df['v'] + 111) \
1072 .collect()
1073
1074 for f3 in [f3s, f3i]:
1075 df1 = df.withColumn('f1', f1(df['v'])) \
1076 .withColumn('f2', f2(df['v'])) \
1077 .withColumn('f3', f3(df['v'])) \
1078 .withColumn('f1_f2', f1(f2(df['v']))) \
1079 .withColumn('f1_f3', f1(f3(df['v']))) \
1080 .withColumn('f2_f1', f2(f1(df['v']))) \
1081 .withColumn('f2_f3', f2(f3(df['v']))) \
1082 .withColumn('f3_f1', f3(f1(df['v']))) \
1083 .withColumn('f3_f2', f3(f2(df['v']))) \
1084 .withColumn('f1_f2_f3', f1(f2(f3(df['v'])))) \
1085 .withColumn('f1_f3_f2', f1(f3(f2(df['v'])))) \
1086 .withColumn('f2_f1_f3', f2(f1(f3(df['v'])))) \
1087 .withColumn('f2_f3_f1', f2(f3(f1(df['v'])))) \
1088 .withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \
1089 .withColumn('f3_f2_f1', f3(f2(f1(df['v']))))
1090
1091 self.assertEquals(expected, df1.collect())
1092
1093
1094 @unittest.skipIf(not test_compiled, test_not_compiled_message)
1095 def test_datasource_with_udf(self):
1096
1097
1098 import numpy as np
1099
1100 path = tempfile.mkdtemp()
1101 shutil.rmtree(path)
1102
1103 try:
1104 self.spark.range(1).write.mode("overwrite").format('csv').save(path)
1105 filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i')
1106 datasource_df = self.spark.read \
1107 .format("org.apache.spark.sql.sources.SimpleScanSource") \
1108 .option('from', 0).option('to', 1).load().toDF('i')
1109 datasource_v2_df = self.spark.read \
1110 .format("org.apache.spark.sql.connector.SimpleDataSourceV2") \
1111 .load().toDF('i', 'j')
1112
1113 c1 = pandas_udf(lambda x: x + 1, 'int')(lit(1))
1114 c2 = pandas_udf(lambda x: x + 1, 'int')(col('i'))
1115
1116 f1 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1))
1117 f2 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(col('i'))
1118
1119 for df in [filesource_df, datasource_df, datasource_v2_df]:
1120 result = df.withColumn('c', c1)
1121 expected = df.withColumn('c', lit(2))
1122 self.assertEquals(expected.collect(), result.collect())
1123
1124 for df in [filesource_df, datasource_df, datasource_v2_df]:
1125 result = df.withColumn('c', c2)
1126 expected = df.withColumn('c', col('i') + 1)
1127 self.assertEquals(expected.collect(), result.collect())
1128
1129 for df in [filesource_df, datasource_df, datasource_v2_df]:
1130 for f in [f1, f2]:
1131 result = df.filter(f)
1132 self.assertEquals(0, result.count())
1133 finally:
1134 shutil.rmtree(path)
1135
1136
1137 if __name__ == "__main__":
1138 from pyspark.sql.tests.test_pandas_udf_scalar import *
1139
1140 try:
1141 import xmlrunner
1142 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
1143 except ImportError:
1144 testRunner = None
1145 unittest.main(testRunner=testRunner, verbosity=2)