mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 00:14:54 +08:00
support torch script call over rpc (#30063)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/30063 This diff makes following changes: 1. Providing a new set of python rpc privated APIs, they can accept an annotated TorchScript call and this call can be serialized, deserialized and executed in C++ without GIL. These privated APIs will be binded to JIT in the future, and they are different from public APIs as future JIT binded private APIs will be able to accept qualified_name, not callables. These private APIs are subject to be deprecated once JIT supports torch script function to be a JIT type. Also, these APIs require torch script function to be defined and annotated by users in python land, it can not be script class/module constructor or class/module methods. 2. This diff also allows public rpc APIs to accept an annotated TorchScript call and execute code path that above private APIs ran on. Therefore if users invoke an annotated TorchScript call over RPC, this call can be serialized, deserialized and executed in C++ without GIL as well. 3. The above private APIs call a newly defined C++ function to make rpc torch script call to be serialized, deserialized and executed in C++ land. This C++ function returns an ivalue::Future. so that in follow up diff this C++ function can be called when these privated APIs are binded to JIT. 4. script_call.cpp/.h and request_callback_impl.cpp files are refactored accordingly so that torch script call and builtin call can share same message type and codes. 5. refactored deserializeResponse() and added a new utility to deserizalize response to IValue ghstack-source-id: 96638829 Test Plan: unit test Differential Revision: D18482934 fbshipit-source-id: bd82a0d820c47a8e45b2e7c616eca06573f7d7ea
This commit is contained in:
committed by
Facebook Github Bot
parent
5f1a881cb8
commit
dbd737158b
@ -7,6 +7,7 @@ from . import (
|
||||
_invoke_remote_python_udf,
|
||||
_invoke_rpc_builtin,
|
||||
_invoke_rpc_python_udf,
|
||||
_invoke_rpc_script,
|
||||
_start_rpc_agent,
|
||||
backend_registry,
|
||||
)
|
||||
@ -24,6 +25,7 @@ import sys
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from torch._jit_internal import _qualified_name
|
||||
|
||||
_agent = None
|
||||
# NB: Ignoring RRef leaks during shutdown. Without this, applications have to
|
||||
@ -326,8 +328,8 @@ def rpc_sync(to, func, args=None, kwargs=None):
|
||||
|
||||
Arguments:
|
||||
to (str or WorkerInfo): id or name of the destination worker.
|
||||
func (callable): any callable function. builtin functions (like
|
||||
:meth:`torch.add`) can be sent over RPC more efficiently.
|
||||
func (callable): any callable function. builtin or annotated TorchScript
|
||||
functions (like meth:`torch.add`) can be sent over RPC more efficiently.
|
||||
args (tuple): the argument tuple for the ``func`` invocation.
|
||||
kwargs (dict): is a dictionary of keyword arguments for the ``func``
|
||||
invocation.
|
||||
@ -356,9 +358,32 @@ def rpc_sync(to, func, args=None, kwargs=None):
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||||
>>> rpc.shutdown()
|
||||
|
||||
If invoking an annotated TorchScript function, then run the following
|
||||
code in two different processes:
|
||||
|
||||
>>> # On worker 0:
|
||||
>>> @torch.jit.script
|
||||
>>> def my_script_add(t1, t2):
|
||||
>>> return torch.add(t1, t2)
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
||||
>>> ret = rpc.rpc_sync("worker1", my_script_add, args=(torch.ones(2), 3))
|
||||
>>> rpc.shutdown()
|
||||
|
||||
>>> # On worker 1:
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||||
>>> rpc.shutdown()
|
||||
|
||||
"""
|
||||
fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs)
|
||||
return fut.wait()
|
||||
# If invoking an annotated TorchScript function,
|
||||
# call the internal API _rpc_sync_torchscript()
|
||||
if isinstance(func, torch.jit.ScriptFunction):
|
||||
return _rpc_sync_torchscript(to, _qualified_name(func), args, kwargs)
|
||||
else:
|
||||
fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs)
|
||||
return fut.wait()
|
||||
|
||||
|
||||
@_require_initialized
|
||||
@ -371,8 +396,8 @@ def rpc_async(to, func, args=None, kwargs=None):
|
||||
|
||||
Arguments:
|
||||
to (str or WorkerInfo): id or name of the destination worker.
|
||||
func (callable): any callable function. builtin functions (like
|
||||
:meth:`torch.add`) can be sent over RPC more efficiently.
|
||||
func (callable): any callable function. builtin or annotated TorchScript
|
||||
functions (like meth:`torch.add`) can be sent over RPC more efficiently.
|
||||
args (tuple): the argument tuple for the ``func`` invocation.
|
||||
kwargs (dict): is a dictionary of keyword arguments for the ``func``
|
||||
invocation.
|
||||
@ -405,6 +430,146 @@ def rpc_async(to, func, args=None, kwargs=None):
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||||
>>> rpc.shutdown()
|
||||
|
||||
If invoking an annotated TorchScript function, then run the following
|
||||
code in two different processes:
|
||||
|
||||
>>> # On worker 0:
|
||||
>>> @torch.jit.script
|
||||
>>> def my_script_add(t1, t2):
|
||||
>>> return torch.add(t1, t2)
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
||||
>>> fut = rpc.rpc_async("worker1", my_script_add, args=(torch.ones(2), 3))
|
||||
>>> ret = fut.wait()
|
||||
>>> rpc.shutdown()
|
||||
|
||||
>>> # On worker 1:
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||||
>>> rpc.shutdown()
|
||||
"""
|
||||
fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs)
|
||||
# If invoking an annotated TorchScript function,
|
||||
# call the internal API _rpc_async_torchscript()
|
||||
if isinstance(func, torch.jit.ScriptFunction):
|
||||
fut = _rpc_async_torchscript(to, _qualified_name(func), args, kwargs)
|
||||
else:
|
||||
fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs)
|
||||
return fut
|
||||
|
||||
|
||||
# All below private APIs are for making rpc torch script call that can be
|
||||
# serialized, deserialized and exectued in C++ without GIL.
|
||||
# These APIs will be binded to JIT and can be called in torch script
|
||||
# function/class/module in the future. But since JIT does not support torch
|
||||
# script function to be a jit type yet, the future binded APIs can only accept
|
||||
# qualified_name of the function as arg, that is why these APIs are made
|
||||
# to be private and different from above public rpc APIs.
|
||||
# Because JIT does not support torch script function to be a jit type, right now
|
||||
# these APIs can only accept torch script call to only be user annotated
|
||||
# torchscript function, they do not accept annotated torchscript class name or
|
||||
# script module class name or their class method name right now.
|
||||
@_require_initialized
|
||||
def _rpc_sync_torchscript(to, qualified_name, args=None, kwargs=None):
|
||||
r"""
|
||||
Make a blocking RPC call to run TorchScript function ``func`` on worker ``to``.
|
||||
RPC messages are sent and received in parallel to execution of Python code. This
|
||||
method is thread-safe.
|
||||
|
||||
Arguments:
|
||||
to (str): name of the destination worker.
|
||||
qualified_name (str): qualifited name of python function annotated with
|
||||
@torch.jit.script
|
||||
(like ``moduleName::torchScriptFuncName``)
|
||||
can be sent over RPC more efficiently.
|
||||
args (tuple): the argument tuple for the ``func`` invocation.
|
||||
kwargs (dict): is a dictionary of keyword arguments for the ``func``
|
||||
invocation.
|
||||
|
||||
Returns:
|
||||
Returns the result of running ``func`` on ``args`` and ``kwargs``.
|
||||
|
||||
Example::
|
||||
Make sure that ``MASTER_ADDRESS`` and ``MASTER_PORT`` are set properly
|
||||
on both workers. Refer to :meth:`~torch.distributed.init_process_group`
|
||||
API for more details. For example,
|
||||
|
||||
>>> export MASTER_ADDRESS=localhost
|
||||
>>> export MASTER_port=5678
|
||||
|
||||
Then run the following code in two different processes:
|
||||
|
||||
>>> # On worker 0:
|
||||
>>> @torch.jit.script
|
||||
>>> def my_script_add(t1, t2):
|
||||
>>> return torch.add(t1, t2)
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> from torch._jit_internal import _qualified_name
|
||||
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
||||
>>> ret = rpc._rpc_sync_torchscript("worker1", _qualified_name(my_script_add), args=(torch.ones(2), 3))
|
||||
>>> rpc.shutdown()
|
||||
|
||||
>>> # On worker 1:
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||||
>>> rpc.shutdown()
|
||||
"""
|
||||
args = args if args else ()
|
||||
kwargs = kwargs if kwargs else {}
|
||||
fut = _invoke_rpc_script(to, qualified_name, *args, **kwargs)
|
||||
return fut.wait()
|
||||
|
||||
|
||||
@_require_initialized
|
||||
def _rpc_async_torchscript(to, qualified_name, args=None, kwargs=None):
|
||||
r"""
|
||||
Make a non-blocking RPC call to run TorchScript function ``func`` on worker ``to``.
|
||||
RPC messages are sent and received in parallel to execution of Python code. This
|
||||
method is thread-safe. This method will immediately return a
|
||||
_pyFuture that can be awaited on.
|
||||
|
||||
Arguments:
|
||||
to (str): name of the destination worker.
|
||||
qualified_name (str): qualifited name of python function annotated with
|
||||
@torch.jit.script
|
||||
(like ``moduleName::torchScriptFuncName``)
|
||||
can be sent over RPC more efficiently.
|
||||
args (tuple): the argument tuple for the ``func`` invocation.
|
||||
kwargs (dict): is a dictionary of keyword arguments for the ``func``
|
||||
invocation.
|
||||
|
||||
Returns:
|
||||
Returns a _pyFuture object that can be waited
|
||||
on. When completed, the return value of ``func`` on ``args`` and
|
||||
``kwargs`` can be retrieved from the _pyFuture object.
|
||||
|
||||
Example::
|
||||
Make sure that ``MASTER_ADDRESS`` and ``MASTER_PORT`` are set properly
|
||||
on both workers. Refer to :meth:`~torch.distributed.init_process_group`
|
||||
API for more details. For example,
|
||||
|
||||
>>> export MASTER_ADDRESS=localhost
|
||||
>>> export MASTER_port=5678
|
||||
|
||||
Then run the following code in two different processes:
|
||||
|
||||
>>> # On worker 0:
|
||||
>>> @torch.jit.script
|
||||
>>> def my_script_add(t1, t2):
|
||||
>>> return torch.add(t1, t2)
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> from torch._jit_internal import _qualified_name
|
||||
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
|
||||
>>> fut = rpc._rpc_async_torchscript("worker1", _qualified_name(my_script_add), args=(torch.ones(2), 3))
|
||||
>>> ret = fut.wait()
|
||||
>>> rpc.shutdown()
|
||||
|
||||
>>> # On worker 1:
|
||||
>>> import torch.distributed.rpc as rpc
|
||||
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
|
||||
>>> rpc.shutdown()
|
||||
"""
|
||||
args = args if args else ()
|
||||
kwargs = kwargs if kwargs else {}
|
||||
fut = _invoke_rpc_script(to, qualified_name, *args, **kwargs)
|
||||
return fut
|
||||
|
||||
Reference in New Issue
Block a user