pickler for GraphModule (#141659)

Pickling GraphModule needs some special handling for wrapping things that normally can't be pickled - but async compile needs to pass them across a wire so we need to be able to serialize it - add some helpers to enable that.

Differential Revision: [D68921318](https://our.internmc.facebook.com/intern/diff/D68921318)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141659
Approved by: https://github.com/jamesjwu
This commit is contained in:
Aaron Orenstein
2025-01-30 14:05:26 -08:00
committed by PyTorch MergeBot
parent f9227e7c33
commit 57d8278ab9
13 changed files with 1014 additions and 41 deletions

View File

@ -541,11 +541,27 @@ class _CustomViewFunc(ViewFunc[_TensorT], Generic[_TensorT]):
return self.func(new_base, symint_visitor_fn, tensor_visitor_fn)
# A callback where the device is either optional or required.
# All of these satisfy this protocol:
# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str])
# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str] = "meta")
# def mk(arg: Callable[[], torch.Tensor], device: Optional[Union[torch.device, str]] = None)
class _MetaTensorCallback(Protocol, Generic[_TensorT_cov]):
def __call__(
self, arg: Callable[[], torch.Tensor], /, *, device: Union[torch.device, str]
) -> _TensorT_cov:
...
class _MetaTensorCallbackKwargs(TypedDict, total=False):
device: Union[torch.device, str]
class _MetaTensorCallback(Protocol, Generic[_TensorT_cov]):
# A callback where the device may not be provided (is optional).
# All of these satisfy this protocol:
# def mk(arg: Callable[[], torch.Tensor], device: Union[torch.device, str] = "meta")
# def mk(arg: Callable[[], torch.Tensor], device: Optional[Union[torch.device, str]] = None)
class _MetaTensorCallbackOptDevice(Protocol, Generic[_TensorT_cov]):
def __call__(
self,
arg: Callable[[], torch.Tensor],
@ -832,11 +848,13 @@ class MetaConverter(Generic[_TensorT]):
self,
t: MetaTensorDesc,
shape_env: Optional[ShapeEnv],
callback: _MetaTensorCallback[_TensorT],
callback_: _MetaTensorCallback[_TensorT],
source: Optional[Source],
symbolic_context: Optional[SymbolicContext],
) -> _TensorT:
callback = functools.partial(callback, device=t.device)
callback: _MetaTensorCallbackOptDevice = functools.partial(
callback_, device=t.device
)
if source is None:
from torch._dynamo.source import ConstantSource
@ -981,7 +999,7 @@ class MetaConverter(Generic[_TensorT]):
symbolic_context: Optional[
torch.fx.experimental.symbolic_shapes.SymbolicContext
],
callback: _MetaTensorCallback[_TensorT],
callback: _MetaTensorCallbackOptDevice[_TensorT],
source: torch._guards.Source,
) -> _TensorT:
# We are hitting plain meta_desc tensor so actually
@ -1216,7 +1234,7 @@ class MetaConverter(Generic[_TensorT]):
shape_env: Optional[
torch.fx.experimental.symbolic_shapes.ShapeEnv
] = shape_env,
callback: _MetaTensorCallback[_TensorT] = callback,
callback: _MetaTensorCallbackOptDevice[_TensorT] = callback,
) -> torch.Tensor:
# It's possible to close over an undefined tensor (e.g. NJT's lengths).
if visited_t is None:
@ -1769,7 +1787,9 @@ class MetaConverter(Generic[_TensorT]):
# Thanks to storage resizing, it's possible to end up with a tensor
# that advertises a real size, but has a storage that actually has zero bytes.
# Need to reflect this in the generated FakeTensor.
if t.storage is not None and t.storage.size == 0:
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
if t.storage is not None and guard_size_oblivious(t.storage.size == 0):
r.untyped_storage().resize_(0)
if t.is_parameter: