Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 import os
0018 import shutil
0019 import sys
0020 import tempfile
0021 import unittest
0022 try:
0023     from importlib import reload  # Python 3.4+ only.
0024 except ImportError:
0025     # Otherwise, we will stick to Python 2's built-in reload.
0026     pass
0027 
0028 import py4j
0029 
0030 from pyspark import SparkContext, SQLContext
0031 from pyspark.sql import Row, SparkSession
0032 from pyspark.sql.types import *
0033 from pyspark.sql.window import Window
0034 from pyspark.testing.utils import ReusedPySparkTestCase
0035 
0036 
0037 class HiveContextSQLTests(ReusedPySparkTestCase):
0038 
0039     @classmethod
0040     def setUpClass(cls):
0041         ReusedPySparkTestCase.setUpClass()
0042         cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
0043         cls.hive_available = True
0044         cls.spark = None
0045         try:
0046             cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
0047         except py4j.protocol.Py4JError:
0048             cls.tearDownClass()
0049             cls.hive_available = False
0050         except TypeError:
0051             cls.tearDownClass()
0052             cls.hive_available = False
0053         if cls.hive_available:
0054             cls.spark = SparkSession.builder.enableHiveSupport().getOrCreate()
0055 
0056         os.unlink(cls.tempdir.name)
0057         if cls.hive_available:
0058             cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
0059             cls.df = cls.sc.parallelize(cls.testData).toDF()
0060 
0061     def setUp(self):
0062         if not self.hive_available:
0063             self.skipTest("Hive is not available.")
0064 
0065     @classmethod
0066     def tearDownClass(cls):
0067         ReusedPySparkTestCase.tearDownClass()
0068         shutil.rmtree(cls.tempdir.name, ignore_errors=True)
0069         if cls.spark is not None:
0070             cls.spark.stop()
0071             cls.spark = None
0072 
0073     def test_save_and_load_table(self):
0074         df = self.df
0075         tmpPath = tempfile.mkdtemp()
0076         shutil.rmtree(tmpPath)
0077         df.write.saveAsTable("savedJsonTable", "json", "append", path=tmpPath)
0078         actual = self.spark.catalog.createTable("externalJsonTable", tmpPath, "json")
0079         self.assertEqual(sorted(df.collect()),
0080                          sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()))
0081         self.assertEqual(sorted(df.collect()),
0082                          sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()))
0083         self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
0084         self.spark.sql("DROP TABLE externalJsonTable")
0085 
0086         df.write.saveAsTable("savedJsonTable", "json", "overwrite", path=tmpPath)
0087         schema = StructType([StructField("value", StringType(), True)])
0088         actual = self.spark.catalog.createTable("externalJsonTable", source="json",
0089                                                 schema=schema, path=tmpPath,
0090                                                 noUse="this options will not be used")
0091         self.assertEqual(sorted(df.collect()),
0092                          sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()))
0093         self.assertEqual(sorted(df.select("value").collect()),
0094                          sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()))
0095         self.assertEqual(sorted(df.select("value").collect()), sorted(actual.collect()))
0096         self.spark.sql("DROP TABLE savedJsonTable")
0097         self.spark.sql("DROP TABLE externalJsonTable")
0098 
0099         defaultDataSourceName = self.spark.conf.get("spark.sql.sources.default",
0100                                                     "org.apache.spark.sql.parquet")
0101         self.spark.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
0102         df.write.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
0103         actual = self.spark.catalog.createTable("externalJsonTable", path=tmpPath)
0104         self.assertEqual(sorted(df.collect()),
0105                          sorted(self.spark.sql("SELECT * FROM savedJsonTable").collect()))
0106         self.assertEqual(sorted(df.collect()),
0107                          sorted(self.spark.sql("SELECT * FROM externalJsonTable").collect()))
0108         self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
0109         self.spark.sql("DROP TABLE savedJsonTable")
0110         self.spark.sql("DROP TABLE externalJsonTable")
0111         self.spark.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
0112 
0113         shutil.rmtree(tmpPath)
0114 
0115     def test_window_functions(self):
0116         df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
0117         w = Window.partitionBy("value").orderBy("key")
0118         from pyspark.sql import functions as F
0119         sel = df.select(df.value, df.key,
0120                         F.max("key").over(w.rowsBetween(0, 1)),
0121                         F.min("key").over(w.rowsBetween(0, 1)),
0122                         F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))),
0123                         F.row_number().over(w),
0124                         F.rank().over(w),
0125                         F.dense_rank().over(w),
0126                         F.ntile(2).over(w))
0127         rs = sorted(sel.collect())
0128         expected = [
0129             ("1", 1, 1, 1, 1, 1, 1, 1, 1),
0130             ("2", 1, 1, 1, 3, 1, 1, 1, 1),
0131             ("2", 1, 2, 1, 3, 2, 1, 1, 1),
0132             ("2", 2, 2, 2, 3, 3, 3, 2, 2)
0133         ]
0134         for r, ex in zip(rs, expected):
0135             self.assertEqual(tuple(r), ex[:len(r)])
0136 
0137     def test_window_functions_without_partitionBy(self):
0138         df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
0139         w = Window.orderBy("key", df.value)
0140         from pyspark.sql import functions as F
0141         sel = df.select(df.value, df.key,
0142                         F.max("key").over(w.rowsBetween(0, 1)),
0143                         F.min("key").over(w.rowsBetween(0, 1)),
0144                         F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))),
0145                         F.row_number().over(w),
0146                         F.rank().over(w),
0147                         F.dense_rank().over(w),
0148                         F.ntile(2).over(w))
0149         rs = sorted(sel.collect())
0150         expected = [
0151             ("1", 1, 1, 1, 4, 1, 1, 1, 1),
0152             ("2", 1, 1, 1, 4, 2, 2, 2, 1),
0153             ("2", 1, 2, 1, 4, 3, 2, 2, 2),
0154             ("2", 2, 2, 2, 4, 4, 4, 3, 2)
0155         ]
0156         for r, ex in zip(rs, expected):
0157             self.assertEqual(tuple(r), ex[:len(r)])
0158 
0159     def test_window_functions_cumulative_sum(self):
0160         df = self.spark.createDataFrame([("one", 1), ("two", 2)], ["key", "value"])
0161         from pyspark.sql import functions as F
0162 
0163         # Test cumulative sum
0164         sel = df.select(
0165             df.key,
0166             F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding, 0)))
0167         rs = sorted(sel.collect())
0168         expected = [("one", 1), ("two", 3)]
0169         for r, ex in zip(rs, expected):
0170             self.assertEqual(tuple(r), ex[:len(r)])
0171 
0172         # Test boundary values less than JVM's Long.MinValue and make sure we don't overflow
0173         sel = df.select(
0174             df.key,
0175             F.sum(df.value).over(Window.rowsBetween(Window.unboundedPreceding - 1, 0)))
0176         rs = sorted(sel.collect())
0177         expected = [("one", 1), ("two", 3)]
0178         for r, ex in zip(rs, expected):
0179             self.assertEqual(tuple(r), ex[:len(r)])
0180 
0181         # Test boundary values greater than JVM's Long.MaxValue and make sure we don't overflow
0182         frame_end = Window.unboundedFollowing + 1
0183         sel = df.select(
0184             df.key,
0185             F.sum(df.value).over(Window.rowsBetween(Window.currentRow, frame_end)))
0186         rs = sorted(sel.collect())
0187         expected = [("one", 3), ("two", 2)]
0188         for r, ex in zip(rs, expected):
0189             self.assertEqual(tuple(r), ex[:len(r)])
0190 
0191     def test_collect_functions(self):
0192         df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
0193         from pyspark.sql import functions
0194 
0195         self.assertEqual(
0196             sorted(df.select(functions.collect_set(df.key).alias('r')).collect()[0].r),
0197             [1, 2])
0198         self.assertEqual(
0199             sorted(df.select(functions.collect_list(df.key).alias('r')).collect()[0].r),
0200             [1, 1, 1, 2])
0201         self.assertEqual(
0202             sorted(df.select(functions.collect_set(df.value).alias('r')).collect()[0].r),
0203             ["1", "2"])
0204         self.assertEqual(
0205             sorted(df.select(functions.collect_list(df.value).alias('r')).collect()[0].r),
0206             ["1", "2", "2", "2"])
0207 
0208     def test_limit_and_take(self):
0209         df = self.spark.range(1, 1000, numPartitions=10)
0210 
0211         def assert_runs_only_one_job_stage_and_task(job_group_name, f):
0212             tracker = self.sc.statusTracker()
0213             self.sc.setJobGroup(job_group_name, description="")
0214             f()
0215             jobs = tracker.getJobIdsForGroup(job_group_name)
0216             self.assertEqual(1, len(jobs))
0217             stages = tracker.getJobInfo(jobs[0]).stageIds
0218             self.assertEqual(1, len(stages))
0219             self.assertEqual(1, tracker.getStageInfo(stages[0]).numTasks)
0220 
0221         # Regression test for SPARK-10731: take should delegate to Scala implementation
0222         assert_runs_only_one_job_stage_and_task("take", lambda: df.take(1))
0223         # Regression test for SPARK-17514: limit(n).collect() should the perform same as take(n)
0224         assert_runs_only_one_job_stage_and_task("collect_limit", lambda: df.limit(1).collect())
0225 
0226     def test_datetime_functions(self):
0227         from pyspark.sql import functions
0228         from datetime import date
0229         df = self.spark.range(1).selectExpr("'2017-01-22' as dateCol")
0230         parse_result = df.select(functions.to_date(functions.col("dateCol"))).first()
0231         self.assertEquals(date(2017, 1, 22), parse_result['to_date(`dateCol`)'])
0232 
0233     def test_unbounded_frames(self):
0234         from pyspark.sql import functions as F
0235         from pyspark.sql import window
0236 
0237         df = self.spark.range(0, 3)
0238 
0239         def rows_frame_match():
0240             return "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" in df.select(
0241                 F.count("*").over(window.Window.rowsBetween(-sys.maxsize, sys.maxsize))
0242             ).columns[0]
0243 
0244         def range_frame_match():
0245             return "RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" in df.select(
0246                 F.count("*").over(window.Window.rangeBetween(-sys.maxsize, sys.maxsize))
0247             ).columns[0]
0248 
0249         for new_maxsize in [2 ** 31 - 1, 2 ** 63 - 1, 2 ** 127 - 1]:
0250             old_maxsize = sys.maxsize
0251             sys.maxsize = new_maxsize
0252             try:
0253                 # Manually reload window module to use monkey-patched sys.maxsize.
0254                 reload(window)
0255                 self.assertTrue(rows_frame_match())
0256                 self.assertTrue(range_frame_match())
0257             finally:
0258                 sys.maxsize = old_maxsize
0259 
0260         reload(window)
0261 
0262 
0263 class SQLContextTests(unittest.TestCase):
0264 
0265     def test_get_or_create(self):
0266         sc = None
0267         sql_context = None
0268         try:
0269             sc = SparkContext('local[4]', "SQLContextTests")
0270             sql_context = SQLContext.getOrCreate(sc)
0271             assert(isinstance(sql_context, SQLContext))
0272         finally:
0273             SQLContext._instantiatedContext = None
0274             if sql_context is not None:
0275                 sql_context.sparkSession.stop()
0276             if sc is not None:
0277                 sc.stop()
0278 
0279 
0280 if __name__ == "__main__":
0281     from pyspark.sql.tests.test_context import *
0282 
0283     try:
0284         import xmlrunner
0285         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0286     except ImportError:
0287         testRunner = None
0288     unittest.main(testRunner=testRunner, verbosity=2)