Allow RPC framework to use rank in addition to WorkerInfo and name. (#46221)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/46221

The RPC framework only allowed sending RPCs based on provided
WorkerInfo or name. When using RPC with DDP, sometimes it might just be easier
to refer to everything in terms of ranks since DDP doesn't support names yet.

As a result, support a `to` parameter in the RPC APIs which allow for
specifying a rank as well would be helpful.
ghstack-source-id: 114207172

Test Plan:
1) waitforbuildbot
2) Unit Tests

Reviewed By: mrshenli

Differential Revision: D24264989

fbshipit-source-id: 5edf5d92e2bd2f213471dfe7c74eebfa9efc9f70
This commit is contained in:
Pritam Damania
2020-10-13 17:50:07 -07:00
committed by Facebook GitHub Bot
parent e1c9aa918a
commit f89498f3f8
4 changed files with 46 additions and 9 deletions

View File

@ -324,13 +324,13 @@ def get_worker_info(worker_name=None):
return _get_current_rpc_agent().get_worker_info()
def _to_worker_info(name_or_info):
if isinstance(name_or_info, WorkerInfo):
return name_or_info
elif isinstance(name_or_info, str):
return get_worker_info(name_or_info)
def _to_worker_info(to):
if isinstance(to, WorkerInfo):
return to
elif isinstance(to, str) or isinstance(to, int):
return get_worker_info(to)
else:
raise ValueError("Cannot get WorkerInfo from name {}".format(name_or_info))
raise ValueError("Cannot get WorkerInfo from name {}".format(to))
def _rref_typeof_on_owner(rref):
@ -417,7 +417,7 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
are no living references to it.
Arguments:
to (str or WorkerInfo): id or name of the destination worker.
to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
func (callable): a callable function, such as Python callables, builtin
operators (e.g. :meth:`~torch.add`) and annotated
TorchScript functions.
@ -672,7 +672,7 @@ def rpc_sync(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
method is thread-safe.
Arguments:
to (str or WorkerInfo): id or name of the destination worker.
to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
func (callable): a callable function, such as Python callables, builtin
operators (e.g. :meth:`~torch.add`) and annotated
TorchScript functions.
@ -751,7 +751,7 @@ def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
:class:`~torch.futures.Future` that can be awaited on.
Arguments:
to (str or WorkerInfo): id or name of the destination worker.
to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
func (callable): a callable function, such as Python callables, builtin
operators (e.g. :meth:`~torch.add`) and annotated
TorchScript functions.