[c10d] Add support for testing SIGABRT return (#153167)

`SIGABRT` is a common return by *negative* distributed tests, which checks for effectiveness of NaN assert, watchdog throw, etc.

These errors are not detectable by traditional statements like `with self.assertRaises(RuntimeError)`.

Instead, we'd need to check for the process's return code, e.g. `SIGABRT(6)` would have a return code of -6.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153167
Approved by: https://github.com/fduwjj
This commit is contained in:
Ke Wen
2025-05-08 10:26:30 -07:00
committed by PyTorch MergeBot
parent 561a11aa68
commit 499a76b844
2 changed files with 43 additions and 41 deletions

View File

@ -642,7 +642,15 @@ class MultiProcessTestCase(TestCase):
def setUp(self) -> None:
super().setUp()
self.skip_return_code_checks = [] # type: ignore[var-annotated]
# Used for tests that are expected to return a non-0 exit code, such as
# SIGABRT thrown by watchdog.
self.special_return_code_checks: dict = {}
# Used for tests that may return any exit code, which makes it hard to
# check. This is rare, use with caution.
self.skip_return_code_checks: list = []
self.processes = [] # type: ignore[var-annotated]
self.rank = self.MAIN_PROCESS_RANK
self.file_name = tempfile.NamedTemporaryFile(delete=False).name
@ -862,28 +870,13 @@ class MultiProcessTestCase(TestCase):
time.sleep(0.1)
elapsed_time = time.time() - start_time
if fn in self.skip_return_code_checks:
self._check_no_test_errors(elapsed_time)
else:
self._check_return_codes(elapsed_time)
self._check_return_codes(fn, elapsed_time)
finally:
# Close all pipes
for pipe in self.pid_to_pipe.values():
pipe.close()
def _check_no_test_errors(self, elapsed_time) -> None:
"""
Checks that we didn't have any errors thrown in the child processes.
"""
for i, p in enumerate(self.processes):
if p.exitcode is None:
raise RuntimeError(
f"Process {i} timed out after {elapsed_time} seconds"
)
self.assertNotEqual(self.TEST_ERROR_EXIT_CODE, p.exitcode)
def _check_return_codes(self, elapsed_time) -> None:
def _check_return_codes(self, fn, elapsed_time) -> None:
"""
Checks that the return codes of all spawned processes match, and skips
tests if they returned a return code indicating a skipping condition.
@ -945,10 +938,22 @@ class MultiProcessTestCase(TestCase):
return
else:
raise unittest.SkipTest(skip.message)
# Skip the test return code check
if fn in self.skip_return_code_checks:
return
# In most cases, we expect test to return exit code 0, standing for success.
expected_return_code = 0
# In some negative tests, we expect test to return non-zero exit code,
# such as watchdog throwing SIGABRT.
if fn in self.special_return_code_checks:
expected_return_code = self.special_return_code_checks[fn]
self.assertEqual(
first_process.exitcode,
0,
msg=f"Expected zero exit code but got {first_process.exitcode} for pid: {first_process.pid}",
expected_return_code,
msg=f"Expected exit code {expected_return_code} but got {first_process.exitcode} for pid: {first_process.pid}",
)
@property