Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 from pyspark.sql.utils import AnalysisException
0019 from pyspark.testing.sqlutils import ReusedSQLTestCase
0020 
0021 
0022 class CatalogTests(ReusedSQLTestCase):
0023 
0024     def test_current_database(self):
0025         spark = self.spark
0026         with self.database("some_db"):
0027             self.assertEquals(spark.catalog.currentDatabase(), "default")
0028             spark.sql("CREATE DATABASE some_db")
0029             spark.catalog.setCurrentDatabase("some_db")
0030             self.assertEquals(spark.catalog.currentDatabase(), "some_db")
0031             self.assertRaisesRegexp(
0032                 AnalysisException,
0033                 "does_not_exist",
0034                 lambda: spark.catalog.setCurrentDatabase("does_not_exist"))
0035 
0036     def test_list_databases(self):
0037         spark = self.spark
0038         with self.database("some_db"):
0039             databases = [db.name for db in spark.catalog.listDatabases()]
0040             self.assertEquals(databases, ["default"])
0041             spark.sql("CREATE DATABASE some_db")
0042             databases = [db.name for db in spark.catalog.listDatabases()]
0043             self.assertEquals(sorted(databases), ["default", "some_db"])
0044 
0045     def test_list_tables(self):
0046         from pyspark.sql.catalog import Table
0047         spark = self.spark
0048         with self.database("some_db"):
0049             spark.sql("CREATE DATABASE some_db")
0050             with self.table("tab1", "some_db.tab2"):
0051                 with self.tempView("temp_tab"):
0052                     self.assertEquals(spark.catalog.listTables(), [])
0053                     self.assertEquals(spark.catalog.listTables("some_db"), [])
0054                     spark.createDataFrame([(1, 1)]).createOrReplaceTempView("temp_tab")
0055                     spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
0056                     spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet")
0057                     tables = sorted(spark.catalog.listTables(), key=lambda t: t.name)
0058                     tablesDefault = \
0059                         sorted(spark.catalog.listTables("default"), key=lambda t: t.name)
0060                     tablesSomeDb = \
0061                         sorted(spark.catalog.listTables("some_db"), key=lambda t: t.name)
0062                     self.assertEquals(tables, tablesDefault)
0063                     self.assertEquals(len(tables), 2)
0064                     self.assertEquals(len(tablesSomeDb), 2)
0065                     self.assertEquals(tables[0], Table(
0066                         name="tab1",
0067                         database="default",
0068                         description=None,
0069                         tableType="MANAGED",
0070                         isTemporary=False))
0071                     self.assertEquals(tables[1], Table(
0072                         name="temp_tab",
0073                         database=None,
0074                         description=None,
0075                         tableType="TEMPORARY",
0076                         isTemporary=True))
0077                     self.assertEquals(tablesSomeDb[0], Table(
0078                         name="tab2",
0079                         database="some_db",
0080                         description=None,
0081                         tableType="MANAGED",
0082                         isTemporary=False))
0083                     self.assertEquals(tablesSomeDb[1], Table(
0084                         name="temp_tab",
0085                         database=None,
0086                         description=None,
0087                         tableType="TEMPORARY",
0088                         isTemporary=True))
0089                     self.assertRaisesRegexp(
0090                         AnalysisException,
0091                         "does_not_exist",
0092                         lambda: spark.catalog.listTables("does_not_exist"))
0093 
0094     def test_list_functions(self):
0095         from pyspark.sql.catalog import Function
0096         spark = self.spark
0097         with self.database("some_db"):
0098             spark.sql("CREATE DATABASE some_db")
0099             functions = dict((f.name, f) for f in spark.catalog.listFunctions())
0100             functionsDefault = dict((f.name, f) for f in spark.catalog.listFunctions("default"))
0101             self.assertTrue(len(functions) > 200)
0102             self.assertTrue("+" in functions)
0103             self.assertTrue("like" in functions)
0104             self.assertTrue("month" in functions)
0105             self.assertTrue("to_date" in functions)
0106             self.assertTrue("to_timestamp" in functions)
0107             self.assertTrue("to_unix_timestamp" in functions)
0108             self.assertTrue("current_database" in functions)
0109             self.assertEquals(functions["+"], Function(
0110                 name="+",
0111                 description=None,
0112                 className="org.apache.spark.sql.catalyst.expressions.Add",
0113                 isTemporary=True))
0114             self.assertEquals(functions, functionsDefault)
0115 
0116             with self.function("func1", "some_db.func2"):
0117                 spark.catalog.registerFunction("temp_func", lambda x: str(x))
0118                 spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'")
0119                 spark.sql("CREATE FUNCTION some_db.func2 AS 'org.apache.spark.data.bricks'")
0120                 newFunctions = dict((f.name, f) for f in spark.catalog.listFunctions())
0121                 newFunctionsSomeDb = \
0122                     dict((f.name, f) for f in spark.catalog.listFunctions("some_db"))
0123                 self.assertTrue(set(functions).issubset(set(newFunctions)))
0124                 self.assertTrue(set(functions).issubset(set(newFunctionsSomeDb)))
0125                 self.assertTrue("temp_func" in newFunctions)
0126                 self.assertTrue("func1" in newFunctions)
0127                 self.assertTrue("func2" not in newFunctions)
0128                 self.assertTrue("temp_func" in newFunctionsSomeDb)
0129                 self.assertTrue("func1" not in newFunctionsSomeDb)
0130                 self.assertTrue("func2" in newFunctionsSomeDb)
0131                 self.assertRaisesRegexp(
0132                     AnalysisException,
0133                     "does_not_exist",
0134                     lambda: spark.catalog.listFunctions("does_not_exist"))
0135 
0136     def test_list_columns(self):
0137         from pyspark.sql.catalog import Column
0138         spark = self.spark
0139         with self.database("some_db"):
0140             spark.sql("CREATE DATABASE some_db")
0141             with self.table("tab1", "some_db.tab2"):
0142                 spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
0143                 spark.sql(
0144                     "CREATE TABLE some_db.tab2 (nickname STRING, tolerance FLOAT) USING parquet")
0145                 columns = sorted(spark.catalog.listColumns("tab1"), key=lambda c: c.name)
0146                 columnsDefault = \
0147                     sorted(spark.catalog.listColumns("tab1", "default"), key=lambda c: c.name)
0148                 self.assertEquals(columns, columnsDefault)
0149                 self.assertEquals(len(columns), 2)
0150                 self.assertEquals(columns[0], Column(
0151                     name="age",
0152                     description=None,
0153                     dataType="int",
0154                     nullable=True,
0155                     isPartition=False,
0156                     isBucket=False))
0157                 self.assertEquals(columns[1], Column(
0158                     name="name",
0159                     description=None,
0160                     dataType="string",
0161                     nullable=True,
0162                     isPartition=False,
0163                     isBucket=False))
0164                 columns2 = \
0165                     sorted(spark.catalog.listColumns("tab2", "some_db"), key=lambda c: c.name)
0166                 self.assertEquals(len(columns2), 2)
0167                 self.assertEquals(columns2[0], Column(
0168                     name="nickname",
0169                     description=None,
0170                     dataType="string",
0171                     nullable=True,
0172                     isPartition=False,
0173                     isBucket=False))
0174                 self.assertEquals(columns2[1], Column(
0175                     name="tolerance",
0176                     description=None,
0177                     dataType="float",
0178                     nullable=True,
0179                     isPartition=False,
0180                     isBucket=False))
0181                 self.assertRaisesRegexp(
0182                     AnalysisException,
0183                     "tab2",
0184                     lambda: spark.catalog.listColumns("tab2"))
0185                 self.assertRaisesRegexp(
0186                     AnalysisException,
0187                     "does_not_exist",
0188                     lambda: spark.catalog.listColumns("does_not_exist"))
0189 
0190 
0191 if __name__ == "__main__":
0192     import unittest
0193     from pyspark.sql.tests.test_catalog import *
0194 
0195     try:
0196         import xmlrunner
0197         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0198     except ImportError:
0199         testRunner = None
0200     unittest.main(testRunner=testRunner, verbosity=2)