mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[c10d] Add support for testing SIGABRT return (#153167)
`SIGABRT` is a common return by *negative* distributed tests, which checks for effectiveness of NaN assert, watchdog throw, etc. These errors are not detectable by traditional statements like `with self.assertRaises(RuntimeError)`. Instead, we'd need to check for the process's return code, e.g. `SIGABRT(6)` would have a return code of -6. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153167 Approved by: https://github.com/fduwjj
This commit is contained in:
@ -642,7 +642,15 @@ class MultiProcessTestCase(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
self.skip_return_code_checks = [] # type: ignore[var-annotated]
|
||||
|
||||
# Used for tests that are expected to return a non-0 exit code, such as
|
||||
# SIGABRT thrown by watchdog.
|
||||
self.special_return_code_checks: dict = {}
|
||||
|
||||
# Used for tests that may return any exit code, which makes it hard to
|
||||
# check. This is rare, use with caution.
|
||||
self.skip_return_code_checks: list = []
|
||||
|
||||
self.processes = [] # type: ignore[var-annotated]
|
||||
self.rank = self.MAIN_PROCESS_RANK
|
||||
self.file_name = tempfile.NamedTemporaryFile(delete=False).name
|
||||
@ -862,28 +870,13 @@ class MultiProcessTestCase(TestCase):
|
||||
time.sleep(0.1)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
if fn in self.skip_return_code_checks:
|
||||
self._check_no_test_errors(elapsed_time)
|
||||
else:
|
||||
self._check_return_codes(elapsed_time)
|
||||
self._check_return_codes(fn, elapsed_time)
|
||||
finally:
|
||||
# Close all pipes
|
||||
for pipe in self.pid_to_pipe.values():
|
||||
pipe.close()
|
||||
|
||||
def _check_no_test_errors(self, elapsed_time) -> None:
|
||||
"""
|
||||
Checks that we didn't have any errors thrown in the child processes.
|
||||
"""
|
||||
for i, p in enumerate(self.processes):
|
||||
if p.exitcode is None:
|
||||
raise RuntimeError(
|
||||
f"Process {i} timed out after {elapsed_time} seconds"
|
||||
)
|
||||
self.assertNotEqual(self.TEST_ERROR_EXIT_CODE, p.exitcode)
|
||||
|
||||
def _check_return_codes(self, elapsed_time) -> None:
|
||||
def _check_return_codes(self, fn, elapsed_time) -> None:
|
||||
"""
|
||||
Checks that the return codes of all spawned processes match, and skips
|
||||
tests if they returned a return code indicating a skipping condition.
|
||||
@ -945,10 +938,22 @@ class MultiProcessTestCase(TestCase):
|
||||
return
|
||||
else:
|
||||
raise unittest.SkipTest(skip.message)
|
||||
|
||||
# Skip the test return code check
|
||||
if fn in self.skip_return_code_checks:
|
||||
return
|
||||
|
||||
# In most cases, we expect test to return exit code 0, standing for success.
|
||||
expected_return_code = 0
|
||||
# In some negative tests, we expect test to return non-zero exit code,
|
||||
# such as watchdog throwing SIGABRT.
|
||||
if fn in self.special_return_code_checks:
|
||||
expected_return_code = self.special_return_code_checks[fn]
|
||||
|
||||
self.assertEqual(
|
||||
first_process.exitcode,
|
||||
0,
|
||||
msg=f"Expected zero exit code but got {first_process.exitcode} for pid: {first_process.pid}",
|
||||
expected_return_code,
|
||||
msg=f"Expected exit code {expected_return_code} but got {first_process.exitcode} for pid: {first_process.pid}",
|
||||
)
|
||||
|
||||
@property
|
||||
|
Reference in New Issue
Block a user