Back to home page

OSCL-LXR

 
 

    


0001 #!/usr/bin/env python3
0002 
0003 #
0004 # Licensed to the Apache Software Foundation (ASF) under one or more
0005 # contributor license agreements.  See the NOTICE file distributed with
0006 # this work for additional information regarding copyright ownership.
0007 # The ASF licenses this file to You under the Apache License, Version 2.0
0008 # (the "License"); you may not use this file except in compliance with
0009 # the License.  You may obtain a copy of the License at
0010 #
0011 #    http://www.apache.org/licenses/LICENSE-2.0
0012 #
0013 # Unless required by applicable law or agreed to in writing, software
0014 # distributed under the License is distributed on an "AS IS" BASIS,
0015 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
0016 # See the License for the specific language governing permissions and
0017 # limitations under the License.
0018 #
0019 
0020 import logging
0021 from argparse import ArgumentParser
0022 import os
0023 import re
0024 import shutil
0025 import subprocess
0026 import sys
0027 import tempfile
0028 from threading import Thread, Lock
0029 import time
0030 import uuid
0031 if sys.version < '3':
0032     import Queue
0033 else:
0034     import queue as Queue
0035 from multiprocessing import Manager
0036 
0037 
0038 # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module
0039 sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../dev/"))
0040 
0041 
0042 from sparktestsupport import SPARK_HOME  # noqa (suppress pep8 warnings)
0043 from sparktestsupport.shellutils import which, subprocess_check_output  # noqa
0044 from sparktestsupport.modules import all_modules, pyspark_sql  # noqa
0045 
0046 
0047 python_modules = dict((m.name, m) for m in all_modules if m.python_test_goals if m.name != 'root')
0048 
0049 
0050 def print_red(text):
0051     print('\033[31m' + text + '\033[0m')
0052 
0053 
0054 SKIPPED_TESTS = Manager().dict()
0055 LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log")
0056 FAILURE_REPORTING_LOCK = Lock()
0057 LOGGER = logging.getLogger()
0058 
0059 # Find out where the assembly jars are located.
0060 # TODO: revisit for Scala 2.13
0061 for scala in ["2.12"]:
0062     build_dir = os.path.join(SPARK_HOME, "assembly", "target", "scala-" + scala)
0063     if os.path.isdir(build_dir):
0064         SPARK_DIST_CLASSPATH = os.path.join(build_dir, "jars", "*")
0065         break
0066 else:
0067     raise Exception("Cannot find assembly build directory, please build Spark first.")
0068 
0069 
0070 def run_individual_python_test(target_dir, test_name, pyspark_python):
0071     env = dict(os.environ)
0072     env.update({
0073         'SPARK_DIST_CLASSPATH': SPARK_DIST_CLASSPATH,
0074         'SPARK_TESTING': '1',
0075         'SPARK_PREPEND_CLASSES': '1',
0076         'PYSPARK_PYTHON': which(pyspark_python),
0077         'PYSPARK_DRIVER_PYTHON': which(pyspark_python),
0078         'PYSPARK_ROW_FIELD_SORTING_ENABLED': 'true'
0079     })
0080 
0081     # Create a unique temp directory under 'target/' for each run. The TMPDIR variable is
0082     # recognized by the tempfile module to override the default system temp directory.
0083     tmp_dir = os.path.join(target_dir, str(uuid.uuid4()))
0084     while os.path.isdir(tmp_dir):
0085         tmp_dir = os.path.join(target_dir, str(uuid.uuid4()))
0086     os.mkdir(tmp_dir)
0087     env["TMPDIR"] = tmp_dir
0088 
0089     # Also override the JVM's temp directory by setting driver and executor options.
0090     java_options = "-Djava.io.tmpdir={0} -Dio.netty.tryReflectionSetAccessible=true".format(tmp_dir)
0091     spark_args = [
0092         "--conf", "spark.driver.extraJavaOptions='{0}'".format(java_options),
0093         "--conf", "spark.executor.extraJavaOptions='{0}'".format(java_options),
0094         "pyspark-shell"
0095     ]
0096     env["PYSPARK_SUBMIT_ARGS"] = " ".join(spark_args)
0097 
0098     LOGGER.info("Starting test(%s): %s", pyspark_python, test_name)
0099     start_time = time.time()
0100     try:
0101         per_test_output = tempfile.TemporaryFile()
0102         retcode = subprocess.Popen(
0103             [os.path.join(SPARK_HOME, "bin/pyspark")] + test_name.split(),
0104             stderr=per_test_output, stdout=per_test_output, env=env).wait()
0105         shutil.rmtree(tmp_dir, ignore_errors=True)
0106     except:
0107         LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python)
0108         # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
0109         # this code is invoked from a thread other than the main thread.
0110         os._exit(1)
0111     duration = time.time() - start_time
0112     # Exit on the first failure.
0113     if retcode != 0:
0114         try:
0115             with FAILURE_REPORTING_LOCK:
0116                 with open(LOG_FILE, 'ab') as log_file:
0117                     per_test_output.seek(0)
0118                     log_file.writelines(per_test_output)
0119                 per_test_output.seek(0)
0120                 for line in per_test_output:
0121                     decoded_line = line.decode("utf-8", "replace")
0122                     if not re.match('[0-9]+', decoded_line):
0123                         print(decoded_line, end='')
0124                 per_test_output.close()
0125         except:
0126             LOGGER.exception("Got an exception while trying to print failed test output")
0127         finally:
0128             print_red("\nHad test failures in %s with %s; see logs." % (test_name, pyspark_python))
0129             # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
0130             # this code is invoked from a thread other than the main thread.
0131             os._exit(-1)
0132     else:
0133         skipped_counts = 0
0134         try:
0135             per_test_output.seek(0)
0136             # Here expects skipped test output from unittest when verbosity level is
0137             # 2 (or --verbose option is enabled).
0138             decoded_lines = map(lambda line: line.decode("utf-8", "replace"), iter(per_test_output))
0139             skipped_tests = list(filter(
0140                 lambda line: re.search(r'test_.* \(pyspark\..*\) ... (skip|SKIP)', line),
0141                 decoded_lines))
0142             skipped_counts = len(skipped_tests)
0143             if skipped_counts > 0:
0144                 key = (pyspark_python, test_name)
0145                 SKIPPED_TESTS[key] = skipped_tests
0146             per_test_output.close()
0147         except:
0148             import traceback
0149             print_red("\nGot an exception while trying to store "
0150                       "skipped test output:\n%s" % traceback.format_exc())
0151             # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
0152             # this code is invoked from a thread other than the main thread.
0153             os._exit(-1)
0154         if skipped_counts != 0:
0155             LOGGER.info(
0156                 "Finished test(%s): %s (%is) ... %s tests were skipped", pyspark_python, test_name,
0157                 duration, skipped_counts)
0158         else:
0159             LOGGER.info(
0160                 "Finished test(%s): %s (%is)", pyspark_python, test_name, duration)
0161 
0162 
0163 def get_default_python_executables():
0164     python_execs = [x for x in ["python3.6", "python2.7", "pypy"] if which(x)]
0165 
0166     if "python3.6" not in python_execs:
0167         p = which("python3")
0168         if not p:
0169             LOGGER.error("No python3 executable found.  Exiting!")
0170             os._exit(1)
0171         else:
0172             python_execs.insert(0, p)
0173     return python_execs
0174 
0175 
0176 def parse_opts():
0177     parser = ArgumentParser(
0178         prog="run-tests"
0179     )
0180     parser.add_argument(
0181         "--python-executables", type=str, default=','.join(get_default_python_executables()),
0182         help="A comma-separated list of Python executables to test against (default: %(default)s)"
0183     )
0184     parser.add_argument(
0185         "--modules", type=str,
0186         default=",".join(sorted(python_modules.keys())),
0187         help="A comma-separated list of Python modules to test (default: %(default)s)"
0188     )
0189     parser.add_argument(
0190         "-p", "--parallelism", type=int, default=4,
0191         help="The number of suites to test in parallel (default %(default)d)"
0192     )
0193     parser.add_argument(
0194         "--verbose", action="store_true",
0195         help="Enable additional debug logging"
0196     )
0197 
0198     group = parser.add_argument_group("Developer Options")
0199     group.add_argument(
0200         "--testnames", type=str,
0201         default=None,
0202         help=(
0203             "A comma-separated list of specific modules, classes and functions of doctest "
0204             "or unittest to test. "
0205             "For example, 'pyspark.sql.foo' to run the module as unittests or doctests, "
0206             "'pyspark.sql.tests FooTests' to run the specific class of unittests, "
0207             "'pyspark.sql.tests FooTests.test_foo' to run the specific unittest in the class. "
0208             "'--modules' option is ignored if they are given.")
0209     )
0210 
0211     args, unknown = parser.parse_known_args()
0212     if unknown:
0213         parser.error("Unsupported arguments: %s" % ' '.join(unknown))
0214     if args.parallelism < 1:
0215         parser.error("Parallelism cannot be less than 1")
0216     return args
0217 
0218 
0219 def _check_coverage(python_exec):
0220     # Make sure if coverage is installed.
0221     try:
0222         subprocess_check_output(
0223             [python_exec, "-c", "import coverage"],
0224             stderr=open(os.devnull, 'w'))
0225     except:
0226         print_red("Coverage is not installed in Python executable '%s' "
0227                   "but 'COVERAGE_PROCESS_START' environment variable is set, "
0228                   "exiting." % python_exec)
0229         sys.exit(-1)
0230 
0231 
0232 def main():
0233     opts = parse_opts()
0234     if opts.verbose:
0235         log_level = logging.DEBUG
0236     else:
0237         log_level = logging.INFO
0238     should_test_modules = opts.testnames is None
0239     logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s")
0240     LOGGER.info("Running PySpark tests. Output is in %s", LOG_FILE)
0241     if os.path.exists(LOG_FILE):
0242         os.remove(LOG_FILE)
0243     python_execs = opts.python_executables.split(',')
0244     LOGGER.info("Will test against the following Python executables: %s", python_execs)
0245 
0246     if should_test_modules:
0247         modules_to_test = []
0248         for module_name in opts.modules.split(','):
0249             if module_name in python_modules:
0250                 modules_to_test.append(python_modules[module_name])
0251             else:
0252                 print("Error: unrecognized module '%s'. Supported modules: %s" %
0253                       (module_name, ", ".join(python_modules)))
0254                 sys.exit(-1)
0255         LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test])
0256     else:
0257         testnames_to_test = opts.testnames.split(',')
0258         LOGGER.info("Will test the following Python tests: %s", testnames_to_test)
0259 
0260     task_queue = Queue.PriorityQueue()
0261     for python_exec in python_execs:
0262         # Check if the python executable has coverage installed when 'COVERAGE_PROCESS_START'
0263         # environmental variable is set.
0264         if "COVERAGE_PROCESS_START" in os.environ:
0265             _check_coverage(python_exec)
0266 
0267         python_implementation = subprocess_check_output(
0268             [python_exec, "-c", "import platform; print(platform.python_implementation())"],
0269             universal_newlines=True).strip()
0270         LOGGER.info("%s python_implementation is %s", python_exec, python_implementation)
0271         LOGGER.info("%s version is: %s", python_exec, subprocess_check_output(
0272             [python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip())
0273         if should_test_modules:
0274             for module in modules_to_test:
0275                 if python_implementation not in module.blacklisted_python_implementations:
0276                     for test_goal in module.python_test_goals:
0277                         heavy_tests = ['pyspark.streaming.tests', 'pyspark.mllib.tests',
0278                                        'pyspark.tests', 'pyspark.sql.tests', 'pyspark.ml.tests']
0279                         if any(map(lambda prefix: test_goal.startswith(prefix), heavy_tests)):
0280                             priority = 0
0281                         else:
0282                             priority = 100
0283                         task_queue.put((priority, (python_exec, test_goal)))
0284         else:
0285             for test_goal in testnames_to_test:
0286                 task_queue.put((0, (python_exec, test_goal)))
0287 
0288     # Create the target directory before starting tasks to avoid races.
0289     target_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'target'))
0290     if not os.path.isdir(target_dir):
0291         os.mkdir(target_dir)
0292 
0293     def process_queue(task_queue):
0294         while True:
0295             try:
0296                 (priority, (python_exec, test_goal)) = task_queue.get_nowait()
0297             except Queue.Empty:
0298                 break
0299             try:
0300                 run_individual_python_test(target_dir, test_goal, python_exec)
0301             finally:
0302                 task_queue.task_done()
0303 
0304     start_time = time.time()
0305     for _ in range(opts.parallelism):
0306         worker = Thread(target=process_queue, args=(task_queue,))
0307         worker.daemon = True
0308         worker.start()
0309     try:
0310         task_queue.join()
0311     except (KeyboardInterrupt, SystemExit):
0312         print_red("Exiting due to interrupt")
0313         sys.exit(-1)
0314     total_duration = time.time() - start_time
0315     LOGGER.info("Tests passed in %i seconds", total_duration)
0316 
0317     for key, lines in sorted(SKIPPED_TESTS.items()):
0318         pyspark_python, test_name = key
0319         LOGGER.info("\nSkipped tests in %s with %s:" % (test_name, pyspark_python))
0320         for line in lines:
0321             LOGGER.info("    %s" % line.rstrip())
0322 
0323 
0324 if __name__ == "__main__":
0325     main()