0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017 import glob
0018 import os
0019 import struct
0020 import sys
0021 import unittest
0022 from time import time, sleep
0023
0024 from pyspark import SparkContext, SparkConf
0025
0026
0027 have_scipy = False
0028 have_numpy = False
0029 try:
0030 import scipy.sparse
0031 have_scipy = True
0032 except:
0033
0034 pass
0035 try:
0036 import numpy as np
0037 have_numpy = True
0038 except:
0039
0040 pass
0041
0042
0043 SPARK_HOME = os.environ["SPARK_HOME"]
0044
0045
0046 def read_int(b):
0047 return struct.unpack("!i", b)[0]
0048
0049
0050 def write_int(i):
0051 return struct.pack("!i", i)
0052
0053
0054 def eventually(condition, timeout=30.0, catch_assertions=False):
0055 """
0056 Wait a given amount of time for a condition to pass, else fail with an error.
0057 This is a helper utility for PySpark tests.
0058
0059 :param condition: Function that checks for termination conditions.
0060 condition() can return:
0061 - True: Conditions met. Return without error.
0062 - other value: Conditions not met yet. Continue. Upon timeout,
0063 include last such value in error message.
0064 Note that this method may be called at any time during
0065 streaming execution (e.g., even before any results
0066 have been created).
0067 :param timeout: Number of seconds to wait. Default 30 seconds.
0068 :param catch_assertions: If False (default), do not catch AssertionErrors.
0069 If True, catch AssertionErrors; continue, but save
0070 error to throw upon timeout.
0071 """
0072 start_time = time()
0073 lastValue = None
0074 while time() - start_time < timeout:
0075 if catch_assertions:
0076 try:
0077 lastValue = condition()
0078 except AssertionError as e:
0079 lastValue = e
0080 else:
0081 lastValue = condition()
0082 if lastValue is True:
0083 return
0084 sleep(0.01)
0085 if isinstance(lastValue, AssertionError):
0086 raise lastValue
0087 else:
0088 raise AssertionError(
0089 "Test failed due to timeout after %g sec, with last condition returning: %s"
0090 % (timeout, lastValue))
0091
0092
0093 class QuietTest(object):
0094 def __init__(self, sc):
0095 self.log4j = sc._jvm.org.apache.log4j
0096
0097 def __enter__(self):
0098 self.old_level = self.log4j.LogManager.getRootLogger().getLevel()
0099 self.log4j.LogManager.getRootLogger().setLevel(self.log4j.Level.FATAL)
0100
0101 def __exit__(self, exc_type, exc_val, exc_tb):
0102 self.log4j.LogManager.getRootLogger().setLevel(self.old_level)
0103
0104
0105 class PySparkTestCase(unittest.TestCase):
0106
0107 def setUp(self):
0108 self._old_sys_path = list(sys.path)
0109 class_name = self.__class__.__name__
0110 self.sc = SparkContext('local[4]', class_name)
0111
0112 def tearDown(self):
0113 self.sc.stop()
0114 sys.path = self._old_sys_path
0115
0116
0117 class ReusedPySparkTestCase(unittest.TestCase):
0118
0119 @classmethod
0120 def conf(cls):
0121 """
0122 Override this in subclasses to supply a more specific conf
0123 """
0124 return SparkConf()
0125
0126 @classmethod
0127 def setUpClass(cls):
0128 cls.sc = SparkContext('local[4]', cls.__name__, conf=cls.conf())
0129
0130 @classmethod
0131 def tearDownClass(cls):
0132 cls.sc.stop()
0133
0134
0135 class ByteArrayOutput(object):
0136 def __init__(self):
0137 self.buffer = bytearray()
0138
0139 def write(self, b):
0140 self.buffer += b
0141
0142 def close(self):
0143 pass
0144
0145
0146 def search_jar(project_relative_path, sbt_jar_name_prefix, mvn_jar_name_prefix):
0147
0148
0149 project_full_path = os.path.join(
0150 os.environ["SPARK_HOME"], project_relative_path)
0151
0152
0153 ignored_jar_suffixes = ("javadoc.jar", "sources.jar", "test-sources.jar", "tests.jar")
0154
0155
0156
0157 sbt_build = glob.glob(os.path.join(
0158 project_full_path, "target/scala-*/%s*.jar" % sbt_jar_name_prefix))
0159 maven_build = glob.glob(os.path.join(
0160 project_full_path, "target/%s*.jar" % mvn_jar_name_prefix))
0161 jar_paths = sbt_build + maven_build
0162 jars = [jar for jar in jar_paths if not jar.endswith(ignored_jar_suffixes)]
0163
0164 if not jars:
0165 return None
0166 elif len(jars) > 1:
0167 raise Exception("Found multiple JARs: %s; please remove all but one" % (", ".join(jars)))
0168 else:
0169 return jars[0]