mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 08:24:57 +08:00
[reland] Fix flaky test_nccl_timeout (#68544)
Summary: Fixes https://github.com/pytorch/pytorch/issues/66882 In addition to changes in https://github.com/pytorch/pytorch/pull/68403, add one more error check that can be raised when a collective times out cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang Pull Request resolved: https://github.com/pytorch/pytorch/pull/68544 Reviewed By: albanD Differential Revision: D32508706 Pulled By: rohan-varma fbshipit-source-id: 7d41b91f547d4ad763c44cd11e7b9914b452b617
This commit is contained in:
committed by
Facebook GitHub Bot
parent
875ba3dddb
commit
183dcdf551
@ -33,6 +33,7 @@ from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.testing._internal.common_distributed import (
|
||||
MultiProcessTestCase,
|
||||
requires_nccl,
|
||||
requires_gloo,
|
||||
requires_nccl_version,
|
||||
skip_if_lt_x_gpu,
|
||||
get_timeout,
|
||||
@ -2366,15 +2367,28 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
self._run_invalid_nccl_blocking_wait_env("2147483647")
|
||||
self._run_invalid_nccl_blocking_wait_env("4294967295")
|
||||
|
||||
def _wait_for_comm_abort(self, process_group):
|
||||
def _check_valid_comm_exception(self, e):
|
||||
exception_str = str(e)
|
||||
valid_exceptions = [
|
||||
"NCCL communicator was aborted",
|
||||
"NCCL communicator encountered error",
|
||||
"Caught collective operation timeout"
|
||||
]
|
||||
return any(exc in exception_str for exc in valid_exceptions)
|
||||
|
||||
def _wait_for_comm_abort(self, process_group, timeout=None):
|
||||
"""
|
||||
Waits for the watchdog thread to abort communicators for the process group.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
process_group.allreduce(torch.rand(10).cuda(self.rank))
|
||||
if not timeout:
|
||||
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait()
|
||||
else:
|
||||
assert isinstance(timeout, timedelta)
|
||||
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=timeout)
|
||||
except Exception as e:
|
||||
if "NCCL communicator was aborted" in str(e):
|
||||
if self._check_valid_comm_exception(e):
|
||||
return
|
||||
else:
|
||||
raise e
|
||||
@ -2382,6 +2396,7 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
|
||||
@with_nccl_blocking_wait
|
||||
@requires_nccl()
|
||||
@requires_gloo()
|
||||
@skip_if_lt_x_gpu(3)
|
||||
def test_nccl_timeout(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
@ -2390,18 +2405,28 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
|
||||
process_group = c10d.ProcessGroupNCCL(
|
||||
store, self.rank, self.world_size, timeout=timedelta(seconds=10)
|
||||
)
|
||||
# Control gloo pg used as go-ahead signal/barrier
|
||||
# to coordinate btwn ranks.
|
||||
pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
|
||||
failed_collective_timeout = timedelta(milliseconds=100)
|
||||
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=timedelta(seconds=5))
|
||||
|
||||
if self.rank == 0:
|
||||
# This should timeout in about 1 second.
|
||||
# Watchdog may abort timed out work resulting in NCCL error instead of operation timed out.
|
||||
with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg):
|
||||
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=timedelta(seconds=1))
|
||||
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=failed_collective_timeout)
|
||||
# Now do a barrier to tell other rank to go ahead.
|
||||
pg_gloo.barrier().wait()
|
||||
else:
|
||||
# Sleep to ensure timeout.
|
||||
time.sleep(10)
|
||||
|
||||
self._wait_for_comm_abort(process_group)
|
||||
# Wait on rank 0 to fail.
|
||||
try:
|
||||
pg_gloo.barrier().wait()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Rank {self.rank} barrier timed out waiting for rank 0 with error: {str(e)}")
|
||||
# Now verify communicators on this rank have
|
||||
# been aborted by watchdog.
|
||||
self._wait_for_comm_abort(process_group, failed_collective_timeout)
|
||||
|
||||
|
||||
class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
||||
|
||||
Reference in New Issue
Block a user