mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[jit] Support Awaitable type (#90863)
We want to make TorchRec sharded models TorchScriptable. TorchRec sharded models uses generic types Awaitable[W] and LazyAwaitable[W] (https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/types.py#L212). In sharded model those types are used instead of contained type W, having the initialization function that produces object of type W. At the moment when the first attribute of W is requested - `LazyAwaitable[W]` will call its initialization function (on the same stack), cache the result inside and work transparently as an object of W. So we can think about it as a delayed object initialization. To support this behavior in TorchScript - we propose a new type to TorchScript - `Await`. In eager mode it works the same as `LazyAwaitable[W]` in TorchRec, being dynamically typed - acting as a type `W` while it is `Await[W]`. Within torchscript it is `Await[W]` and can be only explicitly converted to W, using special function `torch.jit.awaitable_wait(aw)`. Creation of this `Await[W]` is done via another special function `torch.jit.awaitable(func, *args)`. The semantic is close to `torch.jit.Future`, fork, wait and uses the same jit mechanics (inline fork Closures) with the difference that it does not start this function in parallel on fork. It only stores as a lambda inside IValue that will be called on the same thread when `torch.jit.awaitable_wait` is called. For example (more examples in this PR `test/jit/test_await.py`) ``` def delayed(z: Tensor) -> Tensor: return Tensor * 3 @torch.jit.script def fn(x: Tensor): aw: Await[int] = torch.jit._awaitable(delayed, 99) a = torch.eye(2) b = torch.jit._awaitable_wait(aw) return a + b + x ``` Functions semantics: `_awaitable(func -> Callable[Tuple[...], W], *args, **kwargs) -> Await[W]` Creates Await object, owns args and kwargs. Once _awaitable_wait calls, executes function func and owns the result of the function. Following _awaitable_wait calls will return this result from the first function call. `_awaitable_wait(Await[W]) -> W` Returns either cached result of W if it is not the first _awaitable_wait call to this Await object or calls specified function if the first. `_awaitable_nowait(W) -> Await[W]` Creates trivial Await[W] wrapper on specified object To be type complaint for the corner cases. Differential Revision: [D42502706](https://our.internmc.facebook.com/intern/diff/D42502706) Pull Request resolved: https://github.com/pytorch/pytorch/pull/90863 Approved by: https://github.com/davidberard98
This commit is contained in:
committed by
PyTorch MergeBot
parent
53f7fb9a22
commit
2fc73622f8
@ -39,7 +39,8 @@ import torch
|
||||
# Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised.
|
||||
import torch.distributed.rpc
|
||||
import torch.package._mangling as package_mangling
|
||||
from torch._C import Future as CFuture
|
||||
from torch._awaits import _Await
|
||||
from torch._C import _Await as CAwait, Future as CFuture
|
||||
from torch._sources import fake_range, get_source_lines_and_file, parse_def
|
||||
from torch.futures import Future
|
||||
|
||||
@ -1037,6 +1038,12 @@ def is_future(ann) -> bool:
|
||||
return getattr(ann, "__origin__", None) is Future
|
||||
|
||||
|
||||
def is_await(ann) -> bool:
|
||||
if ann is _Await:
|
||||
return True
|
||||
return getattr(ann, "__origin__", None) is _Await
|
||||
|
||||
|
||||
if torch.distributed.rpc.is_available():
|
||||
from torch._C._distributed_rpc import PyRRef
|
||||
from torch.distributed.rpc import RRef
|
||||
@ -1393,6 +1400,8 @@ class _TensorExtractor(pickle.Pickler):
|
||||
# the means to access a value.
|
||||
if isinstance(obj, CFuture) or is_rref_instance(obj):
|
||||
return ""
|
||||
if isinstance(obj, CAwait):
|
||||
return ""
|
||||
if isinstance(obj, torch.cuda.Event):
|
||||
return ""
|
||||
if isinstance(obj, threading.Thread):
|
||||
|
Reference in New Issue
Block a user