0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 import os
0019 import shutil
0020 import tempfile
0021 import time
0022
0023 from pyspark.sql.functions import lit
0024 from pyspark.sql.types import *
0025 from pyspark.testing.sqlutils import ReusedSQLTestCase
0026
0027
0028 class StreamingTests(ReusedSQLTestCase):
0029
0030 def test_stream_trigger(self):
0031 df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
0032
0033
0034 try:
0035 df.writeStream.trigger()
0036 except ValueError:
0037 pass
0038
0039
0040 try:
0041 df.writeStream.trigger(once=True, processingTime='5 seconds')
0042 except ValueError:
0043 pass
0044
0045
0046 try:
0047 df.writeStream.trigger(processingTime='5 seconds', continuous='1 second')
0048 except ValueError:
0049 pass
0050
0051
0052 try:
0053 df.writeStream.trigger('5 seconds')
0054 self.fail("Should have thrown an exception")
0055 except TypeError:
0056 pass
0057
0058 def test_stream_read_options(self):
0059 schema = StructType([StructField("data", StringType(), False)])
0060 df = self.spark.readStream\
0061 .format('text')\
0062 .option('path', 'python/test_support/sql/streaming')\
0063 .schema(schema)\
0064 .load()
0065 self.assertTrue(df.isStreaming)
0066 self.assertEqual(df.schema.simpleString(), "struct<data:string>")
0067
0068 def test_stream_read_options_overwrite(self):
0069 bad_schema = StructType([StructField("test", IntegerType(), False)])
0070 schema = StructType([StructField("data", StringType(), False)])
0071 df = self.spark.readStream.format('csv').option('path', 'python/test_support/sql/fake') \
0072 .schema(bad_schema)\
0073 .load(path='python/test_support/sql/streaming', schema=schema, format='text')
0074 self.assertTrue(df.isStreaming)
0075 self.assertEqual(df.schema.simpleString(), "struct<data:string>")
0076
0077 def test_stream_save_options(self):
0078 df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') \
0079 .withColumn('id', lit(1))
0080 for q in self.spark._wrapped.streams.active:
0081 q.stop()
0082 tmpPath = tempfile.mkdtemp()
0083 shutil.rmtree(tmpPath)
0084 self.assertTrue(df.isStreaming)
0085 out = os.path.join(tmpPath, 'out')
0086 chk = os.path.join(tmpPath, 'chk')
0087 q = df.writeStream.option('checkpointLocation', chk).queryName('this_query') \
0088 .format('parquet').partitionBy('id').outputMode('append').option('path', out).start()
0089 try:
0090 self.assertEqual(q.name, 'this_query')
0091 self.assertTrue(q.isActive)
0092 q.processAllAvailable()
0093 output_files = []
0094 for _, _, files in os.walk(out):
0095 output_files.extend([f for f in files if not f.startswith('.')])
0096 self.assertTrue(len(output_files) > 0)
0097 self.assertTrue(len(os.listdir(chk)) > 0)
0098 finally:
0099 q.stop()
0100 shutil.rmtree(tmpPath)
0101
0102 def test_stream_save_options_overwrite(self):
0103 df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
0104 for q in self.spark._wrapped.streams.active:
0105 q.stop()
0106 tmpPath = tempfile.mkdtemp()
0107 shutil.rmtree(tmpPath)
0108 self.assertTrue(df.isStreaming)
0109 out = os.path.join(tmpPath, 'out')
0110 chk = os.path.join(tmpPath, 'chk')
0111 fake1 = os.path.join(tmpPath, 'fake1')
0112 fake2 = os.path.join(tmpPath, 'fake2')
0113 q = df.writeStream.option('checkpointLocation', fake1)\
0114 .format('memory').option('path', fake2) \
0115 .queryName('fake_query').outputMode('append') \
0116 .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
0117
0118 try:
0119 self.assertEqual(q.name, 'this_query')
0120 self.assertTrue(q.isActive)
0121 q.processAllAvailable()
0122 output_files = []
0123 for _, _, files in os.walk(out):
0124 output_files.extend([f for f in files if not f.startswith('.')])
0125 self.assertTrue(len(output_files) > 0)
0126 self.assertTrue(len(os.listdir(chk)) > 0)
0127 self.assertFalse(os.path.isdir(fake1))
0128 self.assertFalse(os.path.isdir(fake2))
0129 finally:
0130 q.stop()
0131 shutil.rmtree(tmpPath)
0132
0133 def test_stream_status_and_progress(self):
0134 df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
0135 for q in self.spark._wrapped.streams.active:
0136 q.stop()
0137 tmpPath = tempfile.mkdtemp()
0138 shutil.rmtree(tmpPath)
0139 self.assertTrue(df.isStreaming)
0140 out = os.path.join(tmpPath, 'out')
0141 chk = os.path.join(tmpPath, 'chk')
0142
0143 def func(x):
0144 time.sleep(1)
0145 return x
0146
0147 from pyspark.sql.functions import col, udf
0148 sleep_udf = udf(func)
0149
0150
0151
0152 q = df.select(sleep_udf(col("value")).alias('value')).writeStream \
0153 .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
0154 try:
0155
0156
0157
0158
0159 q.lastProgress
0160
0161 q.processAllAvailable()
0162 lastProgress = q.lastProgress
0163 recentProgress = q.recentProgress
0164 status = q.status
0165 self.assertEqual(lastProgress['name'], q.name)
0166 self.assertEqual(lastProgress['id'], q.id)
0167 self.assertTrue(any(p == lastProgress for p in recentProgress))
0168 self.assertTrue(
0169 "message" in status and
0170 "isDataAvailable" in status and
0171 "isTriggerActive" in status)
0172 finally:
0173 q.stop()
0174 shutil.rmtree(tmpPath)
0175
0176 def test_stream_await_termination(self):
0177 df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
0178 for q in self.spark._wrapped.streams.active:
0179 q.stop()
0180 tmpPath = tempfile.mkdtemp()
0181 shutil.rmtree(tmpPath)
0182 self.assertTrue(df.isStreaming)
0183 out = os.path.join(tmpPath, 'out')
0184 chk = os.path.join(tmpPath, 'chk')
0185 q = df.writeStream\
0186 .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
0187 try:
0188 self.assertTrue(q.isActive)
0189 try:
0190 q.awaitTermination("hello")
0191 self.fail("Expected a value exception")
0192 except ValueError:
0193 pass
0194 now = time.time()
0195
0196 res = q.awaitTermination(2.6)
0197 duration = time.time() - now
0198 self.assertTrue(duration >= 2)
0199 self.assertFalse(res)
0200 finally:
0201 q.processAllAvailable()
0202 q.stop()
0203 shutil.rmtree(tmpPath)
0204
0205 def test_stream_exception(self):
0206 sdf = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
0207 sq = sdf.writeStream.format('memory').queryName('query_explain').start()
0208 try:
0209 sq.processAllAvailable()
0210 self.assertEqual(sq.exception(), None)
0211 finally:
0212 sq.stop()
0213
0214 from pyspark.sql.functions import col, udf
0215 from pyspark.sql.utils import StreamingQueryException
0216 bad_udf = udf(lambda x: 1 / 0)
0217 sq = sdf.select(bad_udf(col("value")))\
0218 .writeStream\
0219 .format('memory')\
0220 .queryName('this_query')\
0221 .start()
0222 try:
0223
0224 sq.processAllAvailable()
0225 self.fail("bad udf should fail the query")
0226 except StreamingQueryException as e:
0227
0228 self._assert_exception_tree_contains_msg(e, "ZeroDivisionError")
0229 finally:
0230 sq.stop()
0231 self.assertTrue(type(sq.exception()) is StreamingQueryException)
0232 self._assert_exception_tree_contains_msg(sq.exception(), "ZeroDivisionError")
0233
0234 def _assert_exception_tree_contains_msg(self, exception, msg):
0235 e = exception
0236 contains = msg in e.desc
0237 while e.cause is not None and not contains:
0238 e = e.cause
0239 contains = msg in e.desc
0240 self.assertTrue(contains, "Exception tree doesn't contain the expected message: %s" % msg)
0241
0242 def test_query_manager_await_termination(self):
0243 df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
0244 for q in self.spark._wrapped.streams.active:
0245 q.stop()
0246 tmpPath = tempfile.mkdtemp()
0247 shutil.rmtree(tmpPath)
0248 self.assertTrue(df.isStreaming)
0249 out = os.path.join(tmpPath, 'out')
0250 chk = os.path.join(tmpPath, 'chk')
0251 q = df.writeStream\
0252 .start(path=out, format='parquet', queryName='this_query', checkpointLocation=chk)
0253 try:
0254 self.assertTrue(q.isActive)
0255 try:
0256 self.spark._wrapped.streams.awaitAnyTermination("hello")
0257 self.fail("Expected a value exception")
0258 except ValueError:
0259 pass
0260 now = time.time()
0261
0262 res = self.spark._wrapped.streams.awaitAnyTermination(2.6)
0263 duration = time.time() - now
0264 self.assertTrue(duration >= 2)
0265 self.assertFalse(res)
0266 finally:
0267 q.processAllAvailable()
0268 q.stop()
0269 shutil.rmtree(tmpPath)
0270
0271 class ForeachWriterTester:
0272
0273 def __init__(self, spark):
0274 self.spark = spark
0275
0276 def write_open_event(self, partitionId, epochId):
0277 self._write_event(
0278 self.open_events_dir,
0279 {'partition': partitionId, 'epoch': epochId})
0280
0281 def write_process_event(self, row):
0282 self._write_event(self.process_events_dir, {'value': 'text'})
0283
0284 def write_close_event(self, error):
0285 self._write_event(self.close_events_dir, {'error': str(error)})
0286
0287 def write_input_file(self):
0288 self._write_event(self.input_dir, "text")
0289
0290 def open_events(self):
0291 return self._read_events(self.open_events_dir, 'partition INT, epoch INT')
0292
0293 def process_events(self):
0294 return self._read_events(self.process_events_dir, 'value STRING')
0295
0296 def close_events(self):
0297 return self._read_events(self.close_events_dir, 'error STRING')
0298
0299 def run_streaming_query_on_writer(self, writer, num_files):
0300 self._reset()
0301 try:
0302 sdf = self.spark.readStream.format('text').load(self.input_dir)
0303 sq = sdf.writeStream.foreach(writer).start()
0304 for i in range(num_files):
0305 self.write_input_file()
0306 sq.processAllAvailable()
0307 finally:
0308 self.stop_all()
0309
0310 def assert_invalid_writer(self, writer, msg=None):
0311 self._reset()
0312 try:
0313 sdf = self.spark.readStream.format('text').load(self.input_dir)
0314 sq = sdf.writeStream.foreach(writer).start()
0315 self.write_input_file()
0316 sq.processAllAvailable()
0317 self.fail("invalid writer %s did not fail the query" % str(writer))
0318 except Exception as e:
0319 if msg:
0320 assert msg in str(e), "%s not in %s" % (msg, str(e))
0321
0322 finally:
0323 self.stop_all()
0324
0325 def stop_all(self):
0326 for q in self.spark._wrapped.streams.active:
0327 q.stop()
0328
0329 def _reset(self):
0330 self.input_dir = tempfile.mkdtemp()
0331 self.open_events_dir = tempfile.mkdtemp()
0332 self.process_events_dir = tempfile.mkdtemp()
0333 self.close_events_dir = tempfile.mkdtemp()
0334
0335 def _read_events(self, dir, json):
0336 rows = self.spark.read.schema(json).json(dir).collect()
0337 dicts = [row.asDict() for row in rows]
0338 return dicts
0339
0340 def _write_event(self, dir, event):
0341 import uuid
0342 with open(os.path.join(dir, str(uuid.uuid4())), 'w') as f:
0343 f.write("%s\n" % str(event))
0344
0345 def __getstate__(self):
0346 return (self.open_events_dir, self.process_events_dir, self.close_events_dir)
0347
0348 def __setstate__(self, state):
0349 self.open_events_dir, self.process_events_dir, self.close_events_dir = state
0350
0351
0352
0353
0354 def test_streaming_foreach_with_simple_function(self):
0355 tester = self.ForeachWriterTester(self.spark)
0356
0357 def foreach_func(row):
0358 tester.write_process_event(row)
0359
0360 tester.run_streaming_query_on_writer(foreach_func, 2)
0361 self.assertEqual(len(tester.process_events()), 2)
0362
0363 def test_streaming_foreach_with_basic_open_process_close(self):
0364 tester = self.ForeachWriterTester(self.spark)
0365
0366 class ForeachWriter:
0367 def open(self, partitionId, epochId):
0368 tester.write_open_event(partitionId, epochId)
0369 return True
0370
0371 def process(self, row):
0372 tester.write_process_event(row)
0373
0374 def close(self, error):
0375 tester.write_close_event(error)
0376
0377 tester.run_streaming_query_on_writer(ForeachWriter(), 2)
0378
0379 open_events = tester.open_events()
0380 self.assertEqual(len(open_events), 2)
0381 self.assertSetEqual(set([e['epoch'] for e in open_events]), {0, 1})
0382
0383 self.assertEqual(len(tester.process_events()), 2)
0384
0385 close_events = tester.close_events()
0386 self.assertEqual(len(close_events), 2)
0387 self.assertSetEqual(set([e['error'] for e in close_events]), {'None'})
0388
0389 def test_streaming_foreach_with_open_returning_false(self):
0390 tester = self.ForeachWriterTester(self.spark)
0391
0392 class ForeachWriter:
0393 def open(self, partition_id, epoch_id):
0394 tester.write_open_event(partition_id, epoch_id)
0395 return False
0396
0397 def process(self, row):
0398 tester.write_process_event(row)
0399
0400 def close(self, error):
0401 tester.write_close_event(error)
0402
0403 tester.run_streaming_query_on_writer(ForeachWriter(), 2)
0404
0405 self.assertEqual(len(tester.open_events()), 2)
0406
0407 self.assertEqual(len(tester.process_events()), 0)
0408
0409 close_events = tester.close_events()
0410 self.assertEqual(len(close_events), 2)
0411 self.assertSetEqual(set([e['error'] for e in close_events]), {'None'})
0412
0413 def test_streaming_foreach_without_open_method(self):
0414 tester = self.ForeachWriterTester(self.spark)
0415
0416 class ForeachWriter:
0417 def process(self, row):
0418 tester.write_process_event(row)
0419
0420 def close(self, error):
0421 tester.write_close_event(error)
0422
0423 tester.run_streaming_query_on_writer(ForeachWriter(), 2)
0424 self.assertEqual(len(tester.open_events()), 0)
0425 self.assertEqual(len(tester.process_events()), 2)
0426 self.assertEqual(len(tester.close_events()), 2)
0427
0428 def test_streaming_foreach_without_close_method(self):
0429 tester = self.ForeachWriterTester(self.spark)
0430
0431 class ForeachWriter:
0432 def open(self, partition_id, epoch_id):
0433 tester.write_open_event(partition_id, epoch_id)
0434 return True
0435
0436 def process(self, row):
0437 tester.write_process_event(row)
0438
0439 tester.run_streaming_query_on_writer(ForeachWriter(), 2)
0440 self.assertEqual(len(tester.open_events()), 2)
0441 self.assertEqual(len(tester.process_events()), 2)
0442 self.assertEqual(len(tester.close_events()), 0)
0443
0444 def test_streaming_foreach_without_open_and_close_methods(self):
0445 tester = self.ForeachWriterTester(self.spark)
0446
0447 class ForeachWriter:
0448 def process(self, row):
0449 tester.write_process_event(row)
0450
0451 tester.run_streaming_query_on_writer(ForeachWriter(), 2)
0452 self.assertEqual(len(tester.open_events()), 0)
0453 self.assertEqual(len(tester.process_events()), 2)
0454 self.assertEqual(len(tester.close_events()), 0)
0455
0456 def test_streaming_foreach_with_process_throwing_error(self):
0457 from pyspark.sql.utils import StreamingQueryException
0458
0459 tester = self.ForeachWriterTester(self.spark)
0460
0461 class ForeachWriter:
0462 def process(self, row):
0463 raise Exception("test error")
0464
0465 def close(self, error):
0466 tester.write_close_event(error)
0467
0468 try:
0469 tester.run_streaming_query_on_writer(ForeachWriter(), 1)
0470 self.fail("bad writer did not fail the query")
0471 except StreamingQueryException as e:
0472
0473 pass
0474
0475 self.assertEqual(len(tester.process_events()), 0)
0476 close_events = tester.close_events()
0477 self.assertEqual(len(close_events), 1)
0478
0479
0480 def test_streaming_foreach_with_invalid_writers(self):
0481
0482 tester = self.ForeachWriterTester(self.spark)
0483
0484 def func_with_iterator_input(iter):
0485 for x in iter:
0486 print(x)
0487
0488 tester.assert_invalid_writer(func_with_iterator_input)
0489
0490 class WriterWithoutProcess:
0491 def open(self, partition):
0492 pass
0493
0494 tester.assert_invalid_writer(WriterWithoutProcess(), "does not have a 'process'")
0495
0496 class WriterWithNonCallableProcess():
0497 process = True
0498
0499 tester.assert_invalid_writer(WriterWithNonCallableProcess(),
0500 "'process' in provided object is not callable")
0501
0502 class WriterWithNoParamProcess():
0503 def process(self):
0504 pass
0505
0506 tester.assert_invalid_writer(WriterWithNoParamProcess())
0507
0508
0509 class WithProcess():
0510 def process(self, row):
0511 pass
0512
0513 class WriterWithNonCallableOpen(WithProcess):
0514 open = True
0515
0516 tester.assert_invalid_writer(WriterWithNonCallableOpen(),
0517 "'open' in provided object is not callable")
0518
0519 class WriterWithNoParamOpen(WithProcess):
0520 def open(self):
0521 pass
0522
0523 tester.assert_invalid_writer(WriterWithNoParamOpen())
0524
0525 class WriterWithNonCallableClose(WithProcess):
0526 close = True
0527
0528 tester.assert_invalid_writer(WriterWithNonCallableClose(),
0529 "'close' in provided object is not callable")
0530
0531 def test_streaming_foreachBatch(self):
0532 q = None
0533 collected = dict()
0534
0535 def collectBatch(batch_df, batch_id):
0536 collected[batch_id] = batch_df.collect()
0537
0538 try:
0539 df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
0540 q = df.writeStream.foreachBatch(collectBatch).start()
0541 q.processAllAvailable()
0542 self.assertTrue(0 in collected)
0543 self.assertTrue(len(collected[0]), 2)
0544 finally:
0545 if q:
0546 q.stop()
0547
0548 def test_streaming_foreachBatch_propagates_python_errors(self):
0549 from pyspark.sql.utils import StreamingQueryException
0550
0551 q = None
0552
0553 def collectBatch(df, id):
0554 raise Exception("this should fail the query")
0555
0556 try:
0557 df = self.spark.readStream.format('text').load('python/test_support/sql/streaming')
0558 q = df.writeStream.foreachBatch(collectBatch).start()
0559 q.processAllAvailable()
0560 self.fail("Expected a failure")
0561 except StreamingQueryException as e:
0562 self.assertTrue("this should fail" in str(e))
0563 finally:
0564 if q:
0565 q.stop()
0566
0567
0568 if __name__ == "__main__":
0569 import unittest
0570 from pyspark.sql.tests.test_streaming import *
0571
0572 try:
0573 import xmlrunner
0574 testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2)
0575 except ImportError:
0576 testRunner = None
0577 unittest.main(testRunner=testRunner, verbosity=2)