mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-11-04 16:04:58 +08:00 
			
		
		
		
	[reland] Fix flaky test_nccl_timeout (#68544)
Summary: Fixes https://github.com/pytorch/pytorch/issues/66882 In addition to changes in https://github.com/pytorch/pytorch/pull/68403, add one more error check that can be raised when a collective times out cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang Pull Request resolved: https://github.com/pytorch/pytorch/pull/68544 Reviewed By: albanD Differential Revision: D32508706 Pulled By: rohan-varma fbshipit-source-id: 7d41b91f547d4ad763c44cd11e7b9914b452b617
This commit is contained in:
		
				
					committed by
					
						
						Facebook GitHub Bot
					
				
			
			
				
	
			
			
			
						parent
						
							875ba3dddb
						
					
				
				
					commit
					183dcdf551
				
			@ -33,6 +33,7 @@ from torch.nn.parallel import DistributedDataParallel
 | 
			
		||||
from torch.testing._internal.common_distributed import (
 | 
			
		||||
    MultiProcessTestCase,
 | 
			
		||||
    requires_nccl,
 | 
			
		||||
    requires_gloo,
 | 
			
		||||
    requires_nccl_version,
 | 
			
		||||
    skip_if_lt_x_gpu,
 | 
			
		||||
    get_timeout,
 | 
			
		||||
@ -2366,15 +2367,28 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
 | 
			
		||||
        self._run_invalid_nccl_blocking_wait_env("2147483647")
 | 
			
		||||
        self._run_invalid_nccl_blocking_wait_env("4294967295")
 | 
			
		||||
 | 
			
		||||
    def _wait_for_comm_abort(self, process_group):
 | 
			
		||||
    def _check_valid_comm_exception(self, e):
 | 
			
		||||
        exception_str = str(e)
 | 
			
		||||
        valid_exceptions = [
 | 
			
		||||
            "NCCL communicator was aborted",
 | 
			
		||||
            "NCCL communicator encountered error",
 | 
			
		||||
            "Caught collective operation timeout"
 | 
			
		||||
        ]
 | 
			
		||||
        return any(exc in exception_str for exc in valid_exceptions)
 | 
			
		||||
 | 
			
		||||
    def _wait_for_comm_abort(self, process_group, timeout=None):
 | 
			
		||||
        """
 | 
			
		||||
        Waits for the watchdog thread to abort communicators for the process group.
 | 
			
		||||
        """
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
                process_group.allreduce(torch.rand(10).cuda(self.rank))
 | 
			
		||||
                if not timeout:
 | 
			
		||||
                    process_group.allreduce(torch.rand(10).cuda(self.rank)).wait()
 | 
			
		||||
                else:
 | 
			
		||||
                    assert isinstance(timeout, timedelta)
 | 
			
		||||
                    process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=timeout)
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                if "NCCL communicator was aborted" in str(e):
 | 
			
		||||
                if self._check_valid_comm_exception(e):
 | 
			
		||||
                    return
 | 
			
		||||
                else:
 | 
			
		||||
                    raise e
 | 
			
		||||
@ -2382,6 +2396,7 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
 | 
			
		||||
 | 
			
		||||
    @with_nccl_blocking_wait
 | 
			
		||||
    @requires_nccl()
 | 
			
		||||
    @requires_gloo()
 | 
			
		||||
    @skip_if_lt_x_gpu(3)
 | 
			
		||||
    def test_nccl_timeout(self):
 | 
			
		||||
        store = c10d.FileStore(self.file_name, self.world_size)
 | 
			
		||||
@ -2390,18 +2405,28 @@ class NcclErrorHandlingTest(MultiProcessTestCase):
 | 
			
		||||
        process_group = c10d.ProcessGroupNCCL(
 | 
			
		||||
            store, self.rank, self.world_size, timeout=timedelta(seconds=10)
 | 
			
		||||
        )
 | 
			
		||||
        # Control gloo pg used as go-ahead signal/barrier
 | 
			
		||||
        # to coordinate btwn ranks.
 | 
			
		||||
        pg_gloo = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
 | 
			
		||||
        failed_collective_timeout = timedelta(milliseconds=100)
 | 
			
		||||
        process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=timedelta(seconds=5))
 | 
			
		||||
 | 
			
		||||
        if self.rank == 0:
 | 
			
		||||
            # This should timeout in about 1 second.
 | 
			
		||||
            # Watchdog may abort timed out work resulting in NCCL error instead of operation timed out.
 | 
			
		||||
            with self.assertRaisesRegex(RuntimeError, self.blocking_wait_error_msg):
 | 
			
		||||
                process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=timedelta(seconds=1))
 | 
			
		||||
                process_group.allreduce(torch.rand(10).cuda(self.rank)).wait(timeout=failed_collective_timeout)
 | 
			
		||||
            # Now do a barrier to tell other rank to go ahead.
 | 
			
		||||
            pg_gloo.barrier().wait()
 | 
			
		||||
        else:
 | 
			
		||||
            # Sleep to ensure timeout.
 | 
			
		||||
            time.sleep(10)
 | 
			
		||||
 | 
			
		||||
            self._wait_for_comm_abort(process_group)
 | 
			
		||||
            # Wait on rank 0 to fail.
 | 
			
		||||
            try:
 | 
			
		||||
                pg_gloo.barrier().wait()
 | 
			
		||||
            except Exception as e:
 | 
			
		||||
                raise ValueError(f"Rank {self.rank} barrier timed out waiting for rank 0 with error: {str(e)}")
 | 
			
		||||
            # Now verify communicators on this rank have
 | 
			
		||||
            # been aborted by watchdog.
 | 
			
		||||
            self._wait_for_comm_abort(process_group, failed_collective_timeout)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user