mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix bug with serialization after AOTAutogradCache hit (#165474)
Fixes #165447 On AOTAutogradCache load, the serialization function we pick is just lambda: self, because the object itself is an AOTAutogradCacheEntry. However, this isn't safe, because `wrap_post_compile` will make `self` unserializable, since it needs to load triton kernels and stuff! So instead, on AOTAutogradCache load, we preserve the bytes that were used to load the object to begin with, and return that object on a call to serialize(). This effectively makes it so that we save a copy of the pre-hydrated artifact, without needing to do an eager copy until someone actually calls `serialize`. Test Plan: Run ```py import torch class M(torch.nn.Module): def __init__(self): super().__init__() self.linear1 = torch.nn.Linear(2, 4) self.relu = torch.nn.ReLU() self.linear2 = torch.nn.Linear(4, 8) def forward(self, x): return self.linear2(self.relu(self.linear1(x))) device = "cuda" m = M().to(device) sample_inputs = (torch.randn(2, 2, device=device),) eager_out = m(*sample_inputs) with torch._dynamo.config.patch("enable_aot_compile", True): compiled_fn_path = "./m.pt" compiled_fn = torch.compile( m, fullgraph=True ).forward.aot_compile((sample_inputs, {})) compiled_fn.save_compiled_function(compiled_fn_path) torch._dynamo.reset() with torch.compiler.set_stance("fail_on_recompile"): with open(compiled_fn_path, "rb") as f: loaded_fn = torch.compiler.load_compiled_function(f) assert loaded_fn is not None compiled_out = loaded_fn(m, *sample_inputs) assert torch.allclose(eager_out, compiled_out) ``` twice, see that it succeeds. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165474 Approved by: https://github.com/yiming0416, https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
cff1b20771
commit
dd3b48e85d
@ -17,7 +17,7 @@ import time
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from copy import copy
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar, Union
|
||||
from typing_extensions import override
|
||||
@ -963,10 +963,6 @@ class GenericAOTAutogradCacheEntry(Generic[TForward, TBackward]):
|
||||
)
|
||||
|
||||
# Add serialization function back onto object
|
||||
compiled_function = SerializableCompiledFunction(
|
||||
compiled_function, lambda: self
|
||||
)
|
||||
|
||||
compiled_function, _ = post_compile(
|
||||
self.dispatch_wrappers,
|
||||
compiled_function,
|
||||
@ -1055,6 +1051,9 @@ def deserialize_bundled_cache_entry(entry: BundledAOTAutogradCacheEntry) -> Call
|
||||
# so we don't have a place to track cudagraphs here.
|
||||
cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs)
|
||||
boxed_forward_device_index = BoxedDeviceIndex(None)
|
||||
# We need to make a clean copy of the cache entry
|
||||
# in case it needs to be serialized again
|
||||
serializable_copy = deepcopy(entry)
|
||||
compiled_fn = entry.wrap_post_compile(
|
||||
[],
|
||||
entry.sanitized_aot_config,
|
||||
@ -1063,6 +1062,8 @@ def deserialize_bundled_cache_entry(entry: BundledAOTAutogradCacheEntry) -> Call
|
||||
"boxed_forward_device_index": boxed_forward_device_index,
|
||||
},
|
||||
)
|
||||
# Ensure the deserialized cache entry is still serializable
|
||||
compiled_fn = SerializableCompiledFunction(compiled_fn, lambda: serializable_copy)
|
||||
|
||||
# TODO: this ignores flat_params, which can exist
|
||||
# if inline_builtin_nn_modules=False
|
||||
@ -1155,13 +1156,19 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
||||
cache_key, debug_lines = autograd_cache_key(
|
||||
gm, args, aot_config, fx_config
|
||||
)
|
||||
entry: Optional[GenericAOTAutogradCacheEntry] = (
|
||||
result: Optional[tuple[GenericAOTAutogradCacheEntry, bytes]] = (
|
||||
AOTAutogradCache._lookup(
|
||||
cache_key, local, remote, args, cache_info, aot_config
|
||||
)
|
||||
)
|
||||
if entry is not None:
|
||||
if result is not None:
|
||||
(entry, pickled_content) = result
|
||||
compiled_fn = entry.wrap_post_compile(args, aot_config, fx_config)
|
||||
# Make the compiled_fn serializable, where the serialize function just
|
||||
# makes a copy of the original entry before post compile via the pickled content
|
||||
compiled_fn = SerializableCompiledFunction(
|
||||
compiled_fn, lambda: pickle.loads(pickled_content)
|
||||
)
|
||||
log.info("AOTAutograd cache hit for key %s", cache_key)
|
||||
|
||||
counters["aot_autograd"]["autograd_cache_hit"] += 1
|
||||
@ -1321,7 +1328,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
||||
args: list[Any],
|
||||
cache_info: dict[str, Any],
|
||||
aot_config: Optional[AOTConfig],
|
||||
) -> Optional[GenericAOTAutogradCacheEntry]:
|
||||
) -> Optional[tuple[GenericAOTAutogradCacheEntry, bytes]]:
|
||||
"""Given a key generated by AOTAutogradCachePickler, look up its location in the cache."""
|
||||
remote_cache: Optional[RemoteCache[JsonDataTy]] = None
|
||||
if remote:
|
||||
@ -1330,6 +1337,7 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
||||
symints = AOTAutogradCache._filter_backed_symints(args)
|
||||
hints = [hint_int(s) for s in symints]
|
||||
entry = None
|
||||
pickled_content = None
|
||||
try:
|
||||
(
|
||||
entry,
|
||||
@ -1363,7 +1371,11 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
|
||||
log.info("AOTAutograd cache unable to load compiled graph: %s", e)
|
||||
if config.strict_autograd_cache:
|
||||
raise e
|
||||
return entry
|
||||
if entry is not None:
|
||||
assert pickled_content is not None
|
||||
return (entry, pickled_content)
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _write_to_local_cache(key: str, content: bytes):
|
||||
|
@ -158,7 +158,7 @@ class CompiledArtifact:
|
||||
AOTAutogradCache,
|
||||
)
|
||||
|
||||
entry = AOTAutogradCache._lookup(
|
||||
result = AOTAutogradCache._lookup(
|
||||
key,
|
||||
local=True,
|
||||
remote=False,
|
||||
@ -167,7 +167,8 @@ class CompiledArtifact:
|
||||
aot_config=None,
|
||||
)
|
||||
|
||||
assert entry is not None
|
||||
assert result is not None
|
||||
(entry, _) = result
|
||||
|
||||
from .compile_fx import _CompileFxKwargs
|
||||
|
||||
|
Reference in New Issue
Block a user