Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 # To disallow implicit relative import. Remove this once we drop Python 2.
0019 from __future__ import absolute_import
0020 from __future__ import print_function
0021 import sys
0022 import warnings
0023 from functools import reduce
0024 from threading import RLock
0025 
0026 if sys.version >= '3':
0027     basestring = unicode = str
0028     xrange = range
0029 else:
0030     from itertools import imap as map
0031 
0032 from pyspark import since
0033 from pyspark.rdd import RDD, ignore_unicode_prefix
0034 from pyspark.sql.conf import RuntimeConfig
0035 from pyspark.sql.dataframe import DataFrame
0036 from pyspark.sql.pandas.conversion import SparkConversionMixin
0037 from pyspark.sql.readwriter import DataFrameReader
0038 from pyspark.sql.streaming import DataStreamReader
0039 from pyspark.sql.types import Row, DataType, StringType, StructType, \
0040     _make_type_verifier, _infer_schema, _has_nulltype, _merge_type, _create_converter, \
0041     _parse_datatype_string
0042 from pyspark.sql.utils import install_exception_handler
0043 
0044 __all__ = ["SparkSession"]
0045 
0046 
0047 def _monkey_patch_RDD(sparkSession):
0048     def toDF(self, schema=None, sampleRatio=None):
0049         """
0050         Converts current :class:`RDD` into a :class:`DataFrame`
0051 
0052         This is a shorthand for ``spark.createDataFrame(rdd, schema, sampleRatio)``
0053 
0054         :param schema: a :class:`pyspark.sql.types.StructType` or list of names of columns
0055         :param samplingRatio: the sample ratio of rows used for inferring
0056         :return: a DataFrame
0057 
0058         >>> rdd.toDF().collect()
0059         [Row(name=u'Alice', age=1)]
0060         """
0061         return sparkSession.createDataFrame(self, schema, sampleRatio)
0062 
0063     RDD.toDF = toDF
0064 
0065 
0066 class SparkSession(SparkConversionMixin):
0067     """The entry point to programming Spark with the Dataset and DataFrame API.
0068 
0069     A SparkSession can be used create :class:`DataFrame`, register :class:`DataFrame` as
0070     tables, execute SQL over tables, cache tables, and read parquet files.
0071     To create a SparkSession, use the following builder pattern:
0072 
0073     >>> spark = SparkSession.builder \\
0074     ...     .master("local") \\
0075     ...     .appName("Word Count") \\
0076     ...     .config("spark.some.config.option", "some-value") \\
0077     ...     .getOrCreate()
0078 
0079     .. autoattribute:: builder
0080        :annotation:
0081     """
0082 
0083     class Builder(object):
0084         """Builder for :class:`SparkSession`.
0085         """
0086 
0087         _lock = RLock()
0088         _options = {}
0089         _sc = None
0090 
0091         @since(2.0)
0092         def config(self, key=None, value=None, conf=None):
0093             """Sets a config option. Options set using this method are automatically propagated to
0094             both :class:`SparkConf` and :class:`SparkSession`'s own configuration.
0095 
0096             For an existing SparkConf, use `conf` parameter.
0097 
0098             >>> from pyspark.conf import SparkConf
0099             >>> SparkSession.builder.config(conf=SparkConf())
0100             <pyspark.sql.session...
0101 
0102             For a (key, value) pair, you can omit parameter names.
0103 
0104             >>> SparkSession.builder.config("spark.some.config.option", "some-value")
0105             <pyspark.sql.session...
0106 
0107             :param key: a key name string for configuration property
0108             :param value: a value for configuration property
0109             :param conf: an instance of :class:`SparkConf`
0110             """
0111             with self._lock:
0112                 if conf is None:
0113                     self._options[key] = str(value)
0114                 else:
0115                     for (k, v) in conf.getAll():
0116                         self._options[k] = v
0117                 return self
0118 
0119         @since(2.0)
0120         def master(self, master):
0121             """Sets the Spark master URL to connect to, such as "local" to run locally, "local[4]"
0122             to run locally with 4 cores, or "spark://master:7077" to run on a Spark standalone
0123             cluster.
0124 
0125             :param master: a url for spark master
0126             """
0127             return self.config("spark.master", master)
0128 
0129         @since(2.0)
0130         def appName(self, name):
0131             """Sets a name for the application, which will be shown in the Spark web UI.
0132 
0133             If no application name is set, a randomly generated name will be used.
0134 
0135             :param name: an application name
0136             """
0137             return self.config("spark.app.name", name)
0138 
0139         @since(2.0)
0140         def enableHiveSupport(self):
0141             """Enables Hive support, including connectivity to a persistent Hive metastore, support
0142             for Hive SerDes, and Hive user-defined functions.
0143             """
0144             return self.config("spark.sql.catalogImplementation", "hive")
0145 
0146         def _sparkContext(self, sc):
0147             with self._lock:
0148                 self._sc = sc
0149                 return self
0150 
0151         @since(2.0)
0152         def getOrCreate(self):
0153             """Gets an existing :class:`SparkSession` or, if there is no existing one, creates a
0154             new one based on the options set in this builder.
0155 
0156             This method first checks whether there is a valid global default SparkSession, and if
0157             yes, return that one. If no valid global default SparkSession exists, the method
0158             creates a new SparkSession and assigns the newly created SparkSession as the global
0159             default.
0160 
0161             >>> s1 = SparkSession.builder.config("k1", "v1").getOrCreate()
0162             >>> s1.conf.get("k1") == "v1"
0163             True
0164 
0165             In case an existing SparkSession is returned, the config options specified
0166             in this builder will be applied to the existing SparkSession.
0167 
0168             >>> s2 = SparkSession.builder.config("k2", "v2").getOrCreate()
0169             >>> s1.conf.get("k1") == s2.conf.get("k1")
0170             True
0171             >>> s1.conf.get("k2") == s2.conf.get("k2")
0172             True
0173             """
0174             with self._lock:
0175                 from pyspark.context import SparkContext
0176                 from pyspark.conf import SparkConf
0177                 session = SparkSession._instantiatedSession
0178                 if session is None or session._sc._jsc is None:
0179                     if self._sc is not None:
0180                         sc = self._sc
0181                     else:
0182                         sparkConf = SparkConf()
0183                         for key, value in self._options.items():
0184                             sparkConf.set(key, value)
0185                         # This SparkContext may be an existing one.
0186                         sc = SparkContext.getOrCreate(sparkConf)
0187                     # Do not update `SparkConf` for existing `SparkContext`, as it's shared
0188                     # by all sessions.
0189                     session = SparkSession(sc)
0190                 for key, value in self._options.items():
0191                     session._jsparkSession.sessionState().conf().setConfString(key, value)
0192                 return session
0193 
0194     builder = Builder()
0195     """A class attribute having a :class:`Builder` to construct :class:`SparkSession` instances."""
0196 
0197     _instantiatedSession = None
0198     _activeSession = None
0199 
0200     @ignore_unicode_prefix
0201     def __init__(self, sparkContext, jsparkSession=None):
0202         """Creates a new SparkSession.
0203 
0204         >>> from datetime import datetime
0205         >>> spark = SparkSession(sc)
0206         >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1,
0207         ...     b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
0208         ...     time=datetime(2014, 8, 1, 14, 1, 5))])
0209         >>> df = allTypes.toDF()
0210         >>> df.createOrReplaceTempView("allTypes")
0211         >>> spark.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
0212         ...            'from allTypes where b and i > 0').collect()
0213         [Row((i + CAST(1 AS BIGINT))=2, (d + CAST(1 AS DOUBLE))=2.0, (NOT b)=False, list[1]=2, \
0214             dict[s]=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)]
0215         >>> df.rdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect()
0216         [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
0217         """
0218         from pyspark.sql.context import SQLContext
0219         self._sc = sparkContext
0220         self._jsc = self._sc._jsc
0221         self._jvm = self._sc._jvm
0222         if jsparkSession is None:
0223             if self._jvm.SparkSession.getDefaultSession().isDefined() \
0224                     and not self._jvm.SparkSession.getDefaultSession().get() \
0225                         .sparkContext().isStopped():
0226                 jsparkSession = self._jvm.SparkSession.getDefaultSession().get()
0227             else:
0228                 jsparkSession = self._jvm.SparkSession(self._jsc.sc())
0229         self._jsparkSession = jsparkSession
0230         self._jwrapped = self._jsparkSession.sqlContext()
0231         self._wrapped = SQLContext(self._sc, self, self._jwrapped)
0232         _monkey_patch_RDD(self)
0233         install_exception_handler()
0234         # If we had an instantiated SparkSession attached with a SparkContext
0235         # which is stopped now, we need to renew the instantiated SparkSession.
0236         # Otherwise, we will use invalid SparkSession when we call Builder.getOrCreate.
0237         if SparkSession._instantiatedSession is None \
0238                 or SparkSession._instantiatedSession._sc._jsc is None:
0239             SparkSession._instantiatedSession = self
0240             SparkSession._activeSession = self
0241             self._jvm.SparkSession.setDefaultSession(self._jsparkSession)
0242             self._jvm.SparkSession.setActiveSession(self._jsparkSession)
0243 
0244     def _repr_html_(self):
0245         return """
0246             <div>
0247                 <p><b>SparkSession - {catalogImplementation}</b></p>
0248                 {sc_HTML}
0249             </div>
0250         """.format(
0251             catalogImplementation=self.conf.get("spark.sql.catalogImplementation"),
0252             sc_HTML=self.sparkContext._repr_html_()
0253         )
0254 
0255     @since(2.0)
0256     def newSession(self):
0257         """
0258         Returns a new SparkSession as new session, that has separate SQLConf,
0259         registered temporary views and UDFs, but shared SparkContext and
0260         table cache.
0261         """
0262         return self.__class__(self._sc, self._jsparkSession.newSession())
0263 
0264     @classmethod
0265     @since(3.0)
0266     def getActiveSession(cls):
0267         """
0268         Returns the active SparkSession for the current thread, returned by the builder.
0269         >>> s = SparkSession.getActiveSession()
0270         >>> l = [('Alice', 1)]
0271         >>> rdd = s.sparkContext.parallelize(l)
0272         >>> df = s.createDataFrame(rdd, ['name', 'age'])
0273         >>> df.select("age").collect()
0274         [Row(age=1)]
0275         """
0276         from pyspark import SparkContext
0277         sc = SparkContext._active_spark_context
0278         if sc is None:
0279             return None
0280         else:
0281             if sc._jvm.SparkSession.getActiveSession().isDefined():
0282                 SparkSession(sc, sc._jvm.SparkSession.getActiveSession().get())
0283                 return SparkSession._activeSession
0284             else:
0285                 return None
0286 
0287     @property
0288     @since(2.0)
0289     def sparkContext(self):
0290         """Returns the underlying :class:`SparkContext`."""
0291         return self._sc
0292 
0293     @property
0294     @since(2.0)
0295     def version(self):
0296         """The version of Spark on which this application is running."""
0297         return self._jsparkSession.version()
0298 
0299     @property
0300     @since(2.0)
0301     def conf(self):
0302         """Runtime configuration interface for Spark.
0303 
0304         This is the interface through which the user can get and set all Spark and Hadoop
0305         configurations that are relevant to Spark SQL. When getting the value of a config,
0306         this defaults to the value set in the underlying :class:`SparkContext`, if any.
0307         """
0308         if not hasattr(self, "_conf"):
0309             self._conf = RuntimeConfig(self._jsparkSession.conf())
0310         return self._conf
0311 
0312     @property
0313     @since(2.0)
0314     def catalog(self):
0315         """Interface through which the user may create, drop, alter or query underlying
0316         databases, tables, functions, etc.
0317 
0318         :return: :class:`Catalog`
0319         """
0320         from pyspark.sql.catalog import Catalog
0321         if not hasattr(self, "_catalog"):
0322             self._catalog = Catalog(self)
0323         return self._catalog
0324 
0325     @property
0326     @since(2.0)
0327     def udf(self):
0328         """Returns a :class:`UDFRegistration` for UDF registration.
0329 
0330         :return: :class:`UDFRegistration`
0331         """
0332         from pyspark.sql.udf import UDFRegistration
0333         return UDFRegistration(self)
0334 
0335     @since(2.0)
0336     def range(self, start, end=None, step=1, numPartitions=None):
0337         """
0338         Create a :class:`DataFrame` with single :class:`pyspark.sql.types.LongType` column named
0339         ``id``, containing elements in a range from ``start`` to ``end`` (exclusive) with
0340         step value ``step``.
0341 
0342         :param start: the start value
0343         :param end: the end value (exclusive)
0344         :param step: the incremental step (default: 1)
0345         :param numPartitions: the number of partitions of the DataFrame
0346         :return: :class:`DataFrame`
0347 
0348         >>> spark.range(1, 7, 2).collect()
0349         [Row(id=1), Row(id=3), Row(id=5)]
0350 
0351         If only one argument is specified, it will be used as the end value.
0352 
0353         >>> spark.range(3).collect()
0354         [Row(id=0), Row(id=1), Row(id=2)]
0355         """
0356         if numPartitions is None:
0357             numPartitions = self._sc.defaultParallelism
0358 
0359         if end is None:
0360             jdf = self._jsparkSession.range(0, int(start), int(step), int(numPartitions))
0361         else:
0362             jdf = self._jsparkSession.range(int(start), int(end), int(step), int(numPartitions))
0363 
0364         return DataFrame(jdf, self._wrapped)
0365 
0366     def _inferSchemaFromList(self, data, names=None):
0367         """
0368         Infer schema from list of Row or tuple.
0369 
0370         :param data: list of Row or tuple
0371         :param names: list of column names
0372         :return: :class:`pyspark.sql.types.StructType`
0373         """
0374         if not data:
0375             raise ValueError("can not infer schema from empty dataset")
0376         first = data[0]
0377         if type(first) is dict:
0378             warnings.warn("inferring schema from dict is deprecated,"
0379                           "please use pyspark.sql.Row instead")
0380         schema = reduce(_merge_type, (_infer_schema(row, names) for row in data))
0381         if _has_nulltype(schema):
0382             raise ValueError("Some of types cannot be determined after inferring")
0383         return schema
0384 
0385     def _inferSchema(self, rdd, samplingRatio=None, names=None):
0386         """
0387         Infer schema from an RDD of Row or tuple.
0388 
0389         :param rdd: an RDD of Row or tuple
0390         :param samplingRatio: sampling ratio, or no sampling (default)
0391         :return: :class:`pyspark.sql.types.StructType`
0392         """
0393         first = rdd.first()
0394         if not first:
0395             raise ValueError("The first row in RDD is empty, "
0396                              "can not infer schema")
0397         if type(first) is dict:
0398             warnings.warn("Using RDD of dict to inferSchema is deprecated. "
0399                           "Use pyspark.sql.Row instead")
0400 
0401         if samplingRatio is None:
0402             schema = _infer_schema(first, names=names)
0403             if _has_nulltype(schema):
0404                 for row in rdd.take(100)[1:]:
0405                     schema = _merge_type(schema, _infer_schema(row, names=names))
0406                     if not _has_nulltype(schema):
0407                         break
0408                 else:
0409                     raise ValueError("Some of types cannot be determined by the "
0410                                      "first 100 rows, please try again with sampling")
0411         else:
0412             if samplingRatio < 0.99:
0413                 rdd = rdd.sample(False, float(samplingRatio))
0414             schema = rdd.map(lambda row: _infer_schema(row, names)).reduce(_merge_type)
0415         return schema
0416 
0417     def _createFromRDD(self, rdd, schema, samplingRatio):
0418         """
0419         Create an RDD for DataFrame from an existing RDD, returns the RDD and schema.
0420         """
0421         if schema is None or isinstance(schema, (list, tuple)):
0422             struct = self._inferSchema(rdd, samplingRatio, names=schema)
0423             converter = _create_converter(struct)
0424             rdd = rdd.map(converter)
0425             if isinstance(schema, (list, tuple)):
0426                 for i, name in enumerate(schema):
0427                     struct.fields[i].name = name
0428                     struct.names[i] = name
0429             schema = struct
0430 
0431         elif not isinstance(schema, StructType):
0432             raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
0433 
0434         # convert python objects to sql data
0435         rdd = rdd.map(schema.toInternal)
0436         return rdd, schema
0437 
0438     def _createFromLocal(self, data, schema):
0439         """
0440         Create an RDD for DataFrame from a list or pandas.DataFrame, returns
0441         the RDD and schema.
0442         """
0443         # make sure data could consumed multiple times
0444         if not isinstance(data, list):
0445             data = list(data)
0446 
0447         if schema is None or isinstance(schema, (list, tuple)):
0448             struct = self._inferSchemaFromList(data, names=schema)
0449             converter = _create_converter(struct)
0450             data = map(converter, data)
0451             if isinstance(schema, (list, tuple)):
0452                 for i, name in enumerate(schema):
0453                     struct.fields[i].name = name
0454                     struct.names[i] = name
0455             schema = struct
0456 
0457         elif not isinstance(schema, StructType):
0458             raise TypeError("schema should be StructType or list or None, but got: %s" % schema)
0459 
0460         # convert python objects to sql data
0461         data = [schema.toInternal(row) for row in data]
0462         return self._sc.parallelize(data), schema
0463 
0464     @staticmethod
0465     def _create_shell_session():
0466         """
0467         Initialize a SparkSession for a pyspark shell session. This is called from shell.py
0468         to make error handling simpler without needing to declare local variables in that
0469         script, which would expose those to users.
0470         """
0471         import py4j
0472         from pyspark.conf import SparkConf
0473         from pyspark.context import SparkContext
0474         try:
0475             # Try to access HiveConf, it will raise exception if Hive is not added
0476             conf = SparkConf()
0477             if conf.get('spark.sql.catalogImplementation', 'hive').lower() == 'hive':
0478                 SparkContext._jvm.org.apache.hadoop.hive.conf.HiveConf()
0479                 return SparkSession.builder\
0480                     .enableHiveSupport()\
0481                     .getOrCreate()
0482             else:
0483                 return SparkSession.builder.getOrCreate()
0484         except (py4j.protocol.Py4JError, TypeError):
0485             if conf.get('spark.sql.catalogImplementation', '').lower() == 'hive':
0486                 warnings.warn("Fall back to non-hive support because failing to access HiveConf, "
0487                               "please make sure you build spark with hive")
0488 
0489         return SparkSession.builder.getOrCreate()
0490 
0491     @since(2.0)
0492     @ignore_unicode_prefix
0493     def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
0494         """
0495         Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.
0496 
0497         When ``schema`` is a list of column names, the type of each column
0498         will be inferred from ``data``.
0499 
0500         When ``schema`` is ``None``, it will try to infer the schema (column names and types)
0501         from ``data``, which should be an RDD of either :class:`Row`,
0502         :class:`namedtuple`, or :class:`dict`.
0503 
0504         When ``schema`` is :class:`pyspark.sql.types.DataType` or a datatype string, it must match
0505         the real data, or an exception will be thrown at runtime. If the given schema is not
0506         :class:`pyspark.sql.types.StructType`, it will be wrapped into a
0507         :class:`pyspark.sql.types.StructType` as its only field, and the field name will be "value".
0508         Each record will also be wrapped into a tuple, which can be converted to row later.
0509 
0510         If schema inference is needed, ``samplingRatio`` is used to determined the ratio of
0511         rows used for schema inference. The first row will be used if ``samplingRatio`` is ``None``.
0512 
0513         :param data: an RDD of any kind of SQL data representation (e.g. row, tuple, int, boolean,
0514             etc.), :class:`list`, or :class:`pandas.DataFrame`.
0515         :param schema: a :class:`pyspark.sql.types.DataType` or a datatype string or a list of
0516             column names, default is ``None``.  The data type string format equals to
0517             :class:`pyspark.sql.types.DataType.simpleString`, except that top level struct type can
0518             omit the ``struct<>`` and atomic types use ``typeName()`` as their format, e.g. use
0519             ``byte`` instead of ``tinyint`` for :class:`pyspark.sql.types.ByteType`. We can also use
0520             ``int`` as a short name for ``IntegerType``.
0521         :param samplingRatio: the sample ratio of rows used for inferring
0522         :param verifySchema: verify data types of every row against schema.
0523         :return: :class:`DataFrame`
0524 
0525         .. versionchanged:: 2.1
0526            Added verifySchema.
0527 
0528         .. note:: Usage with spark.sql.execution.arrow.pyspark.enabled=True is experimental.
0529 
0530         .. note:: When Arrow optimization is enabled, strings inside Pandas DataFrame in Python
0531             2 are converted into bytes as they are bytes in Python 2 whereas regular strings are
0532             left as strings. When using strings in Python 2, use unicode `u""` as Python standard
0533             practice.
0534 
0535         >>> l = [('Alice', 1)]
0536         >>> spark.createDataFrame(l).collect()
0537         [Row(_1=u'Alice', _2=1)]
0538         >>> spark.createDataFrame(l, ['name', 'age']).collect()
0539         [Row(name=u'Alice', age=1)]
0540 
0541         >>> d = [{'name': 'Alice', 'age': 1}]
0542         >>> spark.createDataFrame(d).collect()
0543         [Row(age=1, name=u'Alice')]
0544 
0545         >>> rdd = sc.parallelize(l)
0546         >>> spark.createDataFrame(rdd).collect()
0547         [Row(_1=u'Alice', _2=1)]
0548         >>> df = spark.createDataFrame(rdd, ['name', 'age'])
0549         >>> df.collect()
0550         [Row(name=u'Alice', age=1)]
0551 
0552         >>> from pyspark.sql import Row
0553         >>> Person = Row('name', 'age')
0554         >>> person = rdd.map(lambda r: Person(*r))
0555         >>> df2 = spark.createDataFrame(person)
0556         >>> df2.collect()
0557         [Row(name=u'Alice', age=1)]
0558 
0559         >>> from pyspark.sql.types import *
0560         >>> schema = StructType([
0561         ...    StructField("name", StringType(), True),
0562         ...    StructField("age", IntegerType(), True)])
0563         >>> df3 = spark.createDataFrame(rdd, schema)
0564         >>> df3.collect()
0565         [Row(name=u'Alice', age=1)]
0566 
0567         >>> spark.createDataFrame(df.toPandas()).collect()  # doctest: +SKIP
0568         [Row(name=u'Alice', age=1)]
0569         >>> spark.createDataFrame(pandas.DataFrame([[1, 2]])).collect()  # doctest: +SKIP
0570         [Row(0=1, 1=2)]
0571 
0572         >>> spark.createDataFrame(rdd, "a: string, b: int").collect()
0573         [Row(a=u'Alice', b=1)]
0574         >>> rdd = rdd.map(lambda row: row[1])
0575         >>> spark.createDataFrame(rdd, "int").collect()
0576         [Row(value=1)]
0577         >>> spark.createDataFrame(rdd, "boolean").collect() # doctest: +IGNORE_EXCEPTION_DETAIL
0578         Traceback (most recent call last):
0579             ...
0580         Py4JJavaError: ...
0581         """
0582         SparkSession._activeSession = self
0583         self._jvm.SparkSession.setActiveSession(self._jsparkSession)
0584         if isinstance(data, DataFrame):
0585             raise TypeError("data is already a DataFrame")
0586 
0587         if isinstance(schema, basestring):
0588             schema = _parse_datatype_string(schema)
0589         elif isinstance(schema, (list, tuple)):
0590             # Must re-encode any unicode strings to be consistent with StructField names
0591             schema = [x.encode('utf-8') if not isinstance(x, str) else x for x in schema]
0592 
0593         try:
0594             import pandas
0595             has_pandas = True
0596         except Exception:
0597             has_pandas = False
0598         if has_pandas and isinstance(data, pandas.DataFrame):
0599             # Create a DataFrame from pandas DataFrame.
0600             return super(SparkSession, self).createDataFrame(
0601                 data, schema, samplingRatio, verifySchema)
0602         return self._create_dataframe(data, schema, samplingRatio, verifySchema)
0603 
0604     def _create_dataframe(self, data, schema, samplingRatio, verifySchema):
0605         if isinstance(schema, StructType):
0606             verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True
0607 
0608             def prepare(obj):
0609                 verify_func(obj)
0610                 return obj
0611         elif isinstance(schema, DataType):
0612             dataType = schema
0613             schema = StructType().add("value", schema)
0614 
0615             verify_func = _make_type_verifier(
0616                 dataType, name="field value") if verifySchema else lambda _: True
0617 
0618             def prepare(obj):
0619                 verify_func(obj)
0620                 return obj,
0621         else:
0622             prepare = lambda obj: obj
0623 
0624         if isinstance(data, RDD):
0625             rdd, schema = self._createFromRDD(data.map(prepare), schema, samplingRatio)
0626         else:
0627             rdd, schema = self._createFromLocal(map(prepare, data), schema)
0628         jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
0629         jdf = self._jsparkSession.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
0630         df = DataFrame(jdf, self._wrapped)
0631         df._schema = schema
0632         return df
0633 
0634     @ignore_unicode_prefix
0635     @since(2.0)
0636     def sql(self, sqlQuery):
0637         """Returns a :class:`DataFrame` representing the result of the given query.
0638 
0639         :return: :class:`DataFrame`
0640 
0641         >>> df.createOrReplaceTempView("table1")
0642         >>> df2 = spark.sql("SELECT field1 AS f1, field2 as f2 from table1")
0643         >>> df2.collect()
0644         [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
0645         """
0646         return DataFrame(self._jsparkSession.sql(sqlQuery), self._wrapped)
0647 
0648     @since(2.0)
0649     def table(self, tableName):
0650         """Returns the specified table as a :class:`DataFrame`.
0651 
0652         :return: :class:`DataFrame`
0653 
0654         >>> df.createOrReplaceTempView("table1")
0655         >>> df2 = spark.table("table1")
0656         >>> sorted(df.collect()) == sorted(df2.collect())
0657         True
0658         """
0659         return DataFrame(self._jsparkSession.table(tableName), self._wrapped)
0660 
0661     @property
0662     @since(2.0)
0663     def read(self):
0664         """
0665         Returns a :class:`DataFrameReader` that can be used to read data
0666         in as a :class:`DataFrame`.
0667 
0668         :return: :class:`DataFrameReader`
0669         """
0670         return DataFrameReader(self._wrapped)
0671 
0672     @property
0673     @since(2.0)
0674     def readStream(self):
0675         """
0676         Returns a :class:`DataStreamReader` that can be used to read data streams
0677         as a streaming :class:`DataFrame`.
0678 
0679         .. note:: Evolving.
0680 
0681         :return: :class:`DataStreamReader`
0682         """
0683         return DataStreamReader(self._wrapped)
0684 
0685     @property
0686     @since(2.0)
0687     def streams(self):
0688         """Returns a :class:`StreamingQueryManager` that allows managing all the
0689         :class:`StreamingQuery` instances active on `this` context.
0690 
0691         .. note:: Evolving.
0692 
0693         :return: :class:`StreamingQueryManager`
0694         """
0695         from pyspark.sql.streaming import StreamingQueryManager
0696         return StreamingQueryManager(self._jsparkSession.streams())
0697 
0698     @since(2.0)
0699     def stop(self):
0700         """Stop the underlying :class:`SparkContext`.
0701         """
0702         self._sc.stop()
0703         # We should clean the default session up. See SPARK-23228.
0704         self._jvm.SparkSession.clearDefaultSession()
0705         self._jvm.SparkSession.clearActiveSession()
0706         SparkSession._instantiatedSession = None
0707         SparkSession._activeSession = None
0708 
0709     @since(2.0)
0710     def __enter__(self):
0711         """
0712         Enable 'with SparkSession.builder.(...).getOrCreate() as session: app' syntax.
0713         """
0714         return self
0715 
0716     @since(2.0)
0717     def __exit__(self, exc_type, exc_val, exc_tb):
0718         """
0719         Enable 'with SparkSession.builder.(...).getOrCreate() as session: app' syntax.
0720 
0721         Specifically stop the SparkSession on exit of the with block.
0722         """
0723         self.stop()
0724 
0725 
0726 def _test():
0727     import os
0728     import doctest
0729     from pyspark.context import SparkContext
0730     from pyspark.sql import Row
0731     import pyspark.sql.session
0732 
0733     os.chdir(os.environ["SPARK_HOME"])
0734 
0735     globs = pyspark.sql.session.__dict__.copy()
0736     sc = SparkContext('local[4]', 'PythonTest')
0737     globs['sc'] = sc
0738     globs['spark'] = SparkSession(sc)
0739     globs['rdd'] = rdd = sc.parallelize(
0740         [Row(field1=1, field2="row1"),
0741          Row(field1=2, field2="row2"),
0742          Row(field1=3, field2="row3")])
0743     globs['df'] = rdd.toDF()
0744     (failure_count, test_count) = doctest.testmod(
0745         pyspark.sql.session, globs=globs,
0746         optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
0747     globs['sc'].stop()
0748     if failure_count:
0749         sys.exit(-1)
0750 
0751 if __name__ == "__main__":
0752     _test()