[c10d] Add support for testing SIGABRT return (#153167)

`SIGABRT` is a common return by *negative* distributed tests, which checks for effectiveness of NaN assert, watchdog throw, etc.

These errors are not detectable by traditional statements like `with self.assertRaises(RuntimeError)`.

Instead, we'd need to check for the process's return code, e.g. `SIGABRT(6)` would have a return code of -6.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153167
Approved by: https://github.com/fduwjj
This commit is contained in:
Ke Wen
2025-05-23 22:33:31 -07:00
committed by PyTorch MergeBot
parent 10c51b11ff
commit 03e102dbe8
2 changed files with 63 additions and 152 deletions

View File

@ -44,13 +44,11 @@ from torch.testing._internal.common_distributed import (
get_timeout,
init_multigpu_helper,
MultiProcessTestCase,
requires_gloo,
requires_multicast_support,
requires_nccl,
requires_nccl_version,
skip_if_lt_x_gpu,
skip_if_rocm_multiprocess,
sm_is_or_higher_than,
TEST_SKIPS,
with_dist_debug_levels,
with_nccl_blocking_wait,
@ -284,16 +282,17 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
def setUp(self):
super().setUp()
# Need to skip return code checking for these tests since the child
# processes don't exit cleanly in some cuda versions
self.skip_return_code_checks = [
self.test_nan_assert_float16.__wrapped__,
self.test_nan_assert_float32.__wrapped__,
self.test_nan_assert_float64.__wrapped__,
self.test_nan_assert_bfloat16.__wrapped__,
self.test_nan_assert_float8_e4m3fn.__wrapped__,
self.test_nan_assert_float8_e5m2.__wrapped__,
]
# These tests are expected to throw SIGABRT(6); adding the negative sign
# bc the test return code is actually -6
self.special_return_code_checks = {
self.test_nan_assert_float16.__wrapped__: -signal.SIGABRT,
self.test_nan_assert_float32.__wrapped__: -signal.SIGABRT,
self.test_nan_assert_float64.__wrapped__: -signal.SIGABRT,
self.test_nan_assert_bfloat16.__wrapped__: -signal.SIGABRT,
self.test_nan_assert_float8_e4m3fn.__wrapped__: -signal.SIGABRT,
self.test_nan_assert_float8_e5m2.__wrapped__: -signal.SIGABRT,
}
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
# that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
@ -534,14 +533,14 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
# confirm enable/disable flag works
backend._set_enable_nan_check(False)
pg.allreduce(nan_tensor)
# Note: using all-gather here bc some NCCL/SM version does not support
# FP8 reduction
pg._allgather_base(output, nan_tensor)
backend._set_enable_nan_check(True)
with self.assertRaises(RuntimeError):
# Note: using all-gather here bc FP8 types do not support reduce ops
# at the moment
pg._allgather_base(output, nan_tensor)
pg._allgather_base(output, nan_tensor)
dist.destroy_process_group()
# reset env
os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
@ -576,16 +575,13 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
def test_nan_check(self):
# Not expecting an error, NaN check should not make legit code fail
device = torch.device(f"cuda:{self.rank:d}")
if not sm_is_or_higher_than(device, 8, 0):
self.skipTest("bf16 requires sm >= 8.0")
os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
)
x = torch.ones((10,), dtype=torch.bfloat16, device=device) * self.rank
t = torch.ones(3, 4, dtype=torch.bfloat16, device=device)
x = torch.ones((10,), device=device) * self.rank
t = torch.ones(3, 4, device=device)
c10d.broadcast(x, src=0)
c10d.all_reduce(t)
c10d.barrier()
@ -2775,14 +2771,6 @@ class WorkHookTest(MultiProcessTestCase):
class NcclErrorHandlingTest(MultiProcessTestCase):
def setUp(self):
super().setUp()
# Need to skip return code checking for these tests since the child
# processes don't exit cleanly.
self.skip_return_code_checks = [
self.test_nccl_errors_blocking_abort.__wrapped__,
self.test_nccl_errors_blocking_sigkill.__wrapped__,
self.test_nccl_errors_blocking_sigterm.__wrapped__,
self.test_nccl_errors_blocking_nonzero_exit.__wrapped__,
]
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
# that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
@ -2810,12 +2798,19 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
def _run_all_reduce(self, pg):
pg.allreduce(torch.rand(10).cuda(self.rank))
def _reduce_timeout(self):
# set heartbeat timeout to a small value so that we don't wait too long
# for things to shutdown
os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = "4"
os.environ["TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC"] = "1000"
@requires_nccl()
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
@skip_if_rocm_multiprocess
@skip_but_pass_in_sandcastle("Test does not pass when run locally")
def test_nccl_errors_nonblocking(self):
self._reduce_timeout()
# Note: we unset and restore TORCH_NCCL_ASYNC_ERROR_HANDLING for this test
# since test_c10d_common runs with async error handling by default, but this
# tests behavior when it is not enabled.
@ -2846,30 +2841,24 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
] = prev_nccl_async_error_handling
def _test_nccl_errors_blocking(self, func):
@requires_nccl()
@skip_if_lt_x_gpu(3)
@skip_if_rocm_multiprocess
def test_nccl_errors_blocking(self):
self._reduce_timeout()
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(
store,
self.rank,
self.world_size,
timeout=timedelta(seconds=10),
)
process_group.allreduce(torch.rand(10).cuda(self.rank))
x = torch.rand(1024 * 1024).cuda(self.rank)
process_group.allreduce(x)
if self.rank == 0:
work = process_group.allreduce(torch.rand(10).cuda(self.rank))
work = process_group.allreduce(x)
with self.assertRaisesRegex(dist.DistBackendError, ""):
# It seems the error message would be different depending on
# whether the test is run on CI machine and devGPU. Skipping
# the error message check to make both sides happy.
work.wait(timeout=timedelta(seconds=self.op_timeout_sec))
# Run some GPU operations to make sure cuda has not gotten stuck.
# It was observed cuda could get stuck if NCCL communicators were
# not properly aborted before throwing RuntimeError.
torch.rand(10).cuda(self.rank)
elif self.rank == 1:
# Clean up structures (ex: files for FileStore before going down)
del process_group
func()
def _test_barrier_error(self):
store = c10d.FileStore(self.file_name, self.world_size)
@ -2889,60 +2878,19 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
timeout=timedelta(seconds=self.op_timeout_sec)
)
@with_nccl_blocking_wait
@requires_nccl()
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
@skip_if_rocm_multiprocess
def test_nccl_errors_blocking_clean_exit(self):
self._test_nccl_errors_blocking(lambda: sys.exit(0))
@with_nccl_blocking_wait
@requires_nccl()
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
@skip_if_rocm_multiprocess
def test_nccl_errors_blocking_nonzero_exit(self):
self._test_nccl_errors_blocking(lambda: sys.exit(1))
@with_nccl_blocking_wait
@requires_nccl()
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
@skip_if_rocm_multiprocess
@skip_but_pass_in_sandcastle(
"Frequently times out see https://github.com/pytorch/pytorch/issues/58920"
)
def test_nccl_errors_blocking_abort(self):
self._test_nccl_errors_blocking(lambda: os.abort())
@with_nccl_blocking_wait
@requires_nccl()
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
@skip_if_rocm_multiprocess
def test_nccl_errors_blocking_sigkill(self):
self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGKILL))
@with_nccl_blocking_wait
@requires_nccl()
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
@skip_if_rocm_multiprocess
def test_nccl_errors_blocking_sigterm(self):
self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGTERM))
@with_nccl_blocking_wait
@requires_nccl()
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
def test_nccl_blocking_wait_with_barrier(self):
self._reduce_timeout()
self._test_barrier_error()
@requires_nccl()
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
def test_nccl_non_blocking_wait_with_barrier(self):
self._reduce_timeout()
# test the barrier behavior in the non blocking wait setting
prev_nccl_async_error_handling = os.environ.get(
"TORCH_NCCL_ASYNC_ERROR_HANDLING", None
@ -3013,6 +2961,7 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
@skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(3)
def test_restart_pg_after_error(self):
self._reduce_timeout()
# test the barrier behavior in the non blocking wait setting
prev_nccl_async_error_handling = os.environ.get(
"TORCH_NCCL_ASYNC_ERROR_HANDLING", None
@ -3102,45 +3051,6 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
self._run_invalid_nccl_blocking_wait_env("2147483647")
self._run_invalid_nccl_blocking_wait_env("4294967295")
@with_nccl_blocking_wait
@requires_nccl()
@requires_gloo()
@skip_if_lt_x_gpu(3)
def test_nccl_timeout(self):
store = c10d.FileStore(self.file_name, self.world_size)
# Initialize process_group.
process_group = c10d.ProcessGroupNCCL(
store, self.rank, self.world_size, timeout=timedelta(seconds=10)
)
# Control gloo pg used as go-ahead signal/barrier
# to coordinate btwn ranks.
pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
failed_collective_timeout = timedelta(milliseconds=100)
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(
timeout=timedelta(seconds=5)
)
if self.rank == 0:
# This should timeout in about 1 second.
# Watchdog may abort timed out work resulting in NCCL error instead of operation timed out.
with self.assertRaisesRegex(
dist.DistBackendError, self.blocking_wait_error_msg
):
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(
timeout=failed_collective_timeout
)
# Now do a barrier to tell other rank to go ahead.
pg_gloo.barrier().wait()
else:
# Wait on rank 0 to fail.
try:
pg_gloo.barrier().wait()
except Exception as e:
raise ValueError(
f"Rank {self.rank} barrier timed out waiting for rank 0 with error: {str(e)}"
) from e
class NcclUserBufferRegistrationTest(MultiProcessTestCase):
def setUp(self):