0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 import os
0018 import tempfile
0019 import time
0020 import unittest
0021
0022 from pyspark import SparkConf, SparkContext, RDD
0023 from pyspark.streaming import StreamingContext
0024 from pyspark.testing.utils import search_jar
0025
0026
0027
0028 kinesis_test_environ_var = "ENABLE_KINESIS_TESTS"
0029 should_skip_kinesis_tests = not os.environ.get(kinesis_test_environ_var) == '1'
0030
0031 if should_skip_kinesis_tests:
0032 kinesis_requirement_message = (
0033 "Skipping all Kinesis Python tests as environmental variable 'ENABLE_KINESIS_TESTS' "
0034 "was not set.")
0035 else:
0036 kinesis_asl_assembly_jar = search_jar("external/kinesis-asl-assembly",
0037 "spark-streaming-kinesis-asl-assembly-",
0038 "spark-streaming-kinesis-asl-assembly_")
0039 if kinesis_asl_assembly_jar is None:
0040 kinesis_requirement_message = (
0041 "Skipping all Kinesis Python tests as the optional Kinesis project was "
0042 "not compiled into a JAR. To run these tests, "
0043 "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/package "
0044 "streaming-kinesis-asl-assembly/assembly' or "
0045 "'build/mvn -Pkinesis-asl package' before running this test.")
0046 else:
0047 existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
0048 jars_args = "--jars %s" % kinesis_asl_assembly_jar
0049 os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args])
0050 kinesis_requirement_message = None
0051
0052 should_test_kinesis = kinesis_requirement_message is None
0053
0054
0055 class PySparkStreamingTestCase(unittest.TestCase):
0056
0057 timeout = 30
0058 duration = .5
0059
0060 @classmethod
0061 def setUpClass(cls):
0062 class_name = cls.__name__
0063 conf = SparkConf().set("spark.default.parallelism", 1)
0064 cls.sc = SparkContext(appName=class_name, conf=conf)
0065 cls.sc.setCheckpointDir(tempfile.mkdtemp())
0066
0067 @classmethod
0068 def tearDownClass(cls):
0069 cls.sc.stop()
0070
0071 try:
0072 jSparkContextOption = SparkContext._jvm.SparkContext.get()
0073 if jSparkContextOption.nonEmpty():
0074 jSparkContextOption.get().stop()
0075 except:
0076 pass
0077
0078 def setUp(self):
0079 self.ssc = StreamingContext(self.sc, self.duration)
0080
0081 def tearDown(self):
0082 if self.ssc is not None:
0083 self.ssc.stop(False)
0084
0085 try:
0086 jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive()
0087 if jStreamingContextOption.nonEmpty():
0088 jStreamingContextOption.get().stop(False)
0089 except:
0090 pass
0091
0092 def wait_for(self, result, n):
0093 start_time = time.time()
0094 while len(result) < n and time.time() - start_time < self.timeout:
0095 time.sleep(0.01)
0096 if len(result) < n:
0097 print("timeout after", self.timeout)
0098
0099 def _take(self, dstream, n):
0100 """
0101 Return the first `n` elements in the stream (will start and stop).
0102 """
0103 results = []
0104
0105 def take(_, rdd):
0106 if rdd and len(results) < n:
0107 results.extend(rdd.take(n - len(results)))
0108
0109 dstream.foreachRDD(take)
0110
0111 self.ssc.start()
0112 self.wait_for(results, n)
0113 return results
0114
0115 def _collect(self, dstream, n, block=True):
0116 """
0117 Collect each RDDs into the returned list.
0118
0119 :return: list, which will have the collected items.
0120 """
0121 result = []
0122
0123 def get_output(_, rdd):
0124 if rdd and len(result) < n:
0125 r = rdd.collect()
0126 if r:
0127 result.append(r)
0128
0129 dstream.foreachRDD(get_output)
0130
0131 if not block:
0132 return result
0133
0134 self.ssc.start()
0135 self.wait_for(result, n)
0136 return result
0137
0138 def _test_func(self, input, func, expected, sort=False, input2=None):
0139 """
0140 :param input: dataset for the test. This should be list of lists.
0141 :param func: wrapped function. This function should return PythonDStream object.
0142 :param expected: expected output for this testcase.
0143 """
0144 if not isinstance(input[0], RDD):
0145 input = [self.sc.parallelize(d, 1) for d in input]
0146 input_stream = self.ssc.queueStream(input)
0147 if input2 and not isinstance(input2[0], RDD):
0148 input2 = [self.sc.parallelize(d, 1) for d in input2]
0149 input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None
0150
0151
0152 if input2:
0153 stream = func(input_stream, input_stream2)
0154 else:
0155 stream = func(input_stream)
0156
0157 result = self._collect(stream, len(expected))
0158 if sort:
0159 self._sort_result_based_on_key(result)
0160 self._sort_result_based_on_key(expected)
0161 self.assertEqual(expected, result)
0162
0163 def _sort_result_based_on_key(self, outputs):
0164 """Sort the list based on first value."""
0165 for output in outputs:
0166 output.sort(key=lambda x: x[0])