Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import sys
0019 import warnings
0020 from collections import namedtuple
0021 
0022 from pyspark import since
0023 from pyspark.rdd import ignore_unicode_prefix, PythonEvalType
0024 from pyspark.sql.dataframe import DataFrame
0025 from pyspark.sql.udf import UserDefinedFunction
0026 from pyspark.sql.types import IntegerType, StringType, StructType
0027 
0028 
0029 Database = namedtuple("Database", "name description locationUri")
0030 Table = namedtuple("Table", "name database description tableType isTemporary")
0031 Column = namedtuple("Column", "name description dataType nullable isPartition isBucket")
0032 Function = namedtuple("Function", "name description className isTemporary")
0033 
0034 
0035 class Catalog(object):
0036     """User-facing catalog API, accessible through `SparkSession.catalog`.
0037 
0038     This is a thin wrapper around its Scala implementation org.apache.spark.sql.catalog.Catalog.
0039     """
0040 
0041     def __init__(self, sparkSession):
0042         """Create a new Catalog that wraps the underlying JVM object."""
0043         self._sparkSession = sparkSession
0044         self._jsparkSession = sparkSession._jsparkSession
0045         self._jcatalog = sparkSession._jsparkSession.catalog()
0046 
0047     @ignore_unicode_prefix
0048     @since(2.0)
0049     def currentDatabase(self):
0050         """Returns the current default database in this session."""
0051         return self._jcatalog.currentDatabase()
0052 
0053     @ignore_unicode_prefix
0054     @since(2.0)
0055     def setCurrentDatabase(self, dbName):
0056         """Sets the current default database in this session."""
0057         return self._jcatalog.setCurrentDatabase(dbName)
0058 
0059     @ignore_unicode_prefix
0060     @since(2.0)
0061     def listDatabases(self):
0062         """Returns a list of databases available across all sessions."""
0063         iter = self._jcatalog.listDatabases().toLocalIterator()
0064         databases = []
0065         while iter.hasNext():
0066             jdb = iter.next()
0067             databases.append(Database(
0068                 name=jdb.name(),
0069                 description=jdb.description(),
0070                 locationUri=jdb.locationUri()))
0071         return databases
0072 
0073     @ignore_unicode_prefix
0074     @since(2.0)
0075     def listTables(self, dbName=None):
0076         """Returns a list of tables/views in the specified database.
0077 
0078         If no database is specified, the current database is used.
0079         This includes all temporary views.
0080         """
0081         if dbName is None:
0082             dbName = self.currentDatabase()
0083         iter = self._jcatalog.listTables(dbName).toLocalIterator()
0084         tables = []
0085         while iter.hasNext():
0086             jtable = iter.next()
0087             tables.append(Table(
0088                 name=jtable.name(),
0089                 database=jtable.database(),
0090                 description=jtable.description(),
0091                 tableType=jtable.tableType(),
0092                 isTemporary=jtable.isTemporary()))
0093         return tables
0094 
0095     @ignore_unicode_prefix
0096     @since(2.0)
0097     def listFunctions(self, dbName=None):
0098         """Returns a list of functions registered in the specified database.
0099 
0100         If no database is specified, the current database is used.
0101         This includes all temporary functions.
0102         """
0103         if dbName is None:
0104             dbName = self.currentDatabase()
0105         iter = self._jcatalog.listFunctions(dbName).toLocalIterator()
0106         functions = []
0107         while iter.hasNext():
0108             jfunction = iter.next()
0109             functions.append(Function(
0110                 name=jfunction.name(),
0111                 description=jfunction.description(),
0112                 className=jfunction.className(),
0113                 isTemporary=jfunction.isTemporary()))
0114         return functions
0115 
0116     @ignore_unicode_prefix
0117     @since(2.0)
0118     def listColumns(self, tableName, dbName=None):
0119         """Returns a list of columns for the given table/view in the specified database.
0120 
0121         If no database is specified, the current database is used.
0122 
0123         Note: the order of arguments here is different from that of its JVM counterpart
0124         because Python does not support method overloading.
0125         """
0126         if dbName is None:
0127             dbName = self.currentDatabase()
0128         iter = self._jcatalog.listColumns(dbName, tableName).toLocalIterator()
0129         columns = []
0130         while iter.hasNext():
0131             jcolumn = iter.next()
0132             columns.append(Column(
0133                 name=jcolumn.name(),
0134                 description=jcolumn.description(),
0135                 dataType=jcolumn.dataType(),
0136                 nullable=jcolumn.nullable(),
0137                 isPartition=jcolumn.isPartition(),
0138                 isBucket=jcolumn.isBucket()))
0139         return columns
0140 
0141     @since(2.0)
0142     def createExternalTable(self, tableName, path=None, source=None, schema=None, **options):
0143         """Creates a table based on the dataset in a data source.
0144 
0145         It returns the DataFrame associated with the external table.
0146 
0147         The data source is specified by the ``source`` and a set of ``options``.
0148         If ``source`` is not specified, the default data source configured by
0149         ``spark.sql.sources.default`` will be used.
0150 
0151         Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and
0152         created external table.
0153 
0154         :return: :class:`DataFrame`
0155         """
0156         warnings.warn(
0157             "createExternalTable is deprecated since Spark 2.2, please use createTable instead.",
0158             DeprecationWarning)
0159         return self.createTable(tableName, path, source, schema, **options)
0160 
0161     @since(2.2)
0162     def createTable(self, tableName, path=None, source=None, schema=None, **options):
0163         """Creates a table based on the dataset in a data source.
0164 
0165         It returns the DataFrame associated with the table.
0166 
0167         The data source is specified by the ``source`` and a set of ``options``.
0168         If ``source`` is not specified, the default data source configured by
0169         ``spark.sql.sources.default`` will be used. When ``path`` is specified, an external table is
0170         created from the data at the given path. Otherwise a managed table is created.
0171 
0172         Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and
0173         created table.
0174 
0175         :return: :class:`DataFrame`
0176         """
0177         if path is not None:
0178             options["path"] = path
0179         if source is None:
0180             source = self._sparkSession._wrapped._conf.defaultDataSourceName()
0181         if schema is None:
0182             df = self._jcatalog.createTable(tableName, source, options)
0183         else:
0184             if not isinstance(schema, StructType):
0185                 raise TypeError("schema should be StructType")
0186             scala_datatype = self._jsparkSession.parseDataType(schema.json())
0187             df = self._jcatalog.createTable(tableName, source, scala_datatype, options)
0188         return DataFrame(df, self._sparkSession._wrapped)
0189 
0190     @since(2.0)
0191     def dropTempView(self, viewName):
0192         """Drops the local temporary view with the given view name in the catalog.
0193         If the view has been cached before, then it will also be uncached.
0194         Returns true if this view is dropped successfully, false otherwise.
0195 
0196         Note that, the return type of this method was None in Spark 2.0, but changed to Boolean
0197         in Spark 2.1.
0198 
0199         >>> spark.createDataFrame([(1, 1)]).createTempView("my_table")
0200         >>> spark.table("my_table").collect()
0201         [Row(_1=1, _2=1)]
0202         >>> spark.catalog.dropTempView("my_table")
0203         >>> spark.table("my_table") # doctest: +IGNORE_EXCEPTION_DETAIL
0204         Traceback (most recent call last):
0205             ...
0206         AnalysisException: ...
0207         """
0208         self._jcatalog.dropTempView(viewName)
0209 
0210     @since(2.1)
0211     def dropGlobalTempView(self, viewName):
0212         """Drops the global temporary view with the given view name in the catalog.
0213         If the view has been cached before, then it will also be uncached.
0214         Returns true if this view is dropped successfully, false otherwise.
0215 
0216         >>> spark.createDataFrame([(1, 1)]).createGlobalTempView("my_table")
0217         >>> spark.table("global_temp.my_table").collect()
0218         [Row(_1=1, _2=1)]
0219         >>> spark.catalog.dropGlobalTempView("my_table")
0220         >>> spark.table("global_temp.my_table") # doctest: +IGNORE_EXCEPTION_DETAIL
0221         Traceback (most recent call last):
0222             ...
0223         AnalysisException: ...
0224         """
0225         self._jcatalog.dropGlobalTempView(viewName)
0226 
0227     @since(2.0)
0228     def registerFunction(self, name, f, returnType=None):
0229         """An alias for :func:`spark.udf.register`.
0230         See :meth:`pyspark.sql.UDFRegistration.register`.
0231 
0232         .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead.
0233         """
0234         warnings.warn(
0235             "Deprecated in 2.3.0. Use spark.udf.register instead.",
0236             DeprecationWarning)
0237         return self._sparkSession.udf.register(name, f, returnType)
0238 
0239     @since(2.0)
0240     def isCached(self, tableName):
0241         """Returns true if the table is currently cached in-memory."""
0242         return self._jcatalog.isCached(tableName)
0243 
0244     @since(2.0)
0245     def cacheTable(self, tableName):
0246         """Caches the specified table in-memory."""
0247         self._jcatalog.cacheTable(tableName)
0248 
0249     @since(2.0)
0250     def uncacheTable(self, tableName):
0251         """Removes the specified table from the in-memory cache."""
0252         self._jcatalog.uncacheTable(tableName)
0253 
0254     @since(2.0)
0255     def clearCache(self):
0256         """Removes all cached tables from the in-memory cache."""
0257         self._jcatalog.clearCache()
0258 
0259     @since(2.0)
0260     def refreshTable(self, tableName):
0261         """Invalidates and refreshes all the cached data and metadata of the given table."""
0262         self._jcatalog.refreshTable(tableName)
0263 
0264     @since('2.1.1')
0265     def recoverPartitions(self, tableName):
0266         """Recovers all the partitions of the given table and update the catalog.
0267 
0268         Only works with a partitioned table, and not a view.
0269         """
0270         self._jcatalog.recoverPartitions(tableName)
0271 
0272     @since('2.2.0')
0273     def refreshByPath(self, path):
0274         """Invalidates and refreshes all the cached data (and the associated metadata) for any
0275         DataFrame that contains the given data source path.
0276         """
0277         self._jcatalog.refreshByPath(path)
0278 
0279     def _reset(self):
0280         """(Internal use only) Drop all existing databases (except "default"), tables,
0281         partitions and functions, and set the current database to "default".
0282 
0283         This is mainly used for tests.
0284         """
0285         self._jsparkSession.sessionState().catalog().reset()
0286 
0287 
0288 def _test():
0289     import os
0290     import doctest
0291     from pyspark.sql import SparkSession
0292     import pyspark.sql.catalog
0293 
0294     os.chdir(os.environ["SPARK_HOME"])
0295 
0296     globs = pyspark.sql.catalog.__dict__.copy()
0297     spark = SparkSession.builder\
0298         .master("local[4]")\
0299         .appName("sql.catalog tests")\
0300         .getOrCreate()
0301     globs['sc'] = spark.sparkContext
0302     globs['spark'] = spark
0303     (failure_count, test_count) = doctest.testmod(
0304         pyspark.sql.catalog,
0305         globs=globs,
0306         optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
0307     spark.stop()
0308     if failure_count:
0309         sys.exit(-1)
0310 
0311 if __name__ == "__main__":
0312     _test()