0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 from __future__ import print_function
0019 import sys
0020 import warnings
0021
0022 if sys.version >= '3':
0023 basestring = unicode = str
0024
0025 from pyspark import since, _NoValue
0026 from pyspark.rdd import ignore_unicode_prefix
0027 from pyspark.sql.session import _monkey_patch_RDD, SparkSession
0028 from pyspark.sql.dataframe import DataFrame
0029 from pyspark.sql.readwriter import DataFrameReader
0030 from pyspark.sql.streaming import DataStreamReader
0031 from pyspark.sql.types import IntegerType, Row, StringType
0032 from pyspark.sql.udf import UDFRegistration
0033 from pyspark.sql.utils import install_exception_handler
0034
0035 __all__ = ["SQLContext", "HiveContext"]
0036
0037
0038 class SQLContext(object):
0039 """The entry point for working with structured data (rows and columns) in Spark, in Spark 1.x.
0040
0041 As of Spark 2.0, this is replaced by :class:`SparkSession`. However, we are keeping the class
0042 here for backward compatibility.
0043
0044 A SQLContext can be used create :class:`DataFrame`, register :class:`DataFrame` as
0045 tables, execute SQL over tables, cache tables, and read parquet files.
0046
0047 :param sparkContext: The :class:`SparkContext` backing this SQLContext.
0048 :param sparkSession: The :class:`SparkSession` around which this SQLContext wraps.
0049 :param jsqlContext: An optional JVM Scala SQLContext. If set, we do not instantiate a new
0050 SQLContext in the JVM, instead we make all calls to this object.
0051 """
0052
0053 _instantiatedContext = None
0054
0055 @ignore_unicode_prefix
0056 def __init__(self, sparkContext, sparkSession=None, jsqlContext=None):
0057 """Creates a new SQLContext.
0058
0059 .. note:: Deprecated in 3.0.0. Use :func:`SparkSession.builder.getOrCreate()` instead.
0060
0061 >>> from datetime import datetime
0062 >>> sqlContext = SQLContext(sc)
0063 >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1,
0064 ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
0065 ... time=datetime(2014, 8, 1, 14, 1, 5))])
0066 >>> df = allTypes.toDF()
0067 >>> df.createOrReplaceTempView("allTypes")
0068 >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
0069 ... 'from allTypes where b and i > 0').collect()
0070 [Row((i + CAST(1 AS BIGINT))=2, (d + CAST(1 AS DOUBLE))=2.0, (NOT b)=False, list[1]=2, \
0071 dict[s]=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
0072 >>> df.rdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect()
0073 [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
0074 """
0075 warnings.warn(
0076 "Deprecated in 3.0.0. Use SparkSession.builder.getOrCreate() instead.",
0077 DeprecationWarning)
0078
0079 self._sc = sparkContext
0080 self._jsc = self._sc._jsc
0081 self._jvm = self._sc._jvm
0082 if sparkSession is None:
0083 sparkSession = SparkSession.builder.getOrCreate()
0084 if jsqlContext is None:
0085 jsqlContext = sparkSession._jwrapped
0086 self.sparkSession = sparkSession
0087 self._jsqlContext = jsqlContext
0088 _monkey_patch_RDD(self.sparkSession)
0089 install_exception_handler()
0090 if SQLContext._instantiatedContext is None:
0091 SQLContext._instantiatedContext = self
0092
0093 @property
0094 def _ssql_ctx(self):
0095 """Accessor for the JVM Spark SQL context.
0096
0097 Subclasses can override this property to provide their own
0098 JVM Contexts.
0099 """
0100 return self._jsqlContext
0101
0102 @property
0103 def _conf(self):
0104 """Accessor for the JVM SQL-specific configurations"""
0105 return self.sparkSession._jsparkSession.sessionState().conf()
0106
0107 @classmethod
0108 @since(1.6)
0109 def getOrCreate(cls, sc):
0110 """
0111 Get the existing SQLContext or create a new one with given SparkContext.
0112
0113 :param sc: SparkContext
0114
0115 .. note:: Deprecated in 3.0.0. Use :func:`SparkSession.builder.getOrCreate()` instead.
0116 """
0117 warnings.warn(
0118 "Deprecated in 3.0.0. Use SparkSession.builder.getOrCreate() instead.",
0119 DeprecationWarning)
0120
0121 if cls._instantiatedContext is None:
0122 jsqlContext = sc._jvm.SparkSession.builder().sparkContext(
0123 sc._jsc.sc()).getOrCreate().sqlContext()
0124 sparkSession = SparkSession(sc, jsqlContext.sparkSession())
0125 cls(sc, sparkSession, jsqlContext)
0126 return cls._instantiatedContext
0127
0128 @since(1.6)
0129 def newSession(self):
0130 """
0131 Returns a new SQLContext as new session, that has separate SQLConf,
0132 registered temporary views and UDFs, but shared SparkContext and
0133 table cache.
0134 """
0135 return self.__class__(self._sc, self.sparkSession.newSession())
0136
0137 @since(1.3)
0138 def setConf(self, key, value):
0139 """Sets the given Spark SQL configuration property.
0140 """
0141 self.sparkSession.conf.set(key, value)
0142
0143 @ignore_unicode_prefix
0144 @since(1.3)
0145 def getConf(self, key, defaultValue=_NoValue):
0146 """Returns the value of Spark SQL configuration property for the given key.
0147
0148 If the key is not set and defaultValue is set, return
0149 defaultValue. If the key is not set and defaultValue is not set, return
0150 the system default value.
0151
0152 >>> sqlContext.getConf("spark.sql.shuffle.partitions")
0153 u'200'
0154 >>> sqlContext.getConf("spark.sql.shuffle.partitions", u"10")
0155 u'10'
0156 >>> sqlContext.setConf("spark.sql.shuffle.partitions", u"50")
0157 >>> sqlContext.getConf("spark.sql.shuffle.partitions", u"10")
0158 u'50'
0159 """
0160 return self.sparkSession.conf.get(key, defaultValue)
0161
0162 @property
0163 @since("1.3.1")
0164 def udf(self):
0165 """Returns a :class:`UDFRegistration` for UDF registration.
0166
0167 :return: :class:`UDFRegistration`
0168 """
0169 return self.sparkSession.udf
0170
0171 @since(1.4)
0172 def range(self, start, end=None, step=1, numPartitions=None):
0173 """
0174 Create a :class:`DataFrame` with single :class:`pyspark.sql.types.LongType` column named
0175 ``id``, containing elements in a range from ``start`` to ``end`` (exclusive) with
0176 step value ``step``.
0177
0178 :param start: the start value
0179 :param end: the end value (exclusive)
0180 :param step: the incremental step (default: 1)
0181 :param numPartitions: the number of partitions of the DataFrame
0182 :return: :class:`DataFrame`
0183
0184 >>> sqlContext.range(1, 7, 2).collect()
0185 [Row(id=1), Row(id=3), Row(id=5)]
0186
0187 If only one argument is specified, it will be used as the end value.
0188
0189 >>> sqlContext.range(3).collect()
0190 [Row(id=0), Row(id=1), Row(id=2)]
0191 """
0192 return self.sparkSession.range(start, end, step, numPartitions)
0193
0194 @since(1.2)
0195 def registerFunction(self, name, f, returnType=None):
0196 """An alias for :func:`spark.udf.register`.
0197 See :meth:`pyspark.sql.UDFRegistration.register`.
0198
0199 .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.register` instead.
0200 """
0201 warnings.warn(
0202 "Deprecated in 2.3.0. Use spark.udf.register instead.",
0203 DeprecationWarning)
0204 return self.sparkSession.udf.register(name, f, returnType)
0205
0206 @since(2.1)
0207 def registerJavaFunction(self, name, javaClassName, returnType=None):
0208 """An alias for :func:`spark.udf.registerJavaFunction`.
0209 See :meth:`pyspark.sql.UDFRegistration.registerJavaFunction`.
0210
0211 .. note:: Deprecated in 2.3.0. Use :func:`spark.udf.registerJavaFunction` instead.
0212 """
0213 warnings.warn(
0214 "Deprecated in 2.3.0. Use spark.udf.registerJavaFunction instead.",
0215 DeprecationWarning)
0216 return self.sparkSession.udf.registerJavaFunction(name, javaClassName, returnType)
0217
0218
0219 def _inferSchema(self, rdd, samplingRatio=None):
0220 """
0221 Infer schema from an RDD of Row or tuple.
0222
0223 :param rdd: an RDD of Row or tuple
0224 :param samplingRatio: sampling ratio, or no sampling (default)
0225 :return: :class:`pyspark.sql.types.StructType`
0226 """
0227 return self.sparkSession._inferSchema(rdd, samplingRatio)
0228
0229 @since(1.3)
0230 @ignore_unicode_prefix
0231 def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
0232 """
0233 Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.
0234
0235 When ``schema`` is a list of column names, the type of each column
0236 will be inferred from ``data``.
0237
0238 When ``schema`` is ``None``, it will try to infer the schema (column names and types)
0239 from ``data``, which should be an RDD of :class:`Row`,
0240 or :class:`namedtuple`, or :class:`dict`.
0241
0242 When ``schema`` is :class:`pyspark.sql.types.DataType` or a datatype string it must match
0243 the real data, or an exception will be thrown at runtime. If the given schema is not
0244 :class:`pyspark.sql.types.StructType`, it will be wrapped into a
0245 :class:`pyspark.sql.types.StructType` as its only field, and the field name will be "value",
0246 each record will also be wrapped into a tuple, which can be converted to row later.
0247
0248 If schema inference is needed, ``samplingRatio`` is used to determined the ratio of
0249 rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``.
0250
0251 :param data: an RDD of any kind of SQL data representation(e.g. :class:`Row`,
0252 :class:`tuple`, ``int``, ``boolean``, etc.), or :class:`list`, or
0253 :class:`pandas.DataFrame`.
0254 :param schema: a :class:`pyspark.sql.types.DataType` or a datatype string or a list of
0255 column names, default is None. The data type string format equals to
0256 :class:`pyspark.sql.types.DataType.simpleString`, except that top level struct type can
0257 omit the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use
0258 ``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`.
0259 We can also use ``int`` as a short name for :class:`pyspark.sql.types.IntegerType`.
0260 :param samplingRatio: the sample ratio of rows used for inferring
0261 :param verifySchema: verify data types of every row against schema.
0262 :return: :class:`DataFrame`
0263
0264 .. versionchanged:: 2.0
0265 The ``schema`` parameter can be a :class:`pyspark.sql.types.DataType` or a
0266 datatype string after 2.0.
0267 If it's not a :class:`pyspark.sql.types.StructType`, it will be wrapped into a
0268 :class:`pyspark.sql.types.StructType` and each record will also be wrapped into a tuple.
0269
0270 .. versionchanged:: 2.1
0271 Added verifySchema.
0272
0273 >>> l = [('Alice', 1)]
0274 >>> sqlContext.createDataFrame(l).collect()
0275 [Row(_1=u'Alice', _2=1)]
0276 >>> sqlContext.createDataFrame(l, ['name', 'age']).collect()
0277 [Row(name=u'Alice', age=1)]
0278
0279 >>> d = [{'name': 'Alice', 'age': 1}]
0280 >>> sqlContext.createDataFrame(d).collect()
0281 [Row(age=1, name=u'Alice')]
0282
0283 >>> rdd = sc.parallelize(l)
0284 >>> sqlContext.createDataFrame(rdd).collect()
0285 [Row(_1=u'Alice', _2=1)]
0286 >>> df = sqlContext.createDataFrame(rdd, ['name', 'age'])
0287 >>> df.collect()
0288 [Row(name=u'Alice', age=1)]
0289
0290 >>> from pyspark.sql import Row
0291 >>> Person = Row('name', 'age')
0292 >>> person = rdd.map(lambda r: Person(*r))
0293 >>> df2 = sqlContext.createDataFrame(person)
0294 >>> df2.collect()
0295 [Row(name=u'Alice', age=1)]
0296
0297 >>> from pyspark.sql.types import *
0298 >>> schema = StructType([
0299 ... StructField("name", StringType(), True),
0300 ... StructField("age", IntegerType(), True)])
0301 >>> df3 = sqlContext.createDataFrame(rdd, schema)
0302 >>> df3.collect()
0303 [Row(name=u'Alice', age=1)]
0304
0305 >>> sqlContext.createDataFrame(df.toPandas()).collect() # doctest: +SKIP
0306 [Row(name=u'Alice', age=1)]
0307 >>> sqlContext.createDataFrame(pandas.DataFrame([[1, 2]])).collect() # doctest: +SKIP
0308 [Row(0=1, 1=2)]
0309
0310 >>> sqlContext.createDataFrame(rdd, "a: string, b: int").collect()
0311 [Row(a=u'Alice', b=1)]
0312 >>> rdd = rdd.map(lambda row: row[1])
0313 >>> sqlContext.createDataFrame(rdd, "int").collect()
0314 [Row(value=1)]
0315 >>> sqlContext.createDataFrame(rdd, "boolean").collect() # doctest: +IGNORE_EXCEPTION_DETAIL
0316 Traceback (most recent call last):
0317 ...
0318 Py4JJavaError: ...
0319 """
0320 return self.sparkSession.createDataFrame(data, schema, samplingRatio, verifySchema)
0321
0322 @since(1.3)
0323 def registerDataFrameAsTable(self, df, tableName):
0324 """Registers the given :class:`DataFrame` as a temporary table in the catalog.
0325
0326 Temporary tables exist only during the lifetime of this instance of :class:`SQLContext`.
0327
0328 >>> sqlContext.registerDataFrameAsTable(df, "table1")
0329 """
0330 df.createOrReplaceTempView(tableName)
0331
0332 @since(1.6)
0333 def dropTempTable(self, tableName):
0334 """ Remove the temporary table from catalog.
0335
0336 >>> sqlContext.registerDataFrameAsTable(df, "table1")
0337 >>> sqlContext.dropTempTable("table1")
0338 """
0339 self.sparkSession.catalog.dropTempView(tableName)
0340
0341 @since(1.3)
0342 def createExternalTable(self, tableName, path=None, source=None, schema=None, **options):
0343 """Creates an external table based on the dataset in a data source.
0344
0345 It returns the DataFrame associated with the external table.
0346
0347 The data source is specified by the ``source`` and a set of ``options``.
0348 If ``source`` is not specified, the default data source configured by
0349 ``spark.sql.sources.default`` will be used.
0350
0351 Optionally, a schema can be provided as the schema of the returned :class:`DataFrame` and
0352 created external table.
0353
0354 :return: :class:`DataFrame`
0355 """
0356 return self.sparkSession.catalog.createExternalTable(
0357 tableName, path, source, schema, **options)
0358
0359 @ignore_unicode_prefix
0360 @since(1.0)
0361 def sql(self, sqlQuery):
0362 """Returns a :class:`DataFrame` representing the result of the given query.
0363
0364 :return: :class:`DataFrame`
0365
0366 >>> sqlContext.registerDataFrameAsTable(df, "table1")
0367 >>> df2 = sqlContext.sql("SELECT field1 AS f1, field2 as f2 from table1")
0368 >>> df2.collect()
0369 [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
0370 """
0371 return self.sparkSession.sql(sqlQuery)
0372
0373 @since(1.0)
0374 def table(self, tableName):
0375 """Returns the specified table or view as a :class:`DataFrame`.
0376
0377 :return: :class:`DataFrame`
0378
0379 >>> sqlContext.registerDataFrameAsTable(df, "table1")
0380 >>> df2 = sqlContext.table("table1")
0381 >>> sorted(df.collect()) == sorted(df2.collect())
0382 True
0383 """
0384 return self.sparkSession.table(tableName)
0385
0386 @ignore_unicode_prefix
0387 @since(1.3)
0388 def tables(self, dbName=None):
0389 """Returns a :class:`DataFrame` containing names of tables in the given database.
0390
0391 If ``dbName`` is not specified, the current database will be used.
0392
0393 The returned DataFrame has two columns: ``tableName`` and ``isTemporary``
0394 (a column with :class:`BooleanType` indicating if a table is a temporary one or not).
0395
0396 :param dbName: string, name of the database to use.
0397 :return: :class:`DataFrame`
0398
0399 >>> sqlContext.registerDataFrameAsTable(df, "table1")
0400 >>> df2 = sqlContext.tables()
0401 >>> df2.filter("tableName = 'table1'").first()
0402 Row(database=u'', tableName=u'table1', isTemporary=True)
0403 """
0404 if dbName is None:
0405 return DataFrame(self._ssql_ctx.tables(), self)
0406 else:
0407 return DataFrame(self._ssql_ctx.tables(dbName), self)
0408
0409 @since(1.3)
0410 def tableNames(self, dbName=None):
0411 """Returns a list of names of tables in the database ``dbName``.
0412
0413 :param dbName: string, name of the database to use. Default to the current database.
0414 :return: list of table names, in string
0415
0416 >>> sqlContext.registerDataFrameAsTable(df, "table1")
0417 >>> "table1" in sqlContext.tableNames()
0418 True
0419 >>> "table1" in sqlContext.tableNames("default")
0420 True
0421 """
0422 if dbName is None:
0423 return [name for name in self._ssql_ctx.tableNames()]
0424 else:
0425 return [name for name in self._ssql_ctx.tableNames(dbName)]
0426
0427 @since(1.0)
0428 def cacheTable(self, tableName):
0429 """Caches the specified table in-memory."""
0430 self._ssql_ctx.cacheTable(tableName)
0431
0432 @since(1.0)
0433 def uncacheTable(self, tableName):
0434 """Removes the specified table from the in-memory cache."""
0435 self._ssql_ctx.uncacheTable(tableName)
0436
0437 @since(1.3)
0438 def clearCache(self):
0439 """Removes all cached tables from the in-memory cache. """
0440 self._ssql_ctx.clearCache()
0441
0442 @property
0443 @since(1.4)
0444 def read(self):
0445 """
0446 Returns a :class:`DataFrameReader` that can be used to read data
0447 in as a :class:`DataFrame`.
0448
0449 :return: :class:`DataFrameReader`
0450 """
0451 return DataFrameReader(self)
0452
0453 @property
0454 @since(2.0)
0455 def readStream(self):
0456 """
0457 Returns a :class:`DataStreamReader` that can be used to read data streams
0458 as a streaming :class:`DataFrame`.
0459
0460 .. note:: Evolving.
0461
0462 :return: :class:`DataStreamReader`
0463
0464 >>> text_sdf = sqlContext.readStream.text(tempfile.mkdtemp())
0465 >>> text_sdf.isStreaming
0466 True
0467 """
0468 return DataStreamReader(self)
0469
0470 @property
0471 @since(2.0)
0472 def streams(self):
0473 """Returns a :class:`StreamingQueryManager` that allows managing all the
0474 :class:`StreamingQuery` StreamingQueries active on `this` context.
0475
0476 .. note:: Evolving.
0477 """
0478 from pyspark.sql.streaming import StreamingQueryManager
0479 return StreamingQueryManager(self._ssql_ctx.streams())
0480
0481
0482 class HiveContext(SQLContext):
0483 """A variant of Spark SQL that integrates with data stored in Hive.
0484
0485 Configuration for Hive is read from ``hive-site.xml`` on the classpath.
0486 It supports running both SQL and HiveQL commands.
0487
0488 :param sparkContext: The SparkContext to wrap.
0489 :param jhiveContext: An optional JVM Scala HiveContext. If set, we do not instantiate a new
0490 :class:`HiveContext` in the JVM, instead we make all calls to this object.
0491
0492 .. note:: Deprecated in 2.0.0. Use SparkSession.builder.enableHiveSupport().getOrCreate().
0493 """
0494
0495 def __init__(self, sparkContext, jhiveContext=None):
0496 warnings.warn(
0497 "HiveContext is deprecated in Spark 2.0.0. Please use " +
0498 "SparkSession.builder.enableHiveSupport().getOrCreate() instead.",
0499 DeprecationWarning)
0500 if jhiveContext is None:
0501 sparkContext._conf.set("spark.sql.catalogImplementation", "hive")
0502 sparkSession = SparkSession.builder._sparkContext(sparkContext).getOrCreate()
0503 else:
0504 sparkSession = SparkSession(sparkContext, jhiveContext.sparkSession())
0505 SQLContext.__init__(self, sparkContext, sparkSession, jhiveContext)
0506
0507 @classmethod
0508 def _createForTesting(cls, sparkContext):
0509 """(Internal use only) Create a new HiveContext for testing.
0510
0511 All test code that touches HiveContext *must* go through this method. Otherwise,
0512 you may end up launching multiple derby instances and encounter with incredibly
0513 confusing error messages.
0514 """
0515 jsc = sparkContext._jsc.sc()
0516 jtestHive = sparkContext._jvm.org.apache.spark.sql.hive.test.TestHiveContext(jsc, False)
0517 return cls(sparkContext, jtestHive)
0518
0519 def refreshTable(self, tableName):
0520 """Invalidate and refresh all the cached the metadata of the given
0521 table. For performance reasons, Spark SQL or the external data source
0522 library it uses might cache certain metadata about a table, such as the
0523 location of blocks. When those change outside of Spark SQL, users should
0524 call this function to invalidate the cache.
0525 """
0526 self._ssql_ctx.refreshTable(tableName)
0527
0528
0529 def _test():
0530 import os
0531 import doctest
0532 import tempfile
0533 from pyspark.context import SparkContext
0534 from pyspark.sql import Row, SQLContext
0535 import pyspark.sql.context
0536
0537 os.chdir(os.environ["SPARK_HOME"])
0538
0539 globs = pyspark.sql.context.__dict__.copy()
0540 sc = SparkContext('local[4]', 'PythonTest')
0541 globs['tempfile'] = tempfile
0542 globs['os'] = os
0543 globs['sc'] = sc
0544 globs['sqlContext'] = SQLContext(sc)
0545 globs['rdd'] = rdd = sc.parallelize(
0546 [Row(field1=1, field2="row1"),
0547 Row(field1=2, field2="row2"),
0548 Row(field1=3, field2="row3")]
0549 )
0550 globs['df'] = rdd.toDF()
0551 jsonStrings = [
0552 '{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
0553 '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},"field6":[{"field7": "row2"}]}',
0554 '{"field1" : null, "field2": "row3", "field3":{"field4":33, "field5": []}}'
0555 ]
0556 globs['jsonStrings'] = jsonStrings
0557 globs['json'] = sc.parallelize(jsonStrings)
0558 (failure_count, test_count) = doctest.testmod(
0559 pyspark.sql.context, globs=globs,
0560 optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
0561 globs['sc'].stop()
0562 if failure_count:
0563 sys.exit(-1)
0564
0565
0566 if __name__ == "__main__":
0567 _test()