[reland] Fix flaky test_nccl_timeout (#68544)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/66882

In addition to changes in https://github.com/pytorch/pytorch/pull/68403, add one more error check that can be raised when a collective times out

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/68544

Reviewed By: albanD

Differential Revision: D32508706

Pulled By: rohan-varma

fbshipit-source-id: 7d41b91f547d4ad763c44cd11e7b9914b452b617
This commit is contained in:
Rohan Varma
2021-11-19 13:18:46 -08:00
committed by Facebook GitHub Bot
parent 875ba3dddb
commit 183dcdf551

View File

@ -33,6 +33,7 @@ from torch.nn.parallel import DistributedDataParallel
from torch.testing._internal.common_distributed import (
MultiProcessTestCase,
requires_nccl,
requires_gloo,
requires_nccl_version,
skip_if_lt_x_gpu,
get_timeout,
@ -2366,15 +2367,28 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
self._run_invalid_nccl_blocking_wait_env("2147483647")
self._run_invalid_nccl_blocking_wait_env("4294967295")
def _wait_for_comm_abort(self, process_group):
def _check_valid_comm_exception(self, e):
exception_str = str(e)
valid_exceptions = [
"NCCL communicator was aborted",
"NCCL communicator encountered error",
"Caught collective operation timeout"
]
return any(exc in exception_str for exc in valid_exceptions)
def _wait_for_comm_abort(self, process_group, timeout=None):
"""
Waits for the watchdog thread to abort communicators for the process group.
"""
while True:
try:
process_group.allreduce(torch.rand(10).cuda(self.rank))
if not timeout:
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait()
else:
assert isinstance(timeout, timedelta)
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=timeout)
except Exception as e:
if "NCCL communicator was aborted" in str(e):
if self._check_valid_comm_exception(e):
return
else:
raise e
@ -2382,6 +2396,7 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
@with_nccl_blocking_wait
@requires_nccl()
@requires_gloo()
@skip_if_lt_x_gpu(3)
def test_nccl_timeout(self):
store = c10d.FileStore(self.file_name, self.world_size)
@ -2390,18 +2405,28 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
process_group = c10d.ProcessGroupNCCL(
store, self.rank, self.world_size, timeout=timedelta(seconds=10)
)
# Control gloo pg used as go-ahead signal/barrier
# to coordinate btwn ranks.
pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
failed_collective_timeout = timedelta(milliseconds=100)
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=timedelta(seconds=5))
if self.rank == 0:
# This should timeout in about 1 second.
# Watchdog may abort timed out work resulting in NCCL error instead of operation timed out.
with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg):
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=timedelta(seconds=1))
process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=failed_collective_timeout)
# Now do a barrier to tell other rank to go ahead.
pg_gloo.barrier().wait()
else:
# Sleep to ensure timeout.
time.sleep(10)
self._wait_for_comm_abort(process_group)
# Wait on rank 0 to fail.
try:
pg_gloo.barrier().wait()
except Exception as e:
raise ValueError(f"Rank {self.rank} barrier timed out waiting for rank 0 with error: {str(e)}")
# Now verify communicators on this rank have
# been aborted by watchdog.
self._wait_for_comm_abort(process_group, failed_collective_timeout)
class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):