0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import os
0019 import unittest
0020
0021 from pyspark import SparkConf, SparkContext
0022 from pyspark.sql import SparkSession, SQLContext, Row
0023 from pyspark.testing.sqlutils import ReusedSQLTestCase
0024 from pyspark.testing.utils import PySparkTestCase
0025
0026
0027 class SparkSessionTests(ReusedSQLTestCase):
0028 def test_sqlcontext_reuses_sparksession(self):
0029 sqlContext1 = SQLContext(self.sc)
0030 sqlContext2 = SQLContext(self.sc)
0031 self.assertTrue(sqlContext1.sparkSession is sqlContext2.sparkSession)
0032
0033
0034 class SparkSessionTests1(ReusedSQLTestCase):
0035
0036
0037
0038 def test_sparksession_with_stopped_sparkcontext(self):
0039 self.sc.stop()
0040 sc = SparkContext('local[4]', self.sc.appName)
0041 spark = SparkSession.builder.getOrCreate()
0042 try:
0043 df = spark.createDataFrame([(1, 2)], ["c", "c"])
0044 df.collect()
0045 finally:
0046 spark.stop()
0047 sc.stop()
0048
0049
0050 class SparkSessionTests2(PySparkTestCase):
0051
0052
0053
0054 def test_set_jvm_default_session(self):
0055 spark = SparkSession.builder.getOrCreate()
0056 try:
0057 self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined())
0058 finally:
0059 spark.stop()
0060 self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isEmpty())
0061
0062 def test_jvm_default_session_already_set(self):
0063
0064 jsession = self.sc._jvm.SparkSession(self.sc._jsc.sc())
0065 self.sc._jvm.SparkSession.setDefaultSession(jsession)
0066
0067 spark = SparkSession.builder.getOrCreate()
0068 try:
0069 self.assertTrue(spark._jvm.SparkSession.getDefaultSession().isDefined())
0070
0071 self.assertTrue(jsession.equals(spark._jvm.SparkSession.getDefaultSession().get()))
0072 finally:
0073 spark.stop()
0074
0075
0076 class SparkSessionTests3(unittest.TestCase):
0077
0078 def test_active_session(self):
0079 spark = SparkSession.builder \
0080 .master("local") \
0081 .getOrCreate()
0082 try:
0083 activeSession = SparkSession.getActiveSession()
0084 df = activeSession.createDataFrame([(1, 'Alice')], ['age', 'name'])
0085 self.assertEqual(df.collect(), [Row(age=1, name=u'Alice')])
0086 finally:
0087 spark.stop()
0088
0089 def test_get_active_session_when_no_active_session(self):
0090 active = SparkSession.getActiveSession()
0091 self.assertEqual(active, None)
0092 spark = SparkSession.builder \
0093 .master("local") \
0094 .getOrCreate()
0095 active = SparkSession.getActiveSession()
0096 self.assertEqual(active, spark)
0097 spark.stop()
0098 active = SparkSession.getActiveSession()
0099 self.assertEqual(active, None)
0100
0101 def test_SparkSession(self):
0102 spark = SparkSession.builder \
0103 .master("local") \
0104 .config("some-config", "v2") \
0105 .getOrCreate()
0106 try:
0107 self.assertEqual(spark.conf.get("some-config"), "v2")
0108 self.assertEqual(spark.sparkContext._conf.get("some-config"), "v2")
0109 self.assertEqual(spark.version, spark.sparkContext.version)
0110 spark.sql("CREATE DATABASE test_db")
0111 spark.catalog.setCurrentDatabase("test_db")
0112 self.assertEqual(spark.catalog.currentDatabase(), "test_db")
0113 spark.sql("CREATE TABLE table1 (name STRING, age INT) USING parquet")
0114 self.assertEqual(spark.table("table1").columns, ['name', 'age'])
0115 self.assertEqual(spark.range(3).count(), 3)
0116 finally:
0117 spark.sql("DROP DATABASE test_db CASCADE")
0118 spark.stop()
0119
0120 def test_global_default_session(self):
0121 spark = SparkSession.builder \
0122 .master("local") \
0123 .getOrCreate()
0124 try:
0125 self.assertEqual(SparkSession.builder.getOrCreate(), spark)
0126 finally:
0127 spark.stop()
0128
0129 def test_default_and_active_session(self):
0130 spark = SparkSession.builder \
0131 .master("local") \
0132 .getOrCreate()
0133 activeSession = spark._jvm.SparkSession.getActiveSession()
0134 defaultSession = spark._jvm.SparkSession.getDefaultSession()
0135 try:
0136 self.assertEqual(activeSession, defaultSession)
0137 finally:
0138 spark.stop()
0139
0140 def test_config_option_propagated_to_existing_session(self):
0141 session1 = SparkSession.builder \
0142 .master("local") \
0143 .config("spark-config1", "a") \
0144 .getOrCreate()
0145 self.assertEqual(session1.conf.get("spark-config1"), "a")
0146 session2 = SparkSession.builder \
0147 .config("spark-config1", "b") \
0148 .getOrCreate()
0149 try:
0150 self.assertEqual(session1, session2)
0151 self.assertEqual(session1.conf.get("spark-config1"), "b")
0152 finally:
0153 session1.stop()
0154
0155 def test_new_session(self):
0156 session = SparkSession.builder \
0157 .master("local") \
0158 .getOrCreate()
0159 newSession = session.newSession()
0160 try:
0161 self.assertNotEqual(session, newSession)
0162 finally:
0163 session.stop()
0164 newSession.stop()
0165
0166 def test_create_new_session_if_old_session_stopped(self):
0167 session = SparkSession.builder \
0168 .master("local") \
0169 .getOrCreate()
0170 session.stop()
0171 newSession = SparkSession.builder \
0172 .master("local") \
0173 .getOrCreate()
0174 try:
0175 self.assertNotEqual(session, newSession)
0176 finally:
0177 newSession.stop()
0178
0179 def test_active_session_with_None_and_not_None_context(self):
0180 from pyspark.context import SparkContext
0181 from pyspark.conf import SparkConf
0182 sc = None
0183 session = None
0184 try:
0185 sc = SparkContext._active_spark_context
0186 self.assertEqual(sc, None)
0187 activeSession = SparkSession.getActiveSession()
0188 self.assertEqual(activeSession, None)
0189 sparkConf = SparkConf()
0190 sc = SparkContext.getOrCreate(sparkConf)
0191 activeSession = sc._jvm.SparkSession.getActiveSession()
0192 self.assertFalse(activeSession.isDefined())
0193 session = SparkSession(sc)
0194 activeSession = sc._jvm.SparkSession.getActiveSession()
0195 self.assertTrue(activeSession.isDefined())
0196 activeSession2 = SparkSession.getActiveSession()
0197 self.assertNotEqual(activeSession2, None)
0198 finally:
0199 if session is not None:
0200 session.stop()
0201 if sc is not None:
0202 sc.stop()
0203
0204
0205 class SparkSessionTests4(ReusedSQLTestCase):
0206
0207 def test_get_active_session_after_create_dataframe(self):
0208 session2 = None
0209 try:
0210 activeSession1 = SparkSession.getActiveSession()
0211 session1 = self.spark
0212 self.assertEqual(session1, activeSession1)
0213 session2 = self.spark.newSession()
0214 activeSession2 = SparkSession.getActiveSession()
0215 self.assertEqual(session1, activeSession2)
0216 self.assertNotEqual(session2, activeSession2)
0217 session2.createDataFrame([(1, 'Alice')], ['age', 'name'])
0218 activeSession3 = SparkSession.getActiveSession()
0219 self.assertEqual(session2, activeSession3)
0220 session1.createDataFrame([(1, 'Alice')], ['age', 'name'])
0221 activeSession4 = SparkSession.getActiveSession()
0222 self.assertEqual(session1, activeSession4)
0223 finally:
0224 if session2 is not None:
0225 session2.stop()
0226
0227
0228 class SparkSessionBuilderTests(unittest.TestCase):
0229
0230 def test_create_spark_context_first_then_spark_session(self):
0231 sc = None
0232 session = None
0233 try:
0234 conf = SparkConf().set("key1", "value1")
0235 sc = SparkContext('local[4]', "SessionBuilderTests", conf=conf)
0236 session = SparkSession.builder.config("key2", "value2").getOrCreate()
0237
0238 self.assertEqual(session.conf.get("key1"), "value1")
0239 self.assertEqual(session.conf.get("key2"), "value2")
0240 self.assertEqual(session.sparkContext, sc)
0241
0242 self.assertFalse(sc.getConf().contains("key2"))
0243 self.assertEqual(sc.getConf().get("key1"), "value1")
0244 finally:
0245 if session is not None:
0246 session.stop()
0247 if sc is not None:
0248 sc.stop()
0249
0250 def test_another_spark_session(self):
0251 session1 = None
0252 session2 = None
0253 try:
0254 session1 = SparkSession.builder.config("key1", "value1").getOrCreate()
0255 session2 = SparkSession.builder.config("key2", "value2").getOrCreate()
0256
0257 self.assertEqual(session1.conf.get("key1"), "value1")
0258 self.assertEqual(session2.conf.get("key1"), "value1")
0259 self.assertEqual(session1.conf.get("key2"), "value2")
0260 self.assertEqual(session2.conf.get("key2"), "value2")
0261 self.assertEqual(session1.sparkContext, session2.sparkContext)
0262
0263 self.assertEqual(session1.sparkContext.getConf().get("key1"), "value1")
0264 self.assertFalse(session1.sparkContext.getConf().contains("key2"))
0265 finally:
0266 if session1 is not None:
0267 session1.stop()
0268 if session2 is not None:
0269 session2.stop()
0270
0271
0272 class SparkExtensionsTest(unittest.TestCase):
0273
0274
0275
0276 @classmethod
0277 def setUpClass(cls):
0278 import glob
0279 from pyspark.find_spark_home import _find_spark_home
0280
0281 SPARK_HOME = _find_spark_home()
0282 filename_pattern = (
0283 "sql/core/target/scala-*/test-classes/org/apache/spark/sql/"
0284 "SparkSessionExtensionSuite.class")
0285 if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)):
0286 raise unittest.SkipTest(
0287 "'org.apache.spark.sql.SparkSessionExtensionSuite' is not "
0288 "available. Will skip the related tests.")
0289
0290
0291 cls.spark = SparkSession.builder \
0292 .master("local[4]") \
0293 .appName(cls.__name__) \
0294 .config(
0295 "spark.sql.extensions",
0296 "org.apache.spark.sql.MyExtensions") \
0297 .getOrCreate()
0298
0299 @classmethod
0300 def tearDownClass(cls):
0301 cls.spark.stop()
0302
0303 def test_use_custom_class_for_extensions(self):
0304 self.assertTrue(
0305 self.spark._jsparkSession.sessionState().planner().strategies().contains(
0306 self.spark._jvm.org.apache.spark.sql.MySparkStrategy(self.spark._jsparkSession)),
0307 "MySparkStrategy not found in active planner strategies")
0308 self.assertTrue(
0309 self.spark._jsparkSession.sessionState().analyzer().extendedResolutionRules().contains(
0310 self.spark._jvm.org.apache.spark.sql.MyRule(self.spark._jsparkSession)),
0311 "MyRule not found in extended resolution rules")
0312
0313
0314 if __name__ == "__main__":
0315 from pyspark.sql.tests.test_session import *
0316
0317 try:
0318 import xmlrunner
0319 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0320 except ImportError:
0321 testRunner = None
0322 unittest.main(testRunner=testRunner, verbosity=2)