diff --git a/python/run-tests-with-coverage b/python/run-tests-with-coverage index 6d74b563e9..457821037d 100755 --- a/python/run-tests-with-coverage +++ b/python/run-tests-with-coverage @@ -50,8 +50,6 @@ export SPARK_CONF_DIR="$COVERAGE_DIR/conf" # This environment variable enables the coverage. export COVERAGE_PROCESS_START="$FWDIR/.coveragerc" -# If you'd like to run a specific unittest class, you could do such as -# SPARK_TESTING=1 ../bin/pyspark pyspark.sql.tests VectorizedUDFTests ./run-tests "$@" # Don't run coverage for the coverage command itself diff --git a/python/run-tests.py b/python/run-tests.py index 01a6e81264..e45268c137 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -19,7 +19,7 @@ from __future__ import print_function import logging -from optparse import OptionParser +from optparse import OptionParser, OptionGroup import os import re import shutil @@ -99,7 +99,7 @@ def run_individual_python_test(target_dir, test_name, pyspark_python): try: per_test_output = tempfile.TemporaryFile() retcode = subprocess.Popen( - [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], + [os.path.join(SPARK_HOME, "bin/pyspark")] + test_name.split(), stderr=per_test_output, stdout=per_test_output, env=env).wait() shutil.rmtree(tmp_dir, ignore_errors=True) except: @@ -190,6 +190,20 @@ def parse_opts(): help="Enable additional debug logging" ) + group = OptionGroup(parser, "Developer Options") + group.add_option( + "--testnames", type="string", + default=None, + help=( + "A comma-separated list of specific modules, classes and functions of doctest " + "or unittest to test. " + "For example, 'pyspark.sql.foo' to run the module as unittests or doctests, " + "'pyspark.sql.tests FooTests' to run the specific class of unittests, " + "'pyspark.sql.tests FooTests.test_foo' to run the specific unittest in the class. " + "'--modules' option is ignored if they are given.") + ) + parser.add_option_group(group) + (opts, args) = parser.parse_args() if args: parser.error("Unsupported arguments: %s" % ' '.join(args)) @@ -213,25 +227,31 @@ def _check_coverage(python_exec): def main(): opts = parse_opts() - if (opts.verbose): + if opts.verbose: log_level = logging.DEBUG else: log_level = logging.INFO + should_test_modules = opts.testnames is None logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s") LOGGER.info("Running PySpark tests. Output is in %s", LOG_FILE) if os.path.exists(LOG_FILE): os.remove(LOG_FILE) python_execs = opts.python_executables.split(',') - modules_to_test = [] - for module_name in opts.modules.split(','): - if module_name in python_modules: - modules_to_test.append(python_modules[module_name]) - else: - print("Error: unrecognized module '%s'. Supported modules: %s" % - (module_name, ", ".join(python_modules))) - sys.exit(-1) LOGGER.info("Will test against the following Python executables: %s", python_execs) - LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test]) + + if should_test_modules: + modules_to_test = [] + for module_name in opts.modules.split(','): + if module_name in python_modules: + modules_to_test.append(python_modules[module_name]) + else: + print("Error: unrecognized module '%s'. Supported modules: %s" % + (module_name, ", ".join(python_modules))) + sys.exit(-1) + LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test]) + else: + testnames_to_test = opts.testnames.split(',') + LOGGER.info("Will test the following Python tests: %s", testnames_to_test) task_queue = Queue.PriorityQueue() for python_exec in python_execs: @@ -246,16 +266,20 @@ def main(): LOGGER.debug("%s python_implementation is %s", python_exec, python_implementation) LOGGER.debug("%s version is: %s", python_exec, subprocess_check_output( [python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip()) - for module in modules_to_test: - if python_implementation not in module.blacklisted_python_implementations: - for test_goal in module.python_test_goals: - heavy_tests = ['pyspark.streaming.tests', 'pyspark.mllib.tests', - 'pyspark.tests', 'pyspark.sql.tests', 'pyspark.ml.tests'] - if any(map(lambda prefix: test_goal.startswith(prefix), heavy_tests)): - priority = 0 - else: - priority = 100 - task_queue.put((priority, (python_exec, test_goal))) + if should_test_modules: + for module in modules_to_test: + if python_implementation not in module.blacklisted_python_implementations: + for test_goal in module.python_test_goals: + heavy_tests = ['pyspark.streaming.tests', 'pyspark.mllib.tests', + 'pyspark.tests', 'pyspark.sql.tests', 'pyspark.ml.tests'] + if any(map(lambda prefix: test_goal.startswith(prefix), heavy_tests)): + priority = 0 + else: + priority = 100 + task_queue.put((priority, (python_exec, test_goal))) + else: + for test_goal in testnames_to_test: + task_queue.put((0, (python_exec, test_goal))) # Create the target directory before starting tasks to avoid races. target_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'target'))