mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Added context manager enabling all futures returned by rpc_async and custom build rpc functions to be automatically waited on (#41807)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/41807 Test Plan: Make sure ci tests pass, including newly written test Reviewed By: mrshenli Differential Revision: D22640839 Pulled By: osandoval-fb fbshipit-source-id: 3ff98d8e8c6e6d08575e307f05b5e159442d7216
This commit is contained in:
committed by
Facebook GitHub Bot
parent
25db74bf5e
commit
58ed60c259
@ -141,6 +141,34 @@ def _broadcast_to_followers(sequence_id, objects_map):
|
||||
states.gathered_objects = objects_map
|
||||
states.proceed_signal.set()
|
||||
|
||||
_thread_local_var = threading.local()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _wait_all():
|
||||
r"""
|
||||
A context manager that collects all futures returned by ``rpc_async`` and
|
||||
waits them on the context manager's exit; relieving the user of needing
|
||||
to explicitly call wait.
|
||||
|
||||
|
||||
Example::
|
||||
>>> # On worker 0:
|
||||
>>> import torch
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
||||
>>> with rpc._wait_all():
|
||||
>>> fut_1 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
|
||||
>>> fut_2 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
|
||||
>>> #fut_1 and fut_2 are waited on
|
||||
"""
|
||||
_thread_local_var.future_list = []
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
try:
|
||||
torch.futures.wait_all(_thread_local_var.future_list)
|
||||
finally:
|
||||
del _thread_local_var.future_list
|
||||
|
||||
@_require_initialized
|
||||
def _all_gather(obj, timeout=UNSET_RPC_TIMEOUT):
|
||||
@ -830,4 +858,7 @@ def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
|
||||
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||||
>>> rpc.shutdown()
|
||||
"""
|
||||
return _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout)
|
||||
fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout)
|
||||
if hasattr(_thread_local_var, "future_list"):
|
||||
_thread_local_var.future_list.append(fut)
|
||||
return fut
|
||||
|
Reference in New Issue
Block a user