[jit] Support Awaitable type (#90863)

We want to make TorchRec sharded models TorchScriptable.

TorchRec sharded models uses generic types Awaitable[W] and LazyAwaitable[W] (https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/types.py#L212).
In sharded model those types are used instead of contained type W, having the initialization function that produces object of type W.

At the moment when the first attribute of W is requested - `LazyAwaitable[W]` will call its initialization function (on the same stack), cache the result inside and work transparently as an object of W. So we can think about it as a delayed object initialization.

To support this behavior in TorchScript - we propose a new type to TorchScript - `Await`.
In eager mode it works the same as `LazyAwaitable[W]` in TorchRec, being dynamically typed - acting as a type `W` while it is `Await[W]`.

Within torchscript it is `Await[W]` and can be only explicitly converted to W, using special function `torch.jit.awaitable_wait(aw)`.
Creation of this `Await[W]` is done via another special function `torch.jit.awaitable(func, *args)`.

The semantic is close to `torch.jit.Future`, fork, wait and uses the same jit mechanics (inline fork Closures) with the difference that it does not start this function in parallel on fork. It only stores as a lambda inside IValue that will be called on the same thread when `torch.jit.awaitable_wait` is called.

For example (more examples in this PR `test/jit/test_await.py`)
```
      def delayed(z: Tensor) -> Tensor:
          return Tensor * 3

      @torch.jit.script
      def fn(x: Tensor):
          aw: Await[int] = torch.jit._awaitable(delayed, 99)
          a = torch.eye(2)
          b = torch.jit._awaitable_wait(aw)
          return a + b + x
```

Functions semantics:

`_awaitable(func -> Callable[Tuple[...], W], *args, **kwargs) -> Await[W]`

Creates Await object, owns args and kwargs. Once _awaitable_wait calls, executes function func and owns the result of the function. Following _awaitable_wait calls will return this result from the first function call.

`_awaitable_wait(Await[W]) -> W`
Returns either cached result of W if it is not the first _awaitable_wait call to this Await object or calls specified function if the first.

`_awaitable_nowait(W) -> Await[W]`

Creates trivial Await[W] wrapper on specified object To be type complaint for the corner cases.

Differential Revision: [D42502706](https://our.internmc.facebook.com/intern/diff/D42502706)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90863
Approved by: https://github.com/davidberard98
This commit is contained in:
Ivan Kobzarev
2023-01-27 11:04:26 -08:00
committed by PyTorch MergeBot
parent 53f7fb9a22
commit 2fc73622f8
44 changed files with 1068 additions and 19 deletions

View File

@ -39,7 +39,8 @@ import torch
# Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised.
import torch.distributed.rpc
import torch.package._mangling as package_mangling
from torch._C import Future as CFuture
from torch._awaits import _Await
from torch._C import _Await as CAwait, Future as CFuture
from torch._sources import fake_range, get_source_lines_and_file, parse_def
from torch.futures import Future
@ -1037,6 +1038,12 @@ def is_future(ann) -> bool:
return getattr(ann, "__origin__", None) is Future
def is_await(ann) -> bool:
if ann is _Await:
return True
return getattr(ann, "__origin__", None) is _Await
if torch.distributed.rpc.is_available():
from torch._C._distributed_rpc import PyRRef
from torch.distributed.rpc import RRef
@ -1393,6 +1400,8 @@ class _TensorExtractor(pickle.Pickler):
# the means to access a value.
if isinstance(obj, CFuture) or is_rref_instance(obj):
return ""
if isinstance(obj, CAwait):
return ""
if isinstance(obj, torch.cuda.Event):
return ""
if isinstance(obj, threading.Thread):