0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import sys
0019 import warnings
0020 from collections import namedtuple
0021
0022 from pyspark import since
0023 from pyspark.rdd import ignore_unicode_prefix, PythonEvalType
0024 from pyspark.sql.dataframe import DataFrame
0025 from pyspark.sql.udf import UserDefinedFunction
0026 from pyspark.sql.types import IntegerType, StringType, StructType
0027
0028
0029 Database = namedtuple("Database", "name description locationUri")
0030 Table = namedtuple("Table", "name database description tableType isTemporary")
0031 Column = namedtuple("Column", "name description dataType nullable isPartition isBucket")
0032 Function = namedtuple("Function", "name description className isTemporary")
0033
0034
0035 class Catalog(object):
0036 """User-facing catalog API, accessible through `SparkSession.catalog`.
0037
0038 This is a thin wrapper around its Scala implementation org.apache.spark.sql.catalog.Catalog.
0039 """
0040
0041 def __init__(self, sparkSession):
0042 """Create a new Catalog that wraps the underlying JVM object."""
0043 self._sparkSession = sparkSession
0044 self._jsparkSession = sparkSession._jsparkSession
0045 self._jcatalog = sparkSession._jsparkSession.catalog()
0046
0047 @ignore_unicode_prefix
0048 @since(2.0)
0049 def currentDatabase(self):
0050 """Returns the current default database in this session."""
0051 return self._jcatalog.currentDatabase()
0052
0053 @ignore_unicode_prefix
0054 @since(2.0)
0055 def setCurrentDatabase(self, dbName):
0056 """Sets the current default database in this session."""
0057 return self._jcatalog.setCurrentDatabase(dbName)
0058
0059 @ignore_unicode_prefix
0060 @since(2.0)
0061 def listDatabases(self):
0062 """Returns a list of databases available across all sessions."""
0063 iter = self._jcatalog.listDatabases().toLocalIterator()
0064 databases = []
0065 while iter.hasNext():
0066 jdb = iter.next()
0067 databases.append(Database(
0068 name=jdb.name(),
0069 description=jdb.description(),
0070 locationUri=jdb.locationUri()))
0071 return databases
0072
0073 @ignore_unicode_prefix
0074 @since(2.0)
0075 def listTables(self, dbName=None):
0076 """Returns a list of tables/views in the specified database.
0077
0078 If no database is specified, the current database is used.
0079 This includes all temporary views.
0080 """
0081 if dbName is None:
0082 dbName = self.currentDatabase()
0083 iter = self._jcatalog.listTables(dbName).toLocalIterator()
0084 tables = []
0085 while iter.hasNext():
0086 jtable = iter.next()
0087 tables.append(Table(
0088 name=jtable.name(),
0089 database=jtable.database(),
0090 description=jtable.description(),
0091 tableType=jtable.tableType(),
0092 isTemporary=jtable.isTemporary()))
0093 return tables
0094
0095 @ignore_unicode_prefix
0096 @since(2.0)
0097 def listFunctions(self, dbName=None):
0098 """Returns a list of functions registered in the specified database.
0099
0100 If no database is specified, the current database is used.
0101 This includes all temporary functions.
0102 """
0103 if dbName is None:
0104 dbName = self.currentDatabase()
0105 iter = self._jcatalog.listFunctions(dbName).toLocalIterator()
0106 functions = []
0107 while iter.hasNext():
0108 jfunction = iter.next()
0109 functions.append(Function(
0110 name=jfunction.name(),
0111 description=jfunction.description(),
0112 className=jfunction.className(),
0113 isTemporary=jfunction.isTemporary()))
0114 return functions
0115
0116 @ignore_unicode_prefix
0117 @since(2.0)
0118 def listColumns(self, tableName, dbName=None):
0119 """Returns a list of columns for the given table/view in the specified database.
0120
0121 If no database is specified, the current database is used.
0122
0123 Note: the order of arguments here is different from that of its JVM counterpart
0124 because Python does not support method overloading.
0125 """
0126 if dbName is None:
0127 dbName = self.currentDatabase()
0128 iter = self._jcatalog.listColumns(dbName, tableName).toLocalIterator()
0129 columns = []
0130 while iter.hasNext():
0131 jcolumn = iter.next()
0132 columns.append(Column(
0133 name=jcolumn.name(),
0134 description=jcolumn.description(),
0135 dataType=jcolumn.dataType(),
0136 nullable=jcolumn.nullable(),
0137 isPartition=jcolumn.isPartition(),
0138 isBucket=jcolumn.isBucket()))
0139 return columns
0140
0141 @since(2.0)
0142 def createExternalTable(self, tableName, path=None, source=None, schema=None, **options):
0143 """Creates a table based on the dataset in a data source.
0144
0145 It returns the DataFrame associated with the external table.
0146
0147 The data source is specified by the ``source`` and a set of ``options``.
0148 If ``source`` is not specified, the default data source configured by
0149 ``spark.sql.sources.default`` will be used.
0150
0151 Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and
0152 created external table.
0153
0154 :return: :class:`DataFrame`
0155 """
0156 warnings.warn(
0157 "createExternalTable is deprecated since Spark 2.2, please use createTable instead.",
0158 DeprecationWarning)
0159 return self.createTable(tableName, path, source, schema, **options)
0160
0161 @since(2.2)
0162 def createTable(self, tableName, path=None, source=None, schema=None, **options):
0163 """Creates a table based on the dataset in a data source.
0164
0165 It returns the DataFrame associated with the table.
0166
0167 The data source is specified by the ``source`` and a set of ``options``.
0168 If ``source`` is not specified, the default data source configured by
0169 ``spark.sql.sources.default`` will be used. When ``path`` is specified, an external table is
0170 created from the data at the given path. Otherwise a managed table is created.
0171
0172 Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and
0173 created table.
0174
0175 :return: :class:`DataFrame`
0176 """
0177 if path is not None:
0178 options["path"] = path
0179 if source is None:
0180 source = self._sparkSession._wrapped._conf.defaultDataSourceName()
0181 if schema is None:
0182 df = self._jcatalog.createTable(tableName, source, options)
0183 else:
0184 if not isinstance(schema, StructType):
0185 raise TypeError("schema should be StructType")
0186 scala_datatype = self._jsparkSession.parseDataType(schema.json())
0187 df = self._jcatalog.createTable(tableName, source, scala_datatype, options)
0188 return DataFrame(df, self._sparkSession._wrapped)
0189
0190 @since(2.0)
0191 def dropTempView(self, viewName):
0192 """Drops the local temporary view with the given view name in the catalog.
0193 If the view has been cached before, then it will also be uncached.
0194 Returns true if this view is dropped successfully, false otherwise.
0195
0196 Note that, the return type of this method was None in Spark 2.0, but changed to Boolean
0197 in Spark 2.1.
0198
0199 >>> spark.createDataFrame([(1, 1)]).createTempView("my_table")
0200 >>> spark.table("my_table").collect()
0201 [Row(_1=1, _2=1)]
0202 >>> spark.catalog.dropTempView("my_table")
0203 >>> spark.table("my_table") # doctest: +IGNORE_EXCEPTION_DETAIL
0204 Traceback (most recent call last):
0205 ...
0206 AnalysisException: ...
0207 """
0208 self._jcatalog.dropTempView(viewName)
0209
0210 @since(2.1)
0211 def dropGlobalTempView(self, viewName):
0212 """Drops the global temporary view with the given view name in the catalog.
0213 If the view has been cached before, then it will also be uncached.
0214 Returns true if this view is dropped successfully, false otherwise.
0215
0216 >>> spark.createDataFrame([(1, 1)]).createGlobalTempView("my_table")
0217 >>> spark.table("global_temp.my_table").collect()
0218 [Row(_1=1, _2=1)]
0219 >>> spark.catalog.dropGlobalTempView("my_table")
0220 >>> spark.table("global_temp.my_table") # doctest: +IGNORE_EXCEPTION_DETAIL
0221 Traceback (most recent call last):
0222 ...
0223 AnalysisException: ...
0224 """
0225 self._jcatalog.dropGlobalTempView(viewName)
0226
0227 @since(2.0)
0228 def registerFunction(self, name, f, returnType=None):
0229 """An alias for :func:`spark.udf.register`.
0230 See :meth:`pyspark.sql.UDFRegistration.register`.
0231
0232 .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead.
0233 """
0234 warnings.warn(
0235 "Deprecated in 2.3.0. Use spark.udf.register instead.",
0236 DeprecationWarning)
0237 return self._sparkSession.udf.register(name, f, returnType)
0238
0239 @since(2.0)
0240 def isCached(self, tableName):
0241 """Returns true if the table is currently cached in-memory."""
0242 return self._jcatalog.isCached(tableName)
0243
0244 @since(2.0)
0245 def cacheTable(self, tableName):
0246 """Caches the specified table in-memory."""
0247 self._jcatalog.cacheTable(tableName)
0248
0249 @since(2.0)
0250 def uncacheTable(self, tableName):
0251 """Removes the specified table from the in-memory cache."""
0252 self._jcatalog.uncacheTable(tableName)
0253
0254 @since(2.0)
0255 def clearCache(self):
0256 """Removes all cached tables from the in-memory cache."""
0257 self._jcatalog.clearCache()
0258
0259 @since(2.0)
0260 def refreshTable(self, tableName):
0261 """Invalidates and refreshes all the cached data and metadata of the given table."""
0262 self._jcatalog.refreshTable(tableName)
0263
0264 @since('2.1.1')
0265 def recoverPartitions(self, tableName):
0266 """Recovers all the partitions of the given table and update the catalog.
0267
0268 Only works with a partitioned table, and not a view.
0269 """
0270 self._jcatalog.recoverPartitions(tableName)
0271
0272 @since('2.2.0')
0273 def refreshByPath(self, path):
0274 """Invalidates and refreshes all the cached data (and the associated metadata) for any
0275 DataFrame that contains the given data source path.
0276 """
0277 self._jcatalog.refreshByPath(path)
0278
0279 def _reset(self):
0280 """(Internal use only) Drop all existing databases (except "default"), tables,
0281 partitions and functions, and set the current database to "default".
0282
0283 This is mainly used for tests.
0284 """
0285 self._jsparkSession.sessionState().catalog().reset()
0286
0287
0288 def _test():
0289 import os
0290 import doctest
0291 from pyspark.sql import SparkSession
0292 import pyspark.sql.catalog
0293
0294 os.chdir(os.environ["SPARK_HOME"])
0295
0296 globs = pyspark.sql.catalog.__dict__.copy()
0297 spark = SparkSession.builder\
0298 .master("local[4]")\
0299 .appName("sql.catalog tests")\
0300 .getOrCreate()
0301 globs['sc'] = spark.sparkContext
0302 globs['spark'] = spark
0303 (failure_count, test_count) = doctest.testmod(
0304 pyspark.sql.catalog,
0305 globs=globs,
0306 optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
0307 spark.stop()
0308 if failure_count:
0309 sys.exit(-1)
0310
0311 if __name__ == "__main__":
0312 _test()