mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
pickler for GraphModule (#141659)
Pickling GraphModule needs some special handling for wrapping things that normally can't be pickled - but async compile needs to pass them across a wire so we need to be able to serialize it - add some helpers to enable that. Differential Revision: [D68921318](https://our.internmc.facebook.com/intern/diff/D68921318) Pull Request resolved: https://github.com/pytorch/pytorch/pull/141659 Approved by: https://github.com/jamesjwu
This commit is contained in:
committed by
PyTorch MergeBot
parent
f9227e7c33
commit
57d8278ab9
@ -541,11 +541,27 @@ class _CustomViewFunc(ViewFunc[_TensorT], Generic[_TensorT]):
|
||||
return self.func(new_base, symint_visitor_fn, tensor_visitor_fn)
|
||||
|
||||
|
||||
# A callback where the device is either optional or required.
|
||||
# All of these satisfy this protocol:
|
||||
# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str])
|
||||
# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str] = "meta")
|
||||
# def mk(arg: Callable[[], torch.Tensor], device: Optional[Union[torch.device, str]] = None)
|
||||
class _MetaTensorCallback(Protocol, Generic[_TensorT_cov]):
|
||||
def __call__(
|
||||
self, arg: Callable[[], torch.Tensor], /, *, device: Union[torch.device, str]
|
||||
) -> _TensorT_cov:
|
||||
...
|
||||
|
||||
|
||||
class _MetaTensorCallbackKwargs(TypedDict, total=False):
|
||||
device: Union[torch.device, str]
|
||||
|
||||
|
||||
class _MetaTensorCallback(Protocol, Generic[_TensorT_cov]):
|
||||
# A callback where the device may not be provided (is optional).
|
||||
# All of these satisfy this protocol:
|
||||
# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str] = "meta")
|
||||
# def mk(arg: Callable[[], torch.Tensor], device: Optional[Union[torch.device, str]] = None)
|
||||
class _MetaTensorCallbackOptDevice(Protocol, Generic[_TensorT_cov]):
|
||||
def __call__(
|
||||
self,
|
||||
arg: Callable[[], torch.Tensor],
|
||||
@ -832,11 +848,13 @@ class MetaConverter(Generic[_TensorT]):
|
||||
self,
|
||||
t: MetaTensorDesc,
|
||||
shape_env: Optional[ShapeEnv],
|
||||
callback: _MetaTensorCallback[_TensorT],
|
||||
callback_: _MetaTensorCallback[_TensorT],
|
||||
source: Optional[Source],
|
||||
symbolic_context: Optional[SymbolicContext],
|
||||
) -> _TensorT:
|
||||
callback = functools.partial(callback, device=t.device)
|
||||
callback: _MetaTensorCallbackOptDevice = functools.partial(
|
||||
callback_, device=t.device
|
||||
)
|
||||
if source is None:
|
||||
from torch._dynamo.source import ConstantSource
|
||||
|
||||
@ -981,7 +999,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
symbolic_context: Optional[
|
||||
torch.fx.experimental.symbolic_shapes.SymbolicContext
|
||||
],
|
||||
callback: _MetaTensorCallback[_TensorT],
|
||||
callback: _MetaTensorCallbackOptDevice[_TensorT],
|
||||
source: torch._guards.Source,
|
||||
) -> _TensorT:
|
||||
# We are hitting plain meta_desc tensor so actually
|
||||
@ -1216,7 +1234,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||
shape_env: Optional[
|
||||
torch.fx.experimental.symbolic_shapes.ShapeEnv
|
||||
] = shape_env,
|
||||
callback: _MetaTensorCallback[_TensorT] = callback,
|
||||
callback: _MetaTensorCallbackOptDevice[_TensorT] = callback,
|
||||
) -> torch.Tensor:
|
||||
# It's possible to close over an undefined tensor (e.g. NJT's lengths).
|
||||
if visited_t is None:
|
||||
@ -1769,7 +1787,9 @@ class MetaConverter(Generic[_TensorT]):
|
||||
# Thanks to storage resizing, it's possible to end up with a tensor
|
||||
# that advertises a real size, but has a storage that actually has zero bytes.
|
||||
# Need to reflect this in the generated FakeTensor.
|
||||
if t.storage is not None and t.storage.size == 0:
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
|
||||
if t.storage is not None and guard_size_oblivious(t.storage.size == 0):
|
||||
r.untyped_storage().resize_(0)
|
||||
|
||||
if t.is_parameter:
|
||||
|
Reference in New Issue
Block a user