From 2748b2dbc502b2c51a16b40fe375506acfbf3683 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google]" Date: Mon, 21 Nov 2022 08:25:56 +0000 Subject: [PATCH 1/3] Demonstrate test.regrtest unittest sharding. An incomplete implementation with details to be worked out, but it works! It makes our long tail tests take significantly less time. At least when run on their own. Example: ~25 seconds wall time to run test_multiprocessing_spawn and test_concurrent_futures on a 12 thread machine for example. `python -m test -r -j 20 test_multiprocessing_spawn test_concurrent_futures` Known Issues to work out: result reporting and libregrtest accounting. You see any sharded test "complete" multiple times and your total tests run count goes higher than the total number of tests. :joy: Real caveat: This exposes ordering and concurrency weaknesses in some tests like test_asyncio that'll need fixing. Which tests get sharded is explicitly opt-in. Currently not in a maintainable spot. How best to maintain that needs to be worked out, but I expect we only ever have 10-20 test modules that we declare as worth sharding. This implementation is inspired by and with the unittest TestLoader bits derived directly from the Apache 2.0 licensed https://github.com/abseil/abseil-py/blob/v1.3.0/absl/testing/absltest.py#L2359 ``` :~/oss/cpython (performance/test-sharding)$ ../b/python -m test -r -j 20 test_multiprocessing_spawn test_concurrent_futures Using random seed 8555091 0:00:00 load avg: 0.98 Run tests in parallel using 20 child processes 0:00:08 load avg: 1.30 [1/2] test_multiprocessing_spawn passed 0:00:10 load avg: 1.68 [2/2] test_concurrent_futures passed 0:00:11 load avg: 1.68 [3/2] test_multiprocessing_spawn passed 0:00:12 load avg: 1.68 [4/2] test_multiprocessing_spawn passed 0:00:12 load avg: 1.68 [5/2] test_multiprocessing_spawn passed 0:00:14 load avg: 1.87 [6/2] test_multiprocessing_spawn passed 0:00:15 load avg: 1.87 [7/2] test_multiprocessing_spawn passed 0:00:16 load avg: 1.87 [8/2] test_concurrent_futures passed 0:00:16 load avg: 1.87 [9/2] test_multiprocessing_spawn passed 0:00:18 load avg: 1.87 [10/2] test_concurrent_futures passed 0:00:20 load avg: 1.72 [11/2] test_concurrent_futures passed 0:00:20 load avg: 1.72 [12/2] test_concurrent_futures passed 0:00:21 load avg: 1.72 [13/2] test_multiprocessing_spawn passed 0:00:21 load avg: 1.72 [14/2] test_concurrent_futures passed 0:00:22 load avg: 1.72 [15/2] test_concurrent_futures passed 0:00:25 load avg: 1.58 [16/2] test_concurrent_futures passed == Tests result: SUCCESS == All 16 tests OK. Total duration: 25.6 sec Tests result: SUCCESS ``` --- Lib/test/libregrtest/main.py | 17 ++++++++ Lib/test/libregrtest/runtest_mp.py | 66 ++++++++++++++++++++++++------ Lib/unittest/loader.py | 64 +++++++++++++++++++++++++++-- 3 files changed, 132 insertions(+), 15 deletions(-) diff --git a/Lib/test/libregrtest/main.py b/Lib/test/libregrtest/main.py index 3eeef029b22d48..0a3764db46c731 100644 --- a/Lib/test/libregrtest/main.py +++ b/Lib/test/libregrtest/main.py @@ -64,6 +64,23 @@ def __init__(self): # tests self.tests = [] self.selected = [] + self.tests_to_shard = set() + # TODO(gpshead): this list belongs elsewhere - it'd be nice to tag + # these within the test module/package itself but loading everything + # to detect those tags is complicated. As is a feedback mechanism + # from a shard file. + # Our slowest tests per a "-o" run: + self.tests_to_shard.add('test_concurrent_futures') + self.tests_to_shard.add('test_multiprocessing_spawn') + self.tests_to_shard.add('test_asyncio') + self.tests_to_shard.add('test_tools') + self.tests_to_shard.add('test_multiprocessing_forkserver') + self.tests_to_shard.add('test_multiprocessing_fork') + self.tests_to_shard.add('test_signal') + self.tests_to_shard.add('test_socket') + self.tests_to_shard.add('test_io') + self.tests_to_shard.add('test_imaplib') + self.tests_to_shard.add('test_subprocess') # test results self.good = [] diff --git a/Lib/test/libregrtest/runtest_mp.py b/Lib/test/libregrtest/runtest_mp.py index a12fcb46e0fd0b..2b849e9573b76c 100644 --- a/Lib/test/libregrtest/runtest_mp.py +++ b/Lib/test/libregrtest/runtest_mp.py @@ -1,4 +1,5 @@ import faulthandler +from dataclasses import dataclass import json import os.path import queue @@ -9,7 +10,7 @@ import threading import time import traceback -from typing import NamedTuple, NoReturn, Literal, Any, TextIO +from typing import Iterator, NamedTuple, NoReturn, Literal, Any, TextIO from test import support from test.support import os_helper @@ -42,6 +43,13 @@ USE_PROCESS_GROUP = (hasattr(os, "setsid") and hasattr(os, "killpg")) +@dataclass +class ShardInfo: + number: int + total_shards: int + status_file: str = "" + + def must_stop(result: TestResult, ns: Namespace) -> bool: if isinstance(result, Interrupted): return True @@ -56,7 +64,7 @@ def parse_worker_args(worker_args) -> tuple[Namespace, str]: return (ns, test_name) -def run_test_in_subprocess(testname: str, ns: Namespace, tmp_dir: str, stdout_fh: TextIO) -> subprocess.Popen: +def run_test_in_subprocess(testname: str, ns: Namespace, tmp_dir: str, stdout_fh: TextIO, shard: ShardInfo|None = None) -> subprocess.Popen: ns_dict = vars(ns) worker_args = (ns_dict, testname) worker_args = json.dumps(worker_args) @@ -75,6 +83,13 @@ def run_test_in_subprocess(testname: str, ns: Namespace, tmp_dir: str, stdout_fh env['TEMP'] = tmp_dir env['TMP'] = tmp_dir + if shard: + # This follows the "Bazel test sharding protocol" + shard.status_file = os.path.join(tmp_dir, 'sharded') + env['TEST_SHARD_STATUS_FILE'] = shard.status_file + env['TEST_SHARD_INDEX'] = str(shard.number) + env['TEST_TOTAL_SHARDS'] = str(shard.total_shards) + # Running the child from the same working directory as regrtest's original # invocation ensures that TEMPDIR for the child is the same when # sysconfig.is_python_build() is true. See issue 15300. @@ -109,7 +124,7 @@ class MultiprocessIterator: """A thread-safe iterator over tests for multiprocess mode.""" - def __init__(self, tests_iter): + def __init__(self, tests_iter: Iterator[tuple[str, ShardInfo|None]]): self.lock = threading.Lock() self.tests_iter = tests_iter @@ -215,12 +230,17 @@ def mp_result_error( test_result.duration_sec = time.monotonic() - self.start_time return MultiprocessResult(test_result, stdout, err_msg) - def _run_process(self, test_name: str, tmp_dir: str, stdout_fh: TextIO) -> int: + def _run_process(self, test_name: str, tmp_dir: str, stdout_fh: TextIO, + shard: ShardInfo|None = None) -> int: self.start_time = time.monotonic() - self.current_test_name = test_name + if shard: + self.current_test_name = f'{test_name}-shard-{shard.number:02}/{shard.total_shards-1:02}' + else: + self.current_test_name = test_name try: - popen = run_test_in_subprocess(test_name, self.ns, tmp_dir, stdout_fh) + popen = run_test_in_subprocess( + test_name, self.ns, tmp_dir, stdout_fh, shard) self._killed = False self._popen = popen @@ -240,6 +260,17 @@ def _run_process(self, test_name: str, tmp_dir: str, stdout_fh: TextIO) -> int: # gh-94026: stdout+stderr are written to tempfile retcode = popen.wait(timeout=self.timeout) assert retcode is not None + if shard and shard.status_file: + if os.path.exists(shard.status_file): + try: + os.unlink(shard.status_file) + except IOError: + pass + else: + print_warning( + f"{self.current_test_name} process exited " + f"{retcode} without touching a shard status " + f"file. Does it really support sharding?") return retcode except subprocess.TimeoutExpired: if self._stopped: @@ -269,7 +300,7 @@ def _run_process(self, test_name: str, tmp_dir: str, stdout_fh: TextIO) -> int: self._popen = None self.current_test_name = None - def _runtest(self, test_name: str) -> MultiprocessResult: + def _runtest(self, test_name: str, shard: ShardInfo|None) -> MultiprocessResult: if sys.platform == 'win32': # gh-95027: When stdout is not a TTY, Python uses the ANSI code # page for the sys.stdout encoding. If the main process runs in a @@ -290,7 +321,7 @@ def _runtest(self, test_name: str) -> MultiprocessResult: tmp_dir = tempfile.mkdtemp(prefix="test_python_") tmp_dir = os.path.abspath(tmp_dir) try: - retcode = self._run_process(test_name, tmp_dir, stdout_fh) + retcode = self._run_process(test_name, tmp_dir, stdout_fh, shard) finally: tmp_files = os.listdir(tmp_dir) os_helper.rmtree(tmp_dir) @@ -335,11 +366,11 @@ def run(self) -> None: while not self._stopped: try: try: - test_name = next(self.pending) + test_name, shard_info = next(self.pending) except StopIteration: break - mp_result = self._runtest(test_name) + mp_result = self._runtest(test_name, shard_info) self.output.put((False, mp_result)) if must_stop(mp_result.result, self.ns): @@ -402,8 +433,19 @@ def __init__(self, regrtest: Regrtest) -> None: self.regrtest = regrtest self.log = self.regrtest.log self.ns = regrtest.ns + self.num_procs: int = self.ns.use_mp self.output: queue.Queue[QueueOutput] = queue.Queue() - self.pending = MultiprocessIterator(self.regrtest.tests) + tests_and_shards = [] + for test in self.regrtest.tests: + if self.num_procs > 2 and test in self.regrtest.tests_to_shard: + # Split shardable tests across multiple processes to run + # distinct subsets of tests within a given test module. + shards = min(self.num_procs//2+1, 8) # avoid diminishing returns + for shard_no in range(shards): + tests_and_shards.append((test, ShardInfo(shard_no, shards))) + else: + tests_and_shards.append((test, None)) + self.pending = MultiprocessIterator(iter(tests_and_shards)) if self.ns.timeout is not None: # Rely on faulthandler to kill a worker process. This timouet is # when faulthandler fails to kill a worker process. Give a maximum @@ -416,7 +458,7 @@ def __init__(self, regrtest: Regrtest) -> None: def start_workers(self) -> None: self.workers = [TestWorkerProcess(index, self) - for index in range(1, self.ns.use_mp + 1)] + for index in range(1, self.num_procs + 1)] msg = f"Run tests in parallel using {len(self.workers)} child processes" if self.ns.timeout: msg += (" (timeout: %s, worker timeout: %s)" diff --git a/Lib/unittest/loader.py b/Lib/unittest/loader.py index eb18cd0b49cd26..ea2f08eebe8b8b 100644 --- a/Lib/unittest/loader.py +++ b/Lib/unittest/loader.py @@ -1,11 +1,12 @@ """Loading unittests.""" +import itertools +import functools import os import re import sys import traceback import types -import functools import warnings from fnmatch import fnmatch, fnmatchcase @@ -63,7 +64,7 @@ def _jython_aware_splitext(path): return os.path.splitext(path)[0] -class TestLoader(object): +class TestLoader: """ This class is responsible for loading tests according to various criteria and returning them wrapped in a TestSuite @@ -73,6 +74,43 @@ class TestLoader(object): testNamePatterns = None suiteClass = suite.TestSuite _top_level_dir = None + _sharding_setup_complete = False + _shard_bucket_iterator = None + _shard_index = None + + def __new__(cls, *args, **kwargs): + new_instance = super().__new__(cls, *args, **kwargs) + if cls._sharding_setup_complete: + return new_instance + # This assumes single threaded TestLoader construction. + cls._sharding_setup_complete = True + + # It may be useful to write the shard file even if the other sharding + # environment variables are not set. Test runners may use this functionality + # to query whether a test binary implements the test sharding protocol. + if 'TEST_SHARD_STATUS_FILE' in os.environ: + status_name = os.environ['TEST_SHARD_STATUS_FILE'] + try: + with open(status_name, 'w') as f: + f.write('') + except IOError as error: + raise RuntimeError( + f'Error opening TEST_SHARD_STATUS_FILE {status_name=}.') + + if 'TEST_TOTAL_SHARDS' not in os.environ: + # Not using sharding? nothing more to do. + return new_instance + + total_shards = int(os.environ['TEST_TOTAL_SHARDS']) + cls._shard_index = int(os.environ['TEST_SHARD_INDEX']) + + if cls._shard_index < 0 or cls._shard_index >= total_shards: + raise RuntimeError( + 'ERROR: Bad sharding values. ' + f'index={cls._shard_index}, {total_shards=}') + + cls._shard_bucket_iterator = itertools.cycle(range(total_shards)) + return new_instance def __init__(self): super(TestLoader, self).__init__() @@ -198,8 +236,28 @@ def loadTestsFromNames(self, names, module=None): suites = [self.loadTestsFromName(name, module) for name in names] return self.suiteClass(suites) + def _getShardedTestCaseNames(self, testCaseClass): + filtered_names = [] + # We need to sort the list of tests in order to determine which tests this + # shard is responsible for; however, it's important to preserve the order + # returned by the base loader, e.g. in the case of randomized test ordering. + ordered_names = self._getTestCaseNames(testCaseClass) + for testcase in sorted(ordered_names): + bucket = next(self._shard_bucket_iterator) + if bucket == self._shard_index: + filtered_names.append(testcase) + return [x for x in ordered_names if x in filtered_names] + def getTestCaseNames(self, testCaseClass): - """Return a sorted sequence of method names found within testCaseClass + """Return a sorted sequence of method names found within testCaseClass. + Or a unique sharded subset thereof if sharding is enabled. + """ + if self._shard_bucket_iterator: + return self._getShardedTestCaseNames(testCaseClass) + return self._getTestCaseNames(testCaseClass) + + def _getTestCaseNames(self, testCaseClass): + """Return a sorted sequence of all method names found within testCaseClass. """ def shouldIncludeMethod(attrname): if not attrname.startswith(self.testMethodPrefix): From 09b830ea6867f6bd44502bbd0de41be6160a9074 Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Mon, 21 Nov 2022 18:11:14 -0800 Subject: [PATCH 2/3] minor tweaks: unshard test_tools, shard on -j 2 --- Lib/test/libregrtest/main.py | 2 +- Lib/test/libregrtest/runtest_mp.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Lib/test/libregrtest/main.py b/Lib/test/libregrtest/main.py index 0a3764db46c731..2e031109ffca10 100644 --- a/Lib/test/libregrtest/main.py +++ b/Lib/test/libregrtest/main.py @@ -73,7 +73,7 @@ def __init__(self): self.tests_to_shard.add('test_concurrent_futures') self.tests_to_shard.add('test_multiprocessing_spawn') self.tests_to_shard.add('test_asyncio') - self.tests_to_shard.add('test_tools') + # Only 1 long test case #self.tests_to_shard.add('test_tools') self.tests_to_shard.add('test_multiprocessing_forkserver') self.tests_to_shard.add('test_multiprocessing_fork') self.tests_to_shard.add('test_signal') diff --git a/Lib/test/libregrtest/runtest_mp.py b/Lib/test/libregrtest/runtest_mp.py index 2b849e9573b76c..28afee89cad26f 100644 --- a/Lib/test/libregrtest/runtest_mp.py +++ b/Lib/test/libregrtest/runtest_mp.py @@ -235,7 +235,7 @@ def _run_process(self, test_name: str, tmp_dir: str, stdout_fh: TextIO, self.start_time = time.monotonic() if shard: - self.current_test_name = f'{test_name}-shard-{shard.number:02}/{shard.total_shards-1:02}' + self.current_test_name = f'{test_name}-subset:{shard.number}/{shard.total_shards}' else: self.current_test_name = test_name try: @@ -437,10 +437,10 @@ def __init__(self, regrtest: Regrtest) -> None: self.output: queue.Queue[QueueOutput] = queue.Queue() tests_and_shards = [] for test in self.regrtest.tests: - if self.num_procs > 2 and test in self.regrtest.tests_to_shard: + if self.num_procs > 1 and test in self.regrtest.tests_to_shard: # Split shardable tests across multiple processes to run # distinct subsets of tests within a given test module. - shards = min(self.num_procs//2+1, 8) # avoid diminishing returns + shards = min(self.num_procs*2//3+1, 10) # diminishing returns for shard_no in range(shards): tests_and_shards.append((test, ShardInfo(shard_no, shards))) else: From f887c4d0b72321f90484873cba6a926e43f8b3fe Mon Sep 17 00:00:00 2001 From: "Gregory P. Smith [Google LLC]" Date: Mon, 13 Feb 2023 19:59:56 -0800 Subject: [PATCH 3/3] add test_xmlrpc to the shard list. --- Lib/test/libregrtest/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/Lib/test/libregrtest/main.py b/Lib/test/libregrtest/main.py index fdbd3adca39de7..1c0cd1d9b36faf 100644 --- a/Lib/test/libregrtest/main.py +++ b/Lib/test/libregrtest/main.py @@ -82,6 +82,7 @@ def __init__(self): self.tests_to_shard.add('test_io') self.tests_to_shard.add('test_imaplib') self.tests_to_shard.add('test_subprocess') + self.tests_to_shard.add('test_xmlrpc') # test results self.good = []