Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 #    http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 
0018 import os
0019 import shutil
0020 import tempfile
0021 import time
0022 
0023 from pyspark.sql.functions import lit
0024 from pyspark.sql.types import *
0025 from pyspark.testing.sqlutils import ReusedSQLTestCase
0026 
0027 
0028 class StreamingTests(ReusedSQLTestCase):
0029 
0030     def test_stream_trigger(self):
0031         df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
0032 
0033         # Should take at least one arg
0034         try:
0035             df.writeStream.trigger()
0036         except ValueError:
0037             pass
0038 
0039         # Should not take multiple args
0040         try:
0041             df.writeStream.trigger(once=True, processingTime='5 seconds')
0042         except ValueError:
0043             pass
0044 
0045         # Should not take multiple args
0046         try:
0047             df.writeStream.trigger(processingTime='5 seconds', continuous='1 second')
0048         except ValueError:
0049             pass
0050 
0051         # Should take only keyword args
0052         try:
0053             df.writeStream.trigger('5 seconds')
0054             self.fail("Should have thrown an exception")
0055         except TypeError:
0056             pass
0057 
0058     def test_stream_read_options(self):
0059         schema = StructType([StructField("data", StringType(), False)])
0060         df = self.spark.readStream\
0061             .format('text')\
0062             .option('path', 'python/test_support/sql/streaming')\
0063             .schema(schema)\
0064             .load()
0065         self.assertTrue(df.isStreaming)
0066         self.assertEqual(df.schema.simpleString(), "struct<data:string>")
0067 
0068     def test_stream_read_options_overwrite(self):
0069         bad_schema = StructType([StructField("test", IntegerType(), False)])
0070         schema = StructType([StructField("data", StringType(), False)])
0071         df = self.spark.readStream.format('csv').option('path', 'python/test_support/sql/fake') \
0072             .schema(bad_schema)\
0073             .load(path='python/test_support/sql/streaming', schema=schema, format='text')
0074         self.assertTrue(df.isStreaming)
0075         self.assertEqual(df.schema.simpleString(), "struct<data:string>")
0076 
0077     def test_stream_save_options(self):
0078         df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') \
0079             .withColumn('id', lit(1))
0080         for q in self.spark._wrapped.streams.active:
0081             q.stop()
0082         tmpPath = tempfile.mkdtemp()
0083         shutil.rmtree(tmpPath)
0084         self.assertTrue(df.isStreaming)
0085         out = os.path.join(tmpPath, 'out')
0086         chk = os.path.join(tmpPath, 'chk')
0087         q = df.writeStream.option('checkpointLocation', chk).queryName('this_query') \
0088             .format('parquet').partitionBy('id').outputMode('append').option('path', out).start()
0089         try:
0090             self.assertEqual(q.name, 'this_query')
0091             self.assertTrue(q.isActive)
0092             q.processAllAvailable()
0093             output_files = []
0094             for _, _, files in os.walk(out):
0095                 output_files.extend([f for f in files if not f.startswith('.')])
0096             self.assertTrue(len(output_files) > 0)
0097             self.assertTrue(len(os.listdir(chk)) > 0)
0098         finally:
0099             q.stop()
0100             shutil.rmtree(tmpPath)
0101 
0102     def test_stream_save_options_overwrite(self):
0103         df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
0104         for q in self.spark._wrapped.streams.active:
0105             q.stop()
0106         tmpPath = tempfile.mkdtemp()
0107         shutil.rmtree(tmpPath)
0108         self.assertTrue(df.isStreaming)
0109         out = os.path.join(tmpPath, 'out')
0110         chk = os.path.join(tmpPath, 'chk')
0111         fake1 = os.path.join(tmpPath, 'fake1')
0112         fake2 = os.path.join(tmpPath, 'fake2')
0113         q = df.writeStream.option('checkpointLocation', fake1)\
0114             .format('memory').option('path', fake2) \
0115             .queryName('fake_query').outputMode('append') \
0116             .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
0117 
0118         try:
0119             self.assertEqual(q.name, 'this_query')
0120             self.assertTrue(q.isActive)
0121             q.processAllAvailable()
0122             output_files = []
0123             for _, _, files in os.walk(out):
0124                 output_files.extend([f for f in files if not f.startswith('.')])
0125             self.assertTrue(len(output_files) > 0)
0126             self.assertTrue(len(os.listdir(chk)) > 0)
0127             self.assertFalse(os.path.isdir(fake1))  # should not have been created
0128             self.assertFalse(os.path.isdir(fake2))  # should not have been created
0129         finally:
0130             q.stop()
0131             shutil.rmtree(tmpPath)
0132 
0133     def test_stream_status_and_progress(self):
0134         df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
0135         for q in self.spark._wrapped.streams.active:
0136             q.stop()
0137         tmpPath = tempfile.mkdtemp()
0138         shutil.rmtree(tmpPath)
0139         self.assertTrue(df.isStreaming)
0140         out = os.path.join(tmpPath, 'out')
0141         chk = os.path.join(tmpPath, 'chk')
0142 
0143         def func(x):
0144             time.sleep(1)
0145             return x
0146 
0147         from pyspark.sql.functions import col, udf
0148         sleep_udf = udf(func)
0149 
0150         # Use "sleep_udf" to delay the progress update so that we can test `lastProgress` when there
0151         # were no updates.
0152         q = df.select(sleep_udf(col("value")).alias('value')).writeStream \
0153             .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
0154         try:
0155             # "lastProgress" will return None in most cases. However, as it may be flaky when
0156             # Jenkins is very slow, we don't assert it. If there is something wrong, "lastProgress"
0157             # may throw error with a high chance and make this test flaky, so we should still be
0158             # able to detect broken codes.
0159             q.lastProgress
0160 
0161             q.processAllAvailable()
0162             lastProgress = q.lastProgress
0163             recentProgress = q.recentProgress
0164             status = q.status
0165             self.assertEqual(lastProgress['name'], q.name)
0166             self.assertEqual(lastProgress['id'], q.id)
0167             self.assertTrue(any(p == lastProgress for p in recentProgress))
0168             self.assertTrue(
0169                 "message" in status and
0170                 "isDataAvailable" in status and
0171                 "isTriggerActive" in status)
0172         finally:
0173             q.stop()
0174             shutil.rmtree(tmpPath)
0175 
0176     def test_stream_await_termination(self):
0177         df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
0178         for q in self.spark._wrapped.streams.active:
0179             q.stop()
0180         tmpPath = tempfile.mkdtemp()
0181         shutil.rmtree(tmpPath)
0182         self.assertTrue(df.isStreaming)
0183         out = os.path.join(tmpPath, 'out')
0184         chk = os.path.join(tmpPath, 'chk')
0185         q = df.writeStream\
0186             .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
0187         try:
0188             self.assertTrue(q.isActive)
0189             try:
0190                 q.awaitTermination("hello")
0191                 self.fail("Expected a value exception")
0192             except ValueError:
0193                 pass
0194             now = time.time()
0195             # test should take at least 2 seconds
0196             res = q.awaitTermination(2.6)
0197             duration = time.time() - now
0198             self.assertTrue(duration >= 2)
0199             self.assertFalse(res)
0200         finally:
0201             q.processAllAvailable()
0202             q.stop()
0203             shutil.rmtree(tmpPath)
0204 
0205     def test_stream_exception(self):
0206         sdf = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
0207         sq = sdf.writeStream.format('memory').queryName('query_explain').start()
0208         try:
0209             sq.processAllAvailable()
0210             self.assertEqual(sq.exception(), None)
0211         finally:
0212             sq.stop()
0213 
0214         from pyspark.sql.functions import col, udf
0215         from pyspark.sql.utils import StreamingQueryException
0216         bad_udf = udf(lambda x: 1 / 0)
0217         sq = sdf.select(bad_udf(col("value")))\
0218             .writeStream\
0219             .format('memory')\
0220             .queryName('this_query')\
0221             .start()
0222         try:
0223             # Process some data to fail the query
0224             sq.processAllAvailable()
0225             self.fail("bad udf should fail the query")
0226         except StreamingQueryException as e:
0227             # This is expected
0228             self._assert_exception_tree_contains_msg(e, "ZeroDivisionError")
0229         finally:
0230             sq.stop()
0231         self.assertTrue(type(sq.exception()) is StreamingQueryException)
0232         self._assert_exception_tree_contains_msg(sq.exception(), "ZeroDivisionError")
0233 
0234     def _assert_exception_tree_contains_msg(self, exception, msg):
0235         e = exception
0236         contains = msg in e.desc
0237         while e.cause is not None and not contains:
0238             e = e.cause
0239             contains = msg in e.desc
0240         self.assertTrue(contains, "Exception tree doesn't contain the expected message: %s" % msg)
0241 
0242     def test_query_manager_await_termination(self):
0243         df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
0244         for q in self.spark._wrapped.streams.active:
0245             q.stop()
0246         tmpPath = tempfile.mkdtemp()
0247         shutil.rmtree(tmpPath)
0248         self.assertTrue(df.isStreaming)
0249         out = os.path.join(tmpPath, 'out')
0250         chk = os.path.join(tmpPath, 'chk')
0251         q = df.writeStream\
0252             .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
0253         try:
0254             self.assertTrue(q.isActive)
0255             try:
0256                 self.spark._wrapped.streams.awaitAnyTermination("hello")
0257                 self.fail("Expected a value exception")
0258             except ValueError:
0259                 pass
0260             now = time.time()
0261             # test should take at least 2 seconds
0262             res = self.spark._wrapped.streams.awaitAnyTermination(2.6)
0263             duration = time.time() - now
0264             self.assertTrue(duration >= 2)
0265             self.assertFalse(res)
0266         finally:
0267             q.processAllAvailable()
0268             q.stop()
0269             shutil.rmtree(tmpPath)
0270 
0271     class ForeachWriterTester:
0272 
0273         def __init__(self, spark):
0274             self.spark = spark
0275 
0276         def write_open_event(self, partitionId, epochId):
0277             self._write_event(
0278                 self.open_events_dir,
0279                 {'partition': partitionId, 'epoch': epochId})
0280 
0281         def write_process_event(self, row):
0282             self._write_event(self.process_events_dir, {'value': 'text'})
0283 
0284         def write_close_event(self, error):
0285             self._write_event(self.close_events_dir, {'error': str(error)})
0286 
0287         def write_input_file(self):
0288             self._write_event(self.input_dir, "text")
0289 
0290         def open_events(self):
0291             return self._read_events(self.open_events_dir, 'partition INT, epoch INT')
0292 
0293         def process_events(self):
0294             return self._read_events(self.process_events_dir, 'value STRING')
0295 
0296         def close_events(self):
0297             return self._read_events(self.close_events_dir, 'error STRING')
0298 
0299         def run_streaming_query_on_writer(self, writer, num_files):
0300             self._reset()
0301             try:
0302                 sdf = self.spark.readStream.format('text').load(self.input_dir)
0303                 sq = sdf.writeStream.foreach(writer).start()
0304                 for i in range(num_files):
0305                     self.write_input_file()
0306                     sq.processAllAvailable()
0307             finally:
0308                 self.stop_all()
0309 
0310         def assert_invalid_writer(self, writer, msg=None):
0311             self._reset()
0312             try:
0313                 sdf = self.spark.readStream.format('text').load(self.input_dir)
0314                 sq = sdf.writeStream.foreach(writer).start()
0315                 self.write_input_file()
0316                 sq.processAllAvailable()
0317                 self.fail("invalid writer %s did not fail the query" % str(writer))  # not expected
0318             except Exception as e:
0319                 if msg:
0320                     assert msg in str(e), "%s not in %s" % (msg, str(e))
0321 
0322             finally:
0323                 self.stop_all()
0324 
0325         def stop_all(self):
0326             for q in self.spark._wrapped.streams.active:
0327                 q.stop()
0328 
0329         def _reset(self):
0330             self.input_dir = tempfile.mkdtemp()
0331             self.open_events_dir = tempfile.mkdtemp()
0332             self.process_events_dir = tempfile.mkdtemp()
0333             self.close_events_dir = tempfile.mkdtemp()
0334 
0335         def _read_events(self, dir, json):
0336             rows = self.spark.read.schema(json).json(dir).collect()
0337             dicts = [row.asDict() for row in rows]
0338             return dicts
0339 
0340         def _write_event(self, dir, event):
0341             import uuid
0342             with open(os.path.join(dir, str(uuid.uuid4())), 'w') as f:
0343                 f.write("%s\n" % str(event))
0344 
0345         def __getstate__(self):
0346             return (self.open_events_dir, self.process_events_dir, self.close_events_dir)
0347 
0348         def __setstate__(self, state):
0349             self.open_events_dir, self.process_events_dir, self.close_events_dir = state
0350 
0351     # Those foreach tests are failed in Python 3.6 and macOS High Sierra by defined rules
0352     # at http://sealiesoftware.com/blog/archive/2017/6/5/Objective-C_and_fork_in_macOS_1013.html
0353     # To work around this, OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES.
0354     def test_streaming_foreach_with_simple_function(self):
0355         tester = self.ForeachWriterTester(self.spark)
0356 
0357         def foreach_func(row):
0358             tester.write_process_event(row)
0359 
0360         tester.run_streaming_query_on_writer(foreach_func, 2)
0361         self.assertEqual(len(tester.process_events()), 2)
0362 
0363     def test_streaming_foreach_with_basic_open_process_close(self):
0364         tester = self.ForeachWriterTester(self.spark)
0365 
0366         class ForeachWriter:
0367             def open(self, partitionId, epochId):
0368                 tester.write_open_event(partitionId, epochId)
0369                 return True
0370 
0371             def process(self, row):
0372                 tester.write_process_event(row)
0373 
0374             def close(self, error):
0375                 tester.write_close_event(error)
0376 
0377         tester.run_streaming_query_on_writer(ForeachWriter(), 2)
0378 
0379         open_events = tester.open_events()
0380         self.assertEqual(len(open_events), 2)
0381         self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1})
0382 
0383         self.assertEqual(len(tester.process_events()), 2)
0384 
0385         close_events = tester.close_events()
0386         self.assertEqual(len(close_events), 2)
0387         self.assertSetEqual(set([e['error'] for e in close_events]), {'None'})
0388 
0389     def test_streaming_foreach_with_open_returning_false(self):
0390         tester = self.ForeachWriterTester(self.spark)
0391 
0392         class ForeachWriter:
0393             def open(self, partition_id, epoch_id):
0394                 tester.write_open_event(partition_id, epoch_id)
0395                 return False
0396 
0397             def process(self, row):
0398                 tester.write_process_event(row)
0399 
0400             def close(self, error):
0401                 tester.write_close_event(error)
0402 
0403         tester.run_streaming_query_on_writer(ForeachWriter(), 2)
0404 
0405         self.assertEqual(len(tester.open_events()), 2)
0406 
0407         self.assertEqual(len(tester.process_events()), 0)  # no row was processed
0408 
0409         close_events = tester.close_events()
0410         self.assertEqual(len(close_events), 2)
0411         self.assertSetEqual(set([e['error'] for e in close_events]), {'None'})
0412 
0413     def test_streaming_foreach_without_open_method(self):
0414         tester = self.ForeachWriterTester(self.spark)
0415 
0416         class ForeachWriter:
0417             def process(self, row):
0418                 tester.write_process_event(row)
0419 
0420             def close(self, error):
0421                 tester.write_close_event(error)
0422 
0423         tester.run_streaming_query_on_writer(ForeachWriter(), 2)
0424         self.assertEqual(len(tester.open_events()), 0)  # no open events
0425         self.assertEqual(len(tester.process_events()), 2)
0426         self.assertEqual(len(tester.close_events()), 2)
0427 
0428     def test_streaming_foreach_without_close_method(self):
0429         tester = self.ForeachWriterTester(self.spark)
0430 
0431         class ForeachWriter:
0432             def open(self, partition_id, epoch_id):
0433                 tester.write_open_event(partition_id, epoch_id)
0434                 return True
0435 
0436             def process(self, row):
0437                 tester.write_process_event(row)
0438 
0439         tester.run_streaming_query_on_writer(ForeachWriter(), 2)
0440         self.assertEqual(len(tester.open_events()), 2)  # no open events
0441         self.assertEqual(len(tester.process_events()), 2)
0442         self.assertEqual(len(tester.close_events()), 0)
0443 
0444     def test_streaming_foreach_without_open_and_close_methods(self):
0445         tester = self.ForeachWriterTester(self.spark)
0446 
0447         class ForeachWriter:
0448             def process(self, row):
0449                 tester.write_process_event(row)
0450 
0451         tester.run_streaming_query_on_writer(ForeachWriter(), 2)
0452         self.assertEqual(len(tester.open_events()), 0)  # no open events
0453         self.assertEqual(len(tester.process_events()), 2)
0454         self.assertEqual(len(tester.close_events()), 0)
0455 
0456     def test_streaming_foreach_with_process_throwing_error(self):
0457         from pyspark.sql.utils import StreamingQueryException
0458 
0459         tester = self.ForeachWriterTester(self.spark)
0460 
0461         class ForeachWriter:
0462             def process(self, row):
0463                 raise Exception("test error")
0464 
0465             def close(self, error):
0466                 tester.write_close_event(error)
0467 
0468         try:
0469             tester.run_streaming_query_on_writer(ForeachWriter(), 1)
0470             self.fail("bad writer did not fail the query")  # this is not expected
0471         except StreamingQueryException as e:
0472             # TODO: Verify whether original error message is inside the exception
0473             pass
0474 
0475         self.assertEqual(len(tester.process_events()), 0)  # no row was processed
0476         close_events = tester.close_events()
0477         self.assertEqual(len(close_events), 1)
0478         # TODO: Verify whether original error message is inside the exception
0479 
0480     def test_streaming_foreach_with_invalid_writers(self):
0481 
0482         tester = self.ForeachWriterTester(self.spark)
0483 
0484         def func_with_iterator_input(iter):
0485             for x in iter:
0486                 print(x)
0487 
0488         tester.assert_invalid_writer(func_with_iterator_input)
0489 
0490         class WriterWithoutProcess:
0491             def open(self, partition):
0492                 pass
0493 
0494         tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'")
0495 
0496         class WriterWithNonCallableProcess():
0497             process = True
0498 
0499         tester.assert_invalid_writer(WriterWithNonCallableProcess(),
0500                                      "'process' in provided object is not callable")
0501 
0502         class WriterWithNoParamProcess():
0503             def process(self):
0504                 pass
0505 
0506         tester.assert_invalid_writer(WriterWithNoParamProcess())
0507 
0508         # Abstract class for tests below
0509         class WithProcess():
0510             def process(self, row):
0511                 pass
0512 
0513         class WriterWithNonCallableOpen(WithProcess):
0514             open = True
0515 
0516         tester.assert_invalid_writer(WriterWithNonCallableOpen(),
0517                                      "'open' in provided object is not callable")
0518 
0519         class WriterWithNoParamOpen(WithProcess):
0520             def open(self):
0521                 pass
0522 
0523         tester.assert_invalid_writer(WriterWithNoParamOpen())
0524 
0525         class WriterWithNonCallableClose(WithProcess):
0526             close = True
0527 
0528         tester.assert_invalid_writer(WriterWithNonCallableClose(),
0529                                      "'close' in provided object is not callable")
0530 
0531     def test_streaming_foreachBatch(self):
0532         q = None
0533         collected = dict()
0534 
0535         def collectBatch(batch_df, batch_id):
0536             collected[batch_id] = batch_df.collect()
0537 
0538         try:
0539             df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
0540             q = df.writeStream.foreachBatch(collectBatch).start()
0541             q.processAllAvailable()
0542             self.assertTrue(0 in collected)
0543             self.assertTrue(len(collected[0]), 2)
0544         finally:
0545             if q:
0546                 q.stop()
0547 
0548     def test_streaming_foreachBatch_propagates_python_errors(self):
0549         from pyspark.sql.utils import StreamingQueryException
0550 
0551         q = None
0552 
0553         def collectBatch(df, id):
0554             raise Exception("this should fail the query")
0555 
0556         try:
0557             df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
0558             q = df.writeStream.foreachBatch(collectBatch).start()
0559             q.processAllAvailable()
0560             self.fail("Expected a failure")
0561         except StreamingQueryException as e:
0562             self.assertTrue("this should fail" in str(e))
0563         finally:
0564             if q:
0565                 q.stop()
0566 
0567 
0568 if __name__ == "__main__":
0569     import unittest
0570     from pyspark.sql.tests.test_streaming import *
0571 
0572     try:
0573         import xmlrunner
0574         testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0575     except ImportError:
0576         testRunner = None
0577     unittest.main(testRunner=testRunner, verbosity=2)