0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import datetime
0019 import os
0020 import shutil
0021 import tempfile
0022 from contextlib import contextmanager
0023
0024 from pyspark.sql import SparkSession
0025 from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row
0026 from pyspark.testing.utils import ReusedPySparkTestCase
0027 from pyspark.util import _exception_message
0028
0029
0030 pandas_requirement_message = None
0031 try:
0032 from pyspark.sql.pandas.utils import require_minimum_pandas_version
0033 require_minimum_pandas_version()
0034 except ImportError as e:
0035
0036 pandas_requirement_message = _exception_message(e)
0037
0038 pyarrow_requirement_message = None
0039 try:
0040 from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
0041 require_minimum_pyarrow_version()
0042 except ImportError as e:
0043
0044 pyarrow_requirement_message = _exception_message(e)
0045
0046 test_not_compiled_message = None
0047 try:
0048 from pyspark.sql.utils import require_test_compiled
0049 require_test_compiled()
0050 except Exception as e:
0051 test_not_compiled_message = _exception_message(e)
0052
0053 have_pandas = pandas_requirement_message is None
0054 have_pyarrow = pyarrow_requirement_message is None
0055 test_compiled = test_not_compiled_message is None
0056
0057
0058 class UTCOffsetTimezone(datetime.tzinfo):
0059 """
0060 Specifies timezone in UTC offset
0061 """
0062
0063 def __init__(self, offset=0):
0064 self.ZERO = datetime.timedelta(hours=offset)
0065
0066 def utcoffset(self, dt):
0067 return self.ZERO
0068
0069 def dst(self, dt):
0070 return self.ZERO
0071
0072
0073 class ExamplePointUDT(UserDefinedType):
0074 """
0075 User-defined type (UDT) for ExamplePoint.
0076 """
0077
0078 @classmethod
0079 def sqlType(self):
0080 return ArrayType(DoubleType(), False)
0081
0082 @classmethod
0083 def module(cls):
0084 return 'pyspark.sql.tests'
0085
0086 @classmethod
0087 def scalaUDT(cls):
0088 return 'org.apache.spark.sql.test.ExamplePointUDT'
0089
0090 def serialize(self, obj):
0091 return [obj.x, obj.y]
0092
0093 def deserialize(self, datum):
0094 return ExamplePoint(datum[0], datum[1])
0095
0096
0097 class ExamplePoint:
0098 """
0099 An example class to demonstrate UDT in Scala, Java, and Python.
0100 """
0101
0102 __UDT__ = ExamplePointUDT()
0103
0104 def __init__(self, x, y):
0105 self.x = x
0106 self.y = y
0107
0108 def __repr__(self):
0109 return "ExamplePoint(%s,%s)" % (self.x, self.y)
0110
0111 def __str__(self):
0112 return "(%s,%s)" % (self.x, self.y)
0113
0114 def __eq__(self, other):
0115 return isinstance(other, self.__class__) and \
0116 other.x == self.x and other.y == self.y
0117
0118
0119 class PythonOnlyUDT(UserDefinedType):
0120 """
0121 User-defined type (UDT) for ExamplePoint.
0122 """
0123
0124 @classmethod
0125 def sqlType(self):
0126 return ArrayType(DoubleType(), False)
0127
0128 @classmethod
0129 def module(cls):
0130 return '__main__'
0131
0132 def serialize(self, obj):
0133 return [obj.x, obj.y]
0134
0135 def deserialize(self, datum):
0136 return PythonOnlyPoint(datum[0], datum[1])
0137
0138 @staticmethod
0139 def foo():
0140 pass
0141
0142 @property
0143 def props(self):
0144 return {}
0145
0146
0147 class PythonOnlyPoint(ExamplePoint):
0148 """
0149 An example class to demonstrate UDT in only Python
0150 """
0151 __UDT__ = PythonOnlyUDT()
0152
0153
0154 class MyObject(object):
0155 def __init__(self, key, value):
0156 self.key = key
0157 self.value = value
0158
0159
0160 class SQLTestUtils(object):
0161 """
0162 This util assumes the instance of this to have 'spark' attribute, having a spark session.
0163 It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the
0164 the implementation of this class has 'spark' attribute.
0165 """
0166
0167 @contextmanager
0168 def sql_conf(self, pairs):
0169 """
0170 A convenient context manager to test some configuration specific logic. This sets
0171 `value` to the configuration `key` and then restores it back when it exits.
0172 """
0173 assert isinstance(pairs, dict), "pairs should be a dictionary."
0174 assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
0175
0176 keys = pairs.keys()
0177 new_values = pairs.values()
0178 old_values = [self.spark.conf.get(key, None) for key in keys]
0179 for key, new_value in zip(keys, new_values):
0180 self.spark.conf.set(key, new_value)
0181 try:
0182 yield
0183 finally:
0184 for key, old_value in zip(keys, old_values):
0185 if old_value is None:
0186 self.spark.conf.unset(key)
0187 else:
0188 self.spark.conf.set(key, old_value)
0189
0190 @contextmanager
0191 def database(self, *databases):
0192 """
0193 A convenient context manager to test with some specific databases. This drops the given
0194 databases if it exists and sets current database to "default" when it exits.
0195 """
0196 assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
0197
0198 try:
0199 yield
0200 finally:
0201 for db in databases:
0202 self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db)
0203 self.spark.catalog.setCurrentDatabase("default")
0204
0205 @contextmanager
0206 def table(self, *tables):
0207 """
0208 A convenient context manager to test with some specific tables. This drops the given tables
0209 if it exists.
0210 """
0211 assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
0212
0213 try:
0214 yield
0215 finally:
0216 for t in tables:
0217 self.spark.sql("DROP TABLE IF EXISTS %s" % t)
0218
0219 @contextmanager
0220 def tempView(self, *views):
0221 """
0222 A convenient context manager to test with some specific views. This drops the given views
0223 if it exists.
0224 """
0225 assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
0226
0227 try:
0228 yield
0229 finally:
0230 for v in views:
0231 self.spark.catalog.dropTempView(v)
0232
0233 @contextmanager
0234 def function(self, *functions):
0235 """
0236 A convenient context manager to test with some specific functions. This drops the given
0237 functions if it exists.
0238 """
0239 assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
0240
0241 try:
0242 yield
0243 finally:
0244 for f in functions:
0245 self.spark.sql("DROP FUNCTION IF EXISTS %s" % f)
0246
0247
0248 class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils):
0249 @classmethod
0250 def setUpClass(cls):
0251 super(ReusedSQLTestCase, cls).setUpClass()
0252 cls.spark = SparkSession(cls.sc)
0253 cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
0254 os.unlink(cls.tempdir.name)
0255 cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
0256 cls.df = cls.spark.createDataFrame(cls.testData)
0257
0258 @classmethod
0259 def tearDownClass(cls):
0260 super(ReusedSQLTestCase, cls).tearDownClass()
0261 cls.spark.stop()
0262 shutil.rmtree(cls.tempdir.name, ignore_errors=True)