Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 import datetime
0018 import os
0019 import random
0020 import shutil
0021 import sys
0022 import tempfile
0023 import time
0024 import unittest
0025 
0026 if sys.version >= '3':
0027     unicode = str
0028 
0029 from datetime import date, datetime
0030 from decimal import Decimal
0031 
0032 from pyspark import TaskContext
0033 from pyspark.rdd import PythonEvalType
0034 from pyspark.sql import Column
0035 from pyspark.sql.functions import array, col, expr, lit, sum, struct, udf, pandas_udf, \
0036     PandasUDFType
0037 from pyspark.sql.types import Row
0038 from pyspark.sql.types import *
0039 from pyspark.sql.utils import AnalysisException
0040 from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled,\
0041     test_not_compiled_message, have_pandas, have_pyarrow, pandas_requirement_message, \
0042     pyarrow_requirement_message
0043 from pyspark.testing.utils import QuietTest
0044 
0045 if have_pandas:
0046     import pandas as pd
0047 
0048 if have_pyarrow:
0049     import pyarrow as pa
0050 
0051 
0052 @unittest.skipIf(
0053     not have_pandas or not have_pyarrow,
0054     pandas_requirement_message or pyarrow_requirement_message)
0055 class ScalarPandasUDFTests(ReusedSQLTestCase):
0056 
0057     @classmethod
0058     def setUpClass(cls):
0059         ReusedSQLTestCase.setUpClass()
0060 
0061         # Synchronize default timezone between Python and Java
0062         cls.tz_prev = os.environ.get("TZ", None)  # save current tz if set
0063         tz = "America/Los_Angeles"
0064         os.environ["TZ"] = tz
0065         time.tzset()
0066 
0067         cls.sc.environment["TZ"] = tz
0068         cls.spark.conf.set("spark.sql.session.timeZone", tz)
0069 
0070     @classmethod
0071     def tearDownClass(cls):
0072         del os.environ["TZ"]
0073         if cls.tz_prev is not None:
0074             os.environ["TZ"] = cls.tz_prev
0075         time.tzset()
0076         ReusedSQLTestCase.tearDownClass()
0077 
0078     @property
0079     def nondeterministic_vectorized_udf(self):
0080         import numpy as np
0081 
0082         @pandas_udf('double')
0083         def random_udf(v):
0084             return pd.Series(np.random.random(len(v)))
0085         random_udf = random_udf.asNondeterministic()
0086         return random_udf
0087 
0088     @property
0089     def nondeterministic_vectorized_iter_udf(self):
0090         import numpy as np
0091 
0092         @pandas_udf('double', PandasUDFType.SCALAR_ITER)
0093         def random_udf(it):
0094             for v in it:
0095                 yield pd.Series(np.random.random(len(v)))
0096 
0097         random_udf = random_udf.asNondeterministic()
0098         return random_udf
0099 
0100     def test_pandas_udf_tokenize(self):
0101         tokenize = pandas_udf(lambda s: s.apply(lambda str: str.split(' ')),
0102                               ArrayType(StringType()))
0103         self.assertEqual(tokenize.returnType, ArrayType(StringType()))
0104         df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"])
0105         result = df.select(tokenize("vals").alias("hi"))
0106         self.assertEqual([Row(hi=[u'hi', u'boo']), Row(hi=[u'bye', u'boo'])], result.collect())
0107 
0108     def test_pandas_udf_nested_arrays(self):
0109         tokenize = pandas_udf(lambda s: s.apply(lambda str: [str.split(' ')]),
0110                               ArrayType(ArrayType(StringType())))
0111         self.assertEqual(tokenize.returnType, ArrayType(ArrayType(StringType())))
0112         df = self.spark.createDataFrame([("hi boo",), ("bye boo",)], ["vals"])
0113         result = df.select(tokenize("vals").alias("hi"))
0114         self.assertEqual([Row(hi=[[u'hi', u'boo']]), Row(hi=[[u'bye', u'boo']])], result.collect())
0115 
0116     def test_vectorized_udf_basic(self):
0117         df = self.spark.range(10).select(
0118             col('id').cast('string').alias('str'),
0119             col('id').cast('int').alias('int'),
0120             col('id').alias('long'),
0121             col('id').cast('float').alias('float'),
0122             col('id').cast('double').alias('double'),
0123             col('id').cast('decimal').alias('decimal'),
0124             col('id').cast('boolean').alias('bool'),
0125             array(col('id')).alias('array_long'))
0126         f = lambda x: x
0127         for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0128             str_f = pandas_udf(f, StringType(), udf_type)
0129             int_f = pandas_udf(f, IntegerType(), udf_type)
0130             long_f = pandas_udf(f, LongType(), udf_type)
0131             float_f = pandas_udf(f, FloatType(), udf_type)
0132             double_f = pandas_udf(f, DoubleType(), udf_type)
0133             decimal_f = pandas_udf(f, DecimalType(), udf_type)
0134             bool_f = pandas_udf(f, BooleanType(), udf_type)
0135             array_long_f = pandas_udf(f, ArrayType(LongType()), udf_type)
0136             res = df.select(str_f(col('str')), int_f(col('int')),
0137                             long_f(col('long')), float_f(col('float')),
0138                             double_f(col('double')), decimal_f('decimal'),
0139                             bool_f(col('bool')), array_long_f('array_long'))
0140             self.assertEquals(df.collect(), res.collect())
0141 
0142     def test_register_nondeterministic_vectorized_udf_basic(self):
0143         random_pandas_udf = pandas_udf(
0144             lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic()
0145         self.assertEqual(random_pandas_udf.deterministic, False)
0146         self.assertEqual(random_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
0147         nondeterministic_pandas_udf = self.spark.catalog.registerFunction(
0148             "randomPandasUDF", random_pandas_udf)
0149         self.assertEqual(nondeterministic_pandas_udf.deterministic, False)
0150         self.assertEqual(nondeterministic_pandas_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
0151         [row] = self.spark.sql("SELECT randomPandasUDF(1)").collect()
0152         self.assertEqual(row[0], 7)
0153 
0154         def random_iter_udf(it):
0155             for i in it:
0156                 yield random.randint(6, 6) + i
0157         random_pandas_iter_udf = pandas_udf(
0158             random_iter_udf, IntegerType(), PandasUDFType.SCALAR_ITER).asNondeterministic()
0159         self.assertEqual(random_pandas_iter_udf.deterministic, False)
0160         self.assertEqual(random_pandas_iter_udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF)
0161         nondeterministic_pandas_iter_udf = self.spark.catalog.registerFunction(
0162             "randomPandasIterUDF", random_pandas_iter_udf)
0163         self.assertEqual(nondeterministic_pandas_iter_udf.deterministic, False)
0164         self.assertEqual(nondeterministic_pandas_iter_udf.evalType,
0165                          PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF)
0166         [row] = self.spark.sql("SELECT randomPandasIterUDF(1)").collect()
0167         self.assertEqual(row[0], 7)
0168 
0169     def test_vectorized_udf_null_boolean(self):
0170         data = [(True,), (True,), (None,), (False,)]
0171         schema = StructType().add("bool", BooleanType())
0172         df = self.spark.createDataFrame(data, schema)
0173         for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0174             bool_f = pandas_udf(lambda x: x, BooleanType(), udf_type)
0175             res = df.select(bool_f(col('bool')))
0176             self.assertEquals(df.collect(), res.collect())
0177 
0178     def test_vectorized_udf_null_byte(self):
0179         data = [(None,), (2,), (3,), (4,)]
0180         schema = StructType().add("byte", ByteType())
0181         df = self.spark.createDataFrame(data, schema)
0182         for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0183             byte_f = pandas_udf(lambda x: x, ByteType(), udf_type)
0184             res = df.select(byte_f(col('byte')))
0185             self.assertEquals(df.collect(), res.collect())
0186 
0187     def test_vectorized_udf_null_short(self):
0188         data = [(None,), (2,), (3,), (4,)]
0189         schema = StructType().add("short", ShortType())
0190         df = self.spark.createDataFrame(data, schema)
0191         for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0192             short_f = pandas_udf(lambda x: x, ShortType(), udf_type)
0193             res = df.select(short_f(col('short')))
0194             self.assertEquals(df.collect(), res.collect())
0195 
0196     def test_vectorized_udf_null_int(self):
0197         data = [(None,), (2,), (3,), (4,)]
0198         schema = StructType().add("int", IntegerType())
0199         df = self.spark.createDataFrame(data, schema)
0200         for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0201             int_f = pandas_udf(lambda x: x, IntegerType(), udf_type)
0202             res = df.select(int_f(col('int')))
0203             self.assertEquals(df.collect(), res.collect())
0204 
0205     def test_vectorized_udf_null_long(self):
0206         data = [(None,), (2,), (3,), (4,)]
0207         schema = StructType().add("long", LongType())
0208         df = self.spark.createDataFrame(data, schema)
0209         for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0210             long_f = pandas_udf(lambda x: x, LongType(), udf_type)
0211             res = df.select(long_f(col('long')))
0212             self.assertEquals(df.collect(), res.collect())
0213 
0214     def test_vectorized_udf_null_float(self):
0215         data = [(3.0,), (5.0,), (-1.0,), (None,)]
0216         schema = StructType().add("float", FloatType())
0217         df = self.spark.createDataFrame(data, schema)
0218         for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0219             float_f = pandas_udf(lambda x: x, FloatType(), udf_type)
0220             res = df.select(float_f(col('float')))
0221             self.assertEquals(df.collect(), res.collect())
0222 
0223     def test_vectorized_udf_null_double(self):
0224         data = [(3.0,), (5.0,), (-1.0,), (None,)]
0225         schema = StructType().add("double", DoubleType())
0226         df = self.spark.createDataFrame(data, schema)
0227         for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0228             double_f = pandas_udf(lambda x: x, DoubleType(), udf_type)
0229             res = df.select(double_f(col('double')))
0230             self.assertEquals(df.collect(), res.collect())
0231 
0232     def test_vectorized_udf_null_decimal(self):
0233         data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)]
0234         schema = StructType().add("decimal", DecimalType(38, 18))
0235         df = self.spark.createDataFrame(data, schema)
0236         for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0237             decimal_f = pandas_udf(lambda x: x, DecimalType(38, 18), udf_type)
0238             res = df.select(decimal_f(col('decimal')))
0239             self.assertEquals(df.collect(), res.collect())
0240 
0241     def test_vectorized_udf_null_string(self):
0242         data = [("foo",), (None,), ("bar",), ("bar",)]
0243         schema = StructType().add("str", StringType())
0244         df = self.spark.createDataFrame(data, schema)
0245         for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0246             str_f = pandas_udf(lambda x: x, StringType(), udf_type)
0247             res = df.select(str_f(col('str')))
0248             self.assertEquals(df.collect(), res.collect())
0249 
0250     def test_vectorized_udf_string_in_udf(self):
0251         df = self.spark.range(10)
0252         scalar_f = lambda x: pd.Series(map(str, x))
0253 
0254         def iter_f(it):
0255             for i in it:
0256                 yield scalar_f(i)
0257 
0258         for f, udf_type in [(scalar_f, PandasUDFType.SCALAR), (iter_f, PandasUDFType.SCALAR_ITER)]:
0259             str_f = pandas_udf(f, StringType(), udf_type)
0260             actual = df.select(str_f(col('id')))
0261             expected = df.select(col('id').cast('string'))
0262             self.assertEquals(expected.collect(), actual.collect())
0263 
0264     def test_vectorized_udf_datatype_string(self):
0265         df = self.spark.range(10).select(
0266             col('id').cast('string').alias('str'),
0267             col('id').cast('int').alias('int'),
0268             col('id').alias('long'),
0269             col('id').cast('float').alias('float'),
0270             col('id').cast('double').alias('double'),
0271             col('id').cast('decimal').alias('decimal'),
0272             col('id').cast('boolean').alias('bool'))
0273         f = lambda x: x
0274         for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0275             str_f = pandas_udf(f, 'string', udf_type)
0276             int_f = pandas_udf(f, 'integer', udf_type)
0277             long_f = pandas_udf(f, 'long', udf_type)
0278             float_f = pandas_udf(f, 'float', udf_type)
0279             double_f = pandas_udf(f, 'double', udf_type)
0280             decimal_f = pandas_udf(f, 'decimal(38, 18)', udf_type)
0281             bool_f = pandas_udf(f, 'boolean', udf_type)
0282             res = df.select(str_f(col('str')), int_f(col('int')),
0283                             long_f(col('long')), float_f(col('float')),
0284                             double_f(col('double')), decimal_f('decimal'),
0285                             bool_f(col('bool')))
0286             self.assertEquals(df.collect(), res.collect())
0287 
0288     def test_vectorized_udf_null_binary(self):
0289         data = [(bytearray(b"a"),), (None,), (bytearray(b"bb"),), (bytearray(b"ccc"),)]
0290         schema = StructType().add("binary", BinaryType())
0291         df = self.spark.createDataFrame(data, schema)
0292         for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0293             str_f = pandas_udf(lambda x: x, BinaryType(), udf_type)
0294             res = df.select(str_f(col('binary')))
0295             self.assertEquals(df.collect(), res.collect())
0296 
0297     def test_vectorized_udf_array_type(self):
0298         data = [([1, 2],), ([3, 4],)]
0299         array_schema = StructType([StructField("array", ArrayType(IntegerType()))])
0300         df = self.spark.createDataFrame(data, schema=array_schema)
0301         for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0302             array_f = pandas_udf(lambda x: x, ArrayType(IntegerType()), udf_type)
0303             result = df.select(array_f(col('array')))
0304             self.assertEquals(df.collect(), result.collect())
0305 
0306     def test_vectorized_udf_null_array(self):
0307         data = [([1, 2],), (None,), (None,), ([3, 4],), (None,)]
0308         array_schema = StructType([StructField("array", ArrayType(IntegerType()))])
0309         df = self.spark.createDataFrame(data, schema=array_schema)
0310         for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0311             array_f = pandas_udf(lambda x: x, ArrayType(IntegerType()), udf_type)
0312             result = df.select(array_f(col('array')))
0313             self.assertEquals(df.collect(), result.collect())
0314 
0315     def test_vectorized_udf_struct_type(self):
0316         df = self.spark.range(10)
0317         return_type = StructType([
0318             StructField('id', LongType()),
0319             StructField('str', StringType())])
0320 
0321         def scalar_func(id):
0322             return pd.DataFrame({'id': id, 'str': id.apply(unicode)})
0323 
0324         def iter_func(it):
0325             for id in it:
0326                 yield scalar_func(id)
0327 
0328         for func, udf_type in [(scalar_func, PandasUDFType.SCALAR),
0329                                (iter_func, PandasUDFType.SCALAR_ITER)]:
0330             f = pandas_udf(func, returnType=return_type, functionType=udf_type)
0331 
0332             expected = df.select(struct(col('id'), col('id').cast('string').alias('str'))
0333                                  .alias('struct')).collect()
0334 
0335             actual = df.select(f(col('id')).alias('struct')).collect()
0336             self.assertEqual(expected, actual)
0337 
0338             g = pandas_udf(func, 'id: long, str: string', functionType=udf_type)
0339             actual = df.select(g(col('id')).alias('struct')).collect()
0340             self.assertEqual(expected, actual)
0341 
0342             struct_f = pandas_udf(lambda x: x, return_type, functionType=udf_type)
0343             actual = df.select(struct_f(struct(col('id'), col('id').cast('string').alias('str'))))
0344             self.assertEqual(expected, actual.collect())
0345 
0346     def test_vectorized_udf_struct_complex(self):
0347         df = self.spark.range(10)
0348         return_type = StructType([
0349             StructField('ts', TimestampType()),
0350             StructField('arr', ArrayType(LongType()))])
0351 
0352         def _scalar_f(id):
0353             return pd.DataFrame({'ts': id.apply(lambda i: pd.Timestamp(i)),
0354                                  'arr': id.apply(lambda i: [i, i + 1])})
0355 
0356         scalar_f = pandas_udf(_scalar_f, returnType=return_type)
0357 
0358         @pandas_udf(returnType=return_type, functionType=PandasUDFType.SCALAR_ITER)
0359         def iter_f(it):
0360             for id in it:
0361                 yield _scalar_f(id)
0362 
0363         for f, udf_type in [(scalar_f, PandasUDFType.SCALAR), (iter_f, PandasUDFType.SCALAR_ITER)]:
0364             actual = df.withColumn('f', f(col('id'))).collect()
0365             for i, row in enumerate(actual):
0366                 id, f = row
0367                 self.assertEqual(i, id)
0368                 self.assertEqual(pd.Timestamp(i).to_pydatetime(), f[0])
0369                 self.assertListEqual([i, i + 1], f[1])
0370 
0371     def test_vectorized_udf_nested_struct(self):
0372         nested_type = StructType([
0373             StructField('id', IntegerType()),
0374             StructField('nested', StructType([
0375                 StructField('foo', StringType()),
0376                 StructField('bar', FloatType())
0377             ]))
0378         ])
0379 
0380         for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0381             with QuietTest(self.sc):
0382                 with self.assertRaisesRegexp(
0383                         Exception,
0384                         'Invalid return type with scalar Pandas UDFs'):
0385                     pandas_udf(lambda x: x, returnType=nested_type, functionType=udf_type)
0386 
0387     def test_vectorized_udf_complex(self):
0388         df = self.spark.range(10).select(
0389             col('id').cast('int').alias('a'),
0390             col('id').cast('int').alias('b'),
0391             col('id').cast('double').alias('c'))
0392         scalar_add = pandas_udf(lambda x, y: x + y, IntegerType())
0393         scalar_power2 = pandas_udf(lambda x: 2 ** x, IntegerType())
0394         scalar_mul = pandas_udf(lambda x, y: x * y, DoubleType())
0395 
0396         @pandas_udf(IntegerType(), PandasUDFType.SCALAR_ITER)
0397         def iter_add(it):
0398             for x, y in it:
0399                 yield x + y
0400 
0401         @pandas_udf(IntegerType(), PandasUDFType.SCALAR_ITER)
0402         def iter_power2(it):
0403             for x in it:
0404                 yield 2 ** x
0405 
0406         @pandas_udf(DoubleType(), PandasUDFType.SCALAR_ITER)
0407         def iter_mul(it):
0408             for x, y in it:
0409                 yield x * y
0410 
0411         for add, power2, mul in [(scalar_add, scalar_power2, scalar_mul),
0412                                  (iter_add, iter_power2, iter_mul)]:
0413             res = df.select(add(col('a'), col('b')), power2(col('a')), mul(col('b'), col('c')))
0414             expected = df.select(expr('a + b'), expr('power(2, a)'), expr('b * c'))
0415             self.assertEquals(expected.collect(), res.collect())
0416 
0417     def test_vectorized_udf_exception(self):
0418         df = self.spark.range(10)
0419         scalar_raise_exception = pandas_udf(lambda x: x * (1 / 0), LongType())
0420 
0421         @pandas_udf(LongType(), PandasUDFType.SCALAR_ITER)
0422         def iter_raise_exception(it):
0423             for x in it:
0424                 yield x * (1 / 0)
0425 
0426         for raise_exception in [scalar_raise_exception, iter_raise_exception]:
0427             with QuietTest(self.sc):
0428                 with self.assertRaisesRegexp(Exception, 'division( or modulo)? by zero'):
0429                     df.select(raise_exception(col('id'))).collect()
0430 
0431     def test_vectorized_udf_invalid_length(self):
0432         df = self.spark.range(10)
0433         raise_exception = pandas_udf(lambda _: pd.Series(1), LongType())
0434         with QuietTest(self.sc):
0435             with self.assertRaisesRegexp(
0436                     Exception,
0437                     'Result vector from pandas_udf was not the required length'):
0438                 df.select(raise_exception(col('id'))).collect()
0439 
0440         @pandas_udf(LongType(), PandasUDFType.SCALAR_ITER)
0441         def iter_udf_wong_output_size(it):
0442             for _ in it:
0443                 yield pd.Series(1)
0444 
0445         with QuietTest(self.sc):
0446             with self.assertRaisesRegexp(
0447                     Exception,
0448                     "The length of output in Scalar iterator.*"
0449                     "the length of output was 1"):
0450                 df.select(iter_udf_wong_output_size(col('id'))).collect()
0451 
0452         @pandas_udf(LongType(), PandasUDFType.SCALAR_ITER)
0453         def iter_udf_not_reading_all_input(it):
0454             for batch in it:
0455                 batch_len = len(batch)
0456                 yield pd.Series([1] * batch_len)
0457                 break
0458 
0459         with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}):
0460             df1 = self.spark.range(10).repartition(1)
0461             with QuietTest(self.sc):
0462                 with self.assertRaisesRegexp(
0463                         Exception,
0464                         "pandas iterator UDF should exhaust"):
0465                     df1.select(iter_udf_not_reading_all_input(col('id'))).collect()
0466 
0467     def test_vectorized_udf_chained(self):
0468         df = self.spark.range(10)
0469         scalar_f = pandas_udf(lambda x: x + 1, LongType())
0470         scalar_g = pandas_udf(lambda x: x - 1, LongType())
0471 
0472         iter_f = pandas_udf(lambda it: map(lambda x: x + 1, it), LongType(),
0473                             PandasUDFType.SCALAR_ITER)
0474         iter_g = pandas_udf(lambda it: map(lambda x: x - 1, it), LongType(),
0475                             PandasUDFType.SCALAR_ITER)
0476 
0477         for f, g in [(scalar_f, scalar_g), (iter_f, iter_g)]:
0478             res = df.select(g(f(col('id'))))
0479             self.assertEquals(df.collect(), res.collect())
0480 
0481     def test_vectorized_udf_chained_struct_type(self):
0482         df = self.spark.range(10)
0483         return_type = StructType([
0484             StructField('id', LongType()),
0485             StructField('str', StringType())])
0486 
0487         @pandas_udf(return_type)
0488         def scalar_f(id):
0489             return pd.DataFrame({'id': id, 'str': id.apply(unicode)})
0490 
0491         scalar_g = pandas_udf(lambda x: x, return_type)
0492 
0493         @pandas_udf(return_type, PandasUDFType.SCALAR_ITER)
0494         def iter_f(it):
0495             for id in it:
0496                 yield pd.DataFrame({'id': id, 'str': id.apply(unicode)})
0497 
0498         iter_g = pandas_udf(lambda x: x, return_type, PandasUDFType.SCALAR_ITER)
0499 
0500         expected = df.select(struct(col('id'), col('id').cast('string').alias('str'))
0501                              .alias('struct')).collect()
0502 
0503         for f, g in [(scalar_f, scalar_g), (iter_f, iter_g)]:
0504             actual = df.select(g(f(col('id'))).alias('struct')).collect()
0505             self.assertEqual(expected, actual)
0506 
0507     def test_vectorized_udf_wrong_return_type(self):
0508         with QuietTest(self.sc):
0509             for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0510                 with self.assertRaisesRegexp(
0511                         NotImplementedError,
0512                         'Invalid return type.*scalar Pandas UDF.*MapType'):
0513                     pandas_udf(lambda x: x, MapType(LongType(), LongType()), udf_type)
0514 
0515     def test_vectorized_udf_return_scalar(self):
0516         df = self.spark.range(10)
0517         scalar_f = pandas_udf(lambda x: 1.0, DoubleType())
0518         iter_f = pandas_udf(lambda it: map(lambda x: 1.0, it), DoubleType(),
0519                             PandasUDFType.SCALAR_ITER)
0520         for f in [scalar_f, iter_f]:
0521             with QuietTest(self.sc):
0522                 with self.assertRaisesRegexp(Exception, 'Return.*type.*Series'):
0523                     df.select(f(col('id'))).collect()
0524 
0525     def test_vectorized_udf_decorator(self):
0526         df = self.spark.range(10)
0527 
0528         @pandas_udf(returnType=LongType())
0529         def scalar_identity(x):
0530             return x
0531 
0532         @pandas_udf(returnType=LongType(), functionType=PandasUDFType.SCALAR_ITER)
0533         def iter_identity(x):
0534             return x
0535 
0536         for identity in [scalar_identity, iter_identity]:
0537             res = df.select(identity(col('id')))
0538             self.assertEquals(df.collect(), res.collect())
0539 
0540     def test_vectorized_udf_empty_partition(self):
0541         df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))
0542         for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0543             f = pandas_udf(lambda x: x, LongType(), udf_type)
0544             res = df.select(f(col('id')))
0545             self.assertEquals(df.collect(), res.collect())
0546 
0547     def test_vectorized_udf_struct_with_empty_partition(self):
0548         df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))\
0549             .withColumn('name', lit('John Doe'))
0550 
0551         @pandas_udf("first string, last string")
0552         def scalar_split_expand(n):
0553             return n.str.split(expand=True)
0554 
0555         @pandas_udf("first string, last string", PandasUDFType.SCALAR_ITER)
0556         def iter_split_expand(it):
0557             for n in it:
0558                 yield n.str.split(expand=True)
0559 
0560         for split_expand in [scalar_split_expand, iter_split_expand]:
0561             result = df.select(split_expand('name')).collect()
0562             self.assertEqual(1, len(result))
0563             row = result[0]
0564             self.assertEqual('John', row[0]['first'])
0565             self.assertEqual('Doe', row[0]['last'])
0566 
0567     def test_vectorized_udf_varargs(self):
0568         df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))
0569         scalar_f = pandas_udf(lambda *v: v[0], LongType())
0570 
0571         @pandas_udf(LongType(), PandasUDFType.SCALAR_ITER)
0572         def iter_f(it):
0573             for v in it:
0574                 yield v[0]
0575 
0576         for f in [scalar_f, iter_f]:
0577             res = df.select(f(col('id'), col('id')))
0578             self.assertEquals(df.collect(), res.collect())
0579 
0580     def test_vectorized_udf_unsupported_types(self):
0581         with QuietTest(self.sc):
0582             for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0583                 with self.assertRaisesRegexp(
0584                         NotImplementedError,
0585                         'Invalid return type.*scalar Pandas UDF.*MapType'):
0586                     pandas_udf(lambda x: x, MapType(StringType(), IntegerType()), udf_type)
0587                 with self.assertRaisesRegexp(
0588                         NotImplementedError,
0589                         'Invalid return type.*scalar Pandas UDF.*ArrayType.StructType'):
0590                     pandas_udf(lambda x: x,
0591                                ArrayType(StructType([StructField('a', IntegerType())])), udf_type)
0592 
0593     def test_vectorized_udf_dates(self):
0594         schema = StructType().add("idx", LongType()).add("date", DateType())
0595         data = [(0, date(1969, 1, 1),),
0596                 (1, date(2012, 2, 2),),
0597                 (2, None,),
0598                 (3, date(2100, 4, 4),),
0599                 (4, date(2262, 4, 12),)]
0600         df = self.spark.createDataFrame(data, schema=schema)
0601 
0602         def scalar_check_data(idx, date, date_copy):
0603             msgs = []
0604             is_equal = date.isnull()
0605             for i in range(len(idx)):
0606                 if (is_equal[i] and data[idx[i]][1] is None) or \
0607                         date[i] == data[idx[i]][1]:
0608                     msgs.append(None)
0609                 else:
0610                     msgs.append(
0611                         "date values are not equal (date='%s': data[%d][1]='%s')"
0612                         % (date[i], idx[i], data[idx[i]][1]))
0613             return pd.Series(msgs)
0614 
0615         def iter_check_data(it):
0616             for idx, date, date_copy in it:
0617                 yield scalar_check_data(idx, date, date_copy)
0618 
0619         pandas_scalar_check_data = pandas_udf(scalar_check_data, StringType())
0620         pandas_iter_check_data = pandas_udf(iter_check_data, StringType(),
0621                                             PandasUDFType.SCALAR_ITER)
0622 
0623         for check_data, udf_type in [(pandas_scalar_check_data, PandasUDFType.SCALAR),
0624                                      (pandas_iter_check_data, PandasUDFType.SCALAR_ITER)]:
0625             date_copy = pandas_udf(lambda t: t, returnType=DateType(), functionType=udf_type)
0626             df = df.withColumn("date_copy", date_copy(col("date")))
0627             result = df.withColumn("check_data",
0628                                    check_data(col("idx"), col("date"), col("date_copy"))).collect()
0629 
0630             self.assertEquals(len(data), len(result))
0631             for i in range(len(result)):
0632                 self.assertEquals(data[i][1], result[i][1])  # "date" col
0633                 self.assertEquals(data[i][1], result[i][2])  # "date_copy" col
0634                 self.assertIsNone(result[i][3])  # "check_data" col
0635 
0636     def test_vectorized_udf_timestamps(self):
0637         schema = StructType([
0638             StructField("idx", LongType(), True),
0639             StructField("timestamp", TimestampType(), True)])
0640         data = [(0, datetime(1969, 1, 1, 1, 1, 1)),
0641                 (1, datetime(2012, 2, 2, 2, 2, 2)),
0642                 (2, None),
0643                 (3, datetime(2100, 3, 3, 3, 3, 3))]
0644 
0645         df = self.spark.createDataFrame(data, schema=schema)
0646 
0647         def scalar_check_data(idx, timestamp, timestamp_copy):
0648             msgs = []
0649             is_equal = timestamp.isnull()  # use this array to check values are equal
0650             for i in range(len(idx)):
0651                 # Check that timestamps are as expected in the UDF
0652                 if (is_equal[i] and data[idx[i]][1] is None) or \
0653                         timestamp[i].to_pydatetime() == data[idx[i]][1]:
0654                     msgs.append(None)
0655                 else:
0656                     msgs.append(
0657                         "timestamp values are not equal (timestamp='%s': data[%d][1]='%s')"
0658                         % (timestamp[i], idx[i], data[idx[i]][1]))
0659             return pd.Series(msgs)
0660 
0661         def iter_check_data(it):
0662             for idx, timestamp, timestamp_copy in it:
0663                 yield scalar_check_data(idx, timestamp, timestamp_copy)
0664 
0665         pandas_scalar_check_data = pandas_udf(scalar_check_data, StringType())
0666         pandas_iter_check_data = pandas_udf(iter_check_data, StringType(),
0667                                             PandasUDFType.SCALAR_ITER)
0668 
0669         for check_data, udf_type in [(pandas_scalar_check_data, PandasUDFType.SCALAR),
0670                                      (pandas_iter_check_data, PandasUDFType.SCALAR_ITER)]:
0671             # Check that a timestamp passed through a pandas_udf will not be altered by timezone
0672             # calc
0673             f_timestamp_copy = pandas_udf(lambda t: t,
0674                                           returnType=TimestampType(), functionType=udf_type)
0675             df = df.withColumn("timestamp_copy", f_timestamp_copy(col("timestamp")))
0676             result = df.withColumn("check_data", check_data(col("idx"), col("timestamp"),
0677                                                             col("timestamp_copy"))).collect()
0678             # Check that collection values are correct
0679             self.assertEquals(len(data), len(result))
0680             for i in range(len(result)):
0681                 self.assertEquals(data[i][1], result[i][1])  # "timestamp" col
0682                 self.assertEquals(data[i][1], result[i][2])  # "timestamp_copy" col
0683                 self.assertIsNone(result[i][3])  # "check_data" col
0684 
0685     def test_vectorized_udf_return_timestamp_tz(self):
0686         df = self.spark.range(10)
0687 
0688         @pandas_udf(returnType=TimestampType())
0689         def scalar_gen_timestamps(id):
0690             ts = [pd.Timestamp(i, unit='D', tz='America/Los_Angeles') for i in id]
0691             return pd.Series(ts)
0692 
0693         @pandas_udf(returnType=TimestampType(), functionType=PandasUDFType.SCALAR_ITER)
0694         def iter_gen_timestamps(it):
0695             for id in it:
0696                 ts = [pd.Timestamp(i, unit='D', tz='America/Los_Angeles') for i in id]
0697                 yield pd.Series(ts)
0698 
0699         for gen_timestamps in [scalar_gen_timestamps, iter_gen_timestamps]:
0700             result = df.withColumn("ts", gen_timestamps(col("id"))).collect()
0701             spark_ts_t = TimestampType()
0702             for r in result:
0703                 i, ts = r
0704                 ts_tz = pd.Timestamp(i, unit='D', tz='America/Los_Angeles').to_pydatetime()
0705                 expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz))
0706                 self.assertEquals(expected, ts)
0707 
0708     def test_vectorized_udf_check_config(self):
0709         with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}):
0710             df = self.spark.range(10, numPartitions=1)
0711 
0712             @pandas_udf(returnType=LongType())
0713             def scalar_check_records_per_batch(x):
0714                 return pd.Series(x.size).repeat(x.size)
0715 
0716             @pandas_udf(returnType=LongType(), functionType=PandasUDFType.SCALAR_ITER)
0717             def iter_check_records_per_batch(it):
0718                 for x in it:
0719                     yield pd.Series(x.size).repeat(x.size)
0720 
0721             for check_records_per_batch in [scalar_check_records_per_batch,
0722                                             iter_check_records_per_batch]:
0723                 result = df.select(check_records_per_batch(col("id"))).collect()
0724                 for (r,) in result:
0725                     self.assertTrue(r <= 3)
0726 
0727     def test_vectorized_udf_timestamps_respect_session_timezone(self):
0728         schema = StructType([
0729             StructField("idx", LongType(), True),
0730             StructField("timestamp", TimestampType(), True)])
0731         data = [(1, datetime(1969, 1, 1, 1, 1, 1)),
0732                 (2, datetime(2012, 2, 2, 2, 2, 2)),
0733                 (3, None),
0734                 (4, datetime(2100, 3, 3, 3, 3, 3))]
0735         df = self.spark.createDataFrame(data, schema=schema)
0736 
0737         scalar_internal_value = pandas_udf(
0738             lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType())
0739 
0740         @pandas_udf(LongType(), PandasUDFType.SCALAR_ITER)
0741         def iter_internal_value(it):
0742             for ts in it:
0743                 yield ts.apply(lambda ts: ts.value if ts is not pd.NaT else None)
0744 
0745         for internal_value, udf_type in [(scalar_internal_value, PandasUDFType.SCALAR),
0746                                          (iter_internal_value, PandasUDFType.SCALAR_ITER)]:
0747             f_timestamp_copy = pandas_udf(lambda ts: ts, TimestampType(), udf_type)
0748             timezone = "America/Los_Angeles"
0749             with self.sql_conf({"spark.sql.session.timeZone": timezone}):
0750                 df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
0751                     .withColumn("internal_value", internal_value(col("timestamp")))
0752                 result_la = df_la.select(col("idx"), col("internal_value")).collect()
0753                 # Correct result_la by adjusting 3 hours difference between Los Angeles and New York
0754                 diff = 3 * 60 * 60 * 1000 * 1000 * 1000
0755                 result_la_corrected = \
0756                     df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect()
0757 
0758             timezone = "America/New_York"
0759             with self.sql_conf({"spark.sql.session.timeZone": timezone}):
0760                 df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
0761                     .withColumn("internal_value", internal_value(col("timestamp")))
0762                 result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect()
0763 
0764                 self.assertNotEqual(result_ny, result_la)
0765                 self.assertEqual(result_ny, result_la_corrected)
0766 
0767     def test_nondeterministic_vectorized_udf(self):
0768         # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
0769         @pandas_udf('double')
0770         def scalar_plus_ten(v):
0771             return v + 10
0772 
0773         @pandas_udf('double', PandasUDFType.SCALAR_ITER)
0774         def iter_plus_ten(it):
0775             for v in it:
0776                 yield v + 10
0777 
0778         for plus_ten in [scalar_plus_ten, iter_plus_ten]:
0779             random_udf = self.nondeterministic_vectorized_udf
0780 
0781             df = self.spark.range(10).withColumn('rand', random_udf(col('id')))
0782             result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas()
0783 
0784             self.assertEqual(random_udf.deterministic, False)
0785             self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10))
0786 
0787     def test_nondeterministic_vectorized_udf_in_aggregate(self):
0788         df = self.spark.range(10)
0789         for random_udf in [self.nondeterministic_vectorized_udf,
0790                            self.nondeterministic_vectorized_iter_udf]:
0791             with QuietTest(self.sc):
0792                 with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'):
0793                     df.groupby(df.id).agg(sum(random_udf(df.id))).collect()
0794                 with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'):
0795                     df.agg(sum(random_udf(df.id))).collect()
0796 
0797     def test_register_vectorized_udf_basic(self):
0798         df = self.spark.range(10).select(
0799             col('id').cast('int').alias('a'),
0800             col('id').cast('int').alias('b'))
0801         scalar_original_add = pandas_udf(lambda x, y: x + y, IntegerType())
0802         self.assertEqual(scalar_original_add.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
0803 
0804         @pandas_udf(IntegerType(), PandasUDFType.SCALAR_ITER)
0805         def iter_original_add(it):
0806             for x, y in it:
0807                 yield x + y
0808 
0809         self.assertEqual(iter_original_add.evalType, PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF)
0810 
0811         for original_add in [scalar_original_add, iter_original_add]:
0812             self.assertEqual(original_add.deterministic, True)
0813             new_add = self.spark.catalog.registerFunction("add1", original_add)
0814             res1 = df.select(new_add(col('a'), col('b')))
0815             res2 = self.spark.sql(
0816                 "SELECT add1(t.a, t.b) FROM (SELECT id as a, id as b FROM range(10)) t")
0817             expected = df.select(expr('a + b'))
0818             self.assertEquals(expected.collect(), res1.collect())
0819             self.assertEquals(expected.collect(), res2.collect())
0820 
0821     def test_scalar_iter_udf_init(self):
0822         import numpy as np
0823 
0824         @pandas_udf('int', PandasUDFType.SCALAR_ITER)
0825         def rng(batch_iter):
0826             context = TaskContext.get()
0827             part = context.partitionId()
0828             np.random.seed(part)
0829             for batch in batch_iter:
0830                 yield pd.Series(np.random.randint(100, size=len(batch)))
0831         with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 2}):
0832             df = self.spark.range(10, numPartitions=2).select(rng(col("id").alias("v")))
0833             result1 = df.collect()
0834             result2 = df.collect()
0835             self.assertEqual(result1, result2,
0836                              "SCALAR ITER UDF can initialize state and produce deterministic RNG")
0837 
0838     def test_scalar_iter_udf_close(self):
0839         @pandas_udf('int', PandasUDFType.SCALAR_ITER)
0840         def test_close(batch_iter):
0841             try:
0842                 for batch in batch_iter:
0843                     yield batch
0844             finally:
0845                 raise RuntimeError("reached finally block")
0846         with QuietTest(self.sc):
0847             with self.assertRaisesRegexp(Exception, "reached finally block"):
0848                 self.spark.range(1).select(test_close(col("id"))).collect()
0849 
0850     def test_scalar_iter_udf_close_early(self):
0851         tmp_dir = tempfile.mkdtemp()
0852         try:
0853             tmp_file = tmp_dir + '/reach_finally_block'
0854 
0855             @pandas_udf('int', PandasUDFType.SCALAR_ITER)
0856             def test_close(batch_iter):
0857                 generator_exit_caught = False
0858                 try:
0859                     for batch in batch_iter:
0860                         yield batch
0861                         time.sleep(1.0)  # avoid the function finish too fast.
0862                 except GeneratorExit as ge:
0863                     generator_exit_caught = True
0864                     raise ge
0865                 finally:
0866                     assert generator_exit_caught, "Generator exit exception was not caught."
0867                     open(tmp_file, 'a').close()
0868 
0869             with QuietTest(self.sc):
0870                 with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 1,
0871                                     "spark.sql.execution.pandas.udf.buffer.size": 4}):
0872                     self.spark.range(10).repartition(1) \
0873                         .select(test_close(col("id"))).limit(2).collect()
0874                     # wait here because python udf worker will take some time to detect
0875                     # jvm side socket closed and then will trigger `GenerateExit` raised.
0876                     # wait timeout is 10s.
0877                     for i in range(100):
0878                         time.sleep(0.1)
0879                         if os.path.exists(tmp_file):
0880                             break
0881 
0882                     assert os.path.exists(tmp_file), "finally block not reached."
0883 
0884         finally:
0885             shutil.rmtree(tmp_dir)
0886 
0887     # Regression test for SPARK-23314
0888     def test_timestamp_dst(self):
0889         # Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am
0890         dt = [datetime(2015, 11, 1, 0, 30),
0891               datetime(2015, 11, 1, 1, 30),
0892               datetime(2015, 11, 1, 2, 30)]
0893         df = self.spark.createDataFrame(dt, 'timestamp').toDF('time')
0894 
0895         for udf_type in [PandasUDFType.SCALAR, PandasUDFType.SCALAR_ITER]:
0896             foo_udf = pandas_udf(lambda x: x, 'timestamp', udf_type)
0897             result = df.withColumn('time', foo_udf(df.time))
0898             self.assertEquals(df.collect(), result.collect())
0899 
0900     @unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.")
0901     def test_type_annotation(self):
0902         # Regression test to check if type hints can be used. See SPARK-23569.
0903         # Note that it throws an error during compilation in lower Python versions if 'exec'
0904         # is not used. Also, note that we explicitly use another dictionary to avoid modifications
0905         # in the current 'locals()'.
0906         #
0907         # Hyukjin: I think it's an ugly way to test issues about syntax specific in
0908         # higher versions of Python, which we shouldn't encourage. This was the last resort
0909         # I could come up with at that time.
0910         _locals = {}
0911         exec(
0912             "import pandas as pd\ndef noop(col: pd.Series) -> pd.Series: return col",
0913             _locals)
0914         df = self.spark.range(1).select(pandas_udf(f=_locals['noop'], returnType='bigint')('id'))
0915         self.assertEqual(df.first()[0], 0)
0916 
0917     def test_mixed_udf(self):
0918         df = self.spark.range(0, 1).toDF('v')
0919 
0920         # Test mixture of multiple UDFs and Pandas UDFs.
0921 
0922         @udf('int')
0923         def f1(x):
0924             assert type(x) == int
0925             return x + 1
0926 
0927         @pandas_udf('int')
0928         def f2_scalar(x):
0929             assert type(x) == pd.Series
0930             return x + 10
0931 
0932         @pandas_udf('int', PandasUDFType.SCALAR_ITER)
0933         def f2_iter(it):
0934             for x in it:
0935                 assert type(x) == pd.Series
0936                 yield x + 10
0937 
0938         @udf('int')
0939         def f3(x):
0940             assert type(x) == int
0941             return x + 100
0942 
0943         @pandas_udf('int')
0944         def f4_scalar(x):
0945             assert type(x) == pd.Series
0946             return x + 1000
0947 
0948         @pandas_udf('int', PandasUDFType.SCALAR_ITER)
0949         def f4_iter(it):
0950             for x in it:
0951                 assert type(x) == pd.Series
0952                 yield x + 1000
0953 
0954         expected_chained_1 = df.withColumn('f2_f1', df['v'] + 11).collect()
0955         expected_chained_2 = df.withColumn('f3_f2_f1', df['v'] + 111).collect()
0956         expected_chained_3 = df.withColumn('f4_f3_f2_f1', df['v'] + 1111).collect()
0957         expected_chained_4 = df.withColumn('f4_f2_f1', df['v'] + 1011).collect()
0958         expected_chained_5 = df.withColumn('f4_f3_f1', df['v'] + 1101).collect()
0959 
0960         expected_multi = df \
0961             .withColumn('f1', df['v'] + 1) \
0962             .withColumn('f2', df['v'] + 10) \
0963             .withColumn('f3', df['v'] + 100) \
0964             .withColumn('f4', df['v'] + 1000) \
0965             .withColumn('f2_f1', df['v'] + 11) \
0966             .withColumn('f3_f1', df['v'] + 101) \
0967             .withColumn('f4_f1', df['v'] + 1001) \
0968             .withColumn('f3_f2', df['v'] + 110) \
0969             .withColumn('f4_f2', df['v'] + 1010) \
0970             .withColumn('f4_f3', df['v'] + 1100) \
0971             .withColumn('f3_f2_f1', df['v'] + 111) \
0972             .withColumn('f4_f2_f1', df['v'] + 1011) \
0973             .withColumn('f4_f3_f1', df['v'] + 1101) \
0974             .withColumn('f4_f3_f2', df['v'] + 1110) \
0975             .withColumn('f4_f3_f2_f1', df['v'] + 1111) \
0976             .collect()
0977 
0978         for f2, f4 in [(f2_scalar, f4_scalar), (f2_scalar, f4_iter),
0979                        (f2_iter, f4_scalar), (f2_iter, f4_iter)]:
0980             # Test single expression with chained UDFs
0981             df_chained_1 = df.withColumn('f2_f1', f2(f1(df['v'])))
0982             df_chained_2 = df.withColumn('f3_f2_f1', f3(f2(f1(df['v']))))
0983             df_chained_3 = df.withColumn('f4_f3_f2_f1', f4(f3(f2(f1(df['v'])))))
0984             df_chained_4 = df.withColumn('f4_f2_f1', f4(f2(f1(df['v']))))
0985             df_chained_5 = df.withColumn('f4_f3_f1', f4(f3(f1(df['v']))))
0986 
0987             self.assertEquals(expected_chained_1, df_chained_1.collect())
0988             self.assertEquals(expected_chained_2, df_chained_2.collect())
0989             self.assertEquals(expected_chained_3, df_chained_3.collect())
0990             self.assertEquals(expected_chained_4, df_chained_4.collect())
0991             self.assertEquals(expected_chained_5, df_chained_5.collect())
0992 
0993             # Test multiple mixed UDF expressions in a single projection
0994             df_multi_1 = df \
0995                 .withColumn('f1', f1(col('v'))) \
0996                 .withColumn('f2', f2(col('v'))) \
0997                 .withColumn('f3', f3(col('v'))) \
0998                 .withColumn('f4', f4(col('v'))) \
0999                 .withColumn('f2_f1', f2(col('f1'))) \
1000                 .withColumn('f3_f1', f3(col('f1'))) \
1001                 .withColumn('f4_f1', f4(col('f1'))) \
1002                 .withColumn('f3_f2', f3(col('f2'))) \
1003                 .withColumn('f4_f2', f4(col('f2'))) \
1004                 .withColumn('f4_f3', f4(col('f3'))) \
1005                 .withColumn('f3_f2_f1', f3(col('f2_f1'))) \
1006                 .withColumn('f4_f2_f1', f4(col('f2_f1'))) \
1007                 .withColumn('f4_f3_f1', f4(col('f3_f1'))) \
1008                 .withColumn('f4_f3_f2', f4(col('f3_f2'))) \
1009                 .withColumn('f4_f3_f2_f1', f4(col('f3_f2_f1')))
1010 
1011             # Test mixed udfs in a single expression
1012             df_multi_2 = df \
1013                 .withColumn('f1', f1(col('v'))) \
1014                 .withColumn('f2', f2(col('v'))) \
1015                 .withColumn('f3', f3(col('v'))) \
1016                 .withColumn('f4', f4(col('v'))) \
1017                 .withColumn('f2_f1', f2(f1(col('v')))) \
1018                 .withColumn('f3_f1', f3(f1(col('v')))) \
1019                 .withColumn('f4_f1', f4(f1(col('v')))) \
1020                 .withColumn('f3_f2', f3(f2(col('v')))) \
1021                 .withColumn('f4_f2', f4(f2(col('v')))) \
1022                 .withColumn('f4_f3', f4(f3(col('v')))) \
1023                 .withColumn('f3_f2_f1', f3(f2(f1(col('v'))))) \
1024                 .withColumn('f4_f2_f1', f4(f2(f1(col('v'))))) \
1025                 .withColumn('f4_f3_f1', f4(f3(f1(col('v'))))) \
1026                 .withColumn('f4_f3_f2', f4(f3(f2(col('v'))))) \
1027                 .withColumn('f4_f3_f2_f1', f4(f3(f2(f1(col('v'))))))
1028 
1029             self.assertEquals(expected_multi, df_multi_1.collect())
1030             self.assertEquals(expected_multi, df_multi_2.collect())
1031 
1032     def test_mixed_udf_and_sql(self):
1033         df = self.spark.range(0, 1).toDF('v')
1034 
1035         # Test mixture of UDFs, Pandas UDFs and SQL expression.
1036 
1037         @udf('int')
1038         def f1(x):
1039             assert type(x) == int
1040             return x + 1
1041 
1042         def f2(x):
1043             assert type(x) == Column
1044             return x + 10
1045 
1046         @pandas_udf('int')
1047         def f3s(x):
1048             assert type(x) == pd.Series
1049             return x + 100
1050 
1051         @pandas_udf('int', PandasUDFType.SCALAR_ITER)
1052         def f3i(it):
1053             for x in it:
1054                 assert type(x) == pd.Series
1055                 yield x + 100
1056 
1057         expected = df.withColumn('f1', df['v'] + 1) \
1058             .withColumn('f2', df['v'] + 10) \
1059             .withColumn('f3', df['v'] + 100) \
1060             .withColumn('f1_f2', df['v'] + 11) \
1061             .withColumn('f1_f3', df['v'] + 101) \
1062             .withColumn('f2_f1', df['v'] + 11) \
1063             .withColumn('f2_f3', df['v'] + 110) \
1064             .withColumn('f3_f1', df['v'] + 101) \
1065             .withColumn('f3_f2', df['v'] + 110) \
1066             .withColumn('f1_f2_f3', df['v'] + 111) \
1067             .withColumn('f1_f3_f2', df['v'] + 111) \
1068             .withColumn('f2_f1_f3', df['v'] + 111) \
1069             .withColumn('f2_f3_f1', df['v'] + 111) \
1070             .withColumn('f3_f1_f2', df['v'] + 111) \
1071             .withColumn('f3_f2_f1', df['v'] + 111) \
1072             .collect()
1073 
1074         for f3 in [f3s, f3i]:
1075             df1 = df.withColumn('f1', f1(df['v'])) \
1076                 .withColumn('f2', f2(df['v'])) \
1077                 .withColumn('f3', f3(df['v'])) \
1078                 .withColumn('f1_f2', f1(f2(df['v']))) \
1079                 .withColumn('f1_f3', f1(f3(df['v']))) \
1080                 .withColumn('f2_f1', f2(f1(df['v']))) \
1081                 .withColumn('f2_f3', f2(f3(df['v']))) \
1082                 .withColumn('f3_f1', f3(f1(df['v']))) \
1083                 .withColumn('f3_f2', f3(f2(df['v']))) \
1084                 .withColumn('f1_f2_f3', f1(f2(f3(df['v'])))) \
1085                 .withColumn('f1_f3_f2', f1(f3(f2(df['v'])))) \
1086                 .withColumn('f2_f1_f3', f2(f1(f3(df['v'])))) \
1087                 .withColumn('f2_f3_f1', f2(f3(f1(df['v'])))) \
1088                 .withColumn('f3_f1_f2', f3(f1(f2(df['v'])))) \
1089                 .withColumn('f3_f2_f1', f3(f2(f1(df['v']))))
1090 
1091             self.assertEquals(expected, df1.collect())
1092 
1093     # SPARK-24721
1094     @unittest.skipIf(not test_compiled, test_not_compiled_message)
1095     def test_datasource_with_udf(self):
1096         # Same as SQLTests.test_datasource_with_udf, but with Pandas UDF
1097         # This needs to a separate test because Arrow dependency is optional
1098         import numpy as np
1099 
1100         path = tempfile.mkdtemp()
1101         shutil.rmtree(path)
1102 
1103         try:
1104             self.spark.range(1).write.mode("overwrite").format('csv').save(path)
1105             filesource_df = self.spark.read.option('inferSchema', True).csv(path).toDF('i')
1106             datasource_df = self.spark.read \
1107                 .format("org.apache.spark.sql.sources.SimpleScanSource") \
1108                 .option('from', 0).option('to', 1).load().toDF('i')
1109             datasource_v2_df = self.spark.read \
1110                 .format("org.apache.spark.sql.connector.SimpleDataSourceV2") \
1111                 .load().toDF('i', 'j')
1112 
1113             c1 = pandas_udf(lambda x: x + 1, 'int')(lit(1))
1114             c2 = pandas_udf(lambda x: x + 1, 'int')(col('i'))
1115 
1116             f1 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(lit(1))
1117             f2 = pandas_udf(lambda x: pd.Series(np.repeat(False, len(x))), 'boolean')(col('i'))
1118 
1119             for df in [filesource_df, datasource_df, datasource_v2_df]:
1120                 result = df.withColumn('c', c1)
1121                 expected = df.withColumn('c', lit(2))
1122                 self.assertEquals(expected.collect(), result.collect())
1123 
1124             for df in [filesource_df, datasource_df, datasource_v2_df]:
1125                 result = df.withColumn('c', c2)
1126                 expected = df.withColumn('c', col('i') + 1)
1127                 self.assertEquals(expected.collect(), result.collect())
1128 
1129             for df in [filesource_df, datasource_df, datasource_v2_df]:
1130                 for f in [f1, f2]:
1131                     result = df.filter(f)
1132                     self.assertEquals(0, result.count())
1133         finally:
1134             shutil.rmtree(path)
1135 
1136 
1137 if __name__ == "__main__":
1138     from pyspark.sql.tests.test_pandas_udf_scalar import *
1139 
1140     try:
1141         import xmlrunner
1142         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
1143     except ImportError:
1144         testRunner = None
1145     unittest.main(testRunner=testRunner, verbosity=2)