[c10d] Add assertRaisesRegexOnRank helper for distributed

Allow asserting that an exception is raised only on the specified rank,
but not on other ranks.  Useful expecially for pipeline parallelism.

ghstack-source-id: 7a27f9e128f465e52e503617914261af4dbbbb41
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126731
This commit is contained in:
Will Constable
2024-05-20 17:00:26 -07:00
parent f75b79493e
commit 3b19956cb2

View File

@ -15,7 +15,7 @@ import time
import traceback
import types
import unittest
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from datetime import timedelta
from enum import Enum
@ -540,6 +540,11 @@ class MultiProcessTestCase(TestCase):
return types.MethodType(wrapper, self)
def assertRaisesRegexOnRank(self, rank, expected_exception, expected_regex, *args, **kwargs):
if self.rank == rank:
return self.assertRaisesRegex(expected_exception, expected_regex, *args, **kwargs)
return nullcontext()
# The main process spawns N subprocesses that run the test.
# Constructor patches current instance test method to
# assume the role of the main process and join its subprocesses,
@ -1303,6 +1308,11 @@ class MultiProcContinousTest(TestCase):
# Rendezvous file
rdvz_file: Optional[str] = None
def assertRaisesRegexOnRank(self, rank, expected_exception, expected_regex, *args, **kwargs):
if self.rank == rank:
return self.assertRaisesRegex(expected_exception, expected_regex, *args, **kwargs)
return nullcontext()
@classmethod
@abc.abstractmethod
def backend_str(cls) -> str: