mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow RPC framework to use rank in addition to WorkerInfo and name. (#46221)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46221 The RPC framework only allowed sending RPCs based on provided WorkerInfo or name. When using RPC with DDP, sometimes it might just be easier to refer to everything in terms of ranks since DDP doesn't support names yet. As a result, support a `to` parameter in the RPC APIs which allow for specifying a rank as well would be helpful. ghstack-source-id: 114207172 Test Plan: 1) waitforbuildbot 2) Unit Tests Reviewed By: mrshenli Differential Revision: D24264989 fbshipit-source-id: 5edf5d92e2bd2f213471dfe7c74eebfa9efc9f70
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e1c9aa918a
commit
f89498f3f8
@ -324,13 +324,13 @@ def get_worker_info(worker_name=None):
|
||||
return _get_current_rpc_agent().get_worker_info()
|
||||
|
||||
|
||||
def _to_worker_info(name_or_info):
|
||||
if isinstance(name_or_info, WorkerInfo):
|
||||
return name_or_info
|
||||
elif isinstance(name_or_info, str):
|
||||
return get_worker_info(name_or_info)
|
||||
def _to_worker_info(to):
|
||||
if isinstance(to, WorkerInfo):
|
||||
return to
|
||||
elif isinstance(to, str) or isinstance(to, int):
|
||||
return get_worker_info(to)
|
||||
else:
|
||||
raise ValueError("Cannot get WorkerInfo from name {}".format(name_or_info))
|
||||
raise ValueError("Cannot get WorkerInfo from name {}".format(to))
|
||||
|
||||
|
||||
def _rref_typeof_on_owner(rref):
|
||||
@ -417,7 +417,7 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
|
||||
are no living references to it.
|
||||
|
||||
Arguments:
|
||||
to (str or WorkerInfo): id or name of the destination worker.
|
||||
to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
|
||||
func (callable): a callable function, such as Python callables, builtin
|
||||
operators (e.g. :meth:`~torch.add`) and annotated
|
||||
TorchScript functions.
|
||||
@ -672,7 +672,7 @@ def rpc_sync(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
|
||||
method is thread-safe.
|
||||
|
||||
Arguments:
|
||||
to (str or WorkerInfo): id or name of the destination worker.
|
||||
to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
|
||||
func (callable): a callable function, such as Python callables, builtin
|
||||
operators (e.g. :meth:`~torch.add`) and annotated
|
||||
TorchScript functions.
|
||||
@ -751,7 +751,7 @@ def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
|
||||
:class:`~torch.futures.Future` that can be awaited on.
|
||||
|
||||
Arguments:
|
||||
to (str or WorkerInfo): id or name of the destination worker.
|
||||
to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
|
||||
func (callable): a callable function, such as Python callables, builtin
|
||||
operators (e.g. :meth:`~torch.add`) and annotated
|
||||
TorchScript functions.
|
||||
|
Reference in New Issue
Block a user