0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import functools
0019 import pydoc
0020 import shutil
0021 import tempfile
0022 import unittest
0023
0024 from pyspark import SparkContext
0025 from pyspark.sql import SparkSession, Column, Row
0026 from pyspark.sql.functions import UserDefinedFunction, udf
0027 from pyspark.sql.types import *
0028 from pyspark.sql.utils import AnalysisException
0029 from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message
0030 from pyspark.testing.utils import QuietTest
0031
0032
0033 class UDFTests(ReusedSQLTestCase):
0034
0035 def test_udf_with_callable(self):
0036 d = [Row(number=i, squared=i**2) for i in range(10)]
0037 rdd = self.sc.parallelize(d)
0038 data = self.spark.createDataFrame(rdd)
0039
0040 class PlusFour:
0041 def __call__(self, col):
0042 if col is not None:
0043 return col + 4
0044
0045 call = PlusFour()
0046 pudf = UserDefinedFunction(call, LongType())
0047 res = data.select(pudf(data['number']).alias('plus_four'))
0048 self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
0049
0050 def test_udf_with_partial_function(self):
0051 d = [Row(number=i, squared=i**2) for i in range(10)]
0052 rdd = self.sc.parallelize(d)
0053 data = self.spark.createDataFrame(rdd)
0054
0055 def some_func(col, param):
0056 if col is not None:
0057 return col + param
0058
0059 pfunc = functools.partial(some_func, param=4)
0060 pudf = UserDefinedFunction(pfunc, LongType())
0061 res = data.select(pudf(data['number']).alias('plus_four'))
0062 self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
0063
0064 def test_udf(self):
0065 self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
0066 [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
0067 self.assertEqual(row[0], 5)
0068
0069
0070 sqlContext = self.spark._wrapped
0071 sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType())
0072 [row] = sqlContext.sql("SELECT oneArg('test')").collect()
0073 self.assertEqual(row[0], 4)
0074
0075 def test_udf2(self):
0076 with self.tempView("test"):
0077 self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType())
0078 self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\
0079 .createOrReplaceTempView("test")
0080 [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
0081 self.assertEqual(4, res[0])
0082
0083 def test_udf3(self):
0084 two_args = self.spark.catalog.registerFunction(
0085 "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y))
0086 self.assertEqual(two_args.deterministic, True)
0087 [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
0088 self.assertEqual(row[0], u'5')
0089
0090 def test_udf_registration_return_type_none(self):
0091 two_args = self.spark.catalog.registerFunction(
0092 "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, "integer"), None)
0093 self.assertEqual(two_args.deterministic, True)
0094 [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
0095 self.assertEqual(row[0], 5)
0096
0097 def test_udf_registration_return_type_not_none(self):
0098 with QuietTest(self.sc):
0099 with self.assertRaisesRegexp(TypeError, "Invalid return type"):
0100 self.spark.catalog.registerFunction(
0101 "f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType())
0102
0103 def test_nondeterministic_udf(self):
0104
0105 import random
0106 udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
0107 self.assertEqual(udf_random_col.deterministic, False)
0108 df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
0109 udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
0110 [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
0111 self.assertEqual(row[0] + 10, row[1])
0112
0113 def test_nondeterministic_udf2(self):
0114 import random
0115 random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
0116 self.assertEqual(random_udf.deterministic, False)
0117 random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf)
0118 self.assertEqual(random_udf1.deterministic, False)
0119 [row] = self.spark.sql("SELECT randInt()").collect()
0120 self.assertEqual(row[0], 6)
0121 [row] = self.spark.range(1).select(random_udf1()).collect()
0122 self.assertEqual(row[0], 6)
0123 [row] = self.spark.range(1).select(random_udf()).collect()
0124 self.assertEqual(row[0], 6)
0125
0126 pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType()))
0127 pydoc.render_doc(random_udf)
0128 pydoc.render_doc(random_udf1)
0129 pydoc.render_doc(udf(lambda x: x).asNondeterministic)
0130
0131 def test_nondeterministic_udf3(self):
0132
0133 f = udf(lambda x: x)
0134
0135 self.spark.range(1).select(f("id"))
0136
0137 f = f.asNondeterministic()
0138
0139 df = self.spark.range(1).select(f("id"))
0140 deterministic = df._jdf.logicalPlan().projectList().head().deterministic()
0141 self.assertFalse(deterministic)
0142
0143 def test_nondeterministic_udf_in_aggregate(self):
0144 from pyspark.sql.functions import sum
0145 import random
0146 udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic()
0147 df = self.spark.range(10)
0148
0149 with QuietTest(self.sc):
0150 with self.assertRaisesRegexp(AnalysisException, "nondeterministic"):
0151 df.groupby('id').agg(sum(udf_random_col())).collect()
0152 with self.assertRaisesRegexp(AnalysisException, "nondeterministic"):
0153 df.agg(sum(udf_random_col())).collect()
0154
0155 def test_chained_udf(self):
0156 self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType())
0157 [row] = self.spark.sql("SELECT double(1)").collect()
0158 self.assertEqual(row[0], 2)
0159 [row] = self.spark.sql("SELECT double(double(1))").collect()
0160 self.assertEqual(row[0], 4)
0161 [row] = self.spark.sql("SELECT double(double(1) + 1)").collect()
0162 self.assertEqual(row[0], 6)
0163
0164 def test_single_udf_with_repeated_argument(self):
0165
0166 self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType())
0167 row = self.spark.sql("SELECT add(1, 1)").first()
0168 self.assertEqual(tuple(row), (2, ))
0169
0170 def test_multiple_udfs(self):
0171 self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType())
0172 [row] = self.spark.sql("SELECT double(1), double(2)").collect()
0173 self.assertEqual(tuple(row), (2, 4))
0174 [row] = self.spark.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
0175 self.assertEqual(tuple(row), (4, 12))
0176 self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType())
0177 [row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
0178 self.assertEqual(tuple(row), (6, 5))
0179
0180 def test_udf_in_filter_on_top_of_outer_join(self):
0181 left = self.spark.createDataFrame([Row(a=1)])
0182 right = self.spark.createDataFrame([Row(a=1)])
0183 df = left.join(right, on='a', how='left_outer')
0184 df = df.withColumn('b', udf(lambda x: 'x')(df.a))
0185 self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')])
0186
0187 def test_udf_in_filter_on_top_of_join(self):
0188
0189 left = self.spark.createDataFrame([Row(a=1)])
0190 right = self.spark.createDataFrame([Row(b=1)])
0191 f = udf(lambda a, b: a == b, BooleanType())
0192 df = left.crossJoin(right).filter(f("a", "b"))
0193 self.assertEqual(df.collect(), [Row(a=1, b=1)])
0194
0195 def test_udf_in_join_condition(self):
0196
0197 left = self.spark.createDataFrame([Row(a=1)])
0198 right = self.spark.createDataFrame([Row(b=1)])
0199 f = udf(lambda a, b: a == b, BooleanType())
0200
0201
0202 df = left.join(right, f("a", "b"))
0203 with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
0204 with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'):
0205 df.collect()
0206 with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
0207 self.assertEqual(df.collect(), [Row(a=1, b=1)])
0208
0209 def test_udf_in_left_outer_join_condition(self):
0210
0211 from pyspark.sql.functions import col
0212 left = self.spark.createDataFrame([Row(a=1)])
0213 right = self.spark.createDataFrame([Row(b=1)])
0214 f = udf(lambda a: str(a), StringType())
0215
0216
0217 df = left.join(right, f("a") == col("b").cast("string"), how="left_outer")
0218 with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
0219 self.assertEqual(df.collect(), [Row(a=1, b=1)])
0220
0221 def test_udf_and_common_filter_in_join_condition(self):
0222
0223
0224 left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
0225 right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
0226 f = udf(lambda a, b: a == b, BooleanType())
0227 df = left.join(right, [f("a", "b"), left.a1 == right.b1])
0228
0229 self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
0230
0231 def test_udf_not_supported_in_join_condition(self):
0232
0233
0234 left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
0235 right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
0236 f = udf(lambda a, b: a == b, BooleanType())
0237
0238 def runWithJoinType(join_type, type_string):
0239 with self.assertRaisesRegexp(
0240 AnalysisException,
0241 'Using PythonUDF.*%s is not supported.' % type_string):
0242 left.join(right, [f("a", "b"), left.a1 == right.b1], join_type).collect()
0243 runWithJoinType("full", "FullOuter")
0244 runWithJoinType("left", "LeftOuter")
0245 runWithJoinType("right", "RightOuter")
0246 runWithJoinType("leftanti", "LeftAnti")
0247 runWithJoinType("leftsemi", "LeftSemi")
0248
0249 def test_udf_as_join_condition(self):
0250 left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
0251 right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
0252 f = udf(lambda a: a, IntegerType())
0253
0254 df = left.join(right, [f("a") == f("b"), left.a1 == right.b1])
0255 self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
0256
0257 def test_udf_without_arguments(self):
0258 self.spark.catalog.registerFunction("foo", lambda: "bar")
0259 [row] = self.spark.sql("SELECT foo()").collect()
0260 self.assertEqual(row[0], "bar")
0261
0262 def test_udf_with_array_type(self):
0263 with self.tempView("test"):
0264 d = [Row(l=list(range(3)), d={"key": list(range(5))})]
0265 rdd = self.sc.parallelize(d)
0266 self.spark.createDataFrame(rdd).createOrReplaceTempView("test")
0267 self.spark.catalog.registerFunction(
0268 "copylist", lambda l: list(l), ArrayType(IntegerType()))
0269 self.spark.catalog.registerFunction("maplen", lambda d: len(d), IntegerType())
0270 [(l1, l2)] = self.spark.sql("select copylist(l), maplen(d) from test").collect()
0271 self.assertEqual(list(range(3)), l1)
0272 self.assertEqual(1, l2)
0273
0274 def test_broadcast_in_udf(self):
0275 bar = {"a": "aa", "b": "bb", "c": "abc"}
0276 foo = self.sc.broadcast(bar)
0277 self.spark.catalog.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
0278 [res] = self.spark.sql("SELECT MYUDF('c')").collect()
0279 self.assertEqual("abc", res[0])
0280 [res] = self.spark.sql("SELECT MYUDF('')").collect()
0281 self.assertEqual("", res[0])
0282
0283 def test_udf_with_filter_function(self):
0284 df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
0285 from pyspark.sql.functions import col
0286 from pyspark.sql.types import BooleanType
0287
0288 my_filter = udf(lambda a: a < 2, BooleanType())
0289 sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2"))
0290 self.assertEqual(sel.collect(), [Row(key=1, value='1')])
0291
0292 def test_udf_with_aggregate_function(self):
0293 df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
0294 from pyspark.sql.functions import col, sum
0295 from pyspark.sql.types import BooleanType
0296
0297 my_filter = udf(lambda a: a == 1, BooleanType())
0298 sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
0299 self.assertEqual(sel.collect(), [Row(key=1)])
0300
0301 my_copy = udf(lambda x: x, IntegerType())
0302 my_add = udf(lambda a, b: int(a + b), IntegerType())
0303 my_strlen = udf(lambda x: len(x), IntegerType())
0304 sel = df.groupBy(my_copy(col("key")).alias("k"))\
0305 .agg(sum(my_strlen(col("value"))).alias("s"))\
0306 .select(my_add(col("k"), col("s")).alias("t"))
0307 self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
0308
0309 def test_udf_in_generate(self):
0310 from pyspark.sql.functions import explode
0311 df = self.spark.range(5)
0312 f = udf(lambda x: list(range(x)), ArrayType(LongType()))
0313 row = df.select(explode(f(*df))).groupBy().sum().first()
0314 self.assertEqual(row[0], 10)
0315
0316 df = self.spark.range(3)
0317 res = df.select("id", explode(f(df.id))).collect()
0318 self.assertEqual(res[0][0], 1)
0319 self.assertEqual(res[0][1], 0)
0320 self.assertEqual(res[1][0], 2)
0321 self.assertEqual(res[1][1], 0)
0322 self.assertEqual(res[2][0], 2)
0323 self.assertEqual(res[2][1], 1)
0324
0325 range_udf = udf(lambda value: list(range(value - 1, value + 1)), ArrayType(IntegerType()))
0326 res = df.select("id", explode(range_udf(df.id))).collect()
0327 self.assertEqual(res[0][0], 0)
0328 self.assertEqual(res[0][1], -1)
0329 self.assertEqual(res[1][0], 0)
0330 self.assertEqual(res[1][1], 0)
0331 self.assertEqual(res[2][0], 1)
0332 self.assertEqual(res[2][1], 0)
0333 self.assertEqual(res[3][0], 1)
0334 self.assertEqual(res[3][1], 1)
0335
0336 def test_udf_with_order_by_and_limit(self):
0337 my_copy = udf(lambda x: x, IntegerType())
0338 df = self.spark.range(10).orderBy("id")
0339 res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
0340 self.assertEqual(res.collect(), [Row(id=0, copy=0)])
0341
0342 def test_udf_registration_returns_udf(self):
0343 df = self.spark.range(10)
0344 add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType())
0345
0346 self.assertListEqual(
0347 df.selectExpr("add_three(id) AS plus_three").collect(),
0348 df.select(add_three("id").alias("plus_three")).collect()
0349 )
0350
0351
0352 sqlContext = self.spark._wrapped
0353 add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType())
0354
0355 self.assertListEqual(
0356 df.selectExpr("add_four(id) AS plus_four").collect(),
0357 df.select(add_four("id").alias("plus_four")).collect()
0358 )
0359
0360 def test_non_existed_udf(self):
0361 spark = self.spark
0362 self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
0363 lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"))
0364
0365
0366 sqlContext = spark._wrapped
0367 self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
0368 lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf"))
0369
0370 def test_non_existed_udaf(self):
0371 spark = self.spark
0372 self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",
0373 lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))
0374
0375 def test_udf_with_input_file_name(self):
0376 from pyspark.sql.functions import input_file_name
0377 sourceFile = udf(lambda path: path, StringType())
0378 filePath = "python/test_support/sql/people1.json"
0379 row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
0380 self.assertTrue(row[0].find("people1.json") != -1)
0381
0382 def test_udf_with_input_file_name_for_hadooprdd(self):
0383 from pyspark.sql.functions import input_file_name
0384
0385 def filename(path):
0386 return path
0387
0388 sameText = udf(filename, StringType())
0389
0390 rdd = self.sc.textFile('python/test_support/sql/people.json')
0391 df = self.spark.read.json(rdd).select(input_file_name().alias('file'))
0392 row = df.select(sameText(df['file'])).first()
0393 self.assertTrue(row[0].find("people.json") != -1)
0394
0395 rdd2 = self.sc.newAPIHadoopFile(
0396 'python/test_support/sql/people.json',
0397 'org.apache.hadoop.mapreduce.lib.input.TextInputFormat',
0398 'org.apache.hadoop.io.LongWritable',
0399 'org.apache.hadoop.io.Text')
0400
0401 df2 = self.spark.read.json(rdd2).select(input_file_name().alias('file'))
0402 row2 = df2.select(sameText(df2['file'])).first()
0403 self.assertTrue(row2[0].find("people.json") != -1)
0404
0405 def test_udf_defers_judf_initialization(self):
0406
0407
0408
0409 f = UserDefinedFunction(lambda x: x, StringType())
0410
0411 self.assertIsNone(
0412 f._judf_placeholder,
0413 "judf should not be initialized before the first call."
0414 )
0415
0416 self.assertIsInstance(f("foo"), Column, "UDF call should return a Column.")
0417
0418 self.assertIsNotNone(
0419 f._judf_placeholder,
0420 "judf should be initialized after UDF has been called."
0421 )
0422
0423 def test_udf_with_string_return_type(self):
0424 add_one = UserDefinedFunction(lambda x: x + 1, "integer")
0425 make_pair = UserDefinedFunction(lambda x: (-x, x), "struct<x:integer,y:integer>")
0426 make_array = UserDefinedFunction(
0427 lambda x: [float(x) for x in range(x, x + 3)], "array<double>")
0428
0429 expected = (2, Row(x=-1, y=1), [1.0, 2.0, 3.0])
0430 actual = (self.spark.range(1, 2).toDF("x")
0431 .select(add_one("x"), make_pair("x"), make_array("x"))
0432 .first())
0433
0434 self.assertTupleEqual(expected, actual)
0435
0436 def test_udf_shouldnt_accept_noncallable_object(self):
0437 non_callable = None
0438 self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())
0439
0440 def test_udf_with_decorator(self):
0441 from pyspark.sql.functions import lit
0442 from pyspark.sql.types import IntegerType, DoubleType
0443
0444 @udf(IntegerType())
0445 def add_one(x):
0446 if x is not None:
0447 return x + 1
0448
0449 @udf(returnType=DoubleType())
0450 def add_two(x):
0451 if x is not None:
0452 return float(x + 2)
0453
0454 @udf
0455 def to_upper(x):
0456 if x is not None:
0457 return x.upper()
0458
0459 @udf()
0460 def to_lower(x):
0461 if x is not None:
0462 return x.lower()
0463
0464 @udf
0465 def substr(x, start, end):
0466 if x is not None:
0467 return x[start:end]
0468
0469 @udf("long")
0470 def trunc(x):
0471 return int(x)
0472
0473 @udf(returnType="double")
0474 def as_double(x):
0475 return float(x)
0476
0477 df = (
0478 self.spark
0479 .createDataFrame(
0480 [(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", "float"))
0481 .select(
0482 add_one("one"), add_two("one"),
0483 to_upper("Foo"), to_lower("Foo"),
0484 substr("foobar", lit(0), lit(3)),
0485 trunc("float"), as_double("one")))
0486
0487 self.assertListEqual(
0488 [tpe for _, tpe in df.dtypes],
0489 ["int", "double", "string", "string", "string", "bigint", "double"]
0490 )
0491
0492 self.assertListEqual(
0493 list(df.first()),
0494 [2, 3.0, "FOO", "foo", "foo", 3, 1.0]
0495 )
0496
0497 def test_udf_wrapper(self):
0498 from pyspark.sql.types import IntegerType
0499
0500 def f(x):
0501 """Identity"""
0502 return x
0503
0504 return_type = IntegerType()
0505 f_ = udf(f, return_type)
0506
0507 self.assertTrue(f.__doc__ in f_.__doc__)
0508 self.assertEqual(f, f_.func)
0509 self.assertEqual(return_type, f_.returnType)
0510
0511 class F(object):
0512 """Identity"""
0513 def __call__(self, x):
0514 return x
0515
0516 f = F()
0517 return_type = IntegerType()
0518 f_ = udf(f, return_type)
0519
0520 self.assertTrue(f.__doc__ in f_.__doc__)
0521 self.assertEqual(f, f_.func)
0522 self.assertEqual(return_type, f_.returnType)
0523
0524 f = functools.partial(f, x=1)
0525 return_type = IntegerType()
0526 f_ = udf(f, return_type)
0527
0528 self.assertTrue(f.__doc__ in f_.__doc__)
0529 self.assertEqual(f, f_.func)
0530 self.assertEqual(return_type, f_.returnType)
0531
0532 def test_nonparam_udf_with_aggregate(self):
0533 import pyspark.sql.functions as f
0534
0535 df = self.spark.createDataFrame([(1, 2), (1, 2)])
0536 f_udf = f.udf(lambda: "const_str")
0537 rows = df.distinct().withColumn("a", f_udf()).collect()
0538 self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')])
0539
0540
0541 @unittest.skipIf(not test_compiled, test_not_compiled_message)
0542 def test_datasource_with_udf(self):
0543 from pyspark.sql.functions import lit, col
0544
0545 path = tempfile.mkdtemp()
0546 shutil.rmtree(path)
0547
0548 try:
0549 self.spark.range(1).write.mode("overwrite").format('csv').save(path)
0550 filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i')
0551 datasource_df = self.spark.read \
0552 .format("org.apache.spark.sql.sources.SimpleScanSource") \
0553 .option('from', 0).option('to', 1).load().toDF('i')
0554 datasource_v2_df = self.spark.read \
0555 .format("org.apache.spark.sql.connector.SimpleDataSourceV2") \
0556 .load().toDF('i', 'j')
0557
0558 c1 = udf(lambda x: x + 1, 'int')(lit(1))
0559 c2 = udf(lambda x: x + 1, 'int')(col('i'))
0560
0561 f1 = udf(lambda x: False, 'boolean')(lit(1))
0562 f2 = udf(lambda x: False, 'boolean')(col('i'))
0563
0564 for df in [filesource_df, datasource_df, datasource_v2_df]:
0565 result = df.withColumn('c', c1)
0566 expected = df.withColumn('c', lit(2))
0567 self.assertEquals(expected.collect(), result.collect())
0568
0569 for df in [filesource_df, datasource_df, datasource_v2_df]:
0570 result = df.withColumn('c', c2)
0571 expected = df.withColumn('c', col('i') + 1)
0572 self.assertEquals(expected.collect(), result.collect())
0573
0574 for df in [filesource_df, datasource_df, datasource_v2_df]:
0575 for f in [f1, f2]:
0576 result = df.filter(f)
0577 self.assertEquals(0, result.count())
0578 finally:
0579 shutil.rmtree(path)
0580
0581
0582 def test_same_accumulator_in_udfs(self):
0583 data_schema = StructType([StructField("a", IntegerType(), True),
0584 StructField("b", IntegerType(), True)])
0585 data = self.spark.createDataFrame([[1, 2]], schema=data_schema)
0586
0587 test_accum = self.sc.accumulator(0)
0588
0589 def first_udf(x):
0590 test_accum.add(1)
0591 return x
0592
0593 def second_udf(x):
0594 test_accum.add(100)
0595 return x
0596
0597 func_udf = udf(first_udf, IntegerType())
0598 func_udf2 = udf(second_udf, IntegerType())
0599 data = data.withColumn("out1", func_udf(data["a"]))
0600 data = data.withColumn("out2", func_udf2(data["b"]))
0601 data.collect()
0602 self.assertEqual(test_accum.value, 101)
0603
0604
0605 def test_udf_in_subquery(self):
0606 f = udf(lambda x: x, "long")
0607 with self.tempView("v"):
0608 self.spark.range(1).filter(f("id") >= 0).createTempView("v")
0609 sql = self.spark.sql
0610 result = sql("select i from values(0L) as data(i) where i in (select id from v)")
0611 self.assertEqual(result.collect(), [Row(i=0)])
0612
0613 def test_udf_globals_not_overwritten(self):
0614 @udf('string')
0615 def f():
0616 assert "itertools" not in str(map)
0617
0618 self.spark.range(1).select(f()).collect()
0619
0620 def test_worker_original_stdin_closed(self):
0621
0622
0623 def task(iterator):
0624 import sys
0625 res = sys.stdin.read()
0626
0627 assert res == '', "Expect read EOF from stdin."
0628 return iterator
0629
0630 self.sc.parallelize(range(1), 1).mapPartitions(task).count()
0631
0632 def test_udf_with_256_args(self):
0633 N = 256
0634 data = [["data-%d" % i for i in range(N)]] * 5
0635 df = self.spark.createDataFrame(data)
0636
0637 def f(*a):
0638 return "success"
0639
0640 fUdf = udf(f, StringType())
0641
0642 r = df.select(fUdf(*df.columns))
0643 self.assertEqual(r.first()[0], "success")
0644
0645
0646 class UDFInitializationTests(unittest.TestCase):
0647 def tearDown(self):
0648 if SparkSession._instantiatedSession is not None:
0649 SparkSession._instantiatedSession.stop()
0650
0651 if SparkContext._active_spark_context is not None:
0652 SparkContext._active_spark_context.stop()
0653
0654 def test_udf_init_shouldnt_initialize_context(self):
0655 UserDefinedFunction(lambda x: x, StringType())
0656
0657 self.assertIsNone(
0658 SparkContext._active_spark_context,
0659 "SparkContext shouldn't be initialized when UserDefinedFunction is created."
0660 )
0661 self.assertIsNone(
0662 SparkSession._instantiatedSession,
0663 "SparkSession shouldn't be initialized when UserDefinedFunction is created."
0664 )
0665
0666
0667 if __name__ == "__main__":
0668 from pyspark.sql.tests.test_udf import *
0669
0670 try:
0671 import xmlrunner
0672 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0673 except ImportError:
0674 testRunner = None
0675 unittest.main(testRunner=testRunner, verbosity=2)