Add rpc.functions.async_execution decorator for rpc_sync/rpc_async (#39216)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39216

The `rpc.functions.async_execution` decorator specifies that the
wrapped function is guaranteed to return a `torch.futures.Future`.
The decorator adds a `_wrapped_async_rpc_function` attribute to
the wrapper function. The caller retrieves this information and
then sets `isAsyncFunction` argument accordingly which is later
added to PythonCall RPC message as a field. On the callee side,
if the PythonCall carries an asynchronous function, it will cast
the function's return value to a jit::PythonFutureWrapper object,
and then install response creation and communication as a callback
on the that jit::PythonFutureWrapper.

For applications, this feature is useful when a function needs to
wait for IO or additional singaling. In those cases, marking the
user function as `rpc.functions.async_execution` will prevent it
from blocking one thread on callee for too long.

Test Plan: Imported from OSS

Reviewed By: rohan-varma

Differential Revision: D21779962

fbshipit-source-id: 6b6aa698bf6f91dad6ed2a7ee433df429b59e941
This commit is contained in:
Shen Li
2020-06-02 23:19:21 -07:00
committed by Facebook GitHub Bot
parent 15ad9dd30f
commit a05ef17e46
16 changed files with 367 additions and 27 deletions

View File

@ -27,6 +27,7 @@ from . import (
_set_and_start_rpc_agent,
backend_registry,
)
from .internal import (
PythonUDF,
RPCExecMode,
@ -515,6 +516,8 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RP
args = args if args else ()
kwargs = kwargs if kwargs else {}
is_async_fn = hasattr(func, "_wrapped_async_rpc_function")
if qualified_name is not None:
fut = _invoke_rpc_builtin(dst_worker_info, qualified_name, rpc_timeout, *args, **kwargs)
elif isinstance(func, torch.jit.ScriptFunction):
@ -525,7 +528,7 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RP
(pickled_python_udf, tensors) = _default_pickler.serialize(
PythonUDF(func, args, kwargs)
)
fut = _invoke_rpc_python_udf(dst_worker_info, pickled_python_udf, tensors, rpc_timeout)
fut = _invoke_rpc_python_udf(dst_worker_info, pickled_python_udf, tensors, rpc_timeout, is_async_fn)
if should_profile:
assert torch.autograd._profiler_enabled()
assert rf is not None