#!/usr/bin/env python3 # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging from argparse import ArgumentParser import os import re import shutil import subprocess import sys import tempfile from threading import Thread, Lock import time import uuid if sys.version < '3': import Queue else: import queue as Queue from multiprocessing import Manager # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../dev/")) from sparktestsupport import SPARK_HOME # noqa (suppress pep8 warnings) from sparktestsupport.shellutils import which, subprocess_check_output # noqa from sparktestsupport.modules import all_modules, pyspark_sql # noqa python_modules = dict((m.name, m) for m in all_modules if m.python_test_goals if m.name != 'root') def print_red(text): print('\033[31m' + text + '\033[0m') SKIPPED_TESTS = Manager().dict() LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log") FAILURE_REPORTING_LOCK = Lock() LOGGER = logging.getLogger() # Find out where the assembly jars are located. # TODO: revisit for Scala 2.13 for scala in ["2.12"]: build_dir = os.path.join(SPARK_HOME, "assembly", "target", "scala-" + scala) if os.path.isdir(build_dir): SPARK_DIST_CLASSPATH = os.path.join(build_dir, "jars", "*") break else: raise Exception("Cannot find assembly build directory, please build Spark first.") def run_individual_python_test(target_dir, test_name, pyspark_python): env = dict(os.environ) env.update({ 'SPARK_DIST_CLASSPATH': SPARK_DIST_CLASSPATH, 'SPARK_TESTING': '1', 'SPARK_PREPEND_CLASSES': '1', 'PYSPARK_PYTHON': which(pyspark_python), 'PYSPARK_DRIVER_PYTHON': which(pyspark_python), 'PYSPARK_ROW_FIELD_SORTING_ENABLED': 'true' }) # Create a unique temp directory under 'target/' for each run. The TMPDIR variable is # recognized by the tempfile module to override the default system temp directory. tmp_dir = os.path.join(target_dir, str(uuid.uuid4())) while os.path.isdir(tmp_dir): tmp_dir = os.path.join(target_dir, str(uuid.uuid4())) os.mkdir(tmp_dir) env["TMPDIR"] = tmp_dir # Also override the JVM's temp directory by setting driver and executor options. java_options = "-Djava.io.tmpdir={0} -Dio.netty.tryReflectionSetAccessible=true".format(tmp_dir) spark_args = [ "--conf", "spark.driver.extraJavaOptions='{0}'".format(java_options), "--conf", "spark.executor.extraJavaOptions='{0}'".format(java_options), "pyspark-shell" ] env["PYSPARK_SUBMIT_ARGS"] = " ".join(spark_args) LOGGER.info("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() try: per_test_output = tempfile.TemporaryFile() retcode = subprocess.Popen( [os.path.join(SPARK_HOME, "bin/pyspark")] + test_name.split(), stderr=per_test_output, stdout=per_test_output, env=env).wait() shutil.rmtree(tmp_dir, ignore_errors=True) except: LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python) # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if # this code is invoked from a thread other than the main thread. os._exit(1) duration = time.time() - start_time # Exit on the first failure. if retcode != 0: try: with FAILURE_REPORTING_LOCK: with open(LOG_FILE, 'ab') as log_file: per_test_output.seek(0) log_file.writelines(per_test_output) per_test_output.seek(0) for line in per_test_output: decoded_line = line.decode("utf-8", "replace") if not re.match('[0-9]+', decoded_line): print(decoded_line, end='') per_test_output.close() except: LOGGER.exception("Got an exception while trying to print failed test output") finally: print_red("\nHad test failures in %s with %s; see logs." % (test_name, pyspark_python)) # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if # this code is invoked from a thread other than the main thread. os._exit(-1) else: skipped_counts = 0 try: per_test_output.seek(0) # Here expects skipped test output from unittest when verbosity level is # 2 (or --verbose option is enabled). decoded_lines = map(lambda line: line.decode("utf-8", "replace"), iter(per_test_output)) skipped_tests = list(filter( lambda line: re.search(r'test_.* \(pyspark\..*\) ... (skip|SKIP)', line), decoded_lines)) skipped_counts = len(skipped_tests) if skipped_counts > 0: key = (pyspark_python, test_name) SKIPPED_TESTS[key] = skipped_tests per_test_output.close() except: import traceback print_red("\nGot an exception while trying to store " "skipped test output:\n%s" % traceback.format_exc()) # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if # this code is invoked from a thread other than the main thread. os._exit(-1) if skipped_counts != 0: LOGGER.info( "Finished test(%s): %s (%is) ... %s tests were skipped", pyspark_python, test_name, duration, skipped_counts) else: LOGGER.info( "Finished test(%s): %s (%is)", pyspark_python, test_name, duration) def get_default_python_executables(): python_execs = [x for x in ["python3.6", "python2.7", "pypy"] if which(x)] if "python3.6" not in python_execs: p = which("python3") if not p: LOGGER.error("No python3 executable found. Exiting!") os._exit(1) else: python_execs.insert(0, p) return python_execs def parse_opts(): parser = ArgumentParser( prog="run-tests" ) parser.add_argument( "--python-executables", type=str, default=','.join(get_default_python_executables()), help="A comma-separated list of Python executables to test against (default: %(default)s)" ) parser.add_argument( "--modules", type=str, default=",".join(sorted(python_modules.keys())), help="A comma-separated list of Python modules to test (default: %(default)s)" ) parser.add_argument( "-p", "--parallelism", type=int, default=4, help="The number of suites to test in parallel (default %(default)d)" ) parser.add_argument( "--verbose", action="store_true", help="Enable additional debug logging" ) group = parser.add_argument_group("Developer Options") group.add_argument( "--testnames", type=str, default=None, help=( "A comma-separated list of specific modules, classes and functions of doctest " "or unittest to test. " "For example, 'pyspark.sql.foo' to run the module as unittests or doctests, " "'pyspark.sql.tests FooTests' to run the specific class of unittests, " "'pyspark.sql.tests FooTests.test_foo' to run the specific unittest in the class. " "'--modules' option is ignored if they are given.") ) args, unknown = parser.parse_known_args() if unknown: parser.error("Unsupported arguments: %s" % ' '.join(unknown)) if args.parallelism < 1: parser.error("Parallelism cannot be less than 1") return args def _check_coverage(python_exec): # Make sure if coverage is installed. try: subprocess_check_output( [python_exec, "-c", "import coverage"], stderr=open(os.devnull, 'w')) except: print_red("Coverage is not installed in Python executable '%s' " "but 'COVERAGE_PROCESS_START' environment variable is set, " "exiting." % python_exec) sys.exit(-1) def main(): opts = parse_opts() if opts.verbose: log_level = logging.DEBUG else: log_level = logging.INFO should_test_modules = opts.testnames is None logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s") LOGGER.info("Running PySpark tests. Output is in %s", LOG_FILE) if os.path.exists(LOG_FILE): os.remove(LOG_FILE) python_execs = opts.python_executables.split(',') LOGGER.info("Will test against the following Python executables: %s", python_execs) if should_test_modules: modules_to_test = [] for module_name in opts.modules.split(','): if module_name in python_modules: modules_to_test.append(python_modules[module_name]) else: print("Error: unrecognized module '%s'. Supported modules: %s" % (module_name, ", ".join(python_modules))) sys.exit(-1) LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test]) else: testnames_to_test = opts.testnames.split(',') LOGGER.info("Will test the following Python tests: %s", testnames_to_test) task_queue = Queue.PriorityQueue() for python_exec in python_execs: # Check if the python executable has coverage installed when 'COVERAGE_PROCESS_START' # environmental variable is set. if "COVERAGE_PROCESS_START" in os.environ: _check_coverage(python_exec) python_implementation = subprocess_check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], universal_newlines=True).strip() LOGGER.info("%s python_implementation is %s", python_exec, python_implementation) LOGGER.info("%s version is: %s", python_exec, subprocess_check_output( [python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip()) if should_test_modules: for module in modules_to_test: if python_implementation not in module.blacklisted_python_implementations: for test_goal in module.python_test_goals: heavy_tests = ['pyspark.streaming.tests', 'pyspark.mllib.tests', 'pyspark.tests', 'pyspark.sql.tests', 'pyspark.ml.tests'] if any(map(lambda prefix: test_goal.startswith(prefix), heavy_tests)): priority = 0 else: priority = 100 task_queue.put((priority, (python_exec, test_goal))) else: for test_goal in testnames_to_test: task_queue.put((0, (python_exec, test_goal))) # Create the target directory before starting tasks to avoid races. target_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'target')) if not os.path.isdir(target_dir): os.mkdir(target_dir) def process_queue(task_queue): while True: try: (priority, (python_exec, test_goal)) = task_queue.get_nowait() except Queue.Empty: break try: run_individual_python_test(target_dir, test_goal, python_exec) finally: task_queue.task_done() start_time = time.time() for _ in range(opts.parallelism): worker = Thread(target=process_queue, args=(task_queue,)) worker.daemon = True worker.start() try: task_queue.join() except (KeyboardInterrupt, SystemExit): print_red("Exiting due to interrupt") sys.exit(-1) total_duration = time.time() - start_time LOGGER.info("Tests passed in %i seconds", total_duration) for key, lines in sorted(SKIPPED_TESTS.items()): pyspark_python, test_name = key LOGGER.info("\nSkipped tests in %s with %s:" % (test_name, pyspark_python)) for line in lines: LOGGER.info(" %s" % line.rstrip()) if __name__ == "__main__": main()