Files
pytorch/torch/distributed/rpc_api.py
Shen Li 2486b0ba82 Add Python RRef as args and return value (#25499)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25499

See #23110 for model parallel design details, and #26759 for the RRef
protocol. This commit add support for using RRef as Python UDF arguments
and return value. RRefs can now be shared from owner to user, from user to
owner, or from user to user.

Limitations:
1. No implicit type conversion yet. (#27099)
2. No failure handling and retry. (#26116)
3. UDF is not yet blocked until all RRefs are confirmed. (#27098)
4. Internal RRef control messages are not idempotent yet. (#26116)
5. Cannot delete RRefs correctly when there are circular dependencies. (#27096)

Main changes:

1. Added `SCRIPT_REMOTE_CALL` and `PYTHON_REMOTE_CALL` to `Message.h` to represent `dist.remote` invocations.
2. Added `SCRIPT_RREF_FETCH_CALL`, `PYTHON_RREF_FETCH_CALL`, `RREF_USER_ACCEPT`, `RREF_USER_DELETE`, `RREF_CHILD_ACCEPT`, and `RREF_FORK_REQUEST` to `Message.h` as internal RRef control messages.
3. New message request handling code is added to `functions.cpp`, and message format is added in `script_remote_call.h`, `python_remote_call.h`, and `rref_proto.h`.
4. Added a `PyRRef` type in `py_rref.h` and `py_rref.cpp` which holds a shared pointer to C++ `RRef` type. `PyRRef` wraps the C++ API and also implements RRef pickling and unpickling. RRef fork related control messages will be sent during RRef pickling/unpickling procedure.
5.  Update `RRef.h` and `RRef.cpp` accordingly to support `py::object` RRefs.
6. RRef context (reference count, etc.) are tracked in `rref_context.h` and `rref_context.cpp`.

Test Plan:
Imported from OSS

buck test mode/dev-nosan //caffe2/test:rpc_fork

Differential Revision: D17184146

Pulled By: mrshenli

fbshipit-source-id: a3a268efc087ac1ef489136ab957080382629265
2019-10-03 17:47:12 -07:00

349 lines
12 KiB
Python

#!/usr/bin/env python3
from . import invoke_rpc_builtin, invoke_rpc_python_udf
from . import invoke_remote_builtin, invoke_remote_python_udf
from . import _init_rref_context, _destroy_rref_context
from . import ProcessGroupAgent
from . import WorkerInfo
from .internal_rpc_utils import _internal_rpc_pickler, PythonUDF
from .rpc_backend_registry import is_rpc_backend_registered, init_rpc_backend
import functools
import sys
import warnings
import torch
from enum import Enum
_agent = None
def _require_initialized(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if _agent is None:
raise RuntimeError("RPC has not been initialized. "
"Call init_rpc(name) first.")
return func(*args, **kwargs)
return wrapper
def join_rpc():
r"""
Block until all local and remote RPC processes reach this method, process
(send and receive) all pending messages, and then destroy local RPC agent.
Every RPC process must call this method before exit.
"""
global _agent
if _agent:
_agent.join()
_agent = None
_destroy_rref_context()
@_require_initialized
def sync_rpc():
r"""
Block until all local and remote RPC processes reach this method and finish
sending all pending RPCs. As this method synchronizes at the process
level, if multiple threads are spawned, only one of them should call this
method at a time.
"""
_agent.sync()
class RpcBackend(Enum):
PROCESS_GROUP = 1
# TODO: add a context manager to wrap _init_rpc and join_rpc
def _init_rpc(backend=RpcBackend.PROCESS_GROUP,
self_name=None,
self_rank=-1,
init_method=None,
num_send_recv_threads=4):
if sys.version_info < (3, 0):
raise RuntimeError("RPC package does not support Python2.")
global _agent
if _agent:
raise RuntimeError("RPC is already initialized")
if backend == RpcBackend.PROCESS_GROUP:
from .distributed_c10d import _get_default_group
group = _get_default_group()
if (self_rank != -1) and (self_rank != group.rank()):
raise RuntimeError("self_rank argument {} doesn't match pg rank {}".format(
self_rank, group.rank()))
# TODO: add try-except and destroy _agent in all processes if any fails.
_agent = ProcessGroupAgent(self_name, group, num_send_recv_threads)
_init_rref_context(_agent)
elif is_rpc_backend_registered(backend):
_agent = init_rpc_backend(
backend,
self_rank=self_rank,
self_name=self_name,
init_method=init_method
)
_init_rref_context(_agent)
else:
raise RuntimeError("Unrecognized RPC backend ", backend)
@_require_initialized
def get_worker_info(worker_name=None):
r"""
Get WorkerInfo of a given worker name. Use this WorkerInfo to avoid passing
an expensive string to ``rpc`` on every invocation. The WorkerInfo contains
the name of the worker and the id of the worker.
Arguments:
worker_name (str): the string name of a worker. If ``None``, return the
the id of the current worker. (default ``None``)
"""
if worker_name:
return _agent.get_worker_info(worker_name)
else:
return _agent.get_worker_info()
def _to_worker_info(name_or_id):
if isinstance(name_or_id, WorkerInfo):
return name_or_id
elif isinstance(name_or_id, str):
return get_worker_info(name_or_id)
else:
raise ValueError("Unsupported RPC worker ID type {}".format(name_or_id))
@_require_initialized
def remote(to, func, args=None, kwargs=None):
r"""
Make a ``remote`` call to run ``func`` on worker ``to``, and returns an
``RRef`` to the result value immediately. Worker ``to`` will be the owner
of the return ``RRef``, and this worker is a user. The owner manages the
global reference count of its ``RRef``s, and the owner ``RRef`` is only
destructed when globally there is no living references to it.
Arguments:
to (int or str): id or name of the destination worker.
func (callable): builtin functions (like ``torch.add``).
args (tuple): the argument tuple for the ``func`` invocation.
kwargs (dict): is a dictionary of keyword arguments for the ``func``
invocation.
Returns:
A user ``RRef`` instance to the result value. Use the blocking API
``RRef.to_here()`` to retrieve the result value locally.
Example::
On worker 0:
>>> import torch.distributed as dist
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
>>> dist.init_rpc("worker0")
>>> worker1 = dist.get_worker_info("worker1")
>>> rref1 = dist.remote(worker1, torch.add, args=(torch.ones(2), 3))
>>> rref2 = dist.remote(worker1, torch.add, args=(torch.ones(2), 1))
>>> x = rref1.to_here() + rref2.to_here()
>>> dist.join_rpc()
On worker 1:
>>> import torch.distributed as dist
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
>>> dist.init_rpc("worker1")
>>> dist.join_rpc()
"""
qualified_name = torch.jit._find_builtin(func)
args = args if args else ()
kwargs = kwargs if kwargs else {}
info = _to_worker_info(to)
if qualified_name is not None:
return invoke_remote_builtin(
_agent, info, qualified_name, *args, **kwargs)
else:
(pickled_python_udf, tensors) = _internal_rpc_pickler.serialize(
PythonUDF(func, args, kwargs))
return invoke_remote_python_udf(
_agent, info, pickled_python_udf, tensors)
def _invoke_rpc(to, func, args=None, kwargs=None):
if not callable(func):
raise TypeError("function should be callable.")
qualified_name = torch.jit._find_builtin(func)
args = args if args else ()
kwargs = kwargs if kwargs else {}
info = _to_worker_info(to)
if qualified_name is not None:
fut = invoke_rpc_builtin(
_agent, info, qualified_name, *args, **kwargs
)
else:
(pickled_python_udf, tensors) = _internal_rpc_pickler.serialize(
PythonUDF(func, args, kwargs))
fut = invoke_rpc_python_udf(
_agent, info, pickled_python_udf, tensors)
return fut
@_require_initialized
def rpc_sync(to, func, args=None, kwargs=None):
r"""
Make a blocking RPC call to run function ``func`` on worker ``to``. RPC
messages are sent and received in parallel to execution of Python code. This
method is thread-safe.
Arguments:
to (int or str): id or name of the destination worker.
func (callable): any callable function. builtin functions (like
``torch.add``) can be sent over RPC more efficiently.
args (tuple): the argument tuple for the ``func`` invocation.
kwargs (dict): is a dictionary of keyword arguments for the ``func``
invocation.
Returns:
Returns the result of running ``func``on ``args`` and ``kwargs``.
Example::
On worker 0:
>>> import torch.distributed as dist
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
>>> dist.init_model_parallel("worker0")
>>> ret = dist.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
>>> dist.join_rpc()
On worker 1:
>>> import torch.distributed as dist
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
>>> dist.init_model_parallel("worker1")
>>> dist.join_rpc()
"""
fut = _invoke_rpc(to, func, args, kwargs)
return fut.wait()
@_require_initialized
def rpc_async(to, func, args=None, kwargs=None):
r"""
Make a non-blocking RPC call to run function ``func`` on worker ``to``. RPC
messages are sent and received in parallel to execution of Python code. This
method is thread-safe. This method will immediately return a
torch.distributed.FutureMessage that can be awaited on.
Arguments:
to (int or str): id or name of the destination worker.
func (callable): any callable function. builtin functions (like
``torch.add``) can be sent over RPC more efficiently.
args (tuple): the argument tuple for the ``func`` invocation.
kwargs (dict): is a dictionary of keyword arguments for the ``func``
invocation.
Returns:
Returns a ``torch.distributed.FutureMessage`` object that can be waited
on. When completed, the return value of ``func`` on ``args`` and
``kwargs`` can be retrieved from the ``FutureMessage`` object.
Example::
On worker 0:
>>> import torch.distributed as dist
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
>>> dist.init_model_parallel("worker0")
>>> worker1 = dist.get_worker_id("worker1")
>>> fut1 = dist.rpc_async(worker1, torch.add, args=(torch.ones(2), 3))
>>> fut2 = dist.rpc_async(worker1, min, args=(1, 2))
>>> result = fut1.wait() + fut2.wait()
>>> dist.join_rpc()
On worker 1:
>>> import torch.distributed as dist
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
>>> dist.init_model_parallel("worker1")
>>> dist.join_rpc()
"""
fut = _invoke_rpc(to, func, args, kwargs)
return fut
@_require_initialized
def rpc(to, func, args=None, kwargs=None, async_call=False):
r"""
Make an RPC call to run function ``func`` on worker ``to``. By default, it
blocks until the return value is locally available. RPC messages are sent
and received in parallel to execution of Python code. This method is
thread-safe.
Arguments:
to (int or str): id or name of the destination worker.
func (callable): any callable function. builtin functions (like
``torch.add``) can be sent over RPC more efficiently.
args (tuple): the argument tuple for the ``func`` invocation.
kwargs (dict): is a dictionary of keyword arguments for the ``func``
invocation.
async_call (bool): If set to ``True``, this will be an asynchronous RPC,
and returns a ``torch.distributed.FutureMessage``
object immediately. Otherwise, this RPC will block
until the return value is locally available.
(default: ``False``)
Returns:
If ``async_call`` is ``False``, returns the result of running ``func``
on ``args`` and ``kwargs``. If ``async_call`` is ``True``, returns a
``torch.distributed.FutureMessage`` object that can be waited on. When
completed, the return value of ``func`` on ``args`` and ``kwargs`` can
be retrieved from the ``FutureMessage`` object.
Example::
Synchronous example:
On worker 0:
>>> import torch.distributed as dist
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
>>> dist.init_model_parallel("worker0")
>>> ret = dist.rpc("worker1", torch.add, args=(torch.ones(2), 3))
>>> dist.join_rpc()
On worker 1:
>>> import torch.distributed as dist
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
>>> dist.init_model_parallel("worker1")
>>> dist.join_rpc()
Asynchronous example:
On worker 0:
>>> import torch.distributed as dist
>>> dist.init_process_group(backend='gloo', rank=0, world_size=2)
>>> dist.init_model_parallel("worker0")
>>> worker1 = dist.get_worker_info("worker1")
>>> fut1 = dist.rpc(worker1, torch.add, args=(torch.ones(2), 3), async_call=True)
>>> fut2 = dist.rpc(worker1, min, args=(1, 2), async_call=True)
>>> result = fut1.wait() + fut2.wait()
>>> dist.join_rpc()
On worker 1:
>>> import torch.distributed as dist
>>> dist.init_process_group(backend='gloo', rank=1, world_size=2)
>>> dist.init_model_parallel("worker1")
>>> dist.join_rpc()
"""
warnings.warn(
"""dist.rpc is deprecated. Use dist.rpc_async for asynchronous
calls or dist.rpc_sync for synchronous calls instead."""
)
if async_call:
return rpc_async(to, func, args, kwargs)
else:
return rpc_sync(to, func, args, kwargs)