0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 from pyspark.sql.utils import AnalysisException
0019 from pyspark.testing.sqlutils import ReusedSQLTestCase
0020
0021
0022 class CatalogTests(ReusedSQLTestCase):
0023
0024 def test_current_database(self):
0025 spark = self.spark
0026 with self.database("some_db"):
0027 self.assertEquals(spark.catalog.currentDatabase(), "default")
0028 spark.sql("CREATE DATABASE some_db")
0029 spark.catalog.setCurrentDatabase("some_db")
0030 self.assertEquals(spark.catalog.currentDatabase(), "some_db")
0031 self.assertRaisesRegexp(
0032 AnalysisException,
0033 "does_not_exist",
0034 lambda: spark.catalog.setCurrentDatabase("does_not_exist"))
0035
0036 def test_list_databases(self):
0037 spark = self.spark
0038 with self.database("some_db"):
0039 databases = [db.name for db in spark.catalog.listDatabases()]
0040 self.assertEquals(databases, ["default"])
0041 spark.sql("CREATE DATABASE some_db")
0042 databases = [db.name for db in spark.catalog.listDatabases()]
0043 self.assertEquals(sorted(databases), ["default", "some_db"])
0044
0045 def test_list_tables(self):
0046 from pyspark.sql.catalog import Table
0047 spark = self.spark
0048 with self.database("some_db"):
0049 spark.sql("CREATE DATABASE some_db")
0050 with self.table("tab1", "some_db.tab2"):
0051 with self.tempView("temp_tab"):
0052 self.assertEquals(spark.catalog.listTables(), [])
0053 self.assertEquals(spark.catalog.listTables("some_db"), [])
0054 spark.createDataFrame([(1, 1)]).createOrReplaceTempView("temp_tab")
0055 spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
0056 spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet")
0057 tables = sorted(spark.catalog.listTables(), key=lambda t: t.name)
0058 tablesDefault = \
0059 sorted(spark.catalog.listTables("default"), key=lambda t: t.name)
0060 tablesSomeDb = \
0061 sorted(spark.catalog.listTables("some_db"), key=lambda t: t.name)
0062 self.assertEquals(tables, tablesDefault)
0063 self.assertEquals(len(tables), 2)
0064 self.assertEquals(len(tablesSomeDb), 2)
0065 self.assertEquals(tables[0], Table(
0066 name="tab1",
0067 database="default",
0068 description=None,
0069 tableType="MANAGED",
0070 isTemporary=False))
0071 self.assertEquals(tables[1], Table(
0072 name="temp_tab",
0073 database=None,
0074 description=None,
0075 tableType="TEMPORARY",
0076 isTemporary=True))
0077 self.assertEquals(tablesSomeDb[0], Table(
0078 name="tab2",
0079 database="some_db",
0080 description=None,
0081 tableType="MANAGED",
0082 isTemporary=False))
0083 self.assertEquals(tablesSomeDb[1], Table(
0084 name="temp_tab",
0085 database=None,
0086 description=None,
0087 tableType="TEMPORARY",
0088 isTemporary=True))
0089 self.assertRaisesRegexp(
0090 AnalysisException,
0091 "does_not_exist",
0092 lambda: spark.catalog.listTables("does_not_exist"))
0093
0094 def test_list_functions(self):
0095 from pyspark.sql.catalog import Function
0096 spark = self.spark
0097 with self.database("some_db"):
0098 spark.sql("CREATE DATABASE some_db")
0099 functions = dict((f.name, f) for f in spark.catalog.listFunctions())
0100 functionsDefault = dict((f.name, f) for f in spark.catalog.listFunctions("default"))
0101 self.assertTrue(len(functions) > 200)
0102 self.assertTrue("+" in functions)
0103 self.assertTrue("like" in functions)
0104 self.assertTrue("month" in functions)
0105 self.assertTrue("to_date" in functions)
0106 self.assertTrue("to_timestamp" in functions)
0107 self.assertTrue("to_unix_timestamp" in functions)
0108 self.assertTrue("current_database" in functions)
0109 self.assertEquals(functions["+"], Function(
0110 name="+",
0111 description=None,
0112 className="org.apache.spark.sql.catalyst.expressions.Add",
0113 isTemporary=True))
0114 self.assertEquals(functions, functionsDefault)
0115
0116 with self.function("func1", "some_db.func2"):
0117 spark.catalog.registerFunction("temp_func", lambda x: str(x))
0118 spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'")
0119 spark.sql("CREATE FUNCTION some_db.func2 AS 'org.apache.spark.data.bricks'")
0120 newFunctions = dict((f.name, f) for f in spark.catalog.listFunctions())
0121 newFunctionsSomeDb = \
0122 dict((f.name, f) for f in spark.catalog.listFunctions("some_db"))
0123 self.assertTrue(set(functions).issubset(set(newFunctions)))
0124 self.assertTrue(set(functions).issubset(set(newFunctionsSomeDb)))
0125 self.assertTrue("temp_func" in newFunctions)
0126 self.assertTrue("func1" in newFunctions)
0127 self.assertTrue("func2" not in newFunctions)
0128 self.assertTrue("temp_func" in newFunctionsSomeDb)
0129 self.assertTrue("func1" not in newFunctionsSomeDb)
0130 self.assertTrue("func2" in newFunctionsSomeDb)
0131 self.assertRaisesRegexp(
0132 AnalysisException,
0133 "does_not_exist",
0134 lambda: spark.catalog.listFunctions("does_not_exist"))
0135
0136 def test_list_columns(self):
0137 from pyspark.sql.catalog import Column
0138 spark = self.spark
0139 with self.database("some_db"):
0140 spark.sql("CREATE DATABASE some_db")
0141 with self.table("tab1", "some_db.tab2"):
0142 spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
0143 spark.sql(
0144 "CREATE TABLE some_db.tab2 (nickname STRING, tolerance FLOAT) USING parquet")
0145 columns = sorted(spark.catalog.listColumns("tab1"), key=lambda c: c.name)
0146 columnsDefault = \
0147 sorted(spark.catalog.listColumns("tab1", "default"), key=lambda c: c.name)
0148 self.assertEquals(columns, columnsDefault)
0149 self.assertEquals(len(columns), 2)
0150 self.assertEquals(columns[0], Column(
0151 name="age",
0152 description=None,
0153 dataType="int",
0154 nullable=True,
0155 isPartition=False,
0156 isBucket=False))
0157 self.assertEquals(columns[1], Column(
0158 name="name",
0159 description=None,
0160 dataType="string",
0161 nullable=True,
0162 isPartition=False,
0163 isBucket=False))
0164 columns2 = \
0165 sorted(spark.catalog.listColumns("tab2", "some_db"), key=lambda c: c.name)
0166 self.assertEquals(len(columns2), 2)
0167 self.assertEquals(columns2[0], Column(
0168 name="nickname",
0169 description=None,
0170 dataType="string",
0171 nullable=True,
0172 isPartition=False,
0173 isBucket=False))
0174 self.assertEquals(columns2[1], Column(
0175 name="tolerance",
0176 description=None,
0177 dataType="float",
0178 nullable=True,
0179 isPartition=False,
0180 isBucket=False))
0181 self.assertRaisesRegexp(
0182 AnalysisException,
0183 "tab2",
0184 lambda: spark.catalog.listColumns("tab2"))
0185 self.assertRaisesRegexp(
0186 AnalysisException,
0187 "does_not_exist",
0188 lambda: spark.catalog.listColumns("does_not_exist"))
0189
0190
0191 if __name__ == "__main__":
0192 import unittest
0193 from pyspark.sql.tests.test_catalog import *
0194
0195 try:
0196 import xmlrunner
0197 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0198 except ImportError:
0199 testRunner = None
0200 unittest.main(testRunner=testRunner, verbosity=2)