Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import functools
0019 import pydoc
0020 import shutil
0021 import tempfile
0022 import unittest
0023 
0024 from pyspark import SparkContext
0025 from pyspark.sql import SparkSession, Column, Row
0026 from pyspark.sql.functions import UserDefinedFunction, udf
0027 from pyspark.sql.types import *
0028 from pyspark.sql.utils import AnalysisException
0029 from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message
0030 from pyspark.testing.utils import QuietTest
0031 
0032 
0033 class UDFTests(ReusedSQLTestCase):
0034 
0035     def test_udf_with_callable(self):
0036         d = [Row(number=i, squared=i**2) for i in range(10)]
0037         rdd = self.sc.parallelize(d)
0038         data = self.spark.createDataFrame(rdd)
0039 
0040         class PlusFour:
0041             def __call__(self, col):
0042                 if col is not None:
0043                     return col + 4
0044 
0045         call = PlusFour()
0046         pudf = UserDefinedFunction(call, LongType())
0047         res = data.select(pudf(data['number']).alias('plus_four'))
0048         self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
0049 
0050     def test_udf_with_partial_function(self):
0051         d = [Row(number=i, squared=i**2) for i in range(10)]
0052         rdd = self.sc.parallelize(d)
0053         data = self.spark.createDataFrame(rdd)
0054 
0055         def some_func(col, param):
0056             if col is not None:
0057                 return col + param
0058 
0059         pfunc = functools.partial(some_func, param=4)
0060         pudf = UserDefinedFunction(pfunc, LongType())
0061         res = data.select(pudf(data['number']).alias('plus_four'))
0062         self.assertEqual(res.agg({'plus_four': 'sum'}).collect()[0][0], 85)
0063 
0064     def test_udf(self):
0065         self.spark.catalog.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
0066         [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
0067         self.assertEqual(row[0], 5)
0068 
0069         # This is to check if a deprecated 'SQLContext.registerFunction' can call its alias.
0070         sqlContext = self.spark._wrapped
0071         sqlContext.registerFunction("oneArg", lambda x: len(x), IntegerType())
0072         [row] = sqlContext.sql("SELECT oneArg('test')").collect()
0073         self.assertEqual(row[0], 4)
0074 
0075     def test_udf2(self):
0076         with self.tempView("test"):
0077             self.spark.catalog.registerFunction("strlen", lambda string: len(string), IntegerType())
0078             self.spark.createDataFrame(self.sc.parallelize([Row(a="test")]))\
0079                 .createOrReplaceTempView("test")
0080             [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
0081             self.assertEqual(4, res[0])
0082 
0083     def test_udf3(self):
0084         two_args = self.spark.catalog.registerFunction(
0085             "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y))
0086         self.assertEqual(two_args.deterministic, True)
0087         [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
0088         self.assertEqual(row[0], u'5')
0089 
0090     def test_udf_registration_return_type_none(self):
0091         two_args = self.spark.catalog.registerFunction(
0092             "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y, "integer"), None)
0093         self.assertEqual(two_args.deterministic, True)
0094         [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect()
0095         self.assertEqual(row[0], 5)
0096 
0097     def test_udf_registration_return_type_not_none(self):
0098         with QuietTest(self.sc):
0099             with self.assertRaisesRegexp(TypeError, "Invalid return type"):
0100                 self.spark.catalog.registerFunction(
0101                     "f", UserDefinedFunction(lambda x, y: len(x) + y, StringType()), StringType())
0102 
0103     def test_nondeterministic_udf(self):
0104         # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
0105         import random
0106         udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
0107         self.assertEqual(udf_random_col.deterministic, False)
0108         df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
0109         udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
0110         [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
0111         self.assertEqual(row[0] + 10, row[1])
0112 
0113     def test_nondeterministic_udf2(self):
0114         import random
0115         random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
0116         self.assertEqual(random_udf.deterministic, False)
0117         random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf)
0118         self.assertEqual(random_udf1.deterministic, False)
0119         [row] = self.spark.sql("SELECT randInt()").collect()
0120         self.assertEqual(row[0], 6)
0121         [row] = self.spark.range(1).select(random_udf1()).collect()
0122         self.assertEqual(row[0], 6)
0123         [row] = self.spark.range(1).select(random_udf()).collect()
0124         self.assertEqual(row[0], 6)
0125         # render_doc() reproduces the help() exception without printing output
0126         pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType()))
0127         pydoc.render_doc(random_udf)
0128         pydoc.render_doc(random_udf1)
0129         pydoc.render_doc(udf(lambda x: x).asNondeterministic)
0130 
0131     def test_nondeterministic_udf3(self):
0132         # regression test for SPARK-23233
0133         f = udf(lambda x: x)
0134         # Here we cache the JVM UDF instance.
0135         self.spark.range(1).select(f("id"))
0136         # This should reset the cache to set the deterministic status correctly.
0137         f = f.asNondeterministic()
0138         # Check the deterministic status of udf.
0139         df = self.spark.range(1).select(f("id"))
0140         deterministic = df._jdf.logicalPlan().projectList().head().deterministic()
0141         self.assertFalse(deterministic)
0142 
0143     def test_nondeterministic_udf_in_aggregate(self):
0144         from pyspark.sql.functions import sum
0145         import random
0146         udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic()
0147         df = self.spark.range(10)
0148 
0149         with QuietTest(self.sc):
0150             with self.assertRaisesRegexp(AnalysisException, "nondeterministic"):
0151                 df.groupby('id').agg(sum(udf_random_col())).collect()
0152             with self.assertRaisesRegexp(AnalysisException, "nondeterministic"):
0153                 df.agg(sum(udf_random_col())).collect()
0154 
0155     def test_chained_udf(self):
0156         self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType())
0157         [row] = self.spark.sql("SELECT double(1)").collect()
0158         self.assertEqual(row[0], 2)
0159         [row] = self.spark.sql("SELECT double(double(1))").collect()
0160         self.assertEqual(row[0], 4)
0161         [row] = self.spark.sql("SELECT double(double(1) + 1)").collect()
0162         self.assertEqual(row[0], 6)
0163 
0164     def test_single_udf_with_repeated_argument(self):
0165         # regression test for SPARK-20685
0166         self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType())
0167         row = self.spark.sql("SELECT add(1, 1)").first()
0168         self.assertEqual(tuple(row), (2, ))
0169 
0170     def test_multiple_udfs(self):
0171         self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType())
0172         [row] = self.spark.sql("SELECT double(1), double(2)").collect()
0173         self.assertEqual(tuple(row), (2, 4))
0174         [row] = self.spark.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
0175         self.assertEqual(tuple(row), (4, 12))
0176         self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType())
0177         [row] = self.spark.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect()
0178         self.assertEqual(tuple(row), (6, 5))
0179 
0180     def test_udf_in_filter_on_top_of_outer_join(self):
0181         left = self.spark.createDataFrame([Row(a=1)])
0182         right = self.spark.createDataFrame([Row(a=1)])
0183         df = left.join(right, on='a', how='left_outer')
0184         df = df.withColumn('b', udf(lambda x: 'x')(df.a))
0185         self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')])
0186 
0187     def test_udf_in_filter_on_top_of_join(self):
0188         # regression test for SPARK-18589
0189         left = self.spark.createDataFrame([Row(a=1)])
0190         right = self.spark.createDataFrame([Row(b=1)])
0191         f = udf(lambda a, b: a == b, BooleanType())
0192         df = left.crossJoin(right).filter(f("a", "b"))
0193         self.assertEqual(df.collect(), [Row(a=1, b=1)])
0194 
0195     def test_udf_in_join_condition(self):
0196         # regression test for SPARK-25314
0197         left = self.spark.createDataFrame([Row(a=1)])
0198         right = self.spark.createDataFrame([Row(b=1)])
0199         f = udf(lambda a, b: a == b, BooleanType())
0200         # The udf uses attributes from both sides of join, so it is pulled out as Filter +
0201         # Cross join.
0202         df = left.join(right, f("a", "b"))
0203         with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
0204             with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'):
0205                 df.collect()
0206         with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
0207             self.assertEqual(df.collect(), [Row(a=1, b=1)])
0208 
0209     def test_udf_in_left_outer_join_condition(self):
0210         # regression test for SPARK-26147
0211         from pyspark.sql.functions import col
0212         left = self.spark.createDataFrame([Row(a=1)])
0213         right = self.spark.createDataFrame([Row(b=1)])
0214         f = udf(lambda a: str(a), StringType())
0215         # The join condition can't be pushed down, as it refers to attributes from both sides.
0216         # The Python UDF only refer to attributes from one side, so it's evaluable.
0217         df = left.join(right, f("a") == col("b").cast("string"), how="left_outer")
0218         with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
0219             self.assertEqual(df.collect(), [Row(a=1, b=1)])
0220 
0221     def test_udf_and_common_filter_in_join_condition(self):
0222         # regression test for SPARK-25314
0223         # test the complex scenario with both udf and common filter
0224         left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
0225         right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
0226         f = udf(lambda a, b: a == b, BooleanType())
0227         df = left.join(right, [f("a", "b"), left.a1 == right.b1])
0228         # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition.
0229         self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
0230 
0231     def test_udf_not_supported_in_join_condition(self):
0232         # regression test for SPARK-25314
0233         # test python udf is not supported in join type except inner join.
0234         left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
0235         right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
0236         f = udf(lambda a, b: a == b, BooleanType())
0237 
0238         def runWithJoinType(join_type, type_string):
0239             with self.assertRaisesRegexp(
0240                     AnalysisException,
0241                     'Using PythonUDF.*%s is not supported.' % type_string):
0242                 left.join(right, [f("a", "b"), left.a1 == right.b1], join_type).collect()
0243         runWithJoinType("full", "FullOuter")
0244         runWithJoinType("left", "LeftOuter")
0245         runWithJoinType("right", "RightOuter")
0246         runWithJoinType("leftanti", "LeftAnti")
0247         runWithJoinType("leftsemi", "LeftSemi")
0248 
0249     def test_udf_as_join_condition(self):
0250         left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
0251         right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
0252         f = udf(lambda a: a, IntegerType())
0253 
0254         df = left.join(right, [f("a") == f("b"), left.a1 == right.b1])
0255         self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
0256 
0257     def test_udf_without_arguments(self):
0258         self.spark.catalog.registerFunction("foo", lambda: "bar")
0259         [row] = self.spark.sql("SELECT foo()").collect()
0260         self.assertEqual(row[0], "bar")
0261 
0262     def test_udf_with_array_type(self):
0263         with self.tempView("test"):
0264             d = [Row(l=list(range(3)), d={"key": list(range(5))})]
0265             rdd = self.sc.parallelize(d)
0266             self.spark.createDataFrame(rdd).createOrReplaceTempView("test")
0267             self.spark.catalog.registerFunction(
0268                 "copylist", lambda l: list(l), ArrayType(IntegerType()))
0269             self.spark.catalog.registerFunction("maplen", lambda d: len(d), IntegerType())
0270             [(l1, l2)] = self.spark.sql("select copylist(l), maplen(d) from test").collect()
0271             self.assertEqual(list(range(3)), l1)
0272             self.assertEqual(1, l2)
0273 
0274     def test_broadcast_in_udf(self):
0275         bar = {"a": "aa", "b": "bb", "c": "abc"}
0276         foo = self.sc.broadcast(bar)
0277         self.spark.catalog.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
0278         [res] = self.spark.sql("SELECT MYUDF('c')").collect()
0279         self.assertEqual("abc", res[0])
0280         [res] = self.spark.sql("SELECT MYUDF('')").collect()
0281         self.assertEqual("", res[0])
0282 
0283     def test_udf_with_filter_function(self):
0284         df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
0285         from pyspark.sql.functions import col
0286         from pyspark.sql.types import BooleanType
0287 
0288         my_filter = udf(lambda a: a < 2, BooleanType())
0289         sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2"))
0290         self.assertEqual(sel.collect(), [Row(key=1, value='1')])
0291 
0292     def test_udf_with_aggregate_function(self):
0293         df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
0294         from pyspark.sql.functions import col, sum
0295         from pyspark.sql.types import BooleanType
0296 
0297         my_filter = udf(lambda a: a == 1, BooleanType())
0298         sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
0299         self.assertEqual(sel.collect(), [Row(key=1)])
0300 
0301         my_copy = udf(lambda x: x, IntegerType())
0302         my_add = udf(lambda a, b: int(a + b), IntegerType())
0303         my_strlen = udf(lambda x: len(x), IntegerType())
0304         sel = df.groupBy(my_copy(col("key")).alias("k"))\
0305             .agg(sum(my_strlen(col("value"))).alias("s"))\
0306             .select(my_add(col("k"), col("s")).alias("t"))
0307         self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
0308 
0309     def test_udf_in_generate(self):
0310         from pyspark.sql.functions import explode
0311         df = self.spark.range(5)
0312         f = udf(lambda x: list(range(x)), ArrayType(LongType()))
0313         row = df.select(explode(f(*df))).groupBy().sum().first()
0314         self.assertEqual(row[0], 10)
0315 
0316         df = self.spark.range(3)
0317         res = df.select("id", explode(f(df.id))).collect()
0318         self.assertEqual(res[0][0], 1)
0319         self.assertEqual(res[0][1], 0)
0320         self.assertEqual(res[1][0], 2)
0321         self.assertEqual(res[1][1], 0)
0322         self.assertEqual(res[2][0], 2)
0323         self.assertEqual(res[2][1], 1)
0324 
0325         range_udf = udf(lambda value: list(range(value - 1, value + 1)), ArrayType(IntegerType()))
0326         res = df.select("id", explode(range_udf(df.id))).collect()
0327         self.assertEqual(res[0][0], 0)
0328         self.assertEqual(res[0][1], -1)
0329         self.assertEqual(res[1][0], 0)
0330         self.assertEqual(res[1][1], 0)
0331         self.assertEqual(res[2][0], 1)
0332         self.assertEqual(res[2][1], 0)
0333         self.assertEqual(res[3][0], 1)
0334         self.assertEqual(res[3][1], 1)
0335 
0336     def test_udf_with_order_by_and_limit(self):
0337         my_copy = udf(lambda x: x, IntegerType())
0338         df = self.spark.range(10).orderBy("id")
0339         res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
0340         self.assertEqual(res.collect(), [Row(id=0, copy=0)])
0341 
0342     def test_udf_registration_returns_udf(self):
0343         df = self.spark.range(10)
0344         add_three = self.spark.udf.register("add_three", lambda x: x + 3, IntegerType())
0345 
0346         self.assertListEqual(
0347             df.selectExpr("add_three(id) AS plus_three").collect(),
0348             df.select(add_three("id").alias("plus_three")).collect()
0349         )
0350 
0351         # This is to check if a 'SQLContext.udf' can call its alias.
0352         sqlContext = self.spark._wrapped
0353         add_four = sqlContext.udf.register("add_four", lambda x: x + 4, IntegerType())
0354 
0355         self.assertListEqual(
0356             df.selectExpr("add_four(id) AS plus_four").collect(),
0357             df.select(add_four("id").alias("plus_four")).collect()
0358         )
0359 
0360     def test_non_existed_udf(self):
0361         spark = self.spark
0362         self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
0363                                 lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"))
0364 
0365         # This is to check if a deprecated 'SQLContext.registerJavaFunction' can call its alias.
0366         sqlContext = spark._wrapped
0367         self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
0368                                 lambda: sqlContext.registerJavaFunction("udf1", "non_existed_udf"))
0369 
0370     def test_non_existed_udaf(self):
0371         spark = self.spark
0372         self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",
0373                                 lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))
0374 
0375     def test_udf_with_input_file_name(self):
0376         from pyspark.sql.functions import input_file_name
0377         sourceFile = udf(lambda path: path, StringType())
0378         filePath = "python/test_support/sql/people1.json"
0379         row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
0380         self.assertTrue(row[0].find("people1.json") != -1)
0381 
0382     def test_udf_with_input_file_name_for_hadooprdd(self):
0383         from pyspark.sql.functions import input_file_name
0384 
0385         def filename(path):
0386             return path
0387 
0388         sameText = udf(filename, StringType())
0389 
0390         rdd = self.sc.textFile('python/test_support/sql/people.json')
0391         df = self.spark.read.json(rdd).select(input_file_name().alias('file'))
0392         row = df.select(sameText(df['file'])).first()
0393         self.assertTrue(row[0].find("people.json") != -1)
0394 
0395         rdd2 = self.sc.newAPIHadoopFile(
0396             'python/test_support/sql/people.json',
0397             'org.apache.hadoop.mapreduce.lib.input.TextInputFormat',
0398             'org.apache.hadoop.io.LongWritable',
0399             'org.apache.hadoop.io.Text')
0400 
0401         df2 = self.spark.read.json(rdd2).select(input_file_name().alias('file'))
0402         row2 = df2.select(sameText(df2['file'])).first()
0403         self.assertTrue(row2[0].find("people.json") != -1)
0404 
0405     def test_udf_defers_judf_initialization(self):
0406         # This is separate of  UDFInitializationTests
0407         # to avoid context initialization
0408         # when udf is called
0409         f = UserDefinedFunction(lambda x: x, StringType())
0410 
0411         self.assertIsNone(
0412             f._judf_placeholder,
0413             "judf should not be initialized before the first call."
0414         )
0415 
0416         self.assertIsInstance(f("foo"), Column, "UDF call should return a Column.")
0417 
0418         self.assertIsNotNone(
0419             f._judf_placeholder,
0420             "judf should be initialized after UDF has been called."
0421         )
0422 
0423     def test_udf_with_string_return_type(self):
0424         add_one = UserDefinedFunction(lambda x: x + 1, "integer")
0425         make_pair = UserDefinedFunction(lambda x: (-x, x), "struct<x:integer,y:integer>")
0426         make_array = UserDefinedFunction(
0427             lambda x: [float(x) for x in range(x, x + 3)], "array<double>")
0428 
0429         expected = (2, Row(x=-1, y=1), [1.0, 2.0, 3.0])
0430         actual = (self.spark.range(1, 2).toDF("x")
0431                   .select(add_one("x"), make_pair("x"), make_array("x"))
0432                   .first())
0433 
0434         self.assertTupleEqual(expected, actual)
0435 
0436     def test_udf_shouldnt_accept_noncallable_object(self):
0437         non_callable = None
0438         self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())
0439 
0440     def test_udf_with_decorator(self):
0441         from pyspark.sql.functions import lit
0442         from pyspark.sql.types import IntegerType, DoubleType
0443 
0444         @udf(IntegerType())
0445         def add_one(x):
0446             if x is not None:
0447                 return x + 1
0448 
0449         @udf(returnType=DoubleType())
0450         def add_two(x):
0451             if x is not None:
0452                 return float(x + 2)
0453 
0454         @udf
0455         def to_upper(x):
0456             if x is not None:
0457                 return x.upper()
0458 
0459         @udf()
0460         def to_lower(x):
0461             if x is not None:
0462                 return x.lower()
0463 
0464         @udf
0465         def substr(x, start, end):
0466             if x is not None:
0467                 return x[start:end]
0468 
0469         @udf("long")
0470         def trunc(x):
0471             return int(x)
0472 
0473         @udf(returnType="double")
0474         def as_double(x):
0475             return float(x)
0476 
0477         df = (
0478             self.spark
0479                 .createDataFrame(
0480                     [(1, "Foo", "foobar", 3.0)], ("one", "Foo", "foobar", "float"))
0481                 .select(
0482                     add_one("one"), add_two("one"),
0483                     to_upper("Foo"), to_lower("Foo"),
0484                     substr("foobar", lit(0), lit(3)),
0485                     trunc("float"), as_double("one")))
0486 
0487         self.assertListEqual(
0488             [tpe for _, tpe in df.dtypes],
0489             ["int", "double", "string", "string", "string", "bigint", "double"]
0490         )
0491 
0492         self.assertListEqual(
0493             list(df.first()),
0494             [2, 3.0, "FOO", "foo", "foo", 3, 1.0]
0495         )
0496 
0497     def test_udf_wrapper(self):
0498         from pyspark.sql.types import IntegerType
0499 
0500         def f(x):
0501             """Identity"""
0502             return x
0503 
0504         return_type = IntegerType()
0505         f_ = udf(f, return_type)
0506 
0507         self.assertTrue(f.__doc__ in f_.__doc__)
0508         self.assertEqual(f, f_.func)
0509         self.assertEqual(return_type, f_.returnType)
0510 
0511         class F(object):
0512             """Identity"""
0513             def __call__(self, x):
0514                 return x
0515 
0516         f = F()
0517         return_type = IntegerType()
0518         f_ = udf(f, return_type)
0519 
0520         self.assertTrue(f.__doc__ in f_.__doc__)
0521         self.assertEqual(f, f_.func)
0522         self.assertEqual(return_type, f_.returnType)
0523 
0524         f = functools.partial(f, x=1)
0525         return_type = IntegerType()
0526         f_ = udf(f, return_type)
0527 
0528         self.assertTrue(f.__doc__ in f_.__doc__)
0529         self.assertEqual(f, f_.func)
0530         self.assertEqual(return_type, f_.returnType)
0531 
0532     def test_nonparam_udf_with_aggregate(self):
0533         import pyspark.sql.functions as f
0534 
0535         df = self.spark.createDataFrame([(1, 2), (1, 2)])
0536         f_udf = f.udf(lambda: "const_str")
0537         rows = df.distinct().withColumn("a", f_udf()).collect()
0538         self.assertEqual(rows, [Row(_1=1, _2=2, a=u'const_str')])
0539 
0540     # SPARK-24721
0541     @unittest.skipIf(not test_compiled, test_not_compiled_message)
0542     def test_datasource_with_udf(self):
0543         from pyspark.sql.functions import lit, col
0544 
0545         path = tempfile.mkdtemp()
0546         shutil.rmtree(path)
0547 
0548         try:
0549             self.spark.range(1).write.mode("overwrite").format('csv').save(path)
0550             filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i')
0551             datasource_df = self.spark.read \
0552                 .format("org.apache.spark.sql.sources.SimpleScanSource") \
0553                 .option('from', 0).option('to', 1).load().toDF('i')
0554             datasource_v2_df = self.spark.read \
0555                 .format("org.apache.spark.sql.connector.SimpleDataSourceV2") \
0556                 .load().toDF('i', 'j')
0557 
0558             c1 = udf(lambda x: x + 1, 'int')(lit(1))
0559             c2 = udf(lambda x: x + 1, 'int')(col('i'))
0560 
0561             f1 = udf(lambda x: False, 'boolean')(lit(1))
0562             f2 = udf(lambda x: False, 'boolean')(col('i'))
0563 
0564             for df in [filesource_df, datasource_df, datasource_v2_df]:
0565                 result = df.withColumn('c', c1)
0566                 expected = df.withColumn('c', lit(2))
0567                 self.assertEquals(expected.collect(), result.collect())
0568 
0569             for df in [filesource_df, datasource_df, datasource_v2_df]:
0570                 result = df.withColumn('c', c2)
0571                 expected = df.withColumn('c', col('i') + 1)
0572                 self.assertEquals(expected.collect(), result.collect())
0573 
0574             for df in [filesource_df, datasource_df, datasource_v2_df]:
0575                 for f in [f1, f2]:
0576                     result = df.filter(f)
0577                     self.assertEquals(0, result.count())
0578         finally:
0579             shutil.rmtree(path)
0580 
0581     # SPARK-25591
0582     def test_same_accumulator_in_udfs(self):
0583         data_schema = StructType([StructField("a", IntegerType(), True),
0584                                   StructField("b", IntegerType(), True)])
0585         data = self.spark.createDataFrame([[1, 2]], schema=data_schema)
0586 
0587         test_accum = self.sc.accumulator(0)
0588 
0589         def first_udf(x):
0590             test_accum.add(1)
0591             return x
0592 
0593         def second_udf(x):
0594             test_accum.add(100)
0595             return x
0596 
0597         func_udf = udf(first_udf, IntegerType())
0598         func_udf2 = udf(second_udf, IntegerType())
0599         data = data.withColumn("out1", func_udf(data["a"]))
0600         data = data.withColumn("out2", func_udf2(data["b"]))
0601         data.collect()
0602         self.assertEqual(test_accum.value, 101)
0603 
0604     # SPARK-26293
0605     def test_udf_in_subquery(self):
0606         f = udf(lambda x: x, "long")
0607         with self.tempView("v"):
0608             self.spark.range(1).filter(f("id") >= 0).createTempView("v")
0609             sql = self.spark.sql
0610             result = sql("select i from values(0L) as data(i) where i in (select id from v)")
0611             self.assertEqual(result.collect(), [Row(i=0)])
0612 
0613     def test_udf_globals_not_overwritten(self):
0614         @udf('string')
0615         def f():
0616             assert "itertools" not in str(map)
0617 
0618         self.spark.range(1).select(f()).collect()
0619 
0620     def test_worker_original_stdin_closed(self):
0621         # Test if it closes the original standard input of worker inherited from the daemon,
0622         # and replaces it with '/dev/null'.  See SPARK-26175.
0623         def task(iterator):
0624             import sys
0625             res = sys.stdin.read()
0626             # Because the standard input is '/dev/null', it reaches to EOF.
0627             assert res == '', "Expect read EOF from stdin."
0628             return iterator
0629 
0630         self.sc.parallelize(range(1), 1).mapPartitions(task).count()
0631 
0632     def test_udf_with_256_args(self):
0633         N = 256
0634         data = [["data-%d" % i for i in range(N)]] * 5
0635         df = self.spark.createDataFrame(data)
0636 
0637         def f(*a):
0638             return "success"
0639 
0640         fUdf = udf(f, StringType())
0641 
0642         r = df.select(fUdf(*df.columns))
0643         self.assertEqual(r.first()[0], "success")
0644 
0645 
0646 class UDFInitializationTests(unittest.TestCase):
0647     def tearDown(self):
0648         if SparkSession._instantiatedSession is not None:
0649             SparkSession._instantiatedSession.stop()
0650 
0651         if SparkContext._active_spark_context is not None:
0652             SparkContext._active_spark_context.stop()
0653 
0654     def test_udf_init_shouldnt_initialize_context(self):
0655         UserDefinedFunction(lambda x: x, StringType())
0656 
0657         self.assertIsNone(
0658             SparkContext._active_spark_context,
0659             "SparkContext shouldn't be initialized when UserDefinedFunction is created."
0660         )
0661         self.assertIsNone(
0662             SparkSession._instantiatedSession,
0663             "SparkSession shouldn't be initialized when UserDefinedFunction is created."
0664         )
0665 
0666 
0667 if __name__ == "__main__":
0668     from pyspark.sql.tests.test_udf import *
0669 
0670     try:
0671         import xmlrunner
0672         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0673     except ImportError:
0674         testRunner = None
0675     unittest.main(testRunner=testRunner, verbosity=2)