mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Added context manager enabling all futures returned by rpc_async and custom build rpc functions to be automatically waited on (#41807)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/41807 Test Plan: Make sure ci tests pass, including newly written test Reviewed By: mrshenli Differential Revision: D22640839 Pulled By: osandoval-fb fbshipit-source-id: 3ff98d8e8c6e6d08575e307f05b5e159442d7216
This commit is contained in:
committed by
Facebook GitHub Bot
parent
25db74bf5e
commit
58ed60c259
@ -141,6 +141,34 @@ def _broadcast_to_followers(sequence_id, objects_map):
|
||||
states.gathered_objects = objects_map
|
||||
states.proceed_signal.set()
|
||||
|
||||
_thread_local_var = threading.local()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _wait_all():
|
||||
r"""
|
||||
A context manager that collects all futures returned by ``rpc_async`` and
|
||||
waits them on the context manager's exit; relieving the user of needing
|
||||
to explicitly call wait.
|
||||
|
||||
|
||||
Example::
|
||||
>>> # On worker 0:
|
||||
>>> import torch
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
||||
>>> with rpc._wait_all():
|
||||
>>> fut_1 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
|
||||
>>> fut_2 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
|
||||
>>> #fut_1 and fut_2 are waited on
|
||||
"""
|
||||
_thread_local_var.future_list = []
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
try:
|
||||
torch.futures.wait_all(_thread_local_var.future_list)
|
||||
finally:
|
||||
del _thread_local_var.future_list
|
||||
|
||||
@_require_initialized
|
||||
def _all_gather(obj, timeout=UNSET_RPC_TIMEOUT):
|
||||
@ -830,4 +858,7 @@ def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
|
||||
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||||
>>> rpc.shutdown()
|
||||
"""
|
||||
return _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout)
|
||||
fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout)
|
||||
if hasattr(_thread_local_var, "future_list"):
|
||||
_thread_local_var.future_list.append(fut)
|
||||
return fut
|
||||
|
||||
@ -15,7 +15,7 @@ import torch.distributed as dist
|
||||
import torch.distributed.rpc as rpc
|
||||
import torch.distributed.autograd as dist_autograd
|
||||
from torch.distributed.rpc import RRef, _get_debug_info, _rref_context_get_debug_info
|
||||
from torch.distributed.rpc.api import _delete_all_user_and_unforked_owner_rrefs, _use_rpc_pickler
|
||||
from torch.distributed.rpc.api import _delete_all_user_and_unforked_owner_rrefs, _use_rpc_pickler, _thread_local_var, _wait_all
|
||||
from torch.distributed.rpc.internal import (
|
||||
PythonUDF,
|
||||
RPCExecMode,
|
||||
@ -2856,6 +2856,58 @@ class RpcTest(RpcAgentTestFixture):
|
||||
torch.distributed.rpc.api._default_pickler is _internal_rpc_pickler
|
||||
)
|
||||
|
||||
@dist_init
|
||||
def test_wait_all(self):
|
||||
with _wait_all():
|
||||
self.assertTrue(_thread_local_var.future_list == [])
|
||||
dst = worker_name((self.rank + 1) % self.world_size)
|
||||
fut = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
|
||||
self.assertTrue(len(_thread_local_var.future_list) == 1)
|
||||
self.assertTrue(isinstance(_thread_local_var.future_list[0], torch._C.Future))
|
||||
self.assertTrue(fut.done())
|
||||
self.assertEqual(fut.wait(), torch.ones(2, 2) + 1)
|
||||
self.assertFalse(hasattr(_thread_local_var, "future_list"))
|
||||
|
||||
@dist_init
|
||||
def test_wait_all_multiple_call(self):
|
||||
with _wait_all():
|
||||
self.assertTrue(_thread_local_var.future_list == [])
|
||||
dst = worker_name((self.rank + 1) % self.world_size)
|
||||
for i in range(20):
|
||||
fut = rpc.rpc_async(dst, torch.add, (torch.ones(i, i), 1))
|
||||
res = rpc.rpc_sync(dst, torch.add, (torch.ones(i, i), 1))
|
||||
self.assertEqual(res, torch.ones(i, i) + 1)
|
||||
self.assertEqual(fut.wait(), torch.ones(i, i) + 1)
|
||||
self.assertTrue(len(_thread_local_var.future_list) == 20)
|
||||
self.assertFalse(hasattr(_thread_local_var, "future_list"))
|
||||
|
||||
@dist_init
|
||||
def test_wait_all_timeout(self):
|
||||
expected_error = self.get_timeout_error_regex()
|
||||
with self.assertRaisesRegex(RuntimeError, expected_error):
|
||||
with _wait_all():
|
||||
self.assertTrue(_thread_local_var.future_list == [])
|
||||
dst = worker_name((self.rank + 1) % self.world_size)
|
||||
timeout = 0.1 # 100 ms
|
||||
fut = rpc.rpc_async(dst, my_sleep_func, args=(1,), timeout=timeout)
|
||||
self.assertFalse(hasattr(_thread_local_var, "future_list"))
|
||||
|
||||
@dist_init
|
||||
def test_wait_all_raise_in_user_func(self):
|
||||
with self.assertRaises(ValueError):
|
||||
with _wait_all():
|
||||
self.assertTrue(_thread_local_var.future_list == [])
|
||||
dst = worker_name((self.rank + 1) % self.world_size)
|
||||
fut = rpc.rpc_async(dst, raise_func)
|
||||
self.assertFalse(hasattr(_thread_local_var, "future_list"))
|
||||
|
||||
@dist_init
|
||||
def test_wait_all_raise_in_body(self):
|
||||
with self.assertRaises(ValueError):
|
||||
with _wait_all():
|
||||
raise_func()
|
||||
self.assertFalse(hasattr(_thread_local_var, "future_list"))
|
||||
|
||||
@dist_init
|
||||
def test_function_not_on_callee(self):
|
||||
# test that if a function does not exist on a callee, we don't crash,
|
||||
|
||||
Reference in New Issue
Block a user