mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[c10d] Add support for testing SIGABRT return (#153167)"
This reverts commit 03e102dbe8cbffc2e42a3122b262d02f03571de7. Reverted https://github.com/pytorch/pytorch/pull/153167 on behalf of https://github.com/malfet due to It broke lint ([comment](https://github.com/pytorch/pytorch/pull/153167#issuecomment-2907820789))
This commit is contained in:
@ -44,11 +44,13 @@ from torch.testing._internal.common_distributed import (
|
||||
get_timeout,
|
||||
init_multigpu_helper,
|
||||
MultiProcessTestCase,
|
||||
requires_gloo,
|
||||
requires_multicast_support,
|
||||
requires_nccl,
|
||||
requires_nccl_version,
|
||||
skip_if_lt_x_gpu,
|
||||
skip_if_rocm_multiprocess,
|
||||
sm_is_or_higher_than,
|
||||
TEST_SKIPS,
|
||||
with_dist_debug_levels,
|
||||
with_nccl_blocking_wait,
|
||||
@ -282,17 +284,16 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
# These tests are expected to throw SIGABRT(6); adding the negative sign
|
||||
# bc the test return code is actually -6
|
||||
self.special_return_code_checks = {
|
||||
self.test_nan_assert_float16.__wrapped__: -signal.SIGABRT,
|
||||
self.test_nan_assert_float32.__wrapped__: -signal.SIGABRT,
|
||||
self.test_nan_assert_float64.__wrapped__: -signal.SIGABRT,
|
||||
self.test_nan_assert_bfloat16.__wrapped__: -signal.SIGABRT,
|
||||
self.test_nan_assert_float8_e4m3fn.__wrapped__: -signal.SIGABRT,
|
||||
self.test_nan_assert_float8_e5m2.__wrapped__: -signal.SIGABRT,
|
||||
}
|
||||
# Need to skip return code checking for these tests since the child
|
||||
# processes don't exit cleanly in some cuda versions
|
||||
self.skip_return_code_checks = [
|
||||
self.test_nan_assert_float16.__wrapped__,
|
||||
self.test_nan_assert_float32.__wrapped__,
|
||||
self.test_nan_assert_float64.__wrapped__,
|
||||
self.test_nan_assert_bfloat16.__wrapped__,
|
||||
self.test_nan_assert_float8_e4m3fn.__wrapped__,
|
||||
self.test_nan_assert_float8_e5m2.__wrapped__,
|
||||
]
|
||||
|
||||
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
|
||||
# that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
|
||||
@ -533,14 +534,14 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
|
||||
# confirm enable/disable flag works
|
||||
backend._set_enable_nan_check(False)
|
||||
# Note: using all-gather here bc some NCCL/SM version does not support
|
||||
# FP8 reduction
|
||||
pg._allgather_base(output, nan_tensor)
|
||||
pg.allreduce(nan_tensor)
|
||||
|
||||
backend._set_enable_nan_check(True)
|
||||
pg._allgather_base(output, nan_tensor)
|
||||
with self.assertRaises(RuntimeError):
|
||||
# Note: using all-gather here bc FP8 types do not support reduce ops
|
||||
# at the moment
|
||||
pg._allgather_base(output, nan_tensor)
|
||||
dist.destroy_process_group()
|
||||
|
||||
# reset env
|
||||
os.environ["TORCH_NCCL_NAN_CHECK"] = "0"
|
||||
|
||||
@ -575,13 +576,16 @@ class ProcessGroupNCCLGroupTest(MultiProcessTestCase):
|
||||
def test_nan_check(self):
|
||||
# Not expecting an error, NaN check should not make legit code fail
|
||||
device = torch.device(f"cuda:{self.rank:d}")
|
||||
if not sm_is_or_higher_than(device, 8, 0):
|
||||
self.skipTest("bf16 requires sm >= 8.0")
|
||||
|
||||
os.environ["TORCH_NCCL_NAN_CHECK"] = "1"
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
c10d.init_process_group(
|
||||
backend="nccl", store=store, rank=self.rank, world_size=self.world_size
|
||||
)
|
||||
x = torch.ones((10,), device=device) * self.rank
|
||||
t = torch.ones(3, 4, device=device)
|
||||
x = torch.ones((10,), dtype=torch.bfloat16, device=device) * self.rank
|
||||
t = torch.ones(3, 4, dtype=torch.bfloat16, device=device)
|
||||
c10d.broadcast(x, src=0)
|
||||
c10d.all_reduce(t)
|
||||
c10d.barrier()
|
||||
@ -2771,6 +2775,14 @@ class WorkHookTest(MultiProcessTestCase):
|
||||
class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
# Need to skip return code checking for these tests since the child
|
||||
# processes don't exit cleanly.
|
||||
self.skip_return_code_checks = [
|
||||
self.test_nccl_errors_blocking_abort.__wrapped__,
|
||||
self.test_nccl_errors_blocking_sigkill.__wrapped__,
|
||||
self.test_nccl_errors_blocking_sigterm.__wrapped__,
|
||||
self.test_nccl_errors_blocking_nonzero_exit.__wrapped__,
|
||||
]
|
||||
# TORCH_NCCL_BLOCKING_WAIT overrides TORCH_NCCL_ASYNC_ERROR_HANDLING hence tests
|
||||
# that use TORCH_NCCL_BLOCKING_WAIT will test it as expected.
|
||||
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
|
||||
@ -2798,19 +2810,12 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
def _run_all_reduce(self, pg):
|
||||
pg.allreduce(torch.rand(10).cuda(self.rank))
|
||||
|
||||
def _reduce_timeout(self):
|
||||
# set heartbeat timeout to a small value so that we don't wait too long
|
||||
# for things to shutdown
|
||||
os.environ["TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC"] = "4"
|
||||
os.environ["TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC"] = "1000"
|
||||
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||
@skip_if_lt_x_gpu(3)
|
||||
@skip_if_rocm_multiprocess
|
||||
@skip_but_pass_in_sandcastle("Test does not pass when run locally")
|
||||
def test_nccl_errors_nonblocking(self):
|
||||
self._reduce_timeout()
|
||||
# Note: we unset and restore TORCH_NCCL_ASYNC_ERROR_HANDLING for this test
|
||||
# since test_c10d_common runs with async error handling by default, but this
|
||||
# tests behavior when it is not enabled.
|
||||
@ -2841,24 +2846,30 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
"TORCH_NCCL_ASYNC_ERROR_HANDLING"
|
||||
] = prev_nccl_async_error_handling
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(3)
|
||||
@skip_if_rocm_multiprocess
|
||||
def test_nccl_errors_blocking(self):
|
||||
self._reduce_timeout()
|
||||
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0"
|
||||
def _test_nccl_errors_blocking(self, func):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
process_group = c10d.ProcessGroupNCCL(
|
||||
store,
|
||||
self.rank,
|
||||
self.world_size,
|
||||
timeout=timedelta(seconds=10),
|
||||
)
|
||||
x = torch.rand(1024 * 1024).cuda(self.rank)
|
||||
process_group.allreduce(x)
|
||||
process_group.allreduce(torch.rand(10).cuda(self.rank))
|
||||
if self.rank == 0:
|
||||
work = process_group.allreduce(x)
|
||||
work = process_group.allreduce(torch.rand(10).cuda(self.rank))
|
||||
with self.assertRaisesRegex(dist.DistBackendError, ""):
|
||||
# It seems the error message would be different depending on
|
||||
# whether the test is run on CI machine and devGPU. Skipping
|
||||
# the error message check to make both sides happy.
|
||||
work.wait(timeout=timedelta(seconds=self.op_timeout_sec))
|
||||
# Run some GPU operations to make sure cuda has not gotten stuck.
|
||||
# It was observed cuda could get stuck if NCCL communicators were
|
||||
# not properly aborted before throwing RuntimeError.
|
||||
torch.rand(10).cuda(self.rank)
|
||||
elif self.rank == 1:
|
||||
# Clean up structures (ex: files for FileStore before going down)
|
||||
del process_group
|
||||
func()
|
||||
|
||||
def _test_barrier_error(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
@ -2878,19 +2889,60 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
timeout=timedelta(seconds=self.op_timeout_sec)
|
||||
)
|
||||
|
||||
@with_nccl_blocking_wait
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||
@skip_if_lt_x_gpu(3)
|
||||
@skip_if_rocm_multiprocess
|
||||
def test_nccl_errors_blocking_clean_exit(self):
|
||||
self._test_nccl_errors_blocking(lambda: sys.exit(0))
|
||||
|
||||
@with_nccl_blocking_wait
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||
@skip_if_lt_x_gpu(3)
|
||||
@skip_if_rocm_multiprocess
|
||||
def test_nccl_errors_blocking_nonzero_exit(self):
|
||||
self._test_nccl_errors_blocking(lambda: sys.exit(1))
|
||||
|
||||
@with_nccl_blocking_wait
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||
@skip_if_lt_x_gpu(3)
|
||||
@skip_if_rocm_multiprocess
|
||||
@skip_but_pass_in_sandcastle(
|
||||
"Frequently times out see https://github.com/pytorch/pytorch/issues/58920"
|
||||
)
|
||||
def test_nccl_errors_blocking_abort(self):
|
||||
self._test_nccl_errors_blocking(lambda: os.abort())
|
||||
|
||||
@with_nccl_blocking_wait
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||
@skip_if_lt_x_gpu(3)
|
||||
@skip_if_rocm_multiprocess
|
||||
def test_nccl_errors_blocking_sigkill(self):
|
||||
self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGKILL))
|
||||
|
||||
@with_nccl_blocking_wait
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||
@skip_if_lt_x_gpu(3)
|
||||
@skip_if_rocm_multiprocess
|
||||
def test_nccl_errors_blocking_sigterm(self):
|
||||
self._test_nccl_errors_blocking(lambda: os.kill(os.getpid(), signal.SIGTERM))
|
||||
|
||||
@with_nccl_blocking_wait
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||
@skip_if_lt_x_gpu(3)
|
||||
def test_nccl_blocking_wait_with_barrier(self):
|
||||
self._reduce_timeout()
|
||||
self._test_barrier_error()
|
||||
|
||||
@requires_nccl()
|
||||
@requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking")
|
||||
@skip_if_lt_x_gpu(3)
|
||||
def test_nccl_non_blocking_wait_with_barrier(self):
|
||||
self._reduce_timeout()
|
||||
# test the barrier behavior in the non blocking wait setting
|
||||
prev_nccl_async_error_handling = os.environ.get(
|
||||
"TORCH_NCCL_ASYNC_ERROR_HANDLING", None
|
||||
@ -2961,7 +3013,6 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
@skip_if_rocm_multiprocess
|
||||
@skip_if_lt_x_gpu(3)
|
||||
def test_restart_pg_after_error(self):
|
||||
self._reduce_timeout()
|
||||
# test the barrier behavior in the non blocking wait setting
|
||||
prev_nccl_async_error_handling = os.environ.get(
|
||||
"TORCH_NCCL_ASYNC_ERROR_HANDLING", None
|
||||
@ -3051,6 +3102,45 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
self._run_invalid_nccl_blocking_wait_env("2147483647")
|
||||
self._run_invalid_nccl_blocking_wait_env("4294967295")
|
||||
|
||||
@with_nccl_blocking_wait
|
||||
@requires_nccl()
|
||||
@requires_gloo()
|
||||
@skip_if_lt_x_gpu(3)
|
||||
def test_nccl_timeout(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
|
||||
# Initialize process_group.
|
||||
process_group = c10d.ProcessGroupNCCL(
|
||||
store, self.rank, self.world_size, timeout=timedelta(seconds=10)
|
||||
)
|
||||
# Control gloo pg used as go-ahead signal/barrier
|
||||
# to coordinate btwn ranks.
|
||||
pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
|
||||
failed_collective_timeout = timedelta(milliseconds=100)
|
||||
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(
|
||||
timeout=timedelta(seconds=5)
|
||||
)
|
||||
|
||||
if self.rank == 0:
|
||||
# This should timeout in about 1 second.
|
||||
# Watchdog may abort timed out work resulting in NCCL error instead of operation timed out.
|
||||
with self.assertRaisesRegex(
|
||||
dist.DistBackendError, self.blocking_wait_error_msg
|
||||
):
|
||||
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(
|
||||
timeout=failed_collective_timeout
|
||||
)
|
||||
# Now do a barrier to tell other rank to go ahead.
|
||||
pg_gloo.barrier().wait()
|
||||
else:
|
||||
# Wait on rank 0 to fail.
|
||||
try:
|
||||
pg_gloo.barrier().wait()
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Rank {self.rank} barrier timed out waiting for rank 0 with error: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
class NcclUserBufferRegistrationTest(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
|
@ -640,15 +640,7 @@ class MultiProcessTestCase(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
|
||||
# Used for tests that are expected to return a non-0 exit code, such as
|
||||
# SIGABRT thrown by watchdog.
|
||||
self.special_return_code_checks: dict = {}
|
||||
|
||||
# Used for tests that may return any exit code, which makes it hard to
|
||||
# check. This is rare, use with caution.
|
||||
self.skip_return_code_checks: list = []
|
||||
|
||||
self.skip_return_code_checks = [] # type: ignore[var-annotated]
|
||||
self.processes = [] # type: ignore[var-annotated]
|
||||
self.rank = self.MAIN_PROCESS_RANK
|
||||
self.file_name = tempfile.NamedTemporaryFile(delete=False).name
|
||||
@ -873,13 +865,28 @@ class MultiProcessTestCase(TestCase):
|
||||
time.sleep(0.1)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
self._check_return_codes(fn, elapsed_time)
|
||||
|
||||
if fn in self.skip_return_code_checks:
|
||||
self._check_no_test_errors(elapsed_time)
|
||||
else:
|
||||
self._check_return_codes(elapsed_time)
|
||||
finally:
|
||||
# Close all pipes
|
||||
for pipe in self.pid_to_pipe.values():
|
||||
pipe.close()
|
||||
|
||||
def _check_return_codes(self, fn, elapsed_time) -> None:
|
||||
def _check_no_test_errors(self, elapsed_time) -> None:
|
||||
"""
|
||||
Checks that we didn't have any errors thrown in the child processes.
|
||||
"""
|
||||
for i, p in enumerate(self.processes):
|
||||
if p.exitcode is None:
|
||||
raise RuntimeError(
|
||||
f"Process {i} timed out after {elapsed_time} seconds"
|
||||
)
|
||||
self.assertNotEqual(self.TEST_ERROR_EXIT_CODE, p.exitcode)
|
||||
|
||||
def _check_return_codes(self, elapsed_time) -> None:
|
||||
"""
|
||||
Checks that the return codes of all spawned processes match, and skips
|
||||
tests if they returned a return code indicating a skipping condition.
|
||||
@ -921,11 +928,11 @@ class MultiProcessTestCase(TestCase):
|
||||
raise RuntimeError(
|
||||
f"Process {i} terminated or timed out after {elapsed_time} seconds"
|
||||
)
|
||||
|
||||
# Skip the test return code check
|
||||
if fn in self.skip_return_code_checks:
|
||||
return
|
||||
|
||||
self.assertEqual(
|
||||
p.exitcode,
|
||||
first_process.exitcode,
|
||||
msg=f"Expect process {i} exit code to match Process 0 exit code of {first_process.exitcode}, but got {p.exitcode}",
|
||||
)
|
||||
for skip in TEST_SKIPS.values():
|
||||
if first_process.exitcode == skip.exit_code:
|
||||
if IS_SANDCASTLE:
|
||||
@ -941,18 +948,10 @@ class MultiProcessTestCase(TestCase):
|
||||
return
|
||||
else:
|
||||
raise unittest.SkipTest(skip.message)
|
||||
|
||||
# In most cases, we expect test to return exit code 0, standing for success.
|
||||
expected_return_code = 0
|
||||
# In some negative tests, we expect test to return non-zero exit code,
|
||||
# such as watchdog throwing SIGABRT.
|
||||
if fn in self.special_return_code_checks:
|
||||
expected_return_code = self.special_return_code_checks[fn]
|
||||
|
||||
self.assertEqual(
|
||||
first_process.exitcode,
|
||||
expected_return_code,
|
||||
msg=f"Expected exit code {expected_return_code} but got {first_process.exitcode} for pid: {first_process.pid}",
|
||||
0,
|
||||
msg=f"Expected zero exit code but got {first_process.exitcode} for pid: {first_process.pid}",
|
||||
)
|
||||
|
||||
@property
|
||||
|
Reference in New Issue
Block a user