0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019 from __future__ import absolute_import
0020 from __future__ import print_function
0021 import sys
0022 import warnings
0023 from functools import reduce
0024 from threading import RLock
0025
0026 if sys.version >= '3':
0027 basestring = unicode = str
0028 xrange = range
0029 else:
0030 from itertools import imap as map
0031
0032 from pyspark import since
0033 from pyspark.rdd import RDD, ignore_unicode_prefix
0034 from pyspark.sql.conf import RuntimeConfig
0035 from pyspark.sql.dataframe import DataFrame
0036 from pyspark.sql.pandas.conversion import SparkConversionMixin
0037 from pyspark.sql.readwriter import DataFrameReader
0038 from pyspark.sql.streaming import DataStreamReader
0039 from pyspark.sql.types import Row, DataType, StringType, StructType, \
0040 _make_type_verifier, _infer_schema, _has_nulltype, _merge_type, _create_converter, \
0041 _parse_datatype_string
0042 from pyspark.sql.utils import install_exception_handler
0043
0044 __all__ = ["SparkSession"]
0045
0046
0047 def _monkey_patch_RDD(sparkSession):
0048 def toDF(self, schema=None, sampleRatio=None):
0049 """
0050 Converts current :class:`RDD` into a :class:`DataFrame`
0051
0052 This is a shorthand for ``spark.createDataFrame(rdd, schema, sampleRatio)``
0053
0054 :param schema: a :class:`pyspark.sql.types.StructType` or list of names of columns
0055 :param samplingRatio: the sample ratio of rows used for inferring
0056 :return: a DataFrame
0057
0058 >>> rdd.toDF().collect()
0059 [Row(name=u'Alice', age=1)]
0060 """
0061 return sparkSession.createDataFrame(self, schema, sampleRatio)
0062
0063 RDD.toDF = toDF
0064
0065
0066 class SparkSession(SparkConversionMixin):
0067 """The entry point to programming Spark with the Dataset and DataFrame API.
0068
0069 A SparkSession can be used create :class:`DataFrame`, register :class:`DataFrame` as
0070 tables, execute SQL over tables, cache tables, and read parquet files.
0071 To create a SparkSession, use the following builder pattern:
0072
0073 >>> spark = SparkSession.builder \\
0074 ... .master("local") \\
0075 ... .appName("Word Count") \\
0076 ... .config("spark.some.config.option", "some-value") \\
0077 ... .getOrCreate()
0078
0079 .. autoattribute:: builder
0080 :annotation:
0081 """
0082
0083 class Builder(object):
0084 """Builder for :class:`SparkSession`.
0085 """
0086
0087 _lock = RLock()
0088 _options = {}
0089 _sc = None
0090
0091 @since(2.0)
0092 def config(self, key=None, value=None, conf=None):
0093 """Sets a config option. Options set using this method are automatically propagated to
0094 both :class:`SparkConf` and :class:`SparkSession`'s own configuration.
0095
0096 For an existing SparkConf, use `conf` parameter.
0097
0098 >>> from pyspark.conf import SparkConf
0099 >>> SparkSession.builder.config(conf=SparkConf())
0100 <pyspark.sql.session...
0101
0102 For a (key, value) pair, you can omit parameter names.
0103
0104 >>> SparkSession.builder.config("spark.some.config.option", "some-value")
0105 <pyspark.sql.session...
0106
0107 :param key: a key name string for configuration property
0108 :param value: a value for configuration property
0109 :param conf: an instance of :class:`SparkConf`
0110 """
0111 with self._lock:
0112 if conf is None:
0113 self._options[key] = str(value)
0114 else:
0115 for (k, v) in conf.getAll():
0116 self._options[k] = v
0117 return self
0118
0119 @since(2.0)
0120 def master(self, master):
0121 """Sets the Spark master URL to connect to, such as "local" to run locally, "local[4]"
0122 to run locally with 4 cores, or "spark://master:7077" to run on a Spark standalone
0123 cluster.
0124
0125 :param master: a url for spark master
0126 """
0127 return self.config("spark.master", master)
0128
0129 @since(2.0)
0130 def appName(self, name):
0131 """Sets a name for the application, which will be shown in the Spark web UI.
0132
0133 If no application name is set, a randomly generated name will be used.
0134
0135 :param name: an application name
0136 """
0137 return self.config("spark.app.name", name)
0138
0139 @since(2.0)
0140 def enableHiveSupport(self):
0141 """Enables Hive support, including connectivity to a persistent Hive metastore, support
0142 for Hive SerDes, and Hive user-defined functions.
0143 """
0144 return self.config("spark.sql.catalogImplementation", "hive")
0145
0146 def _sparkContext(self, sc):
0147 with self._lock:
0148 self._sc = sc
0149 return self
0150
0151 @since(2.0)
0152 def getOrCreate(self):
0153 """Gets an existing :class:`SparkSession` or, if there is no existing one, creates a
0154 new one based on the options set in this builder.
0155
0156 This method first checks whether there is a valid global default SparkSession, and if
0157 yes, return that one. If no valid global default SparkSession exists, the method
0158 creates a new SparkSession and assigns the newly created SparkSession as the global
0159 default.
0160
0161 >>> s1 = SparkSession.builder.config("k1", "v1").getOrCreate()
0162 >>> s1.conf.get("k1") == "v1"
0163 True
0164
0165 In case an existing SparkSession is returned, the config options specified
0166 in this builder will be applied to the existing SparkSession.
0167
0168 >>> s2 = SparkSession.builder.config("k2", "v2").getOrCreate()
0169 >>> s1.conf.get("k1") == s2.conf.get("k1")
0170 True
0171 >>> s1.conf.get("k2") == s2.conf.get("k2")
0172 True
0173 """
0174 with self._lock:
0175 from pyspark.context import SparkContext
0176 from pyspark.conf import SparkConf
0177 session = SparkSession._instantiatedSession
0178 if session is None or session._sc._jsc is None:
0179 if self._sc is not None:
0180 sc = self._sc
0181 else:
0182 sparkConf = SparkConf()
0183 for key, value in self._options.items():
0184 sparkConf.set(key, value)
0185
0186 sc = SparkContext.getOrCreate(sparkConf)
0187
0188
0189 session = SparkSession(sc)
0190 for key, value in self._options.items():
0191 session._jsparkSession.sessionState().conf().setConfString(key, value)
0192 return session
0193
0194 builder = Builder()
0195 """A class attribute having a :class:`Builder` to construct :class:`SparkSession` instances."""
0196
0197 _instantiatedSession = None
0198 _activeSession = None
0199
0200 @ignore_unicode_prefix
0201 def __init__(self, sparkContext, jsparkSession=None):
0202 """Creates a new SparkSession.
0203
0204 >>> from datetime import datetime
0205 >>> spark = SparkSession(sc)
0206 >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1,
0207 ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
0208 ... time=datetime(2014, 8, 1, 14, 1, 5))])
0209 >>> df = allTypes.toDF()
0210 >>> df.createOrReplaceTempView("allTypes")
0211 >>> spark.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
0212 ... 'from allTypes where b and i > 0').collect()
0213 [Row((i + CAST(1 AS BIGINT))=2, (d + CAST(1 AS DOUBLE))=2.0, (NOT b)=False, list[1]=2, \
0214 dict[s]=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
0215 >>> df.rdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect()
0216 [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
0217 """
0218 from pyspark.sql.context import SQLContext
0219 self._sc = sparkContext
0220 self._jsc = self._sc._jsc
0221 self._jvm = self._sc._jvm
0222 if jsparkSession is None:
0223 if self._jvm.SparkSession.getDefaultSession().isDefined() \
0224 and not self._jvm.SparkSession.getDefaultSession().get() \
0225 .sparkContext().isStopped():
0226 jsparkSession = self._jvm.SparkSession.getDefaultSession().get()
0227 else:
0228 jsparkSession = self._jvm.SparkSession(self._jsc.sc())
0229 self._jsparkSession = jsparkSession
0230 self._jwrapped = self._jsparkSession.sqlContext()
0231 self._wrapped = SQLContext(self._sc, self, self._jwrapped)
0232 _monkey_patch_RDD(self)
0233 install_exception_handler()
0234
0235
0236
0237 if SparkSession._instantiatedSession is None \
0238 or SparkSession._instantiatedSession._sc._jsc is None:
0239 SparkSession._instantiatedSession = self
0240 SparkSession._activeSession = self
0241 self._jvm.SparkSession.setDefaultSession(self._jsparkSession)
0242 self._jvm.SparkSession.setActiveSession(self._jsparkSession)
0243
0244 def _repr_html_(self):
0245 return """
0246 <div>
0247 <p><b>SparkSession - {catalogImplementation}</b></p>
0248 {sc_HTML}
0249 </div>
0250 """.format(
0251 catalogImplementation=self.conf.get("spark.sql.catalogImplementation"),
0252 sc_HTML=self.sparkContext._repr_html_()
0253 )
0254
0255 @since(2.0)
0256 def newSession(self):
0257 """
0258 Returns a new SparkSession as new session, that has separate SQLConf,
0259 registered temporary views and UDFs, but shared SparkContext and
0260 table cache.
0261 """
0262 return self.__class__(self._sc, self._jsparkSession.newSession())
0263
0264 @classmethod
0265 @since(3.0)
0266 def getActiveSession(cls):
0267 """
0268 Returns the active SparkSession for the current thread, returned by the builder.
0269 >>> s = SparkSession.getActiveSession()
0270 >>> l = [('Alice', 1)]
0271 >>> rdd = s.sparkContext.parallelize(l)
0272 >>> df = s.createDataFrame(rdd, ['name', 'age'])
0273 >>> df.select("age").collect()
0274 [Row(age=1)]
0275 """
0276 from pyspark import SparkContext
0277 sc = SparkContext._active_spark_context
0278 if sc is None:
0279 return None
0280 else:
0281 if sc._jvm.SparkSession.getActiveSession().isDefined():
0282 SparkSession(sc, sc._jvm.SparkSession.getActiveSession().get())
0283 return SparkSession._activeSession
0284 else:
0285 return None
0286
0287 @property
0288 @since(2.0)
0289 def sparkContext(self):
0290 """Returns the underlying :class:`SparkContext`."""
0291 return self._sc
0292
0293 @property
0294 @since(2.0)
0295 def version(self):
0296 """The version of Spark on which this application is running."""
0297 return self._jsparkSession.version()
0298
0299 @property
0300 @since(2.0)
0301 def conf(self):
0302 """Runtime configuration interface for Spark.
0303
0304 This is the interface through which the user can get and set all Spark and Hadoop
0305 configurations that are relevant to Spark SQL. When getting the value of a config,
0306 this defaults to the value set in the underlying :class:`SparkContext`, if any.
0307 """
0308 if not hasattr(self, "_conf"):
0309 self._conf = RuntimeConfig(self._jsparkSession.conf())
0310 return self._conf
0311
0312 @property
0313 @since(2.0)
0314 def catalog(self):
0315 """Interface through which the user may create, drop, alter or query underlying
0316 databases, tables, functions, etc.
0317
0318 :return: :class:`Catalog`
0319 """
0320 from pyspark.sql.catalog import Catalog
0321 if not hasattr(self, "_catalog"):
0322 self._catalog = Catalog(self)
0323 return self._catalog
0324
0325 @property
0326 @since(2.0)
0327 def udf(self):
0328 """Returns a :class:`UDFRegistration` for UDF registration.
0329
0330 :return: :class:`UDFRegistration`
0331 """
0332 from pyspark.sql.udf import UDFRegistration
0333 return UDFRegistration(self)
0334
0335 @since(2.0)
0336 def range(self, start, end=None, step=1, numPartitions=None):
0337 """
0338 Create a :class:`DataFrame` with single :class:`pyspark.sql.types.LongType` column named
0339 ``id``, containing elements in a range from ``start`` to ``end`` (exclusive) with
0340 step value ``step``.
0341
0342 :param start: the start value
0343 :param end: the end value (exclusive)
0344 :param step: the incremental step (default: 1)
0345 :param numPartitions: the number of partitions of the DataFrame
0346 :return: :class:`DataFrame`
0347
0348 >>> spark.range(1, 7, 2).collect()
0349 [Row(id=1), Row(id=3), Row(id=5)]
0350
0351 If only one argument is specified, it will be used as the end value.
0352
0353 >>> spark.range(3).collect()
0354 [Row(id=0), Row(id=1), Row(id=2)]
0355 """
0356 if numPartitions is None:
0357 numPartitions = self._sc.defaultParallelism
0358
0359 if end is None:
0360 jdf = self._jsparkSession.range(0, int(start), int(step), int(numPartitions))
0361 else:
0362 jdf = self._jsparkSession.range(int(start), int(end), int(step), int(numPartitions))
0363
0364 return DataFrame(jdf, self._wrapped)
0365
0366 def _inferSchemaFromList(self, data, names=None):
0367 """
0368 Infer schema from list of Row or tuple.
0369
0370 :param data: list of Row or tuple
0371 :param names: list of column names
0372 :return: :class:`pyspark.sql.types.StructType`
0373 """
0374 if not data:
0375 raise ValueError("can not infer schema from empty dataset")
0376 first = data[0]
0377 if type(first) is dict:
0378 warnings.warn("inferring schema from dict is deprecated,"
0379 "please use pyspark.sql.Row instead")
0380 schema = reduce(_merge_type, (_infer_schema(row, names) for row in data))
0381 if _has_nulltype(schema):
0382 raise ValueError("Some of types cannot be determined after inferring")
0383 return schema
0384
0385 def _inferSchema(self, rdd, samplingRatio=None, names=None):
0386 """
0387 Infer schema from an RDD of Row or tuple.
0388
0389 :param rdd: an RDD of Row or tuple
0390 :param samplingRatio: sampling ratio, or no sampling (default)
0391 :return: :class:`pyspark.sql.types.StructType`
0392 """
0393 first = rdd.first()
0394 if not first:
0395 raise ValueError("The first row in RDD is empty, "
0396 "can not infer schema")
0397 if type(first) is dict:
0398 warnings.warn("Using RDD of dict to inferSchema is deprecated. "
0399 "Use pyspark.sql.Row instead")
0400
0401 if samplingRatio is None:
0402 schema = _infer_schema(first, names=names)
0403 if _has_nulltype(schema):
0404 for row in rdd.take(100)[1:]:
0405 schema = _merge_type(schema, _infer_schema(row, names=names))
0406 if not _has_nulltype(schema):
0407 break
0408 else:
0409 raise ValueError("Some of types cannot be determined by the "
0410 "first 100 rows, please try again with sampling")
0411 else:
0412 if samplingRatio < 0.99:
0413 rdd = rdd.sample(False, float(samplingRatio))
0414 schema = rdd.map(lambda row: _infer_schema(row, names)).reduce(_merge_type)
0415 return schema
0416
0417 def _createFromRDD(self, rdd, schema, samplingRatio):
0418 """
0419 Create an RDD for DataFrame from an existing RDD, returns the RDD and schema.
0420 """
0421 if schema is None or isinstance(schema, (list, tuple)):
0422 struct = self._inferSchema(rdd, samplingRatio, names=schema)
0423 converter = _create_converter(struct)
0424 rdd = rdd.map(converter)
0425 if isinstance(schema, (list, tuple)):
0426 for i, name in enumerate(schema):
0427 struct.fields[i].name = name
0428 struct.names[i] = name
0429 schema = struct
0430
0431 elif not isinstance(schema, StructType):
0432 raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
0433
0434
0435 rdd = rdd.map(schema.toInternal)
0436 return rdd, schema
0437
0438 def _createFromLocal(self, data, schema):
0439 """
0440 Create an RDD for DataFrame from a list or pandas.DataFrame, returns
0441 the RDD and schema.
0442 """
0443
0444 if not isinstance(data, list):
0445 data = list(data)
0446
0447 if schema is None or isinstance(schema, (list, tuple)):
0448 struct = self._inferSchemaFromList(data, names=schema)
0449 converter = _create_converter(struct)
0450 data = map(converter, data)
0451 if isinstance(schema, (list, tuple)):
0452 for i, name in enumerate(schema):
0453 struct.fields[i].name = name
0454 struct.names[i] = name
0455 schema = struct
0456
0457 elif not isinstance(schema, StructType):
0458 raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
0459
0460
0461 data = [schema.toInternal(row) for row in data]
0462 return self._sc.parallelize(data), schema
0463
0464 @staticmethod
0465 def _create_shell_session():
0466 """
0467 Initialize a SparkSession for a pyspark shell session. This is called from shell.py
0468 to make error handling simpler without needing to declare local variables in that
0469 script, which would expose those to users.
0470 """
0471 import py4j
0472 from pyspark.conf import SparkConf
0473 from pyspark.context import SparkContext
0474 try:
0475
0476 conf = SparkConf()
0477 if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive':
0478 SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf()
0479 return SparkSession.builder\
0480 .enableHiveSupport()\
0481 .getOrCreate()
0482 else:
0483 return SparkSession.builder.getOrCreate()
0484 except (py4j.protocol.Py4JError, TypeError):
0485 if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive':
0486 warnings.warn("Fall back to non-hive support because failing to access HiveConf, "
0487 "please make sure you build spark with hive")
0488
0489 return SparkSession.builder.getOrCreate()
0490
0491 @since(2.0)
0492 @ignore_unicode_prefix
0493 def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
0494 """
0495 Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.
0496
0497 When ``schema`` is a list of column names, the type of each column
0498 will be inferred from ``data``.
0499
0500 When ``schema`` is ``None``, it will try to infer the schema (column names and types)
0501 from ``data``, which should be an RDD of either :class:`Row`,
0502 :class:`namedtuple`, or :class:`dict`.
0503
0504 When ``schema`` is :class:`pyspark.sql.types.DataType` or a datatype string, it must match
0505 the real data, or an exception will be thrown at runtime. If the given schema is not
0506 :class:`pyspark.sql.types.StructType`, it will be wrapped into a
0507 :class:`pyspark.sql.types.StructType` as its only field, and the field name will be "value".
0508 Each record will also be wrapped into a tuple, which can be converted to row later.
0509
0510 If schema inference is needed, ``samplingRatio`` is used to determined the ratio of
0511 rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``.
0512
0513 :param data: an RDD of any kind of SQL data representation (e.g. row, tuple, int, boolean,
0514 etc.), :class:`list`, or :class:`pandas.DataFrame`.
0515 :param schema: a :class:`pyspark.sql.types.DataType` or a datatype string or a list of
0516 column names, default is ``None``. The data type string format equals to
0517 :class:`pyspark.sql.types.DataType.simpleString`, except that top level struct type can
0518 omit the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use
0519 ``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`. We can also use
0520 ``int`` as a short name for ``IntegerType``.
0521 :param samplingRatio: the sample ratio of rows used for inferring
0522 :param verifySchema: verify data types of every row against schema.
0523 :return: :class:`DataFrame`
0524
0525 .. versionchanged:: 2.1
0526 Added verifySchema.
0527
0528 .. note:: Usage with spark.sql.execution.arrow.pyspark.enabled=True is experimental.
0529
0530 .. note:: When Arrow optimization is enabled, strings inside Pandas DataFrame in Python
0531 2 are converted into bytes as they are bytes in Python 2 whereas regular strings are
0532 left as strings. When using strings in Python 2, use unicode `u""` as Python standard
0533 practice.
0534
0535 >>> l = [('Alice', 1)]
0536 >>> spark.createDataFrame(l).collect()
0537 [Row(_1=u'Alice', _2=1)]
0538 >>> spark.createDataFrame(l, ['name', 'age']).collect()
0539 [Row(name=u'Alice', age=1)]
0540
0541 >>> d = [{'name': 'Alice', 'age': 1}]
0542 >>> spark.createDataFrame(d).collect()
0543 [Row(age=1, name=u'Alice')]
0544
0545 >>> rdd = sc.parallelize(l)
0546 >>> spark.createDataFrame(rdd).collect()
0547 [Row(_1=u'Alice', _2=1)]
0548 >>> df = spark.createDataFrame(rdd, ['name', 'age'])
0549 >>> df.collect()
0550 [Row(name=u'Alice', age=1)]
0551
0552 >>> from pyspark.sql import Row
0553 >>> Person = Row('name', 'age')
0554 >>> person = rdd.map(lambda r: Person(*r))
0555 >>> df2 = spark.createDataFrame(person)
0556 >>> df2.collect()
0557 [Row(name=u'Alice', age=1)]
0558
0559 >>> from pyspark.sql.types import *
0560 >>> schema = StructType([
0561 ... StructField("name", StringType(), True),
0562 ... StructField("age", IntegerType(), True)])
0563 >>> df3 = spark.createDataFrame(rdd, schema)
0564 >>> df3.collect()
0565 [Row(name=u'Alice', age=1)]
0566
0567 >>> spark.createDataFrame(df.toPandas()).collect() # doctest: +SKIP
0568 [Row(name=u'Alice', age=1)]
0569 >>> spark.createDataFrame(pandas.DataFrame([[1, 2]])).collect() # doctest: +SKIP
0570 [Row(0=1, 1=2)]
0571
0572 >>> spark.createDataFrame(rdd, "a: string, b: int").collect()
0573 [Row(a=u'Alice', b=1)]
0574 >>> rdd = rdd.map(lambda row: row[1])
0575 >>> spark.createDataFrame(rdd, "int").collect()
0576 [Row(value=1)]
0577 >>> spark.createDataFrame(rdd, "boolean").collect() # doctest: +IGNORE_EXCEPTION_DETAIL
0578 Traceback (most recent call last):
0579 ...
0580 Py4JJavaError: ...
0581 """
0582 SparkSession._activeSession = self
0583 self._jvm.SparkSession.setActiveSession(self._jsparkSession)
0584 if isinstance(data, DataFrame):
0585 raise TypeError("data is already a DataFrame")
0586
0587 if isinstance(schema, basestring):
0588 schema = _parse_datatype_string(schema)
0589 elif isinstance(schema, (list, tuple)):
0590
0591 schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]
0592
0593 try:
0594 import pandas
0595 has_pandas = True
0596 except Exception:
0597 has_pandas = False
0598 if has_pandas and isinstance(data, pandas.DataFrame):
0599
0600 return super(SparkSession, self).createDataFrame(
0601 data, schema, samplingRatio, verifySchema)
0602 return self._create_dataframe(data, schema, samplingRatio, verifySchema)
0603
0604 def _create_dataframe(self, data, schema, samplingRatio, verifySchema):
0605 if isinstance(schema, StructType):
0606 verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True
0607
0608 def prepare(obj):
0609 verify_func(obj)
0610 return obj
0611 elif isinstance(schema, DataType):
0612 dataType = schema
0613 schema = StructType().add("value", schema)
0614
0615 verify_func = _make_type_verifier(
0616 dataType, name="field value") if verifySchema else lambda _: True
0617
0618 def prepare(obj):
0619 verify_func(obj)
0620 return obj,
0621 else:
0622 prepare = lambda obj: obj
0623
0624 if isinstance(data, RDD):
0625 rdd, schema = self._createFromRDD(data.map(prepare), schema, samplingRatio)
0626 else:
0627 rdd, schema = self._createFromLocal(map(prepare, data), schema)
0628 jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
0629 jdf = self._jsparkSession.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
0630 df = DataFrame(jdf, self._wrapped)
0631 df._schema = schema
0632 return df
0633
0634 @ignore_unicode_prefix
0635 @since(2.0)
0636 def sql(self, sqlQuery):
0637 """Returns a :class:`DataFrame` representing the result of the given query.
0638
0639 :return: :class:`DataFrame`
0640
0641 >>> df.createOrReplaceTempView("table1")
0642 >>> df2 = spark.sql("SELECT field1 AS f1, field2 as f2 from table1")
0643 >>> df2.collect()
0644 [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
0645 """
0646 return DataFrame(self._jsparkSession.sql(sqlQuery), self._wrapped)
0647
0648 @since(2.0)
0649 def table(self, tableName):
0650 """Returns the specified table as a :class:`DataFrame`.
0651
0652 :return: :class:`DataFrame`
0653
0654 >>> df.createOrReplaceTempView("table1")
0655 >>> df2 = spark.table("table1")
0656 >>> sorted(df.collect()) == sorted(df2.collect())
0657 True
0658 """
0659 return DataFrame(self._jsparkSession.table(tableName), self._wrapped)
0660
0661 @property
0662 @since(2.0)
0663 def read(self):
0664 """
0665 Returns a :class:`DataFrameReader` that can be used to read data
0666 in as a :class:`DataFrame`.
0667
0668 :return: :class:`DataFrameReader`
0669 """
0670 return DataFrameReader(self._wrapped)
0671
0672 @property
0673 @since(2.0)
0674 def readStream(self):
0675 """
0676 Returns a :class:`DataStreamReader` that can be used to read data streams
0677 as a streaming :class:`DataFrame`.
0678
0679 .. note:: Evolving.
0680
0681 :return: :class:`DataStreamReader`
0682 """
0683 return DataStreamReader(self._wrapped)
0684
0685 @property
0686 @since(2.0)
0687 def streams(self):
0688 """Returns a :class:`StreamingQueryManager` that allows managing all the
0689 :class:`StreamingQuery` instances active on `this` context.
0690
0691 .. note:: Evolving.
0692
0693 :return: :class:`StreamingQueryManager`
0694 """
0695 from pyspark.sql.streaming import StreamingQueryManager
0696 return StreamingQueryManager(self._jsparkSession.streams())
0697
0698 @since(2.0)
0699 def stop(self):
0700 """Stop the underlying :class:`SparkContext`.
0701 """
0702 self._sc.stop()
0703
0704 self._jvm.SparkSession.clearDefaultSession()
0705 self._jvm.SparkSession.clearActiveSession()
0706 SparkSession._instantiatedSession = None
0707 SparkSession._activeSession = None
0708
0709 @since(2.0)
0710 def __enter__(self):
0711 """
0712 Enable 'with SparkSession.builder.(...).getOrCreate() as session: app' syntax.
0713 """
0714 return self
0715
0716 @since(2.0)
0717 def __exit__(self, exc_type, exc_val, exc_tb):
0718 """
0719 Enable 'with SparkSession.builder.(...).getOrCreate() as session: app' syntax.
0720
0721 Specifically stop the SparkSession on exit of the with block.
0722 """
0723 self.stop()
0724
0725
0726 def _test():
0727 import os
0728 import doctest
0729 from pyspark.context import SparkContext
0730 from pyspark.sql import Row
0731 import pyspark.sql.session
0732
0733 os.chdir(os.environ["SPARK_HOME"])
0734
0735 globs = pyspark.sql.session.__dict__.copy()
0736 sc = SparkContext('local[4]', 'PythonTest')
0737 globs['sc'] = sc
0738 globs['spark'] = SparkSession(sc)
0739 globs['rdd'] = rdd = sc.parallelize(
0740 [Row(field1=1, field2="row1"),
0741 Row(field1=2, field2="row2"),
0742 Row(field1=3, field2="row3")])
0743 globs['df'] = rdd.toDF()
0744 (failure_count, test_count) = doctest.testmod(
0745 pyspark.sql.session, globs=globs,
0746 optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
0747 globs['sc'].stop()
0748 if failure_count:
0749 sys.exit(-1)
0750
0751 if __name__ == "__main__":
0752 _test()