mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[c10d] Add support for testing SIGABRT return (#153167)
`SIGABRT` is a common return by *negative* distributed tests, which checks for effectiveness of NaN assert, watchdog throw, etc. These errors are not detectable by traditional statements like `with self.assertRaises(RuntimeError)`. Instead, we'd need to check for the process's return code, e.g. `SIGABRT(6)` would have a return code of -6. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153167 Approved by: https://github.com/fduwjj
This commit is contained in:
@ -44,13 +44,11 @@ from torch.testing._internal.common_distributed import (
|
||||
get_timeout,
|
||||
init_multigpu_helper,
|
||||
MultiProcessTestCase,
|
||||
requires_gloo,
|
||||
requires_multicast_support,
|
||||
requires_nccl,
|
||||
requires_nccl_version,
|
||||
skip_if_lt_x_gpu,
|
||||
skip_if_rocm_multiprocess,
|
||||
sm_is_or_higher_than,
|
||||
TEST_SKIPS,
|
||||
with_dist_debug_levels,
|
||||
with_nccl_blocking_wait,
|
||||
@ -284,16 +282,17 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Need to skip return code checking for these tests since the child
|
||||
# processes don't exit cleanly in some cuda versions
|
||||
self.skip_return_code_checks = [
|
||||
self.test_nan_assert_float16.__wrapped__,
|
||||
self.test_nan_assert_float32.__wrapped__,
|
||||
self.test_nan_assert_float64.__wrapped__,
|
||||
self.test_nan_assert_bfloat16.__wrapped__,
|
||||
self.test_nan_assert_float8_e4m3fn.__wrapped__,
|
||||
self.test_nan_assert_float8_e5m2.__wrapped__,
|
||||
]
|
||||
|
||||
# These tests are expected to throw SIGABRT(6); adding the negative sign
|
||||
# bc the test return code is actually -6
|
||||
self.special_return_code_checks = {
|
||||
self.test_nan_assert_float16.__wrapped__: -signal.SIGABRT,
|
||||
self.test_nan_assert_float32.__wrapped__: -signal.SIGABRT,
|
||||
self.test_nan_assert_float64.__wrapped__: -signal.SIGABRT,
|
||||
self.test_nan_assert_bfloat16.__wrapped__: -signal.SIGABRT,
|
||||
self.test_nan_assert_float8_e4m3fn.__wrapped__: -signal.SIGABRT,
|
||||
self.test_nan_assert_float8_e5m2.__wrapped__: -signal.SIGABRT,
|
||||
}
|
||||
|
||||
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
|
||||
# that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
|
||||
@ -534,14 +533,14 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
|
||||
# confirm enable/disable flag works
|
||||
backend._set_enable_nan_check(False)
|
||||
pg.allreduce(nan_tensor)
|
||||
# Note: using all-gather here bc some NCCL/SM version does not support
|
||||
# FP8 reduction
|
||||
pg._allgather_base(output, nan_tensor)
|
||||
|
||||
backend._set_enable_nan_check(True)
|
||||
with self.assertRaises(RuntimeError):
|
||||
# Note: using all-gather here bc FP8 types do not support reduce ops
|
||||
# at the moment
|
||||
pg._allgather_base(output, nan_tensor)
|
||||
pg._allgather_base(output, nan_tensor)
|
||||
dist.destroy_process_group()
|
||||
|
||||
# reset env
|
||||
os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
|
||||
|
||||
@ -576,16 +575,13 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
def test_nan_check(self):
|
||||
# Not expecting an error, NaN check should not make legit code fail
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
if not sm_is_or_higher_than(device, 8, 0):
|
||||
self.skipTest("bf16 requires sm >= 8.0")
|
||||
|
||||
os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
c10d.init_process_group(
|
||||
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
|
||||
)
|
||||
x = torch.ones((10,), dtype=torch.bfloat16, device=device) * self.rank
|
||||
t = torch.ones(3, 4, dtype=torch.bfloat16, device=device)
|
||||
x = torch.ones((10,), device=device) * self.rank
|
||||
t = torch.ones(3, 4, device=device)
|
||||
c10d.broadcast(x, src=0)
|
||||
c10d.all_reduce(t)
|
||||
c10d.barrier()
|
||||
@ -2775,14 +2771,6 @@ class WorkHookTest(MultiProcessTestCase):
|
||||
class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Need to skip return code checking for these tests since the child
|
||||
# processes don't exit cleanly.
|
||||
self.skip_return_code_checks = [
|
||||
self.test_nccl_errors_blocking_abort.__wrapped__,
|
||||
self.test_nccl_errors_blocking_sigkill.__wrapped__,
|
||||
self.test_nccl_errors_blocking_sigterm.__wrapped__,
|
||||
self.test_nccl_errors_blocking_nonzero_exit.__wrapped__,
|
||||
]
|
||||
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
|
||||
# that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
|
||||
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
|
||||
@ -2810,12 +2798,19 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
def _run_all_reduce(self, pg):
|
||||
pg.allreduce(torch.rand(10).cuda(self.rank))
|
||||
|
||||
def _reduce_timeout(self):
|
||||
# set heartbeat timeout to a small value so that we don't wait too long
|
||||
# for things to shutdown
|
||||
os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = "4"
|
||||
os.environ["TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC"] = "1000"
|
||||
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||
@skip_if_lt_x_gpu(3)
|
||||
@skip_if_rocm_multiprocess
|
||||
@skip_but_pass_in_sandcastle("Test does not pass when run locally")
|
||||
def test_nccl_errors_nonblocking(self):
|
||||
self._reduce_timeout()
|
||||
# Note: we unset and restore TORCH_NCCL_ASYNC_ERROR_HANDLING for this test
|
||||
# since test_c10d_common runs with async error handling by default, but this
|
||||
# tests behavior when it is not enabled.
|
||||
@ -2846,30 +2841,24 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
|
||||
] = prev_nccl_async_error_handling
|
||||
|
||||
def _test_nccl_errors_blocking(self, func):
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(3)
|
||||
@skip_if_rocm_multiprocess
|
||||
def test_nccl_errors_blocking(self):
|
||||
self._reduce_timeout()
|
||||
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
process_group = c10d.ProcessGroupNCCL(
|
||||
store,
|
||||
self.rank,
|
||||
self.world_size,
|
||||
timeout=timedelta(seconds=10),
|
||||
)
|
||||
process_group.allreduce(torch.rand(10).cuda(self.rank))
|
||||
x = torch.rand(1024 * 1024).cuda(self.rank)
|
||||
process_group.allreduce(x)
|
||||
if self.rank == 0:
|
||||
work = process_group.allreduce(torch.rand(10).cuda(self.rank))
|
||||
work = process_group.allreduce(x)
|
||||
with self.assertRaisesRegex(dist.DistBackendError, ""):
|
||||
# It seems the error message would be different depending on
|
||||
# whether the test is run on CI machine and devGPU. Skipping
|
||||
# the error message check to make both sides happy.
|
||||
work.wait(timeout=timedelta(seconds=self.op_timeout_sec))
|
||||
# Run some GPU operations to make sure cuda has not gotten stuck.
|
||||
# It was observed cuda could get stuck if NCCL communicators were
|
||||
# not properly aborted before throwing RuntimeError.
|
||||
torch.rand(10).cuda(self.rank)
|
||||
elif self.rank == 1:
|
||||
# Clean up structures (ex: files for FileStore before going down)
|
||||
del process_group
|
||||
func()
|
||||
|
||||
def _test_barrier_error(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
@ -2889,60 +2878,19 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
timeout=timedelta(seconds=self.op_timeout_sec)
|
||||
)
|
||||
|
||||
@with_nccl_blocking_wait
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||
@skip_if_lt_x_gpu(3)
|
||||
@skip_if_rocm_multiprocess
|
||||
def test_nccl_errors_blocking_clean_exit(self):
|
||||
self._test_nccl_errors_blocking(lambda: sys.exit(0))
|
||||
|
||||
@with_nccl_blocking_wait
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||
@skip_if_lt_x_gpu(3)
|
||||
@skip_if_rocm_multiprocess
|
||||
def test_nccl_errors_blocking_nonzero_exit(self):
|
||||
self._test_nccl_errors_blocking(lambda: sys.exit(1))
|
||||
|
||||
@with_nccl_blocking_wait
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||
@skip_if_lt_x_gpu(3)
|
||||
@skip_if_rocm_multiprocess
|
||||
@skip_but_pass_in_sandcastle(
|
||||
"Frequently times out see https://github.com/pytorch/pytorch/issues/58920"
|
||||
)
|
||||
def test_nccl_errors_blocking_abort(self):
|
||||
self._test_nccl_errors_blocking(lambda: os.abort())
|
||||
|
||||
@with_nccl_blocking_wait
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||
@skip_if_lt_x_gpu(3)
|
||||
@skip_if_rocm_multiprocess
|
||||
def test_nccl_errors_blocking_sigkill(self):
|
||||
self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGKILL))
|
||||
|
||||
@with_nccl_blocking_wait
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||
@skip_if_lt_x_gpu(3)
|
||||
@skip_if_rocm_multiprocess
|
||||
def test_nccl_errors_blocking_sigterm(self):
|
||||
self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGTERM))
|
||||
|
||||
@with_nccl_blocking_wait
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||
@skip_if_lt_x_gpu(3)
|
||||
def test_nccl_blocking_wait_with_barrier(self):
|
||||
self._reduce_timeout()
|
||||
self._test_barrier_error()
|
||||
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||
@skip_if_lt_x_gpu(3)
|
||||
def test_nccl_non_blocking_wait_with_barrier(self):
|
||||
self._reduce_timeout()
|
||||
# test the barrier behavior in the non blocking wait setting
|
||||
prev_nccl_async_error_handling = os.environ.get(
|
||||
"TORCH_NCCL_ASYNC_ERROR_HANDLING", None
|
||||
@ -3013,6 +2961,7 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
@skip_if_rocm_multiprocess
|
||||
@skip_if_lt_x_gpu(3)
|
||||
def test_restart_pg_after_error(self):
|
||||
self._reduce_timeout()
|
||||
# test the barrier behavior in the non blocking wait setting
|
||||
prev_nccl_async_error_handling = os.environ.get(
|
||||
"TORCH_NCCL_ASYNC_ERROR_HANDLING", None
|
||||
@ -3102,45 +3051,6 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
self._run_invalid_nccl_blocking_wait_env("2147483647")
|
||||
self._run_invalid_nccl_blocking_wait_env("4294967295")
|
||||
|
||||
@with_nccl_blocking_wait
|
||||
@requires_nccl()
|
||||
@requires_gloo()
|
||||
@skip_if_lt_x_gpu(3)
|
||||
def test_nccl_timeout(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
|
||||
# Initialize process_group.
|
||||
process_group = c10d.ProcessGroupNCCL(
|
||||
store, self.rank, self.world_size, timeout=timedelta(seconds=10)
|
||||
)
|
||||
# Control gloo pg used as go-ahead signal/barrier
|
||||
# to coordinate btwn ranks.
|
||||
pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
|
||||
failed_collective_timeout = timedelta(milliseconds=100)
|
||||
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(
|
||||
timeout=timedelta(seconds=5)
|
||||
)
|
||||
|
||||
if self.rank == 0:
|
||||
# This should timeout in about 1 second.
|
||||
# Watchdog may abort timed out work resulting in NCCL error instead of operation timed out.
|
||||
with self.assertRaisesRegex(
|
||||
dist.DistBackendError, self.blocking_wait_error_msg
|
||||
):
|
||||
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(
|
||||
timeout=failed_collective_timeout
|
||||
)
|
||||
# Now do a barrier to tell other rank to go ahead.
|
||||
pg_gloo.barrier().wait()
|
||||
else:
|
||||
# Wait on rank 0 to fail.
|
||||
try:
|
||||
pg_gloo.barrier().wait()
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Rank {self.rank} barrier timed out waiting for rank 0 with error: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
class NcclUserBufferRegistrationTest(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
|
Reference in New Issue
Block a user