mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
C10D extension to enable per-thread PG (#86348)
Move a bunch of globals to instance methods and replace all use to them. We move all PG related globals under World and use a singleton instance under _world. This creates an undocumented extension point to inject full control of how how c10d state behaves. One simple hack is to change _world to an implementation that uses a threadlocal and enable per-thread PGs. It almost get DDP working and the PG is missing an implementation of all_reduce. This enables notebook usage of PTD, which is a big deal for learning it: https://gist.github.com/kumpera/32cb051fa26b8cad8bdf671f968dcd68 This change ensures BC by keeping the global variables around and have the default _World wrap it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/86348 Approved by: https://github.com/rohan-varma
This commit is contained in:
committed by
PyTorch MergeBot
parent
66979fbfaa
commit
97abc21f2b
45
test/distributed/test_multi_threaded_pg.py
Normal file
45
test/distributed/test_multi_threaded_pg.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import sys
|
||||
import torch.distributed as dist
|
||||
|
||||
if not dist.is_available():
|
||||
print("Distributed not available, skipping tests", file=sys.stderr)
|
||||
sys.exit(0)
|
||||
|
||||
from torch.testing._internal.common_distributed import (
|
||||
spawn_threads_and_init_comms,
|
||||
MultiThreadedTestCase
|
||||
|
||||
)
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
DEFAULT_WORLD_SIZE = 4
|
||||
|
||||
class TestObjectCollectivesWithWrapper(TestCase):
|
||||
@spawn_threads_and_init_comms(world_size=4)
|
||||
def test_broadcast_object_list(self):
|
||||
val = 99 if dist.get_rank() == 0 else None
|
||||
object_list = [val] * dist.get_world_size()
|
||||
|
||||
dist.broadcast_object_list(object_list=object_list)
|
||||
self.assertEqual(99, object_list[0])
|
||||
|
||||
class TestObjectCollectivesWithBaseClass(MultiThreadedTestCase):
|
||||
@property
|
||||
def world_size(self):
|
||||
return 4
|
||||
|
||||
def test_broadcast_object_list(self):
|
||||
val = 99 if dist.get_rank() == 0 else None
|
||||
object_list = [val] * dist.get_world_size()
|
||||
print(f"{dist.get_rank()} -> {dist.get_world_size()}")
|
||||
|
||||
dist.broadcast_object_list(object_list=object_list)
|
||||
self.assertEqual(99, object_list[0])
|
||||
|
||||
def test_something_else(self):
|
||||
pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -81,5 +81,8 @@ else:
|
||||
# python test/test_public_bindings.py -k test_correct_module_names
|
||||
# working even when USE_DISTRIBUTED=0. Feel free to add more
|
||||
# stubs as necessary.
|
||||
class ProcessGroup: # type: ignore[no-redef]
|
||||
# We cannot define stubs directly because they confuse pyre
|
||||
|
||||
class _ProcessGroupStub:
|
||||
pass
|
||||
sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined]
|
||||
|
@ -258,32 +258,110 @@ class _reduce_op(object):
|
||||
reduce_op = _reduce_op()
|
||||
|
||||
|
||||
class group(object):
|
||||
# DO NOT USE THIS FIELDS DIRECTLY.
|
||||
# Use them through the _world object to make sure the _world override mecanism
|
||||
_pg_map: Dict[ProcessGroup, Tuple[str, Optional[Store]]] = {}
|
||||
_pg_names: Dict[ProcessGroup, str] = {}
|
||||
_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}
|
||||
_group_count = 0
|
||||
|
||||
class _World:
|
||||
"""
|
||||
Container class for c10d process group state.
|
||||
This is used during registration and lookup of PG state.
|
||||
|
||||
.. warning:: This is an experimental API inteded to expose the inner workings
|
||||
of c10d and is subject to change..
|
||||
"""
|
||||
def __init__(self):
|
||||
self._default_pg = None
|
||||
|
||||
@property
|
||||
def default_pg(self):
|
||||
"""
|
||||
The default ProcessGroup includes all ranks of the cluster.
|
||||
This is used by c10d APIs when a ProcessGroup is needed but None is provided.
|
||||
"""
|
||||
return self._default_pg
|
||||
|
||||
@default_pg.setter
|
||||
def default_pg(self, value):
|
||||
self._default_pg = value
|
||||
|
||||
@property
|
||||
def pg_map(self) -> Dict[ProcessGroup, Tuple[str, Optional[Store]]]:
|
||||
"""
|
||||
Cached process groups
|
||||
For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store)
|
||||
For MPI pg, it is a map from ProcessGroup to (Backend, None)
|
||||
|
||||
TODO don't expose the map, expose fine grained ops
|
||||
"""
|
||||
global _pg_map
|
||||
return _pg_map
|
||||
|
||||
@property
|
||||
def pg_names(self) -> Dict[ProcessGroup, str]:
|
||||
"""
|
||||
Process group's names, map from ProcessGroup to str.
|
||||
|
||||
TODO don't expose the map, expose fine grained ops
|
||||
"""
|
||||
global _pg_names
|
||||
return _pg_names
|
||||
|
||||
@property
|
||||
def pg_group_ranks(self) -> Dict[ProcessGroup, Dict[int, int]]:
|
||||
"""
|
||||
Process group's global rank to local rank mapping
|
||||
TODO don't expose the map, expose fine grained ops
|
||||
"""
|
||||
global _pg_group_ranks
|
||||
return _pg_group_ranks
|
||||
|
||||
@property
|
||||
def group_count(self) -> int:
|
||||
"""
|
||||
Process group count for default naming.
|
||||
|
||||
TODO don't expose group_count, use something else instead
|
||||
"""
|
||||
global _group_count
|
||||
return _group_count
|
||||
|
||||
@group_count.setter
|
||||
def group_count(self, value):
|
||||
"""
|
||||
Count is used when computing the name of ProcessGroups when using global synchronization.
|
||||
"""
|
||||
global _group_count
|
||||
_group_count = value
|
||||
|
||||
|
||||
_world = _World()
|
||||
"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
|
||||
|
||||
class _WorldMeta(type):
|
||||
"""
|
||||
Meta class of ``group`` and ``GroupMember`` so they
|
||||
can have the class property ``WORLD``.
|
||||
"""
|
||||
# Points to the default PG once initialized.
|
||||
WORLD: Optional[ProcessGroup] = None
|
||||
@property
|
||||
def WORLD(cls) -> Optional[ProcessGroup]:
|
||||
return _world.default_pg
|
||||
|
||||
class group(object, metaclass=_WorldMeta):
|
||||
pass
|
||||
|
||||
class GroupMember(object):
|
||||
# Alias to group.WORLD for backward compatibility
|
||||
WORLD = group.WORLD
|
||||
class GroupMember(object, metaclass=_WorldMeta):
|
||||
NON_GROUP_MEMBER = object()
|
||||
|
||||
|
||||
# Cached process groups
|
||||
# For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store)
|
||||
# For MPI pg, it is a map from ProcessGroup to (Backend, None)
|
||||
_pg_map: Dict[ProcessGroup, Tuple[str, Optional[Store]]] = {}
|
||||
# Process group's names, map from ProcessGroup to str
|
||||
_pg_names: Dict[ProcessGroup, str] = {}
|
||||
# Process group's global rank to local rank mapping
|
||||
_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}
|
||||
|
||||
# Default process group state
|
||||
_default_pg_init_method = None
|
||||
|
||||
# Process group count for default naming
|
||||
_group_count = 0
|
||||
|
||||
STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key"
|
||||
|
||||
|
||||
@ -303,7 +381,7 @@ def _store_based_barrier(rank, store, timeout):
|
||||
``init_process_group`` or ``new_group``. Intended to be used only with
|
||||
those two methods and is not a generic alternative to ``barrier()``.
|
||||
"""
|
||||
store_key = "{}:{}".format(STORE_BASED_BARRIER_PREFIX, _group_count)
|
||||
store_key = "{}:{}".format(STORE_BASED_BARRIER_PREFIX, _world.group_count)
|
||||
store.add(store_key, 1)
|
||||
logger.info("Added key: {} to store for rank: {}".format(store_key, rank))
|
||||
|
||||
@ -378,9 +456,9 @@ def get_group_rank(group: ProcessGroup, global_rank: int) -> int:
|
||||
"""
|
||||
if group is GroupMember.WORLD:
|
||||
return global_rank
|
||||
if group not in _pg_group_ranks:
|
||||
if group not in _world.pg_group_ranks:
|
||||
raise RuntimeError(f"Group {group} is not registered, please create group with torch.distributed.new_group API")
|
||||
group_ranks = _pg_group_ranks[group]
|
||||
group_ranks = _world.pg_group_ranks[group]
|
||||
if global_rank not in group_ranks:
|
||||
raise RuntimeError(f"Global rank {global_rank} is not part of group {group}")
|
||||
|
||||
@ -403,9 +481,9 @@ def get_global_rank(group: ProcessGroup, group_rank: int) -> int:
|
||||
"""
|
||||
if group is GroupMember.WORLD:
|
||||
return group_rank
|
||||
if group not in _pg_group_ranks:
|
||||
if group not in _world.pg_group_ranks:
|
||||
raise RuntimeError(f"Group {group} is not registered, please create group with torch.distributed.new_group API")
|
||||
for rank, grp_rank in _pg_group_ranks[group].items():
|
||||
for rank, grp_rank in _world.pg_group_ranks[group].items():
|
||||
if grp_rank == group_rank:
|
||||
return rank
|
||||
raise RuntimeError(f"Group rank {group_rank} is not part of group {group}")
|
||||
@ -432,7 +510,7 @@ def get_process_group_ranks(group: ProcessGroup):
|
||||
Returns:
|
||||
List of global ranks ordered by group rank.
|
||||
"""
|
||||
return list(_pg_group_ranks[group].keys())
|
||||
return list(_world.pg_group_ranks[group].keys())
|
||||
|
||||
def _get_group_size(group):
|
||||
"""
|
||||
@ -587,13 +665,12 @@ def _get_default_store():
|
||||
"please make sure to call init_process_group."
|
||||
)
|
||||
default_pg = _get_default_group()
|
||||
_, default_store = _pg_map[default_pg]
|
||||
_, default_store = _world.pg_map[default_pg]
|
||||
return default_store
|
||||
|
||||
|
||||
def _update_default_pg(pg):
|
||||
GroupMember.WORLD = group.WORLD = pg
|
||||
|
||||
_world.default_pg = pg
|
||||
|
||||
def get_backend(group: Optional[ProcessGroup] = None) -> str:
|
||||
"""
|
||||
@ -614,7 +691,7 @@ def get_backend(group: Optional[ProcessGroup] = None) -> str:
|
||||
pg = group
|
||||
if _rank_not_in_group(pg):
|
||||
raise RuntimeError("Invalid process group specified")
|
||||
pg_store = _pg_map.get(pg, None)
|
||||
pg_store = _world.pg_map.get(pg, None)
|
||||
assert pg_store is not None
|
||||
return pg_store[0]
|
||||
|
||||
@ -698,7 +775,8 @@ def init_process_group(
|
||||
on a system that supports MPI.
|
||||
|
||||
"""
|
||||
global _pg_group_ranks
|
||||
global _world
|
||||
|
||||
global _backend
|
||||
global _default_pg_init_method
|
||||
|
||||
@ -759,8 +837,8 @@ def init_process_group(
|
||||
)
|
||||
_update_default_pg(default_pg)
|
||||
|
||||
_pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())} # type: ignore[attr-defined, index]
|
||||
_backend = _pg_map[GroupMember.WORLD][0] # type: ignore[index]
|
||||
_world.pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())} # type: ignore[attr-defined, index]
|
||||
_backend = _world.pg_map[GroupMember.WORLD][0] # type: ignore[index]
|
||||
_default_pg_init_method = init_method
|
||||
|
||||
# barrier at the end to ensure that once we return from this method, all
|
||||
@ -797,15 +875,13 @@ def _new_process_group_helper(
|
||||
|
||||
This function is called with ``group_ranks == []`` for the default group.
|
||||
"""
|
||||
global _pg_map
|
||||
global _group_count
|
||||
global _pg_names
|
||||
global _world
|
||||
|
||||
if not group_name:
|
||||
group_name = str(_group_count)
|
||||
_group_count += 1
|
||||
group_name = str(_world.group_count)
|
||||
_world.group_count = _world.group_count + 1
|
||||
|
||||
if group_name in _pg_names.values():
|
||||
if group_name in _world.pg_names.values():
|
||||
raise RuntimeError(
|
||||
"The specified group name has already been "
|
||||
"created, please use a different group name"
|
||||
@ -831,8 +907,8 @@ def _new_process_group_helper(
|
||||
pg = ProcessGroupMPI.create(global_ranks_in_group)
|
||||
if not pg:
|
||||
return GroupMember.NON_GROUP_MEMBER
|
||||
_pg_map[pg] = (Backend.MPI, None)
|
||||
_pg_names[pg] = group_name
|
||||
_world.pg_map[pg] = (Backend.MPI, None)
|
||||
_world.pg_names[pg] = group_name
|
||||
else:
|
||||
# If this is a subgroup (which means group_ranks is specified),
|
||||
# we check if the current process is a member of the new group.
|
||||
@ -868,8 +944,8 @@ def _new_process_group_helper(
|
||||
world_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
_pg_map[pg] = (Backend.GLOO, store)
|
||||
_pg_names[pg] = group_name
|
||||
_world.pg_map[pg] = (Backend.GLOO, store)
|
||||
_world.pg_names[pg] = group_name
|
||||
elif backend == Backend.NCCL:
|
||||
if not is_nccl_available():
|
||||
raise RuntimeError("Distributed package doesn't have NCCL " "built in")
|
||||
@ -903,8 +979,8 @@ def _new_process_group_helper(
|
||||
world_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
_pg_map[pg] = (Backend.NCCL, store)
|
||||
_pg_names[pg] = group_name
|
||||
_world.pg_map[pg] = (Backend.NCCL, store)
|
||||
_world.pg_names[pg] = group_name
|
||||
elif backend == Backend.UCC and is_ucc_available():
|
||||
# TODO: once UCC plugin is fully deprecated, remove
|
||||
# is_ucc_available() from above elif-condition and raise
|
||||
@ -930,8 +1006,8 @@ def _new_process_group_helper(
|
||||
world_size=group_size,
|
||||
timeout=timeout,
|
||||
)
|
||||
_pg_map[pg] = (Backend.UCC, store)
|
||||
_pg_names[pg] = group_name
|
||||
_world.pg_map[pg] = (Backend.UCC, store)
|
||||
_world.pg_names[pg] = group_name
|
||||
else:
|
||||
assert backend.upper() in Backend._plugins, (
|
||||
f"Unknown c10d backend type {backend.upper()}"
|
||||
@ -953,8 +1029,8 @@ def _new_process_group_helper(
|
||||
dist_backend_opts.global_ranks_in_group = global_ranks_in_group
|
||||
|
||||
pg = creator_fn(dist_backend_opts, pg_options)
|
||||
_pg_map[pg] = (backend, store)
|
||||
_pg_names[pg] = group_name
|
||||
_world.pg_map[pg] = (backend, store)
|
||||
_world.pg_names[pg] = group_name
|
||||
|
||||
return pg
|
||||
|
||||
@ -969,11 +1045,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
|
||||
groups including the default one will
|
||||
be destroyed.
|
||||
"""
|
||||
global _pg_map
|
||||
global _pg_names
|
||||
global _pg_group_ranks
|
||||
global _default_pg_init_method
|
||||
global _group_count
|
||||
global _world
|
||||
|
||||
if group == GroupMember.NON_GROUP_MEMBER:
|
||||
return
|
||||
@ -984,29 +1056,28 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
|
||||
pg = group
|
||||
|
||||
assert pg is not None
|
||||
if _pg_map.get(pg, None) is None:
|
||||
if _world.pg_map.get(pg, None) is None:
|
||||
raise RuntimeError("Invalid process group specified")
|
||||
|
||||
if group is None or group == GroupMember.WORLD:
|
||||
_update_default_pg(None)
|
||||
_default_pg_init_method = None
|
||||
_pg_map.clear()
|
||||
_pg_names.clear()
|
||||
_pg_group_ranks.clear()
|
||||
_world.pg_map.clear()
|
||||
_world.pg_names.clear()
|
||||
_world.pg_group_ranks.clear()
|
||||
|
||||
# when process group doesn't have an explicit name (only WORLD (default)
|
||||
# process group can have an explicit name), we use global _group_counter
|
||||
# process group can have an explicit name), we use global _world.group_count
|
||||
# to generate the name. We need to reset the counter on destruction to
|
||||
# allow consistent value to be generated when we re-create process
|
||||
# groups after some trainers recover from failure
|
||||
#
|
||||
# We only reset this when WORLD is being destroyed because if this
|
||||
# process group is in good state, we aren't dealing with failures.
|
||||
_group_count = 0
|
||||
_world.group_count = 0
|
||||
else:
|
||||
del _pg_map[pg]
|
||||
del _pg_names[pg]
|
||||
del _pg_group_ranks[pg]
|
||||
del _world.pg_map[pg]
|
||||
del _world.pg_names[pg]
|
||||
del _world.pg_group_ranks[pg]
|
||||
|
||||
|
||||
def get_rank(group: Optional[ProcessGroup] = None) -> int:
|
||||
@ -3282,10 +3353,10 @@ def new_group(ranks=None, timeout=default_pg_timeout, backend=None, pg_options=N
|
||||
A handle of distributed group that can be given to collective calls.
|
||||
"""
|
||||
|
||||
global _pg_group_ranks
|
||||
global _world
|
||||
|
||||
default_pg = _get_default_group()
|
||||
default_backend, default_store = _pg_map[default_pg]
|
||||
default_backend, default_store = _world.pg_map[default_pg]
|
||||
global_rank = default_pg.rank()
|
||||
global_world_size = default_pg.size()
|
||||
|
||||
@ -3334,7 +3405,7 @@ def new_group(ranks=None, timeout=default_pg_timeout, backend=None, pg_options=N
|
||||
)
|
||||
|
||||
# Create the global rank to group rank mapping
|
||||
_pg_group_ranks[pg] = {
|
||||
_world.pg_group_ranks[pg] = {
|
||||
global_rank: group_rank for group_rank, global_rank in enumerate(ranks)
|
||||
}
|
||||
|
||||
|
@ -36,6 +36,9 @@ from torch.testing._internal.common_utils import (
|
||||
sandcastle_skip_if,
|
||||
sandcastle_skip,
|
||||
)
|
||||
from torch.testing._internal.distributed.multi_threaded_pg import (
|
||||
run_with_threaded_pg
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -465,6 +468,8 @@ def cleanup_temp_dir() -> None:
|
||||
if tmp_dir is not None:
|
||||
tmp_dir.cleanup()
|
||||
|
||||
# Most tests operate with thi worldsize
|
||||
DEFAULT_WORLD_SIZE = 4
|
||||
|
||||
# [How does MultiProcessTestCase work?]
|
||||
# Each MultiProcessTestCase instance uses 1 + `world_size()` processes, by
|
||||
@ -477,6 +482,8 @@ def cleanup_temp_dir() -> None:
|
||||
# from the test instance and run it. The main process simply waits for all
|
||||
# subprocesses to join.
|
||||
|
||||
# Most tests operate with thi worldsize
|
||||
DEFAULT_WORLD_SIZE = 4
|
||||
|
||||
class MultiProcessTestCase(TestCase):
|
||||
MAIN_PROCESS_RANK = -1
|
||||
@ -492,7 +499,7 @@ class MultiProcessTestCase(TestCase):
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 4
|
||||
return DEFAULT_WORLD_SIZE
|
||||
|
||||
def join_or_run(self, fn):
|
||||
@wraps(fn)
|
||||
@ -834,3 +841,71 @@ def tp_transports():
|
||||
see https://github.com/pytorch/pytorch/issues/73885 and https://github.com/pytorch/pytorch/issues/65022
|
||||
"""
|
||||
return ["shm", "uv"] if has_efa() else None
|
||||
|
||||
|
||||
def _run_test_with_mt_pg(self, timeout, world_size, callback):
|
||||
failed_ranks = run_with_threaded_pg(world_size, timeout, callback)
|
||||
for rank, exc_info in failed_ranks:
|
||||
print(f"Rank {rank} raised:")
|
||||
for line in traceback.format_exception(*exc_info):
|
||||
sys.stdout.write(line)
|
||||
self.assertEqual([], failed_ranks, "Some ranks failed")
|
||||
|
||||
def spawn_threads_and_init_comms(func=None, timeout=TIMEOUT_DEFAULT, world_size=DEFAULT_WORLD_SIZE):
|
||||
"""
|
||||
Wrapper to use with a test method
|
||||
"""
|
||||
if func is None:
|
||||
return partial(spawn_threads_and_init_comms, timeout=timeout, world_size=world_size)
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
_run_test_with_mt_pg(self, timeout, world_size, lambda: func(self, *args, **kwargs))
|
||||
|
||||
|
||||
return wrapper
|
||||
|
||||
class MultiThreadedTestCase(TestCase):
|
||||
"""
|
||||
Simple test runner that executes all tests with the in-proc process group.
|
||||
|
||||
A single instance of the TestCase object for all threads.
|
||||
|
||||
Difference from regular test runner:
|
||||
Cannot use setUp / tearDown (must use perThreadSetup / perThreadShutdown)
|
||||
Not sure what these two would be good for though.
|
||||
No global state possible
|
||||
How bad of a limitation is this?
|
||||
"""
|
||||
|
||||
def __init__(self, method_name: str = "runTest") -> None:
|
||||
super().__init__(method_name)
|
||||
self._test_method = getattr(self, method_name, None)
|
||||
setattr(self, method_name, self.threaded_run_test)
|
||||
if TestCase.setUp != type(self).setUp:
|
||||
raise RuntimeError(f"Test class {type(self)} overrides disabled method setUp. Use perThreadSetUp instead")
|
||||
if TestCase.tearDown != type(self).tearDown:
|
||||
raise RuntimeError(f"Test class {type(self)} overrides disabled method tearDown. Use perThreadTearDown instead")
|
||||
|
||||
|
||||
def threaded_run_test(self):
|
||||
self.perThreadSetUp()
|
||||
try:
|
||||
_run_test_with_mt_pg(
|
||||
self=self,
|
||||
timeout=TIMEOUT_DEFAULT,
|
||||
world_size=self.world_size,
|
||||
callback=self._test_method,
|
||||
)
|
||||
finally:
|
||||
self.perThreadTearDown()
|
||||
|
||||
def perThreadSetUp(self):
|
||||
pass
|
||||
|
||||
def perThreadTearDown(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
raise RuntimeError("world size not implemented")
|
||||
|
@ -1147,7 +1147,6 @@ class DistributedTest:
|
||||
@require_world_size(4)
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_3_level_hierarchical_model_averager(self):
|
||||
from torch.distributed.distributed_c10d import _pg_group_ranks
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
|
||||
@ -1178,8 +1177,8 @@ class DistributedTest:
|
||||
subgroup1 = averager.period_process_group_dict[subgroup_avg_period1]
|
||||
subgroup2 = averager.period_process_group_dict[subgroup_avg_period2]
|
||||
|
||||
real_group_ranks_res1 = list(_pg_group_ranks[subgroup1].keys())
|
||||
real_group_ranks_res2 = list(_pg_group_ranks[subgroup2].keys())
|
||||
real_group_ranks_res1 = dist.get_process_group_ranks(subgroup1)
|
||||
real_group_ranks_res2 = dist.get_process_group_ranks(subgroup2)
|
||||
expect_group_ranks_res1 = (rank // subgroup_size1 * subgroup_size1 + np.array(list(range(subgroup_size1)))).tolist()
|
||||
expect_group_ranks_res2 = (rank // subgroup_size2 * subgroup_size2 + np.array(list(range(subgroup_size2)))).tolist()
|
||||
self.assertEqual(real_group_ranks_res1, expect_group_ranks_res1)
|
||||
|
268
torch/testing/_internal/distributed/multi_threaded_pg.py
Normal file
268
torch/testing/_internal/distributed/multi_threaded_pg.py
Normal file
@ -0,0 +1,268 @@
|
||||
import time
|
||||
import sys
|
||||
import queue
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.futures import Future
|
||||
|
||||
from torch._C._distributed_c10d import _create_work_from_future
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
"""
|
||||
TODO:
|
||||
Lots of missing collectives.
|
||||
Collectives validation.
|
||||
Make timeout robust by making collectives respect the test deadline.
|
||||
Make tests robuts by making collectives interruptible.
|
||||
We need some synchronization around cleanup to ensure that timedout ranks don't cause spurious failures.
|
||||
|
||||
"""
|
||||
|
||||
def flatten_list(lst):
|
||||
return tree_flatten(lst)[0]
|
||||
|
||||
def ret_work(ret):
|
||||
fut = Future()
|
||||
fut.set_result(ret)
|
||||
return _create_work_from_future(fut)
|
||||
|
||||
class AllGather:
|
||||
def work(self, data):
|
||||
for src_rank in range(len(data)):
|
||||
in_tensor_list = data[src_rank][1]
|
||||
# Can't handle all_gather with multiple tensors
|
||||
assert len(in_tensor_list) == 1
|
||||
src_tensor = in_tensor_list[0]
|
||||
|
||||
for dest in data:
|
||||
dest_tensor = dest[0][0][src_rank]
|
||||
with torch.no_grad():
|
||||
dest_tensor.copy_(src_tensor)
|
||||
|
||||
class Broadcast:
|
||||
def __init__(self, src):
|
||||
self.src = src
|
||||
|
||||
def work(self, data):
|
||||
in_tensor_list = flatten_list(data[self.src])
|
||||
for i in range(len(data)):
|
||||
out_tensor_list = flatten_list(data[i])
|
||||
for j in range(len(in_tensor_list)):
|
||||
with torch.no_grad():
|
||||
out_tensor_list[j].copy_(in_tensor_list[j])
|
||||
|
||||
class Collective:
|
||||
def __init__(self, world_size, collective):
|
||||
self._world_size = world_size
|
||||
self._collective = collective
|
||||
|
||||
self._start_cond = threading.Condition()
|
||||
self._done_cond = threading.Condition()
|
||||
|
||||
self._data = [None] * world_size
|
||||
self._count = 0
|
||||
self._done = False
|
||||
|
||||
def join(self, rank, data):
|
||||
with self._start_cond:
|
||||
self._data[rank] = data
|
||||
self._count += 1
|
||||
|
||||
# notify rank 0
|
||||
if self._count == self._world_size:
|
||||
if rank > 0:
|
||||
self._start_cond.notify()
|
||||
|
||||
if rank == 0:
|
||||
while self._count < self._world_size:
|
||||
self._start_cond.wait()
|
||||
|
||||
with self._done_cond:
|
||||
# wait for rank 0 to finish
|
||||
if rank > 0:
|
||||
while not self._done:
|
||||
self._done_cond.wait()
|
||||
else:
|
||||
# copy data around
|
||||
self._collective.work(self._data)
|
||||
self._done = True
|
||||
self._done_cond.notify_all()
|
||||
return ret_work(data)
|
||||
|
||||
class ProcessLocalGroup(dist.ProcessGroup):
|
||||
_pg_lock = threading.Lock()
|
||||
_pg_list = []
|
||||
_count = 0
|
||||
_ready = False
|
||||
|
||||
_coll_lock = threading.Lock()
|
||||
_cur_coll = None
|
||||
|
||||
@classmethod
|
||||
def _register(cls, pg):
|
||||
with cls._pg_lock:
|
||||
while len(cls._pg_list) <= pg._rank:
|
||||
cls._pg_list.append(None)
|
||||
cls._pg_list[pg._rank] = pg
|
||||
cls._count += 1
|
||||
if cls._count == pg._world:
|
||||
cls._ready = True
|
||||
|
||||
|
||||
@classmethod
|
||||
def _start_coll(cls, world_size, collective):
|
||||
with cls._coll_lock:
|
||||
if not cls._ready:
|
||||
raise Exception(f"world not ready, only {cls._count} PG's registered but world has {world_size} ranks")
|
||||
if cls._cur_coll is None:
|
||||
cls._cur_coll = Collective(world_size, collective)
|
||||
return cls._cur_coll
|
||||
|
||||
@classmethod
|
||||
def _end_coll(cls, collective):
|
||||
# This is racily called by all ranks, so only one will work
|
||||
with cls._coll_lock:
|
||||
if cls._cur_coll == collective:
|
||||
cls._cur_coll = None
|
||||
|
||||
def allgather(self, output_tensors, input_tensor, options):
|
||||
coll = ProcessLocalGroup._start_coll(self._world, AllGather())
|
||||
res = coll.join(self._rank, (output_tensors, input_tensor))
|
||||
ProcessLocalGroup._end_coll(coll)
|
||||
return res
|
||||
|
||||
def broadcast(self, tensor_list, opts):
|
||||
coll = ProcessLocalGroup._start_coll(self._world, Broadcast(opts.rootRank))
|
||||
res = coll.join(self._rank, tensor_list)
|
||||
ProcessLocalGroup._end_coll(coll)
|
||||
return res
|
||||
|
||||
def __init__(self, rank, world):
|
||||
super(ProcessLocalGroup, self).__init__(rank, world)
|
||||
self._rank = rank
|
||||
self._world = world
|
||||
ProcessLocalGroup._register(self)
|
||||
|
||||
def size(self):
|
||||
return self._world
|
||||
|
||||
def getBackendName(self):
|
||||
return "local"
|
||||
|
||||
def __repr__(self):
|
||||
return f"PLG w:{self._world} r:{self._rank}"
|
||||
|
||||
def _create_threaded_pg(prefix_store, rank, world_size, timeout):
|
||||
return ProcessLocalGroup(rank, world_size)
|
||||
|
||||
dist.Backend.register_backend('threaded', _create_threaded_pg)
|
||||
|
||||
@dataclass
|
||||
class WorldData:
|
||||
default_pg: dist.ProcessGroup
|
||||
pg_map: dict
|
||||
pg_names: dict
|
||||
pg_group_ranks: dict
|
||||
group_count: int
|
||||
|
||||
class ThreadLocalWorld:
|
||||
_world = threading.local()
|
||||
|
||||
def _get_world(self) -> WorldData:
|
||||
if not hasattr(ThreadLocalWorld._world, "world"):
|
||||
ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, 0)
|
||||
return ThreadLocalWorld._world.world
|
||||
|
||||
@property
|
||||
def default_pg(self):
|
||||
return self._get_world().default_pg
|
||||
|
||||
@default_pg.setter
|
||||
def default_pg(self, value):
|
||||
self._get_world().default_pg = value
|
||||
|
||||
@property
|
||||
def pg_map(self):
|
||||
return self._get_world().pg_map
|
||||
|
||||
@property
|
||||
def pg_names(self):
|
||||
return self._get_world().pg_names
|
||||
|
||||
@property
|
||||
def pg_group_ranks(self):
|
||||
return self._get_world().pg_group_ranks
|
||||
|
||||
@property
|
||||
def group_count(self) -> int:
|
||||
return self._get_world().group_count
|
||||
|
||||
@group_count.setter
|
||||
def group_count(self, value):
|
||||
self._get_world().group_count = value
|
||||
|
||||
_old_pg_world = None
|
||||
|
||||
def _install_threaded_pg():
|
||||
global _old_pg_world
|
||||
_old_pg_world = dist.distributed_c10d._world
|
||||
dist.distributed_c10d._world = ThreadLocalWorld()
|
||||
return dist.distributed_c10d._world
|
||||
|
||||
def _uninstall_threaded_pg():
|
||||
dist.distributed_c10d._world = _old_pg_world
|
||||
|
||||
def run_with_threaded_pg(world_size, timeout, callback):
|
||||
"""
|
||||
Run ``callback`` with ``world_size`` threads using the in-proc process group
|
||||
"""
|
||||
world = _install_threaded_pg()
|
||||
|
||||
def world_is_valid():
|
||||
return world == dist.distributed_c10d._world
|
||||
|
||||
global_store = dist.HashStore()
|
||||
exception_queue = queue.Queue()
|
||||
|
||||
def worker(rank):
|
||||
if not world_is_valid():
|
||||
raise TimeoutError("Invalid world")
|
||||
dist.init_process_group(
|
||||
backend="threaded",
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
store=global_store
|
||||
)
|
||||
try:
|
||||
callback()
|
||||
except BaseException as ex:
|
||||
exception_queue.put((rank, sys.exc_info()))
|
||||
finally:
|
||||
if world_is_valid():
|
||||
dist.destroy_process_group()
|
||||
|
||||
try:
|
||||
threads = [
|
||||
threading.Thread(
|
||||
target=worker,
|
||||
args=(rank,)
|
||||
) for rank in range(world_size)
|
||||
]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
|
||||
deadline = time.time() + timeout
|
||||
for idx, thread in enumerate(threads):
|
||||
thread.join(max(0, deadline - time.time()))
|
||||
if thread.is_alive():
|
||||
exception_queue.put((idx, (TimeoutError, TimeoutError(f"Rank failed to join in under {timeout} seconds"), None)))
|
||||
failed_ranks = []
|
||||
while not exception_queue.empty():
|
||||
failure = exception_queue.get()
|
||||
failed_ranks.append(failure)
|
||||
return failed_ranks
|
||||
finally:
|
||||
_uninstall_threaded_pg()
|
Reference in New Issue
Block a user