mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Refactoring test execution and improving comments (#99467)
Sharing code between the code that handles test results in parallel vs serial mode. Note that the original version of this code had an inconsistency between the two versions where it would execute `print_to_stderr(err_message)` on every test that ran in parallel, but for serial tests it would only invoke `print_to_stderr(err_message)` if `continue_on_error` was also specified. By sharing code, this PR changes that behavior to be consistent between the two modes. Also adding some comments. <!-- copilot:poem --> ### <samp>🤖 Generated by Copilot at 029342c</samp> > _Sing, O Muse, of the skillful coder who refined_ > _The PyTorch testing script, `run_test.py`, and shined_ > _A light on its obscure logic, with docstrings and comments_ > _And made it run more smoothly, with better error contents_ Pull Request resolved: https://github.com/pytorch/pytorch/pull/99467 Approved by: https://github.com/huydhn, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
6ca991cacf
commit
7546972565
@ -239,6 +239,7 @@ ROCM_BLOCKLIST = [
|
||||
"test_cuda_nvml_based_avail",
|
||||
]
|
||||
|
||||
# The tests inside these files should never be run in parallel with each other
|
||||
RUN_PARALLEL_BLOCKLIST = [
|
||||
"test_cpp_extensions_jit",
|
||||
"test_cpp_extensions_open_device_registration",
|
||||
@ -255,6 +256,8 @@ RUN_PARALLEL_BLOCKLIST = [
|
||||
"test_cuda_nvml_based_avail",
|
||||
] + FSDP_TEST
|
||||
|
||||
# Test files that should always be run serially with other test files,
|
||||
# but it's okay if the tests inside them are run in parallel with each other.
|
||||
CI_SERIAL_LIST = [
|
||||
"test_nn",
|
||||
"test_fake_tensor",
|
||||
@ -1348,14 +1351,17 @@ def main():
|
||||
)
|
||||
os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True)
|
||||
|
||||
def success_callback(err_message):
|
||||
def handle_error_messages(err_message):
|
||||
if err_message is None:
|
||||
return True
|
||||
return False
|
||||
failure_messages.append(err_message)
|
||||
print_to_stderr(err_message)
|
||||
if not options.continue_through_error:
|
||||
return True
|
||||
|
||||
def parallel_test_completion_callback(err_message):
|
||||
test_failed = handle_error_messages(err_message)
|
||||
if test_failed and not options.continue_through_error:
|
||||
pool.terminate()
|
||||
return False
|
||||
|
||||
try:
|
||||
os.environ["NUM_PARALLEL_PROCS"] = str(NUM_PROCS)
|
||||
@ -1366,7 +1372,7 @@ def main():
|
||||
pool.apply_async(
|
||||
run_test_module,
|
||||
args=(test, test_directory, options_clone),
|
||||
callback=success_callback,
|
||||
callback=parallel_test_completion_callback,
|
||||
)
|
||||
pool.close()
|
||||
pool.join()
|
||||
@ -1386,12 +1392,9 @@ def main():
|
||||
if can_run_in_pytest(test):
|
||||
options_clone.pytest = True
|
||||
err_message = run_test_module(test, test_directory, options_clone)
|
||||
if err_message is None:
|
||||
continue
|
||||
failure_messages.append(err_message)
|
||||
if not options_clone.continue_through_error:
|
||||
test_failed = handle_error_messages(err_message)
|
||||
if test_failed and not options.continue_through_error:
|
||||
raise RuntimeError(err_message)
|
||||
print_to_stderr(err_message)
|
||||
finally:
|
||||
pool.terminate()
|
||||
pool.join()
|
||||
|
||||
Reference in New Issue
Block a user