[ghstack-poisoned]
This commit is contained in:
James Wu
2025-10-18 13:52:44 -07:00
parent 6dad4c8de7
commit a9e8d188d8

View File

@ -10,6 +10,7 @@ from typing_extensions import override, Self
import torch
import torch.utils._pytree as pytree
from torch._guards import TracingContext
from torch._inductor.standalone_compile import AOTCompiledArtifact
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, Tensor
from torch._subclasses.meta_utils import (
MetaConverter,
@ -422,7 +423,14 @@ class _OpPickleData:
if isinstance(op, str):
return _OpStrPickleData(op)
if isinstance(getattr(op, "__wrapped__", None), AOTCompiledArtifact):
assert hasattr(op, "__wrapped__")
artifact = op.__wrapped__
assert isinstance(artifact, AOTCompiledArtifact)
return _OpPrecompiledPickleData(artifact)
name = torch.fx.Node._pretty_print_target(op)
if isinstance(op, torch._ops.OpOverload):
return cls._pickle_op(name, _OpOverloadPickleData, options)
elif isinstance(op, torch._ops.OpOverloadPacket):
@ -503,6 +511,21 @@ class _OpOverloadPacketPickleData(_OpPickleData):
return obj
class _OpPrecompiledPickleData(_OpPickleData):
def __init__(self, artifact: AOTCompiledArtifact) -> None:
self.contents = artifact.serialize()
def unpickle(self, unpickle_state: _UnpickleState) -> object:
precompiled_artifact = AOTCompiledArtifact.deserialize(self.contents)
import functools
@functools.wraps(precompiled_artifact)
def wrapped(*args: Any) -> Any:
return precompiled_artifact(*args)
return wrapped
class _OpFunctionPickleData(_OpPickleData):
"""
Supports pickling a set of standard/common functions