Fix bug with serialization after AOTAutogradCache hit (#165474)

Fixes #165447

On AOTAutogradCache load, the serialization function we pick is just lambda: self, because the object itself is an AOTAutogradCacheEntry. However, this isn't safe, because `wrap_post_compile` will make `self` unserializable, since it needs to load triton kernels and stuff!

So instead, on AOTAutogradCache load, we preserve the bytes that were used to load the object to begin with, and return that object on a call to serialize(). This effectively makes it so that we save a copy of the pre-hydrated artifact, without needing to do an eager copy until someone actually calls `serialize`.

Test Plan:

Run

```py
import torch

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(2, 4)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(4, 8)
    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

device = "cuda"
m = M().to(device)
sample_inputs = (torch.randn(2, 2, device=device),)
eager_out = m(*sample_inputs)

with torch._dynamo.config.patch("enable_aot_compile", True):
    compiled_fn_path = "./m.pt"
    compiled_fn = torch.compile(
        m,
        fullgraph=True
    ).forward.aot_compile((sample_inputs, {}))

    compiled_fn.save_compiled_function(compiled_fn_path)
    torch._dynamo.reset()
    with torch.compiler.set_stance("fail_on_recompile"):
        with open(compiled_fn_path, "rb") as f:
            loaded_fn = torch.compiler.load_compiled_function(f)

assert loaded_fn is not None

compiled_out = loaded_fn(m, *sample_inputs)

assert torch.allclose(eager_out, compiled_out)
```

twice, see that it succeeds.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165474
Approved by: https://github.com/yiming0416, https://github.com/zhxchen17
This commit is contained in:
James Wu
2025-10-14 14:24:23 -07:00
committed by PyTorch MergeBot
parent cff1b20771
commit dd3b48e85d
2 changed files with 24 additions and 11 deletions

View File

@ -17,7 +17,7 @@ import time
import traceback
from abc import ABC, abstractmethod
from collections.abc import Callable
from copy import copy
from copy import copy, deepcopy
from dataclasses import dataclass
from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import override
@ -963,10 +963,6 @@ class GenericAOTAutogradCacheEntry(Generic[TForward, TBackward]):
)
# Add serialization function back onto object
compiled_function = SerializableCompiledFunction(
compiled_function, lambda: self
)
compiled_function, _ = post_compile(
self.dispatch_wrappers,
compiled_function,
@ -1055,6 +1051,9 @@ def deserialize_bundled_cache_entry(entry: BundledAOTAutogradCacheEntry) -> Call
# so we don't have a place to track cudagraphs here.
cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs)
boxed_forward_device_index = BoxedDeviceIndex(None)
# We need to make a clean copy of the cache entry
# in case it needs to be serialized again
serializable_copy = deepcopy(entry)
compiled_fn = entry.wrap_post_compile(
[],
entry.sanitized_aot_config,
@ -1063,6 +1062,8 @@ def deserialize_bundled_cache_entry(entry: BundledAOTAutogradCacheEntry) -> Call
"boxed_forward_device_index": boxed_forward_device_index,
},
)
# Ensure the deserialized cache entry is still serializable
compiled_fn = SerializableCompiledFunction(compiled_fn, lambda: serializable_copy)
# TODO: this ignores flat_params, which can exist
# if inline_builtin_nn_modules=False
@ -1155,13 +1156,19 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
cache_key, debug_lines = autograd_cache_key(
gm, args, aot_config, fx_config
)
entry: Optional[GenericAOTAutogradCacheEntry] = (
result: Optional[tuple[GenericAOTAutogradCacheEntry, bytes]] = (
AOTAutogradCache._lookup(
cache_key, local, remote, args, cache_info, aot_config
)
)
if entry is not None:
if result is not None:
(entry, pickled_content) = result
compiled_fn = entry.wrap_post_compile(args, aot_config, fx_config)
# Make the compiled_fn serializable, where the serialize function just
# makes a copy of the original entry before post compile via the pickled content
compiled_fn = SerializableCompiledFunction(
compiled_fn, lambda: pickle.loads(pickled_content)
)
log.info("AOTAutograd cache hit for key %s", cache_key)
counters["aot_autograd"]["autograd_cache_hit"] += 1
@ -1321,7 +1328,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
args: list[Any],
cache_info: dict[str, Any],
aot_config: Optional[AOTConfig],
) -> Optional[GenericAOTAutogradCacheEntry]:
) -> Optional[tuple[GenericAOTAutogradCacheEntry, bytes]]:
"""Given a key generated by AOTAutogradCachePickler, look up its location in the cache."""
remote_cache: Optional[RemoteCache[JsonDataTy]] = None
if remote:
@ -1330,6 +1337,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
symints = AOTAutogradCache._filter_backed_symints(args)
hints = [hint_int(s) for s in symints]
entry = None
pickled_content = None
try:
(
entry,
@ -1363,7 +1371,11 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
log.info("AOTAutograd cache unable to load compiled graph: %s", e)
if config.strict_autograd_cache:
raise e
return entry
if entry is not None:
assert pickled_content is not None
return (entry, pickled_content)
else:
return None
@staticmethod
def _write_to_local_cache(key: str, content: bytes):

View File

@ -158,7 +158,7 @@ class CompiledArtifact:
AOTAutogradCache,
)
entry = AOTAutogradCache._lookup(
result = AOTAutogradCache._lookup(
key,
local=True,
remote=False,
@ -167,7 +167,8 @@ class CompiledArtifact:
aot_config=None,
)
assert entry is not None
assert result is not None
(entry, _) = result
from .compile_fx import _CompileFxKwargs