0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 import os
0018 import shutil
0019 import sys
0020 import tempfile
0021 import unittest
0022 try:
0023 from importlib import reload
0024 except ImportError:
0025
0026 pass
0027
0028 import py4j
0029
0030 from pyspark import SparkContext, SQLContext
0031 from pyspark.sql import Row, SparkSession
0032 from pyspark.sql.types import *
0033 from pyspark.sql.window import Window
0034 from pyspark.testing.utils import ReusedPySparkTestCase
0035
0036
0037 class HiveContextSQLTests(ReusedPySparkTestCase):
0038
0039 @classmethod
0040 def setUpClass(cls):
0041 ReusedPySparkTestCase.setUpClass()
0042 cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
0043 cls.hive_available = True
0044 cls.spark = None
0045 try:
0046 cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
0047 except py4j.protocol.Py4JError:
0048 cls.tearDownClass()
0049 cls.hive_available = False
0050 except TypeError:
0051 cls.tearDownClass()
0052 cls.hive_available = False
0053 if cls.hive_available:
0054 cls.spark = SparkSession.builder.enableHiveSupport().getOrCreate()
0055
0056 os.unlink(cls.tempdir.name)
0057 if cls.hive_available:
0058 cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
0059 cls.df = cls.sc.parallelize(cls.testData).toDF()
0060
0061 def setUp(self):
0062 if not self.hive_available:
0063 self.skipTest("Hive is not available.")
0064
0065 @classmethod
0066 def tearDownClass(cls):
0067 ReusedPySparkTestCase.tearDownClass()
0068 shutil.rmtree(cls.tempdir.name, ignore_errors=True)
0069 if cls.spark is not None:
0070 cls.spark.stop()
0071 cls.spark = None
0072
0073 def test_save_and_load_table(self):
0074 df = self.df
0075 tmpPath = tempfile.mkdtemp()
0076 shutil.rmtree(tmpPath)
0077 df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath)
0078 actual = self.spark.catalog.createTable("externalJsonTable", tmpPath, "json")
0079 self.assertEqual(sorted(df.collect()),
0080 sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()))
0081 self.assertEqual(sorted(df.collect()),
0082 sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()))
0083 self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
0084 self.spark.sql("DROP TABLE externalJsonTable")
0085
0086 df.write.saveAsTable("savedJsonTable", "json", "overwrite", path=tmpPath)
0087 schema = StructType([StructField("value", StringType(), True)])
0088 actual = self.spark.catalog.createTable("externalJsonTable", source="json",
0089 schema=schema, path=tmpPath,
0090 noUse="this options will not be used")
0091 self.assertEqual(sorted(df.collect()),
0092 sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()))
0093 self.assertEqual(sorted(df.select("value").collect()),
0094 sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()))
0095 self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
0096 self.spark.sql("DROP TABLE savedJsonTable")
0097 self.spark.sql("DROP TABLE externalJsonTable")
0098
0099 defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default",
0100 "org.apache.spark.sql.parquet")
0101 self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
0102 df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
0103 actual = self.spark.catalog.createTable("externalJsonTable", path=tmpPath)
0104 self.assertEqual(sorted(df.collect()),
0105 sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()))
0106 self.assertEqual(sorted(df.collect()),
0107 sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()))
0108 self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
0109 self.spark.sql("DROP TABLE savedJsonTable")
0110 self.spark.sql("DROP TABLE externalJsonTable")
0111 self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
0112
0113 shutil.rmtree(tmpPath)
0114
0115 def test_window_functions(self):
0116 df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
0117 w = Window.partitionBy("value").orderBy("key")
0118 from pyspark.sql import functions as F
0119 sel = df.select(df.value, df.key,
0120 F.max("key").over(w.rowsBetween(0, 1)),
0121 F.min("key").over(w.rowsBetween(0, 1)),
0122 F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))),
0123 F.row_number().over(w),
0124 F.rank().over(w),
0125 F.dense_rank().over(w),
0126 F.ntile(2).over(w))
0127 rs = sorted(sel.collect())
0128 expected = [
0129 ("1", 1, 1, 1, 1, 1, 1, 1, 1),
0130 ("2", 1, 1, 1, 3, 1, 1, 1, 1),
0131 ("2", 1, 2, 1, 3, 2, 1, 1, 1),
0132 ("2", 2, 2, 2, 3, 3, 3, 2, 2)
0133 ]
0134 for r, ex in zip(rs, expected):
0135 self.assertEqual(tuple(r), ex[:len(r)])
0136
0137 def test_window_functions_without_partitionBy(self):
0138 df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
0139 w = Window.orderBy("key", df.value)
0140 from pyspark.sql import functions as F
0141 sel = df.select(df.value, df.key,
0142 F.max("key").over(w.rowsBetween(0, 1)),
0143 F.min("key").over(w.rowsBetween(0, 1)),
0144 F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))),
0145 F.row_number().over(w),
0146 F.rank().over(w),
0147 F.dense_rank().over(w),
0148 F.ntile(2).over(w))
0149 rs = sorted(sel.collect())
0150 expected = [
0151 ("1", 1, 1, 1, 4, 1, 1, 1, 1),
0152 ("2", 1, 1, 1, 4, 2, 2, 2, 1),
0153 ("2", 1, 2, 1, 4, 3, 2, 2, 2),
0154 ("2", 2, 2, 2, 4, 4, 4, 3, 2)
0155 ]
0156 for r, ex in zip(rs, expected):
0157 self.assertEqual(tuple(r), ex[:len(r)])
0158
0159 def test_window_functions_cumulative_sum(self):
0160 df = self.spark.createDataFrame([("one", 1), ("two", 2)], ["key", "value"])
0161 from pyspark.sql import functions as F
0162
0163
0164 sel = df.select(
0165 df.key,
0166 F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding, 0)))
0167 rs = sorted(sel.collect())
0168 expected = [("one", 1), ("two", 3)]
0169 for r, ex in zip(rs, expected):
0170 self.assertEqual(tuple(r), ex[:len(r)])
0171
0172
0173 sel = df.select(
0174 df.key,
0175 F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding - 1, 0)))
0176 rs = sorted(sel.collect())
0177 expected = [("one", 1), ("two", 3)]
0178 for r, ex in zip(rs, expected):
0179 self.assertEqual(tuple(r), ex[:len(r)])
0180
0181
0182 frame_end = Window.unboundedFollowing + 1
0183 sel = df.select(
0184 df.key,
0185 F.sum(df.value).over(Window.rowsBetween(Window.currentRow, frame_end)))
0186 rs = sorted(sel.collect())
0187 expected = [("one", 3), ("two", 2)]
0188 for r, ex in zip(rs, expected):
0189 self.assertEqual(tuple(r), ex[:len(r)])
0190
0191 def test_collect_functions(self):
0192 df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
0193 from pyspark.sql import functions
0194
0195 self.assertEqual(
0196 sorted(df.select(functions.collect_set(df.key).alias('r')).collect()[0].r),
0197 [1, 2])
0198 self.assertEqual(
0199 sorted(df.select(functions.collect_list(df.key).alias('r')).collect()[0].r),
0200 [1, 1, 1, 2])
0201 self.assertEqual(
0202 sorted(df.select(functions.collect_set(df.value).alias('r')).collect()[0].r),
0203 ["1", "2"])
0204 self.assertEqual(
0205 sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r),
0206 ["1", "2", "2", "2"])
0207
0208 def test_limit_and_take(self):
0209 df = self.spark.range(1, 1000, numPartitions=10)
0210
0211 def assert_runs_only_one_job_stage_and_task(job_group_name, f):
0212 tracker = self.sc.statusTracker()
0213 self.sc.setJobGroup(job_group_name, description="")
0214 f()
0215 jobs = tracker.getJobIdsForGroup(job_group_name)
0216 self.assertEqual(1, len(jobs))
0217 stages = tracker.getJobInfo(jobs[0]).stageIds
0218 self.assertEqual(1, len(stages))
0219 self.assertEqual(1, tracker.getStageInfo(stages[0]).numTasks)
0220
0221
0222 assert_runs_only_one_job_stage_and_task("take", lambda: df.take(1))
0223
0224 assert_runs_only_one_job_stage_and_task("collect_limit", lambda: df.limit(1).collect())
0225
0226 def test_datetime_functions(self):
0227 from pyspark.sql import functions
0228 from datetime import date
0229 df = self.spark.range(1).selectExpr("'2017-01-22' as dateCol")
0230 parse_result = df.select(functions.to_date(functions.col("dateCol"))).first()
0231 self.assertEquals(date(2017, 1, 22), parse_result['to_date(`dateCol`)'])
0232
0233 def test_unbounded_frames(self):
0234 from pyspark.sql import functions as F
0235 from pyspark.sql import window
0236
0237 df = self.spark.range(0, 3)
0238
0239 def rows_frame_match():
0240 return "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" in df.select(
0241 F.count("*").over(window.Window.rowsBetween(-sys.maxsize, sys.maxsize))
0242 ).columns[0]
0243
0244 def range_frame_match():
0245 return "RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" in df.select(
0246 F.count("*").over(window.Window.rangeBetween(-sys.maxsize, sys.maxsize))
0247 ).columns[0]
0248
0249 for new_maxsize in [2 ** 31 - 1, 2 ** 63 - 1, 2 ** 127 - 1]:
0250 old_maxsize = sys.maxsize
0251 sys.maxsize = new_maxsize
0252 try:
0253
0254 reload(window)
0255 self.assertTrue(rows_frame_match())
0256 self.assertTrue(range_frame_match())
0257 finally:
0258 sys.maxsize = old_maxsize
0259
0260 reload(window)
0261
0262
0263 class SQLContextTests(unittest.TestCase):
0264
0265 def test_get_or_create(self):
0266 sc = None
0267 sql_context = None
0268 try:
0269 sc = SparkContext('local[4]', "SQLContextTests")
0270 sql_context = SQLContext.getOrCreate(sc)
0271 assert(isinstance(sql_context, SQLContext))
0272 finally:
0273 SQLContext._instantiatedContext = None
0274 if sql_context is not None:
0275 sql_context.sparkSession.stop()
0276 if sc is not None:
0277 sc.stop()
0278
0279
0280 if __name__ == "__main__":
0281 from pyspark.sql.tests.test_context import *
0282
0283 try:
0284 import xmlrunner
0285 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0286 except ImportError:
0287 testRunner = None
0288 unittest.main(testRunner=testRunner, verbosity=2)