Revert "[c10d] Add support for testing SIGABRT return (#153167)"

This reverts commit 03e102dbe8cbffc2e42a3122b262d02f03571de7.

Reverted https://github.com/pytorch/pytorch/pull/153167 on behalf of https://github.com/malfet due to It broke lint ([comment](https://github.com/pytorch/pytorch/pull/153167#issuecomment-2907820789))
This commit is contained in:
PyTorch MergeBot
2025-05-25 13:17:27 +00:00
parent c4ef4090c5
commit 54932d865e
2 changed files with 152 additions and 63 deletions

View File

@ -44,11 +44,13 @@ from torch.testing._internal.common_distributed import (
get_timeout,
init_multigpu_helper,
MultiProcessTestCase,
requires_gloo,
requires_multicast_support,
requires_nccl,
requires_nccl_version,
skip_if_lt_x_gpu,
skip_if_rocm_multiprocess,
sm_is_or_higher_than,
TEST_SKIPS,
with_dist_debug_levels,
with_nccl_blocking_wait,
@ -282,17 +284,16 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
def setUp(self):
super().setUp()
# These tests are expected to throw SIGABRT(6); adding the negative sign
# bc the test return code is actually -6
self.special_return_code_checks = {
self.test_nan_assert_float16.__wrapped__: -signal.SIGABRT,
self.test_nan_assert_float32.__wrapped__: -signal.SIGABRT,
self.test_nan_assert_float64.__wrapped__: -signal.SIGABRT,
self.test_nan_assert_bfloat16.__wrapped__: -signal.SIGABRT,
self.test_nan_assert_float8_e4m3fn.__wrapped__: -signal.SIGABRT,
self.test_nan_assert_float8_e5m2.__wrapped__: -signal.SIGABRT,
}
# Need to skip return code checking for these tests since the child
# processes don't exit cleanly in some cuda versions
self.skip_return_code_checks = [
self.test_nan_assert_float16.__wrapped__,
self.test_nan_assert_float32.__wrapped__,
self.test_nan_assert_float64.__wrapped__,
self.test_nan_assert_bfloat16.__wrapped__,
self.test_nan_assert_float8_e4m3fn.__wrapped__,
self.test_nan_assert_float8_e5m2.__wrapped__,
]
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
# that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
@ -533,14 +534,14 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
# confirm enable/disable flag works
backend._set_enable_nan_check(False)
# Note: using all-gather here bc some NCCL/SM version does not support
# FP8 reduction
pg._allgather_base(output, nan_tensor)
pg.allreduce(nan_tensor)
backend._set_enable_nan_check(True)
pg._allgather_base(output, nan_tensor)
with self.assertRaises(RuntimeError):
# Note: using all-gather here bc FP8 types do not support reduce ops
# at the moment
pg._allgather_base(output, nan_tensor)
dist.destroy_process_group()
# reset env
os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
@ -575,13 +576,16 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
def test_nan_check(self):
# Not expecting an error, NaN check should not make legit code fail
device = torch.device(f"cuda:{self.rank:d}")
if not sm_is_or_higher_than(device, 8, 0):
self.skipTest("bf16 requires sm >= 8.0")
os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
)
x = torch.ones((10,), device=device) * self.rank
t = torch.ones(3, 4, device=device)
x = torch.ones((10,), dtype=torch.bfloat16, device=device) * self.rank
t = torch.ones(3, 4, dtype=torch.bfloat16, device=device)
c10d.broadcast(x, src=0)
c10d.all_reduce(t)
c10d.barrier()
@ -2771,6 +2775,14 @@ class WorkHookTest(MultiProcessTestCase):
class NcclErrorHandlingTest(MultiProcessTestCase):
def setUp(self):
super().setUp()
# Need to skip return code checking for these tests since the child
# processes don't exit cleanly.
self.skip_return_code_checks = [
self.test_nccl_errors_blocking_abort.__wrapped__,
self.test_nccl_errors_blocking_sigkill.__wrapped__,
self.test_nccl_errors_blocking_sigterm.__wrapped__,
self.test_nccl_errors_blocking_nonzero_exit.__wrapped__,
]
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
# that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
@ -2798,19 +2810,12 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
def _run_all_reduce(self, pg):
pg.allreduce(torch.rand(10).cuda(self.rank))
def _reduce_timeout(self):
# set heartbeat timeout to a small value so that we don't wait too long
# for things to shutdown
os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = "4"
os.environ["TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC"] = "1000"
@requires_nccl()
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
@skip_if_rocm_multiprocess
@skip_but_pass_in_sandcastle("Test does not pass when run locally")
def test_nccl_errors_nonblocking(self):
self._reduce_timeout()
# Note: we unset and restore TORCH_NCCL_ASYNC_ERROR_HANDLING for this test
# since test_c10d_common runs with async error handling by default, but this
# tests behavior when it is not enabled.
@ -2841,24 +2846,30 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
] = prev_nccl_async_error_handling
@requires_nccl()
@skip_if_lt_x_gpu(3)
@skip_if_rocm_multiprocess
def test_nccl_errors_blocking(self):
self._reduce_timeout()
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
def _test_nccl_errors_blocking(self, func):
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(
store,
self.rank,
self.world_size,
timeout=timedelta(seconds=10),
)
x = torch.rand(1024 * 1024).cuda(self.rank)
process_group.allreduce(x)
process_group.allreduce(torch.rand(10).cuda(self.rank))
if self.rank == 0:
work = process_group.allreduce(x)
work = process_group.allreduce(torch.rand(10).cuda(self.rank))
with self.assertRaisesRegex(dist.DistBackendError, ""):
# It seems the error message would be different depending on
# whether the test is run on CI machine and devGPU. Skipping
# the error message check to make both sides happy.
work.wait(timeout=timedelta(seconds=self.op_timeout_sec))
# Run some GPU operations to make sure cuda has not gotten stuck.
# It was observed cuda could get stuck if NCCL communicators were
# not properly aborted before throwing RuntimeError.
torch.rand(10).cuda(self.rank)
elif self.rank == 1:
# Clean up structures (ex: files for FileStore before going down)
del process_group
func()
def _test_barrier_error(self):
store = c10d.FileStore(self.file_name, self.world_size)
@ -2878,19 +2889,60 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
timeout=timedelta(seconds=self.op_timeout_sec)
)
@with_nccl_blocking_wait
@requires_nccl()
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
@skip_if_rocm_multiprocess
def test_nccl_errors_blocking_clean_exit(self):
self._test_nccl_errors_blocking(lambda: sys.exit(0))
@with_nccl_blocking_wait
@requires_nccl()
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
@skip_if_rocm_multiprocess
def test_nccl_errors_blocking_nonzero_exit(self):
self._test_nccl_errors_blocking(lambda: sys.exit(1))
@with_nccl_blocking_wait
@requires_nccl()
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
@skip_if_rocm_multiprocess
@skip_but_pass_in_sandcastle(
"Frequently times out see https://github.com/pytorch/pytorch/issues/58920"
)
def test_nccl_errors_blocking_abort(self):
self._test_nccl_errors_blocking(lambda: os.abort())
@with_nccl_blocking_wait
@requires_nccl()
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
@skip_if_rocm_multiprocess
def test_nccl_errors_blocking_sigkill(self):
self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGKILL))
@with_nccl_blocking_wait
@requires_nccl()
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
@skip_if_rocm_multiprocess
def test_nccl_errors_blocking_sigterm(self):
self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGTERM))
@with_nccl_blocking_wait
@requires_nccl()
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
def test_nccl_blocking_wait_with_barrier(self):
self._reduce_timeout()
self._test_barrier_error()
@requires_nccl()
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
@skip_if_lt_x_gpu(3)
def test_nccl_non_blocking_wait_with_barrier(self):
self._reduce_timeout()
# test the barrier behavior in the non blocking wait setting
prev_nccl_async_error_handling = os.environ.get(
"TORCH_NCCL_ASYNC_ERROR_HANDLING", None
@ -2961,7 +3013,6 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
@skip_if_rocm_multiprocess
@skip_if_lt_x_gpu(3)
def test_restart_pg_after_error(self):
self._reduce_timeout()
# test the barrier behavior in the non blocking wait setting
prev_nccl_async_error_handling = os.environ.get(
"TORCH_NCCL_ASYNC_ERROR_HANDLING", None
@ -3051,6 +3102,45 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
self._run_invalid_nccl_blocking_wait_env("2147483647")
self._run_invalid_nccl_blocking_wait_env("4294967295")
@with_nccl_blocking_wait
@requires_nccl()
@requires_gloo()
@skip_if_lt_x_gpu(3)
def test_nccl_timeout(self):
store = c10d.FileStore(self.file_name, self.world_size)
# Initialize process_group.
process_group = c10d.ProcessGroupNCCL(
store, self.rank, self.world_size, timeout=timedelta(seconds=10)
)
# Control gloo pg used as go-ahead signal/barrier
# to coordinate btwn ranks.
pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
failed_collective_timeout = timedelta(milliseconds=100)
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(
timeout=timedelta(seconds=5)
)
if self.rank == 0:
# This should timeout in about 1 second.
# Watchdog may abort timed out work resulting in NCCL error instead of operation timed out.
with self.assertRaisesRegex(
dist.DistBackendError, self.blocking_wait_error_msg
):
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(
timeout=failed_collective_timeout
)
# Now do a barrier to tell other rank to go ahead.
pg_gloo.barrier().wait()
else:
# Wait on rank 0 to fail.
try:
pg_gloo.barrier().wait()
except Exception as e:
raise ValueError(
f"Rank {self.rank} barrier timed out waiting for rank 0 with error: {str(e)}"
) from e
class NcclUserBufferRegistrationTest(MultiProcessTestCase):
def setUp(self):

View File

@ -640,15 +640,7 @@ class MultiProcessTestCase(TestCase):
def setUp(self) -> None:
super().setUp()
# Used for tests that are expected to return a non-0 exit code, such as
# SIGABRT thrown by watchdog.
self.special_return_code_checks: dict = {}
# Used for tests that may return any exit code, which makes it hard to
# check. This is rare, use with caution.
self.skip_return_code_checks: list = []
self.skip_return_code_checks = [] # type: ignore[var-annotated]
self.processes = [] # type: ignore[var-annotated]
self.rank = self.MAIN_PROCESS_RANK
self.file_name = tempfile.NamedTemporaryFile(delete=False).name
@ -873,13 +865,28 @@ class MultiProcessTestCase(TestCase):
time.sleep(0.1)
elapsed_time = time.time() - start_time
self._check_return_codes(fn, elapsed_time)
if fn in self.skip_return_code_checks:
self._check_no_test_errors(elapsed_time)
else:
self._check_return_codes(elapsed_time)
finally:
# Close all pipes
for pipe in self.pid_to_pipe.values():
pipe.close()
def _check_return_codes(self, fn, elapsed_time) -> None:
def _check_no_test_errors(self, elapsed_time) -> None:
"""
Checks that we didn't have any errors thrown in the child processes.
"""
for i, p in enumerate(self.processes):
if p.exitcode is None:
raise RuntimeError(
f"Process {i} timed out after {elapsed_time} seconds"
)
self.assertNotEqual(self.TEST_ERROR_EXIT_CODE, p.exitcode)
def _check_return_codes(self, elapsed_time) -> None:
"""
Checks that the return codes of all spawned processes match, and skips
tests if they returned a return code indicating a skipping condition.
@ -921,11 +928,11 @@ class MultiProcessTestCase(TestCase):
raise RuntimeError(
f"Process {i} terminated or timed out after {elapsed_time} seconds"
)
# Skip the test return code check
if fn in self.skip_return_code_checks:
return
self.assertEqual(
p.exitcode,
first_process.exitcode,
msg=f"Expect process {i} exit code to match Process 0 exit code of {first_process.exitcode}, but got {p.exitcode}",
)
for skip in TEST_SKIPS.values():
if first_process.exitcode == skip.exit_code:
if IS_SANDCASTLE:
@ -941,18 +948,10 @@ class MultiProcessTestCase(TestCase):
return
else:
raise unittest.SkipTest(skip.message)
# In most cases, we expect test to return exit code 0, standing for success.
expected_return_code = 0
# In some negative tests, we expect test to return non-zero exit code,
# such as watchdog throwing SIGABRT.
if fn in self.special_return_code_checks:
expected_return_code = self.special_return_code_checks[fn]
self.assertEqual(
first_process.exitcode,
expected_return_code,
msg=f"Expected exit code {expected_return_code} but got {first_process.exitcode} for pid: {first_process.pid}",
0,
msg=f"Expected zero exit code but got {first_process.exitcode} for pid: {first_process.pid}",
)
@property