[BE] Refactoring test execution and improving comments (#99467)

Sharing code between the code that handles test results in parallel vs serial mode.

Note that the original version of this code had an inconsistency between the two versions where it would execute `print_to_stderr(err_message)` on every test that ran in parallel, but for serial tests it would only invoke `print_to_stderr(err_message)` if `continue_on_error` was also specified.  By sharing code, this PR changes that behavior to be consistent between the two modes.

Also adding some comments.

<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at 029342c</samp>

> _Sing, O Muse, of the skillful coder who refined_
> _The PyTorch testing script, `run_test.py`, and shined_
> _A light on its obscure logic, with docstrings and comments_
> _And made it run more smoothly, with better error contents_
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99467
Approved by: https://github.com/huydhn, https://github.com/malfet
This commit is contained in:
Zain Rizvi
2023-04-19 19:29:03 +00:00
committed by PyTorch MergeBot
parent 6ca991cacf
commit 7546972565
2 changed files with 17 additions and 11 deletions

View File

@ -239,6 +239,7 @@ ROCM_BLOCKLIST = [
"test_cuda_nvml_based_avail",
]
# The tests inside these files should never be run in parallel with each other
RUN_PARALLEL_BLOCKLIST = [
"test_cpp_extensions_jit",
"test_cpp_extensions_open_device_registration",
@ -255,6 +256,8 @@ RUN_PARALLEL_BLOCKLIST = [
"test_cuda_nvml_based_avail",
] + FSDP_TEST
# Test files that should always be run serially with other test files,
# but it's okay if the tests inside them are run in parallel with each other.
CI_SERIAL_LIST = [
"test_nn",
"test_fake_tensor",
@ -1348,14 +1351,17 @@ def main():
)
os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
def success_callback(err_message):
def handle_error_messages(err_message):
if err_message is None:
return True
return False
failure_messages.append(err_message)
print_to_stderr(err_message)
if not options.continue_through_error:
return True
def parallel_test_completion_callback(err_message):
test_failed = handle_error_messages(err_message)
if test_failed and not options.continue_through_error:
pool.terminate()
return False
try:
os.environ["NUM_PARALLEL_PROCS"] = str(NUM_PROCS)
@ -1366,7 +1372,7 @@ def main():
pool.apply_async(
run_test_module,
args=(test, test_directory, options_clone),
callback=success_callback,
callback=parallel_test_completion_callback,
)
pool.close()
pool.join()
@ -1386,12 +1392,9 @@ def main():
if can_run_in_pytest(test):
options_clone.pytest = True
err_message = run_test_module(test, test_directory, options_clone)
if err_message is None:
continue
failure_messages.append(err_message)
if not options_clone.continue_through_error:
test_failed = handle_error_messages(err_message)
if test_failed and not options.continue_through_error:
raise RuntimeError(err_message)
print_to_stderr(err_message)
finally:
pool.terminate()
pool.join()