Added context manager enabling all futures returned by rpc_async and custom build rpc functions to be automatically waited on (#41807)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/41807

Test Plan: Make sure ci tests pass, including newly written test

Reviewed By: mrshenli

Differential Revision: D22640839

Pulled By: osandoval-fb

fbshipit-source-id: 3ff98d8e8c6e6d08575e307f05b5e159442d7216
This commit is contained in:
Oscar Sandoval
2020-10-26 12:47:32 -07:00
committed by Facebook GitHub Bot
parent 25db74bf5e
commit 58ed60c259
2 changed files with 85 additions and 2 deletions

View File

@ -141,6 +141,34 @@ def _broadcast_to_followers(sequence_id, objects_map):
states.gathered_objects = objects_map
states.proceed_signal.set()
_thread_local_var = threading.local()
@contextlib.contextmanager
def _wait_all():
r"""
A context manager that collects all futures returned by ``rpc_async`` and
waits them on the context manager's exit; relieving the user of needing
to explicitly call wait.
Example::
>>> # On worker 0:
>>> import torch
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
>>> with rpc._wait_all():
>>> fut_1 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
>>> fut_2 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
>>> #fut_1 and fut_2 are waited on
"""
_thread_local_var.future_list = []
try:
yield
finally:
try:
torch.futures.wait_all(_thread_local_var.future_list)
finally:
del _thread_local_var.future_list
@_require_initialized
def _all_gather(obj, timeout=UNSET_RPC_TIMEOUT):
@ -830,4 +858,7 @@ def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
"""
return _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout)
fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout)
if hasattr(_thread_local_var, "future_list"):
_thread_local_var.future_list.append(fut)
return fut