Back to home page

OSCL-LXR

 
 

    


0001 #
0002 # Licensed to the Apache Software Foundation (ASF) under one or more
0003 # contributor license agreements.  See the NOTICE file distributed with
0004 # this work for additional information regarding copyright ownership.
0005 # The ASF licenses this file to You under the Apache License, Version 2.0
0006 # (the "License"); you may not use this file except in compliance with
0007 # the License.  You may obtain a copy of the License at
0008 #
0009 # http://www.apache.org/licenses/LICENSE-2.0
0010 #
0011 # Unless required by applicable law or agreed to in writing, software
0012 # distributed under the License is distributed on an "AS IS" BASIS,
0013 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0014 # See the License for the specific language governing permissions and
0015 # limitations under the License.
0016 #
0017 import os
0018 import tempfile
0019 import time
0020 import unittest
0021 
0022 from pyspark import SparkConf, SparkContext, RDD
0023 from pyspark.streaming import StreamingContext
0024 from pyspark.testing.utils import search_jar
0025 
0026 
0027 # Must be same as the variable and condition defined in KinesisTestUtils.scala and modules.py
0028 kinesis_test_environ_var = "ENABLE_KINESIS_TESTS"
0029 should_skip_kinesis_tests = not os.environ.get(kinesis_test_environ_var) == '1'
0030 
0031 if should_skip_kinesis_tests:
0032     kinesis_requirement_message = (
0033         "Skipping all Kinesis Python tests as environmental variable 'ENABLE_KINESIS_TESTS' "
0034         "was not set.")
0035 else:
0036     kinesis_asl_assembly_jar = search_jar("external/kinesis-asl-assembly",
0037                                           "spark-streaming-kinesis-asl-assembly-",
0038                                           "spark-streaming-kinesis-asl-assembly_")
0039     if kinesis_asl_assembly_jar is None:
0040         kinesis_requirement_message = (
0041             "Skipping all Kinesis Python tests as the optional Kinesis project was "
0042             "not compiled into a JAR. To run these tests, "
0043             "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/package "
0044             "streaming-kinesis-asl-assembly/assembly' or "
0045             "'build/mvn -Pkinesis-asl package' before running this test.")
0046     else:
0047         existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
0048         jars_args = "--jars %s" % kinesis_asl_assembly_jar
0049         os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args])
0050         kinesis_requirement_message = None
0051 
0052 should_test_kinesis = kinesis_requirement_message is None
0053 
0054 
0055 class PySparkStreamingTestCase(unittest.TestCase):
0056 
0057     timeout = 30  # seconds
0058     duration = .5
0059 
0060     @classmethod
0061     def setUpClass(cls):
0062         class_name = cls.__name__
0063         conf = SparkConf().set("spark.default.parallelism", 1)
0064         cls.sc = SparkContext(appName=class_name, conf=conf)
0065         cls.sc.setCheckpointDir(tempfile.mkdtemp())
0066 
0067     @classmethod
0068     def tearDownClass(cls):
0069         cls.sc.stop()
0070         # Clean up in the JVM just in case there has been some issues in Python API
0071         try:
0072             jSparkContextOption = SparkContext._jvm.SparkContext.get()
0073             if jSparkContextOption.nonEmpty():
0074                 jSparkContextOption.get().stop()
0075         except:
0076             pass
0077 
0078     def setUp(self):
0079         self.ssc = StreamingContext(self.sc, self.duration)
0080 
0081     def tearDown(self):
0082         if self.ssc is not None:
0083             self.ssc.stop(False)
0084         # Clean up in the JVM just in case there has been some issues in Python API
0085         try:
0086             jStreamingContextOption = StreamingContext._jvm.SparkContext.getActive()
0087             if jStreamingContextOption.nonEmpty():
0088                 jStreamingContextOption.get().stop(False)
0089         except:
0090             pass
0091 
0092     def wait_for(self, result, n):
0093         start_time = time.time()
0094         while len(result) < n and time.time() - start_time < self.timeout:
0095             time.sleep(0.01)
0096         if len(result) < n:
0097             print("timeout after", self.timeout)
0098 
0099     def _take(self, dstream, n):
0100         """
0101         Return the first `n` elements in the stream (will start and stop).
0102         """
0103         results = []
0104 
0105         def take(_, rdd):
0106             if rdd and len(results) < n:
0107                 results.extend(rdd.take(n - len(results)))
0108 
0109         dstream.foreachRDD(take)
0110 
0111         self.ssc.start()
0112         self.wait_for(results, n)
0113         return results
0114 
0115     def _collect(self, dstream, n, block=True):
0116         """
0117         Collect each RDDs into the returned list.
0118 
0119         :return: list, which will have the collected items.
0120         """
0121         result = []
0122 
0123         def get_output(_, rdd):
0124             if rdd and len(result) < n:
0125                 r = rdd.collect()
0126                 if r:
0127                     result.append(r)
0128 
0129         dstream.foreachRDD(get_output)
0130 
0131         if not block:
0132             return result
0133 
0134         self.ssc.start()
0135         self.wait_for(result, n)
0136         return result
0137 
0138     def _test_func(self, input, func, expected, sort=False, input2=None):
0139         """
0140         :param input: dataset for the test. This should be list of lists.
0141         :param func: wrapped function. This function should return PythonDStream object.
0142         :param expected: expected output for this testcase.
0143         """
0144         if not isinstance(input[0], RDD):
0145             input = [self.sc.parallelize(d, 1) for d in input]
0146         input_stream = self.ssc.queueStream(input)
0147         if input2 and not isinstance(input2[0], RDD):
0148             input2 = [self.sc.parallelize(d, 1) for d in input2]
0149         input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None
0150 
0151         # Apply test function to stream.
0152         if input2:
0153             stream = func(input_stream, input_stream2)
0154         else:
0155             stream = func(input_stream)
0156 
0157         result = self._collect(stream, len(expected))
0158         if sort:
0159             self._sort_result_based_on_key(result)
0160             self._sort_result_based_on_key(expected)
0161         self.assertEqual(expected, result)
0162 
0163     def _sort_result_based_on_key(self, outputs):
0164         """Sort the list based on first value."""
0165         for output in outputs:
0166             output.sort(key=lambda x: x[0])