mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-16 23:44:53 +08:00
Revert " C10D extension to enable per-thread PG (#86348)"
This reverts commit 97abc21f2bda38e73de2a86da7f43c8126930681.
Reverted https://github.com/pytorch/pytorch/pull/86348 on behalf of https://github.com/huydhn due to Sorry for reverting your PR but it breaks macos tests 97abc21f2b
This commit is contained in:
@ -36,9 +36,6 @@ from torch.testing._internal.common_utils import (
|
||||
sandcastle_skip_if,
|
||||
sandcastle_skip,
|
||||
)
|
||||
from torch.testing._internal.distributed.multi_threaded_pg import (
|
||||
run_with_threaded_pg
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -468,8 +465,6 @@ def cleanup_temp_dir() -> None:
|
||||
if tmp_dir is not None:
|
||||
tmp_dir.cleanup()
|
||||
|
||||
# Most tests operate with thi worldsize
|
||||
DEFAULT_WORLD_SIZE = 4
|
||||
|
||||
# [How does MultiProcessTestCase work?]
|
||||
# Each MultiProcessTestCase instance uses 1 + `world_size()` processes, by
|
||||
@ -482,8 +477,6 @@ DEFAULT_WORLD_SIZE = 4
|
||||
# from the test instance and run it. The main process simply waits for all
|
||||
# subprocesses to join.
|
||||
|
||||
# Most tests operate with thi worldsize
|
||||
DEFAULT_WORLD_SIZE = 4
|
||||
|
||||
class MultiProcessTestCase(TestCase):
|
||||
MAIN_PROCESS_RANK = -1
|
||||
@ -499,7 +492,7 @@ class MultiProcessTestCase(TestCase):
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return DEFAULT_WORLD_SIZE
|
||||
return 4
|
||||
|
||||
def join_or_run(self, fn):
|
||||
@wraps(fn)
|
||||
@ -841,71 +834,3 @@ def tp_transports():
|
||||
see https://github.com/pytorch/pytorch/issues/73885 and https://github.com/pytorch/pytorch/issues/65022
|
||||
"""
|
||||
return ["shm", "uv"] if has_efa() else None
|
||||
|
||||
|
||||
def _run_test_with_mt_pg(self, timeout, world_size, callback):
|
||||
failed_ranks = run_with_threaded_pg(world_size, timeout, callback)
|
||||
for rank, exc_info in failed_ranks:
|
||||
print(f"Rank {rank} raised:")
|
||||
for line in traceback.format_exception(*exc_info):
|
||||
sys.stdout.write(line)
|
||||
self.assertEqual([], failed_ranks, "Some ranks failed")
|
||||
|
||||
def spawn_threads_and_init_comms(func=None, timeout=TIMEOUT_DEFAULT, world_size=DEFAULT_WORLD_SIZE):
|
||||
"""
|
||||
Wrapper to use with a test method
|
||||
"""
|
||||
if func is None:
|
||||
return partial(spawn_threads_and_init_comms, timeout=timeout, world_size=world_size)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
_run_test_with_mt_pg(self, timeout, world_size, lambda: func(self, *args, **kwargs))
|
||||
|
||||
|
||||
return wrapper
|
||||
|
||||
class MultiThreadedTestCase(TestCase):
|
||||
"""
|
||||
Simple test runner that executes all tests with the in-proc process group.
|
||||
|
||||
A single instance of the TestCase object for all threads.
|
||||
|
||||
Difference from regular test runner:
|
||||
Cannot use setUp / tearDown (must use perThreadSetup / perThreadShutdown)
|
||||
Not sure what these two would be good for though.
|
||||
No global state possible
|
||||
How bad of a limitation is this?
|
||||
"""
|
||||
|
||||
def __init__(self, method_name: str = "runTest") -> None:
|
||||
super().__init__(method_name)
|
||||
self._test_method = getattr(self, method_name, None)
|
||||
setattr(self, method_name, self.threaded_run_test)
|
||||
if TestCase.setUp != type(self).setUp:
|
||||
raise RuntimeError(f"Test class {type(self)} overrides disabled method setUp. Use perThreadSetUp instead")
|
||||
if TestCase.tearDown != type(self).tearDown:
|
||||
raise RuntimeError(f"Test class {type(self)} overrides disabled method tearDown. Use perThreadTearDown instead")
|
||||
|
||||
|
||||
def threaded_run_test(self):
|
||||
self.perThreadSetUp()
|
||||
try:
|
||||
_run_test_with_mt_pg(
|
||||
self=self,
|
||||
timeout=TIMEOUT_DEFAULT,
|
||||
world_size=self.world_size,
|
||||
callback=self._test_method,
|
||||
)
|
||||
finally:
|
||||
self.perThreadTearDown()
|
||||
|
||||
def perThreadSetUp(self):
|
||||
pass
|
||||
|
||||
def perThreadTearDown(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
raise RuntimeError("world size not implemented")
|
||||
|
||||
Reference in New Issue
Block a user