mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-19 10:04:58 +08:00
[c10d] Add assertRaisesRegexOnRank helper for distributed
Allow asserting that an exception is raised only on the specified rank, but not on other ranks. Useful expecially for pipeline parallelism. ghstack-source-id: 7a27f9e128f465e52e503617914261af4dbbbb41 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126731
This commit is contained in:
@ -15,7 +15,7 @@ import time
|
||||
import traceback
|
||||
import types
|
||||
import unittest
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
@ -540,6 +540,11 @@ class MultiProcessTestCase(TestCase):
|
||||
|
||||
return types.MethodType(wrapper, self)
|
||||
|
||||
def assertRaisesRegexOnRank(self, rank, expected_exception, expected_regex, *args, **kwargs):
|
||||
if self.rank == rank:
|
||||
return self.assertRaisesRegex(expected_exception, expected_regex, *args, **kwargs)
|
||||
return nullcontext()
|
||||
|
||||
# The main process spawns N subprocesses that run the test.
|
||||
# Constructor patches current instance test method to
|
||||
# assume the role of the main process and join its subprocesses,
|
||||
@ -1303,6 +1308,11 @@ class MultiProcContinousTest(TestCase):
|
||||
# Rendezvous file
|
||||
rdvz_file: Optional[str] = None
|
||||
|
||||
def assertRaisesRegexOnRank(self, rank, expected_exception, expected_regex, *args, **kwargs):
|
||||
if self.rank == rank:
|
||||
return self.assertRaisesRegex(expected_exception, expected_regex, *args, **kwargs)
|
||||
return nullcontext()
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def backend_str(cls) -> str:
|
||||
|
||||
Reference in New Issue
Block a user