Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 # http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import datetime
0019 import os
0020 import shutil
0021 import tempfile
0022 from contextlib import contextmanager
0023 
0024 from pyspark.sql import SparkSession
0025 from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row
0026 from pyspark.testing.utils import ReusedPySparkTestCase
0027 from pyspark.util import _exception_message
0028 
0029 
0030 pandas_requirement_message = None
0031 try:
0032     from pyspark.sql.pandas.utils import require_minimum_pandas_version
0033     require_minimum_pandas_version()
0034 except ImportError as e:
0035     # If Pandas version requirement is not satisfied, skip related tests.
0036     pandas_requirement_message = _exception_message(e)
0037 
0038 pyarrow_requirement_message = None
0039 try:
0040     from pyspark.sql.pandas.utils import require_minimum_pyarrow_version
0041     require_minimum_pyarrow_version()
0042 except ImportError as e:
0043     # If Arrow version requirement is not satisfied, skip related tests.
0044     pyarrow_requirement_message = _exception_message(e)
0045 
0046 test_not_compiled_message = None
0047 try:
0048     from pyspark.sql.utils import require_test_compiled
0049     require_test_compiled()
0050 except Exception as e:
0051     test_not_compiled_message = _exception_message(e)
0052 
0053 have_pandas = pandas_requirement_message is None
0054 have_pyarrow = pyarrow_requirement_message is None
0055 test_compiled = test_not_compiled_message is None
0056 
0057 
0058 class UTCOffsetTimezone(datetime.tzinfo):
0059     """
0060     Specifies timezone in UTC offset
0061     """
0062 
0063     def __init__(self, offset=0):
0064         self.ZERO = datetime.timedelta(hours=offset)
0065 
0066     def utcoffset(self, dt):
0067         return self.ZERO
0068 
0069     def dst(self, dt):
0070         return self.ZERO
0071 
0072 
0073 class ExamplePointUDT(UserDefinedType):
0074     """
0075     User-defined type (UDT) for ExamplePoint.
0076     """
0077 
0078     @classmethod
0079     def sqlType(self):
0080         return ArrayType(DoubleType(), False)
0081 
0082     @classmethod
0083     def module(cls):
0084         return 'pyspark.sql.tests'
0085 
0086     @classmethod
0087     def scalaUDT(cls):
0088         return 'org.apache.spark.sql.test.ExamplePointUDT'
0089 
0090     def serialize(self, obj):
0091         return [obj.x, obj.y]
0092 
0093     def deserialize(self, datum):
0094         return ExamplePoint(datum[0], datum[1])
0095 
0096 
0097 class ExamplePoint:
0098     """
0099     An example class to demonstrate UDT in Scala, Java, and Python.
0100     """
0101 
0102     __UDT__ = ExamplePointUDT()
0103 
0104     def __init__(self, x, y):
0105         self.x = x
0106         self.y = y
0107 
0108     def __repr__(self):
0109         return "ExamplePoint(%s,%s)" % (self.x, self.y)
0110 
0111     def __str__(self):
0112         return "(%s,%s)" % (self.x, self.y)
0113 
0114     def __eq__(self, other):
0115         return isinstance(other, self.__class__) and \
0116             other.x == self.x and other.y == self.y
0117 
0118 
0119 class PythonOnlyUDT(UserDefinedType):
0120     """
0121     User-defined type (UDT) for ExamplePoint.
0122     """
0123 
0124     @classmethod
0125     def sqlType(self):
0126         return ArrayType(DoubleType(), False)
0127 
0128     @classmethod
0129     def module(cls):
0130         return '__main__'
0131 
0132     def serialize(self, obj):
0133         return [obj.x, obj.y]
0134 
0135     def deserialize(self, datum):
0136         return PythonOnlyPoint(datum[0], datum[1])
0137 
0138     @staticmethod
0139     def foo():
0140         pass
0141 
0142     @property
0143     def props(self):
0144         return {}
0145 
0146 
0147 class PythonOnlyPoint(ExamplePoint):
0148     """
0149     An example class to demonstrate UDT in only Python
0150     """
0151     __UDT__ = PythonOnlyUDT()
0152 
0153 
0154 class MyObject(object):
0155     def __init__(self, key, value):
0156         self.key = key
0157         self.value = value
0158 
0159 
0160 class SQLTestUtils(object):
0161     """
0162     This util assumes the instance of this to have 'spark' attribute, having a spark session.
0163     It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the
0164     the implementation of this class has 'spark' attribute.
0165     """
0166 
0167     @contextmanager
0168     def sql_conf(self, pairs):
0169         """
0170         A convenient context manager to test some configuration specific logic. This sets
0171         `value` to the configuration `key` and then restores it back when it exits.
0172         """
0173         assert isinstance(pairs, dict), "pairs should be a dictionary."
0174         assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
0175 
0176         keys = pairs.keys()
0177         new_values = pairs.values()
0178         old_values = [self.spark.conf.get(key, None) for key in keys]
0179         for key, new_value in zip(keys, new_values):
0180             self.spark.conf.set(key, new_value)
0181         try:
0182             yield
0183         finally:
0184             for key, old_value in zip(keys, old_values):
0185                 if old_value is None:
0186                     self.spark.conf.unset(key)
0187                 else:
0188                     self.spark.conf.set(key, old_value)
0189 
0190     @contextmanager
0191     def database(self, *databases):
0192         """
0193         A convenient context manager to test with some specific databases. This drops the given
0194         databases if it exists and sets current database to "default" when it exits.
0195         """
0196         assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
0197 
0198         try:
0199             yield
0200         finally:
0201             for db in databases:
0202                 self.spark.sql("DROP DATABASE IF EXISTS %s CASCADE" % db)
0203             self.spark.catalog.setCurrentDatabase("default")
0204 
0205     @contextmanager
0206     def table(self, *tables):
0207         """
0208         A convenient context manager to test with some specific tables. This drops the given tables
0209         if it exists.
0210         """
0211         assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
0212 
0213         try:
0214             yield
0215         finally:
0216             for t in tables:
0217                 self.spark.sql("DROP TABLE IF EXISTS %s" % t)
0218 
0219     @contextmanager
0220     def tempView(self, *views):
0221         """
0222         A convenient context manager to test with some specific views. This drops the given views
0223         if it exists.
0224         """
0225         assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
0226 
0227         try:
0228             yield
0229         finally:
0230             for v in views:
0231                 self.spark.catalog.dropTempView(v)
0232 
0233     @contextmanager
0234     def function(self, *functions):
0235         """
0236         A convenient context manager to test with some specific functions. This drops the given
0237         functions if it exists.
0238         """
0239         assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session."
0240 
0241         try:
0242             yield
0243         finally:
0244             for f in functions:
0245                 self.spark.sql("DROP FUNCTION IF EXISTS %s" % f)
0246 
0247 
0248 class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils):
0249     @classmethod
0250     def setUpClass(cls):
0251         super(ReusedSQLTestCase, cls).setUpClass()
0252         cls.spark = SparkSession(cls.sc)
0253         cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
0254         os.unlink(cls.tempdir.name)
0255         cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
0256         cls.df = cls.spark.createDataFrame(cls.testData)
0257 
0258     @classmethod
0259     def tearDownClass(cls):
0260         super(ReusedSQLTestCase, cls).tearDownClass()
0261         cls.spark.stop()
0262         shutil.rmtree(cls.tempdir.name, ignore_errors=True)