0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020 import logging
0021 from argparse import ArgumentParser
0022 import os
0023 import re
0024 import shutil
0025 import subprocess
0026 import sys
0027 import tempfile
0028 from threading import Thread, Lock
0029 import time
0030 import uuid
0031 if sys.version < '3':
0032 import Queue
0033 else:
0034 import queue as Queue
0035 from multiprocessing import Manager
0036
0037
0038
0039 sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../dev/"))
0040
0041
0042 from sparktestsupport import SPARK_HOME
0043 from sparktestsupport.shellutils import which, subprocess_check_output
0044 from sparktestsupport.modules import all_modules, pyspark_sql
0045
0046
0047 python_modules = dict((m.name, m) for m in all_modules if m.python_test_goals if m.name != 'root')
0048
0049
0050 def print_red(text):
0051 print('\033[31m' + text + '\033[0m')
0052
0053
0054 SKIPPED_TESTS = Manager().dict()
0055 LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log")
0056 FAILURE_REPORTING_LOCK = Lock()
0057 LOGGER = logging.getLogger()
0058
0059
0060
0061 for scala in ["2.12"]:
0062 build_dir = os.path.join(SPARK_HOME, "assembly", "target", "scala-" + scala)
0063 if os.path.isdir(build_dir):
0064 SPARK_DIST_CLASSPATH = os.path.join(build_dir, "jars", "*")
0065 break
0066 else:
0067 raise Exception("Cannot find assembly build directory, please build Spark first.")
0068
0069
0070 def run_individual_python_test(target_dir, test_name, pyspark_python):
0071 env = dict(os.environ)
0072 env.update({
0073 'SPARK_DIST_CLASSPATH': SPARK_DIST_CLASSPATH,
0074 'SPARK_TESTING': '1',
0075 'SPARK_PREPEND_CLASSES': '1',
0076 'PYSPARK_PYTHON': which(pyspark_python),
0077 'PYSPARK_DRIVER_PYTHON': which(pyspark_python),
0078 'PYSPARK_ROW_FIELD_SORTING_ENABLED': 'true'
0079 })
0080
0081
0082
0083 tmp_dir = os.path.join(target_dir, str(uuid.uuid4()))
0084 while os.path.isdir(tmp_dir):
0085 tmp_dir = os.path.join(target_dir, str(uuid.uuid4()))
0086 os.mkdir(tmp_dir)
0087 env["TMPDIR"] = tmp_dir
0088
0089
0090 java_options = "-Djava.io.tmpdir={0} -Dio.netty.tryReflectionSetAccessible=true".format(tmp_dir)
0091 spark_args = [
0092 "--conf", "spark.driver.extraJavaOptions='{0}'".format(java_options),
0093 "--conf", "spark.executor.extraJavaOptions='{0}'".format(java_options),
0094 "pyspark-shell"
0095 ]
0096 env["PYSPARK_SUBMIT_ARGS"] = " ".join(spark_args)
0097
0098 LOGGER.info("Starting test(%s): %s", pyspark_python, test_name)
0099 start_time = time.time()
0100 try:
0101 per_test_output = tempfile.TemporaryFile()
0102 retcode = subprocess.Popen(
0103 [os.path.join(SPARK_HOME, "bin/pyspark")] + test_name.split(),
0104 stderr=per_test_output, stdout=per_test_output, env=env).wait()
0105 shutil.rmtree(tmp_dir, ignore_errors=True)
0106 except:
0107 LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python)
0108
0109
0110 os._exit(1)
0111 duration = time.time() - start_time
0112
0113 if retcode != 0:
0114 try:
0115 with FAILURE_REPORTING_LOCK:
0116 with open(LOG_FILE, 'ab') as log_file:
0117 per_test_output.seek(0)
0118 log_file.writelines(per_test_output)
0119 per_test_output.seek(0)
0120 for line in per_test_output:
0121 decoded_line = line.decode("utf-8", "replace")
0122 if not re.match('[0-9]+', decoded_line):
0123 print(decoded_line, end='')
0124 per_test_output.close()
0125 except:
0126 LOGGER.exception("Got an exception while trying to print failed test output")
0127 finally:
0128 print_red("\nHad test failures in %s with %s; see logs." % (test_name, pyspark_python))
0129
0130
0131 os._exit(-1)
0132 else:
0133 skipped_counts = 0
0134 try:
0135 per_test_output.seek(0)
0136
0137
0138 decoded_lines = map(lambda line: line.decode("utf-8", "replace"), iter(per_test_output))
0139 skipped_tests = list(filter(
0140 lambda line: re.search(r'test_.* \(pyspark\..*\) ... (skip|SKIP)', line),
0141 decoded_lines))
0142 skipped_counts = len(skipped_tests)
0143 if skipped_counts > 0:
0144 key = (pyspark_python, test_name)
0145 SKIPPED_TESTS[key] = skipped_tests
0146 per_test_output.close()
0147 except:
0148 import traceback
0149 print_red("\nGot an exception while trying to store "
0150 "skipped test output:\n%s" % traceback.format_exc())
0151
0152
0153 os._exit(-1)
0154 if skipped_counts != 0:
0155 LOGGER.info(
0156 "Finished test(%s): %s (%is) ... %s tests were skipped", pyspark_python, test_name,
0157 duration, skipped_counts)
0158 else:
0159 LOGGER.info(
0160 "Finished test(%s): %s (%is)", pyspark_python, test_name, duration)
0161
0162
0163 def get_default_python_executables():
0164 python_execs = [x for x in ["python3.6", "python2.7", "pypy"] if which(x)]
0165
0166 if "python3.6" not in python_execs:
0167 p = which("python3")
0168 if not p:
0169 LOGGER.error("No python3 executable found. Exiting!")
0170 os._exit(1)
0171 else:
0172 python_execs.insert(0, p)
0173 return python_execs
0174
0175
0176 def parse_opts():
0177 parser = ArgumentParser(
0178 prog="run-tests"
0179 )
0180 parser.add_argument(
0181 "--python-executables", type=str, default=','.join(get_default_python_executables()),
0182 help="A comma-separated list of Python executables to test against (default: %(default)s)"
0183 )
0184 parser.add_argument(
0185 "--modules", type=str,
0186 default=",".join(sorted(python_modules.keys())),
0187 help="A comma-separated list of Python modules to test (default: %(default)s)"
0188 )
0189 parser.add_argument(
0190 "-p", "--parallelism", type=int, default=4,
0191 help="The number of suites to test in parallel (default %(default)d)"
0192 )
0193 parser.add_argument(
0194 "--verbose", action="store_true",
0195 help="Enable additional debug logging"
0196 )
0197
0198 group = parser.add_argument_group("Developer Options")
0199 group.add_argument(
0200 "--testnames", type=str,
0201 default=None,
0202 help=(
0203 "A comma-separated list of specific modules, classes and functions of doctest "
0204 "or unittest to test. "
0205 "For example, 'pyspark.sql.foo' to run the module as unittests or doctests, "
0206 "'pyspark.sql.tests FooTests' to run the specific class of unittests, "
0207 "'pyspark.sql.tests FooTests.test_foo' to run the specific unittest in the class. "
0208 "'--modules' option is ignored if they are given.")
0209 )
0210
0211 args, unknown = parser.parse_known_args()
0212 if unknown:
0213 parser.error("Unsupported arguments: %s" % ' '.join(unknown))
0214 if args.parallelism < 1:
0215 parser.error("Parallelism cannot be less than 1")
0216 return args
0217
0218
0219 def _check_coverage(python_exec):
0220
0221 try:
0222 subprocess_check_output(
0223 [python_exec, "-c", "import coverage"],
0224 stderr=open(os.devnull, 'w'))
0225 except:
0226 print_red("Coverage is not installed in Python executable '%s' "
0227 "but 'COVERAGE_PROCESS_START' environment variable is set, "
0228 "exiting." % python_exec)
0229 sys.exit(-1)
0230
0231
0232 def main():
0233 opts = parse_opts()
0234 if opts.verbose:
0235 log_level = logging.DEBUG
0236 else:
0237 log_level = logging.INFO
0238 should_test_modules = opts.testnames is None
0239 logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s")
0240 LOGGER.info("Running PySpark tests. Output is in %s", LOG_FILE)
0241 if os.path.exists(LOG_FILE):
0242 os.remove(LOG_FILE)
0243 python_execs = opts.python_executables.split(',')
0244 LOGGER.info("Will test against the following Python executables: %s", python_execs)
0245
0246 if should_test_modules:
0247 modules_to_test = []
0248 for module_name in opts.modules.split(','):
0249 if module_name in python_modules:
0250 modules_to_test.append(python_modules[module_name])
0251 else:
0252 print("Error: unrecognized module '%s'. Supported modules: %s" %
0253 (module_name, ", ".join(python_modules)))
0254 sys.exit(-1)
0255 LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test])
0256 else:
0257 testnames_to_test = opts.testnames.split(',')
0258 LOGGER.info("Will test the following Python tests: %s", testnames_to_test)
0259
0260 task_queue = Queue.PriorityQueue()
0261 for python_exec in python_execs:
0262
0263
0264 if "COVERAGE_PROCESS_START" in os.environ:
0265 _check_coverage(python_exec)
0266
0267 python_implementation = subprocess_check_output(
0268 [python_exec, "-c", "import platform; print(platform.python_implementation())"],
0269 universal_newlines=True).strip()
0270 LOGGER.info("%s python_implementation is %s", python_exec, python_implementation)
0271 LOGGER.info("%s version is: %s", python_exec, subprocess_check_output(
0272 [python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip())
0273 if should_test_modules:
0274 for module in modules_to_test:
0275 if python_implementation not in module.blacklisted_python_implementations:
0276 for test_goal in module.python_test_goals:
0277 heavy_tests = ['pyspark.streaming.tests', 'pyspark.mllib.tests',
0278 'pyspark.tests', 'pyspark.sql.tests', 'pyspark.ml.tests']
0279 if any(map(lambda prefix: test_goal.startswith(prefix), heavy_tests)):
0280 priority = 0
0281 else:
0282 priority = 100
0283 task_queue.put((priority, (python_exec, test_goal)))
0284 else:
0285 for test_goal in testnames_to_test:
0286 task_queue.put((0, (python_exec, test_goal)))
0287
0288
0289 target_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'target'))
0290 if not os.path.isdir(target_dir):
0291 os.mkdir(target_dir)
0292
0293 def process_queue(task_queue):
0294 while True:
0295 try:
0296 (priority, (python_exec, test_goal)) = task_queue.get_nowait()
0297 except Queue.Empty:
0298 break
0299 try:
0300 run_individual_python_test(target_dir, test_goal, python_exec)
0301 finally:
0302 task_queue.task_done()
0303
0304 start_time = time.time()
0305 for _ in range(opts.parallelism):
0306 worker = Thread(target=process_queue, args=(task_queue,))
0307 worker.daemon = True
0308 worker.start()
0309 try:
0310 task_queue.join()
0311 except (KeyboardInterrupt, SystemExit):
0312 print_red("Exiting due to interrupt")
0313 sys.exit(-1)
0314 total_duration = time.time() - start_time
0315 LOGGER.info("Tests passed in %i seconds", total_duration)
0316
0317 for key, lines in sorted(SKIPPED_TESTS.items()):
0318 pyspark_python, test_name = key
0319 LOGGER.info("\nSkipped tests in %s with %s:" % (test_name, pyspark_python))
0320 for line in lines:
0321 LOGGER.info(" %s" % line.rstrip())
0322
0323
0324 if __name__ == "__main__":
0325 main()