C10D extension to enable per-thread PG (#86348)

Move a bunch of globals to instance methods and replace all use to them.

We move all PG related globals under World and use a singleton instance under _world.

This creates an undocumented extension point to inject full control of how how c10d
state behaves.

One simple hack is to change _world to an implementation that uses a threadlocal
and enable per-thread PGs.

It almost get DDP working and the PG is missing an implementation of all_reduce.

This enables notebook usage of PTD, which is a big deal for learning it:
https://gist.github.com/kumpera/32cb051fa26b8cad8bdf671f968dcd68

This change ensures BC by keeping the global variables around and have the default _World wrap it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86348
Approved by: https://github.com/rohan-varma
This commit is contained in:
Rodrigo Kumpera
2022-10-13 22:23:28 +00:00
committed by PyTorch MergeBot
parent 66979fbfaa
commit 97abc21f2b
6 changed files with 529 additions and 68 deletions

View File

@ -0,0 +1,45 @@
# Owner(s): ["oncall: distributed"]
import sys
import torch.distributed as dist
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
from torch.testing._internal.common_distributed import (
spawn_threads_and_init_comms,
MultiThreadedTestCase
)
from torch.testing._internal.common_utils import TestCase, run_tests
DEFAULT_WORLD_SIZE = 4
class TestObjectCollectivesWithWrapper(TestCase):
@spawn_threads_and_init_comms(world_size=4)
def test_broadcast_object_list(self):
val = 99 if dist.get_rank() == 0 else None
object_list = [val] * dist.get_world_size()
dist.broadcast_object_list(object_list=object_list)
self.assertEqual(99, object_list[0])
class TestObjectCollectivesWithBaseClass(MultiThreadedTestCase):
@property
def world_size(self):
return 4
def test_broadcast_object_list(self):
val = 99 if dist.get_rank() == 0 else None
object_list = [val] * dist.get_world_size()
print(f"{dist.get_rank()} -> {dist.get_world_size()}")
dist.broadcast_object_list(object_list=object_list)
self.assertEqual(99, object_list[0])
def test_something_else(self):
pass
if __name__ == "__main__":
run_tests()

View File

@ -81,5 +81,8 @@ else:
# python test/test_public_bindings.py -k test_correct_module_names
# working even when USE_DISTRIBUTED=0. Feel free to add more
# stubs as necessary.
class ProcessGroup: # type: ignore[no-redef]
# We cannot define stubs directly because they confuse pyre
class _ProcessGroupStub:
pass
sys.modules["torch.distributed"].ProcessGroup = _ProcessGroupStub # type: ignore[attr-defined]

View File

@ -258,32 +258,110 @@ class _reduce_op(object):
reduce_op = _reduce_op()
class group(object):
# DO NOT USE THIS FIELDS DIRECTLY.
# Use them through the _world object to make sure the _world override mecanism
_pg_map: Dict[ProcessGroup, Tuple[str, Optional[Store]]] = {}
_pg_names: Dict[ProcessGroup, str] = {}
_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}
_group_count = 0
class _World:
"""
Container class for c10d process group state.
This is used during registration and lookup of PG state.
.. warning:: This is an experimental API inteded to expose the inner workings
of c10d and is subject to change..
"""
def __init__(self):
self._default_pg = None
@property
def default_pg(self):
"""
The default ProcessGroup includes all ranks of the cluster.
This is used by c10d APIs when a ProcessGroup is needed but None is provided.
"""
return self._default_pg
@default_pg.setter
def default_pg(self, value):
self._default_pg = value
@property
def pg_map(self) -> Dict[ProcessGroup, Tuple[str, Optional[Store]]]:
"""
Cached process groups
For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store)
For MPI pg, it is a map from ProcessGroup to (Backend, None)
TODO don't expose the map, expose fine grained ops
"""
global _pg_map
return _pg_map
@property
def pg_names(self) -> Dict[ProcessGroup, str]:
"""
Process group's names, map from ProcessGroup to str.
TODO don't expose the map, expose fine grained ops
"""
global _pg_names
return _pg_names
@property
def pg_group_ranks(self) -> Dict[ProcessGroup, Dict[int, int]]:
"""
Process group's global rank to local rank mapping
TODO don't expose the map, expose fine grained ops
"""
global _pg_group_ranks
return _pg_group_ranks
@property
def group_count(self) -> int:
"""
Process group count for default naming.
TODO don't expose group_count, use something else instead
"""
global _group_count
return _group_count
@group_count.setter
def group_count(self, value):
"""
Count is used when computing the name of ProcessGroups when using global synchronization.
"""
global _group_count
_group_count = value
_world = _World()
"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
class _WorldMeta(type):
"""
Meta class of ``group`` and ``GroupMember`` so they
can have the class property ``WORLD``.
"""
# Points to the default PG once initialized.
WORLD: Optional[ProcessGroup] = None
@property
def WORLD(cls) -> Optional[ProcessGroup]:
return _world.default_pg
class group(object, metaclass=_WorldMeta):
pass
class GroupMember(object):
# Alias to group.WORLD for backward compatibility
WORLD = group.WORLD
class GroupMember(object, metaclass=_WorldMeta):
NON_GROUP_MEMBER = object()
# Cached process groups
# For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store)
# For MPI pg, it is a map from ProcessGroup to (Backend, None)
_pg_map: Dict[ProcessGroup, Tuple[str, Optional[Store]]] = {}
# Process group's names, map from ProcessGroup to str
_pg_names: Dict[ProcessGroup, str] = {}
# Process group's global rank to local rank mapping
_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}
# Default process group state
_default_pg_init_method = None
# Process group count for default naming
_group_count = 0
STORE_BASED_BARRIER_PREFIX = "store_based_barrier_key"
@ -303,7 +381,7 @@ def _store_based_barrier(rank, store, timeout):
``init_process_group`` or ``new_group``. Intended to be used only with
those two methods and is not a generic alternative to ``barrier()``.
"""
store_key = "{}:{}".format(STORE_BASED_BARRIER_PREFIX, _group_count)
store_key = "{}:{}".format(STORE_BASED_BARRIER_PREFIX, _world.group_count)
store.add(store_key, 1)
logger.info("Added key: {} to store for rank: {}".format(store_key, rank))
@ -378,9 +456,9 @@ def get_group_rank(group: ProcessGroup, global_rank: int) -> int:
"""
if group is GroupMember.WORLD:
return global_rank
if group not in _pg_group_ranks:
if group not in _world.pg_group_ranks:
raise RuntimeError(f"Group {group} is not registered, please create group with torch.distributed.new_group API")
group_ranks = _pg_group_ranks[group]
group_ranks = _world.pg_group_ranks[group]
if global_rank not in group_ranks:
raise RuntimeError(f"Global rank {global_rank} is not part of group {group}")
@ -403,9 +481,9 @@ def get_global_rank(group: ProcessGroup, group_rank: int) -> int:
"""
if group is GroupMember.WORLD:
return group_rank
if group not in _pg_group_ranks:
if group not in _world.pg_group_ranks:
raise RuntimeError(f"Group {group} is not registered, please create group with torch.distributed.new_group API")
for rank, grp_rank in _pg_group_ranks[group].items():
for rank, grp_rank in _world.pg_group_ranks[group].items():
if grp_rank == group_rank:
return rank
raise RuntimeError(f"Group rank {group_rank} is not part of group {group}")
@ -432,7 +510,7 @@ def get_process_group_ranks(group: ProcessGroup):
Returns:
List of global ranks ordered by group rank.
"""
return list(_pg_group_ranks[group].keys())
return list(_world.pg_group_ranks[group].keys())
def _get_group_size(group):
"""
@ -587,13 +665,12 @@ def _get_default_store():
"please make sure to call init_process_group."
)
default_pg = _get_default_group()
_, default_store = _pg_map[default_pg]
_, default_store = _world.pg_map[default_pg]
return default_store
def _update_default_pg(pg):
GroupMember.WORLD = group.WORLD = pg
_world.default_pg = pg
def get_backend(group: Optional[ProcessGroup] = None) -> str:
"""
@ -614,7 +691,7 @@ def get_backend(group: Optional[ProcessGroup] = None) -> str:
pg = group
if _rank_not_in_group(pg):
raise RuntimeError("Invalid process group specified")
pg_store = _pg_map.get(pg, None)
pg_store = _world.pg_map.get(pg, None)
assert pg_store is not None
return pg_store[0]
@ -698,7 +775,8 @@ def init_process_group(
on a system that supports MPI.
"""
global _pg_group_ranks
global _world
global _backend
global _default_pg_init_method
@ -759,8 +837,8 @@ def init_process_group(
)
_update_default_pg(default_pg)
_pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())} # type: ignore[attr-defined, index]
_backend = _pg_map[GroupMember.WORLD][0] # type: ignore[index]
_world.pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())} # type: ignore[attr-defined, index]
_backend = _world.pg_map[GroupMember.WORLD][0] # type: ignore[index]
_default_pg_init_method = init_method
# barrier at the end to ensure that once we return from this method, all
@ -797,15 +875,13 @@ def _new_process_group_helper(
This function is called with ``group_ranks == []`` for the default group.
"""
global _pg_map
global _group_count
global _pg_names
global _world
if not group_name:
group_name = str(_group_count)
_group_count += 1
group_name = str(_world.group_count)
_world.group_count = _world.group_count + 1
if group_name in _pg_names.values():
if group_name in _world.pg_names.values():
raise RuntimeError(
"The specified group name has already been "
"created, please use a different group name"
@ -831,8 +907,8 @@ def _new_process_group_helper(
pg = ProcessGroupMPI.create(global_ranks_in_group)
if not pg:
return GroupMember.NON_GROUP_MEMBER
_pg_map[pg] = (Backend.MPI, None)
_pg_names[pg] = group_name
_world.pg_map[pg] = (Backend.MPI, None)
_world.pg_names[pg] = group_name
else:
# If this is a subgroup (which means group_ranks is specified),
# we check if the current process is a member of the new group.
@ -868,8 +944,8 @@ def _new_process_group_helper(
world_size=group_size,
timeout=timeout,
)
_pg_map[pg] = (Backend.GLOO, store)
_pg_names[pg] = group_name
_world.pg_map[pg] = (Backend.GLOO, store)
_world.pg_names[pg] = group_name
elif backend == Backend.NCCL:
if not is_nccl_available():
raise RuntimeError("Distributed package doesn't have NCCL " "built in")
@ -903,8 +979,8 @@ def _new_process_group_helper(
world_size=group_size,
timeout=timeout,
)
_pg_map[pg] = (Backend.NCCL, store)
_pg_names[pg] = group_name
_world.pg_map[pg] = (Backend.NCCL, store)
_world.pg_names[pg] = group_name
elif backend == Backend.UCC and is_ucc_available():
# TODO: once UCC plugin is fully deprecated, remove
# is_ucc_available() from above elif-condition and raise
@ -930,8 +1006,8 @@ def _new_process_group_helper(
world_size=group_size,
timeout=timeout,
)
_pg_map[pg] = (Backend.UCC, store)
_pg_names[pg] = group_name
_world.pg_map[pg] = (Backend.UCC, store)
_world.pg_names[pg] = group_name
else:
assert backend.upper() in Backend._plugins, (
f"Unknown c10d backend type {backend.upper()}"
@ -953,8 +1029,8 @@ def _new_process_group_helper(
dist_backend_opts.global_ranks_in_group = global_ranks_in_group
pg = creator_fn(dist_backend_opts, pg_options)
_pg_map[pg] = (backend, store)
_pg_names[pg] = group_name
_world.pg_map[pg] = (backend, store)
_world.pg_names[pg] = group_name
return pg
@ -969,11 +1045,7 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
groups including the default one will
be destroyed.
"""
global _pg_map
global _pg_names
global _pg_group_ranks
global _default_pg_init_method
global _group_count
global _world
if group == GroupMember.NON_GROUP_MEMBER:
return
@ -984,29 +1056,28 @@ def destroy_process_group(group: Optional[ProcessGroup] = None):
pg = group
assert pg is not None
if _pg_map.get(pg, None) is None:
if _world.pg_map.get(pg, None) is None:
raise RuntimeError("Invalid process group specified")
if group is None or group == GroupMember.WORLD:
_update_default_pg(None)
_default_pg_init_method = None
_pg_map.clear()
_pg_names.clear()
_pg_group_ranks.clear()
_world.pg_map.clear()
_world.pg_names.clear()
_world.pg_group_ranks.clear()
# when process group doesn't have an explicit name (only WORLD (default)
# process group can have an explicit name), we use global _group_counter
# process group can have an explicit name), we use global _world.group_count
# to generate the name. We need to reset the counter on destruction to
# allow consistent value to be generated when we re-create process
# groups after some trainers recover from failure
#
# We only reset this when WORLD is being destroyed because if this
# process group is in good state, we aren't dealing with failures.
_group_count = 0
_world.group_count = 0
else:
del _pg_map[pg]
del _pg_names[pg]
del _pg_group_ranks[pg]
del _world.pg_map[pg]
del _world.pg_names[pg]
del _world.pg_group_ranks[pg]
def get_rank(group: Optional[ProcessGroup] = None) -> int:
@ -3282,10 +3353,10 @@ def new_group(ranks=None, timeout=default_pg_timeout, backend=None, pg_options=N
A handle of distributed group that can be given to collective calls.
"""
global _pg_group_ranks
global _world
default_pg = _get_default_group()
default_backend, default_store = _pg_map[default_pg]
default_backend, default_store = _world.pg_map[default_pg]
global_rank = default_pg.rank()
global_world_size = default_pg.size()
@ -3334,7 +3405,7 @@ def new_group(ranks=None, timeout=default_pg_timeout, backend=None, pg_options=N
)
# Create the global rank to group rank mapping
_pg_group_ranks[pg] = {
_world.pg_group_ranks[pg] = {
global_rank: group_rank for group_rank, global_rank in enumerate(ranks)
}

View File

@ -36,6 +36,9 @@ from torch.testing._internal.common_utils import (
sandcastle_skip_if,
sandcastle_skip,
)
from torch.testing._internal.distributed.multi_threaded_pg import (
run_with_threaded_pg
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@ -465,6 +468,8 @@ def cleanup_temp_dir() -> None:
if tmp_dir is not None:
tmp_dir.cleanup()
# Most tests operate with thi worldsize
DEFAULT_WORLD_SIZE = 4
# [How does MultiProcessTestCase work?]
# Each MultiProcessTestCase instance uses 1 + `world_size()` processes, by
@ -477,6 +482,8 @@ def cleanup_temp_dir() -> None:
# from the test instance and run it. The main process simply waits for all
# subprocesses to join.
# Most tests operate with thi worldsize
DEFAULT_WORLD_SIZE = 4
class MultiProcessTestCase(TestCase):
MAIN_PROCESS_RANK = -1
@ -492,7 +499,7 @@ class MultiProcessTestCase(TestCase):
@property
def world_size(self) -> int:
return 4
return DEFAULT_WORLD_SIZE
def join_or_run(self, fn):
@wraps(fn)
@ -834,3 +841,71 @@ def tp_transports():
see https://github.com/pytorch/pytorch/issues/73885 and https://github.com/pytorch/pytorch/issues/65022
"""
return ["shm", "uv"] if has_efa() else None
def _run_test_with_mt_pg(self, timeout, world_size, callback):
failed_ranks = run_with_threaded_pg(world_size, timeout, callback)
for rank, exc_info in failed_ranks:
print(f"Rank {rank} raised:")
for line in traceback.format_exception(*exc_info):
sys.stdout.write(line)
self.assertEqual([], failed_ranks, "Some ranks failed")
def spawn_threads_and_init_comms(func=None, timeout=TIMEOUT_DEFAULT, world_size=DEFAULT_WORLD_SIZE):
"""
Wrapper to use with a test method
"""
if func is None:
return partial(spawn_threads_and_init_comms, timeout=timeout, world_size=world_size)
@wraps(func)
def wrapper(self, *args, **kwargs):
_run_test_with_mt_pg(self, timeout, world_size, lambda: func(self, *args, **kwargs))
return wrapper
class MultiThreadedTestCase(TestCase):
"""
Simple test runner that executes all tests with the in-proc process group.
A single instance of the TestCase object for all threads.
Difference from regular test runner:
Cannot use setUp / tearDown (must use perThreadSetup / perThreadShutdown)
Not sure what these two would be good for though.
No global state possible
How bad of a limitation is this?
"""
def __init__(self, method_name: str = "runTest") -> None:
super().__init__(method_name)
self._test_method = getattr(self, method_name, None)
setattr(self, method_name, self.threaded_run_test)
if TestCase.setUp != type(self).setUp:
raise RuntimeError(f"Test class {type(self)} overrides disabled method setUp. Use perThreadSetUp instead")
if TestCase.tearDown != type(self).tearDown:
raise RuntimeError(f"Test class {type(self)} overrides disabled method tearDown. Use perThreadTearDown instead")
def threaded_run_test(self):
self.perThreadSetUp()
try:
_run_test_with_mt_pg(
self=self,
timeout=TIMEOUT_DEFAULT,
world_size=self.world_size,
callback=self._test_method,
)
finally:
self.perThreadTearDown()
def perThreadSetUp(self):
pass
def perThreadTearDown(self):
pass
@property
def world_size(self) -> int:
raise RuntimeError("world size not implemented")

View File

@ -1147,7 +1147,6 @@ class DistributedTest:
@require_world_size(4)
@skip_if_lt_x_gpu(4)
def test_3_level_hierarchical_model_averager(self):
from torch.distributed.distributed_c10d import _pg_group_ranks
rank = dist.get_rank()
world_size = dist.get_world_size()
rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
@ -1178,8 +1177,8 @@ class DistributedTest:
subgroup1 = averager.period_process_group_dict[subgroup_avg_period1]
subgroup2 = averager.period_process_group_dict[subgroup_avg_period2]
real_group_ranks_res1 = list(_pg_group_ranks[subgroup1].keys())
real_group_ranks_res2 = list(_pg_group_ranks[subgroup2].keys())
real_group_ranks_res1 = dist.get_process_group_ranks(subgroup1)
real_group_ranks_res2 = dist.get_process_group_ranks(subgroup2)
expect_group_ranks_res1 = (rank // subgroup_size1 * subgroup_size1 + np.array(list(range(subgroup_size1)))).tolist()
expect_group_ranks_res2 = (rank // subgroup_size2 * subgroup_size2 + np.array(list(range(subgroup_size2)))).tolist()
self.assertEqual(real_group_ranks_res1, expect_group_ranks_res1)

View File

@ -0,0 +1,268 @@
import time
import sys
import queue
import threading
from dataclasses import dataclass
import torch
import torch.distributed as dist
from torch.futures import Future
from torch._C._distributed_c10d import _create_work_from_future
from torch.utils._pytree import tree_flatten
"""
TODO:
Lots of missing collectives.
Collectives validation.
Make timeout robust by making collectives respect the test deadline.
Make tests robuts by making collectives interruptible.
We need some synchronization around cleanup to ensure that timedout ranks don't cause spurious failures.
"""
def flatten_list(lst):
return tree_flatten(lst)[0]
def ret_work(ret):
fut = Future()
fut.set_result(ret)
return _create_work_from_future(fut)
class AllGather:
def work(self, data):
for src_rank in range(len(data)):
in_tensor_list = data[src_rank][1]
# Can't handle all_gather with multiple tensors
assert len(in_tensor_list) == 1
src_tensor = in_tensor_list[0]
for dest in data:
dest_tensor = dest[0][0][src_rank]
with torch.no_grad():
dest_tensor.copy_(src_tensor)
class Broadcast:
def __init__(self, src):
self.src = src
def work(self, data):
in_tensor_list = flatten_list(data[self.src])
for i in range(len(data)):
out_tensor_list = flatten_list(data[i])
for j in range(len(in_tensor_list)):
with torch.no_grad():
out_tensor_list[j].copy_(in_tensor_list[j])
class Collective:
def __init__(self, world_size, collective):
self._world_size = world_size
self._collective = collective
self._start_cond = threading.Condition()
self._done_cond = threading.Condition()
self._data = [None] * world_size
self._count = 0
self._done = False
def join(self, rank, data):
with self._start_cond:
self._data[rank] = data
self._count += 1
# notify rank 0
if self._count == self._world_size:
if rank > 0:
self._start_cond.notify()
if rank == 0:
while self._count < self._world_size:
self._start_cond.wait()
with self._done_cond:
# wait for rank 0 to finish
if rank > 0:
while not self._done:
self._done_cond.wait()
else:
# copy data around
self._collective.work(self._data)
self._done = True
self._done_cond.notify_all()
return ret_work(data)
class ProcessLocalGroup(dist.ProcessGroup):
_pg_lock = threading.Lock()
_pg_list = []
_count = 0
_ready = False
_coll_lock = threading.Lock()
_cur_coll = None
@classmethod
def _register(cls, pg):
with cls._pg_lock:
while len(cls._pg_list) <= pg._rank:
cls._pg_list.append(None)
cls._pg_list[pg._rank] = pg
cls._count += 1
if cls._count == pg._world:
cls._ready = True
@classmethod
def _start_coll(cls, world_size, collective):
with cls._coll_lock:
if not cls._ready:
raise Exception(f"world not ready, only {cls._count} PG's registered but world has {world_size} ranks")
if cls._cur_coll is None:
cls._cur_coll = Collective(world_size, collective)
return cls._cur_coll
@classmethod
def _end_coll(cls, collective):
# This is racily called by all ranks, so only one will work
with cls._coll_lock:
if cls._cur_coll == collective:
cls._cur_coll = None
def allgather(self, output_tensors, input_tensor, options):
coll = ProcessLocalGroup._start_coll(self._world, AllGather())
res = coll.join(self._rank, (output_tensors, input_tensor))
ProcessLocalGroup._end_coll(coll)
return res
def broadcast(self, tensor_list, opts):
coll = ProcessLocalGroup._start_coll(self._world, Broadcast(opts.rootRank))
res = coll.join(self._rank, tensor_list)
ProcessLocalGroup._end_coll(coll)
return res
def __init__(self, rank, world):
super(ProcessLocalGroup, self).__init__(rank, world)
self._rank = rank
self._world = world
ProcessLocalGroup._register(self)
def size(self):
return self._world
def getBackendName(self):
return "local"
def __repr__(self):
return f"PLG w:{self._world} r:{self._rank}"
def _create_threaded_pg(prefix_store, rank, world_size, timeout):
return ProcessLocalGroup(rank, world_size)
dist.Backend.register_backend('threaded', _create_threaded_pg)
@dataclass
class WorldData:
default_pg: dist.ProcessGroup
pg_map: dict
pg_names: dict
pg_group_ranks: dict
group_count: int
class ThreadLocalWorld:
_world = threading.local()
def _get_world(self) -> WorldData:
if not hasattr(ThreadLocalWorld._world, "world"):
ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, 0)
return ThreadLocalWorld._world.world
@property
def default_pg(self):
return self._get_world().default_pg
@default_pg.setter
def default_pg(self, value):
self._get_world().default_pg = value
@property
def pg_map(self):
return self._get_world().pg_map
@property
def pg_names(self):
return self._get_world().pg_names
@property
def pg_group_ranks(self):
return self._get_world().pg_group_ranks
@property
def group_count(self) -> int:
return self._get_world().group_count
@group_count.setter
def group_count(self, value):
self._get_world().group_count = value
_old_pg_world = None
def _install_threaded_pg():
global _old_pg_world
_old_pg_world = dist.distributed_c10d._world
dist.distributed_c10d._world = ThreadLocalWorld()
return dist.distributed_c10d._world
def _uninstall_threaded_pg():
dist.distributed_c10d._world = _old_pg_world
def run_with_threaded_pg(world_size, timeout, callback):
"""
Run ``callback`` with ``world_size`` threads using the in-proc process group
"""
world = _install_threaded_pg()
def world_is_valid():
return world == dist.distributed_c10d._world
global_store = dist.HashStore()
exception_queue = queue.Queue()
def worker(rank):
if not world_is_valid():
raise TimeoutError("Invalid world")
dist.init_process_group(
backend="threaded",
rank=rank,
world_size=world_size,
store=global_store
)
try:
callback()
except BaseException as ex:
exception_queue.put((rank, sys.exc_info()))
finally:
if world_is_valid():
dist.destroy_process_group()
try:
threads = [
threading.Thread(
target=worker,
args=(rank,)
) for rank in range(world_size)
]
for thread in threads:
thread.start()
deadline = time.time() + timeout
for idx, thread in enumerate(threads):
thread.join(max(0, deadline - time.time()))
if thread.is_alive():
exception_queue.put((idx, (TimeoutError, TimeoutError(f"Rank failed to join in under {timeout} seconds"), None)))
failed_ranks = []
while not exception_queue.empty():
failure = exception_queue.get()
failed_ranks.append(failure)
return failed_ranks
finally:
_uninstall_threaded_pg()