Simplify PrecompileContext to no longer be a CacheArtifactManager (#162886)

Summary:
This diff does a big refactor of PrecompileContext to make it considerably simpler: instead of being a CacheArtifactManager and managing a bunch of bytes, it simply stores two things: dynamo cache entries and backend cache entries. When asked, it stitches them together into PrecompileCacheEntries, which are stored by DynamoCache.

This structure then allows us to register DynamoCache to the regular Megacache API, instead of having two separate APIs that are confusing. It also lets us remove the autotune cache integration, since MegaCache API will automatically store autotune cache entries.

The intent here is that users who want to use caching precompile will simply be able to use torch.compiler.save_cache_artifacts as before, just with `torch.dynamo.config.caching_precompile` set to True. They can also directly interact with PrecompileContext if they wish to specifically only load Precompile entries, using PrecompileContext.create_cache_entries().

Saving single entries and such with DynamoCache still works normally.

Test Plan:
All existing unit tests pass.

Rollback Plan:

Differential Revision: D82380307

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162886
Approved by: https://github.com/zhxchen17
This commit is contained in:
James Wu
2025-09-20 01:24:37 +00:00
committed by PyTorch MergeBot
parent 8225a26835
commit bfe9e60ffb
10 changed files with 146 additions and 439 deletions

View File

@ -3580,18 +3580,10 @@ def process_caching_precompile():
)
from torch._dynamo.precompile_context import PrecompileContext
# Serialize all callables, clear PrecompileContext
# TODO: put this under torch.compiler API once ready
serialized = PrecompileContext.serialize()
PrecompileContext.clear()
if serialized is not None:
artifacts, info = serialized
print(
f"Saving {len(info.precompile_dynamo_artifacts)} Precompile Artifact(s)..."
)
results = PrecompileContext.deserialize(artifacts)
assert results is not None
PrecompileContext.populate_caches(results)
debug_info = PrecompileContext.save_to_dynamo_cache()
print(
f"Saved {len(debug_info['dynamo'])} precompile artifacts with {len(debug_info['backends'])} backends"
)
def process_entry(rank, runner, original_dir, args):

View File

@ -16,13 +16,10 @@ from torch._dynamo.package import CompilePackage, DiskDynamoStore, DynamoCache
from torch._dynamo.precompile_context import PrecompileContext
from torch._dynamo.testing import reduce_to_scalar_loss
from torch._functorch import config as functorch_config
from torch._inductor.mock_cache import global_stats, PatchCaches, Stats
from torch._inductor.runtime.runtime_utils import cache_dir
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
skipIfRocm,
skipIfXpu,
)
from torch.testing._internal.inductor_utils import (
HAS_CUDA_AND_TRITON,
@ -50,9 +47,7 @@ class TestPackage(torch._inductor.test_case.TestCase):
DynamoCache.clear()
PrecompileContext.clear()
def _save_and_reload(
self, expected_backends, expected_dynamo, expected_autotune=None
):
def _save_and_reload(self, expected_backends, expected_dynamo):
"""
Serializes all artifacts, clears all caches, then reloads the serialized artifact
Simulates a new process.
@ -61,24 +56,12 @@ class TestPackage(torch._inductor.test_case.TestCase):
expected_backends: Expected number of precompile_aot_autograd_artifacts
expected_dynamo: Expected number of precompile_dynamo_artifacts
"""
serialized = PrecompileContext.serialize()
assert serialized is not None
(bytes_, cache_info) = serialized
self.assertEqual(
len(cache_info.precompile_aot_autograd_artifacts), expected_backends
)
self.assertEqual(len(cache_info.precompile_dynamo_artifacts), expected_dynamo)
if expected_autotune is not None:
self.assertEqual(len(cache_info.autotune_artifacts), expected_autotune)
debug_info = PrecompileContext.save_to_dynamo_cache()
self.assertEqual(len(debug_info["dynamo"]), expected_dynamo)
self.assertEqual(len(debug_info["backends"]), expected_backends)
torch._dynamo.reset()
DynamoCache.clear()
PrecompileContext.clear()
deserialized = PrecompileContext.deserialize(bytes_)
assert deserialized is not None
PrecompileContext.populate_caches(deserialized)
@unittest.expectedFailure # FUNCTION_MATCH guard not serializable today
def test_nn_module(self):
class MyModule(torch.nn.Module):
@ -440,41 +423,6 @@ def add(x, y):
self.assertEqual(expected, [result1, result2])
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
@parametrize("device", ("cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
@skipIfXpu
@skipIfRocm
def test_automatic_dynamo_autotune_cache(self, device):
if device == "cuda" and not HAS_CUDA_AND_TRITON:
raise unittest.SkipTest("Requires CUDA/Triton")
if device == "xpu" and not HAS_XPU_AND_TRITON:
raise unittest.SkipTest("Requires XPU/Triton")
def fn(x, y):
return x.sin() + y
arg1 = torch.randn(3, 3, device=device)
arg2 = torch.randn(3, 3, device=device)
expected = fn(arg1, arg2).clone()
with PatchCaches():
compiled_fn1 = torch.compile(fn, mode="max-autotune")
result = compiled_fn1(arg1, arg2).clone()
self.assertEqual(expected, result)
self.assertEqual(global_stats.autotune_local, Stats(1, 0, 1))
DynamoCache.clear()
total_frames = torch._dynamo.convert_frame.FRAME_COUNTER
self._save_and_reload(
expected_backends=1, expected_dynamo=1, expected_autotune=1
)
compiled_fn1 = torch.compile(fn, mode="max-autotune")
with torch.compiler.set_stance("fail_on_recompile"):
result1 = compiled_fn1(arg1, arg2).clone()
self.assertEqual(expected, result1)
self.assertEqual(torch._dynamo.convert_frame.FRAME_COUNTER, total_frames)
self.assertEqual(global_stats.autotune_local, Stats(2, 1, 1))
@parametrize("device", ("cpu", "cuda", "xpu"))
@torch._dynamo.config.patch(caching_precompile=True)
def test_automatic_dynamo_recompiles(self, device):

View File

@ -1,16 +1,9 @@
# Owner(s): ["module: dynamo"]
import pickle
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._functorch
from torch._dynamo.precompile_context import (
EditablePrecompileCacheArtifact,
PrecompileCacheArtifact,
PrecompileContext,
)
from torch._dynamo.precompile_context import BackendCacheArtifact, PrecompileContext
from torch._functorch import config as functorch_config
from torch._functorch._aot_autograd.autograd_cache import (
BundledAOTAutogradCacheArtifact,
@ -49,31 +42,11 @@ class PrecompileContextTests(InductorTestCase):
result.sum().backward()
self.assertEqual(len(PrecompileContext._dynamo_cache_entries), 1)
self.assertEqual(len(PrecompileContext._backend_artifacts_by_key), 1)
self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
result = PrecompileContext.serialize()
assert result is not None
serialized, cache_info = result
self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1)
artifacts = PrecompileContext.deserialize(serialized)
assert artifacts is not None
deserialized = artifacts["precompile_aot_autograd"]
assert len(deserialized) == 1
entry = deserialized[0]
assert isinstance(entry, BundledAOTAutogradCacheArtifact)
entry = entry.after_deserialization()
# Now that we've serialized, there should be no new cache artifacts
self.assertEqual(
len(PrecompileContext._new_cache_artifacts["precompile_aot_autograd"]), 0
)
cache_entries, _ = PrecompileContext.create_cache_entries()
self.assertEqual(len(cache_entries), 1)
@requires_triton()
def test_serialize_by_key(self):
"""
Test that after torch.compile, PrecompileContext._new_cache_artifacts length is 1
"""
def simple_function(x):
return x.sin() + x.cos()
@ -87,14 +60,12 @@ class PrecompileContextTests(InductorTestCase):
self.assertEqual(len(PrecompileContext._backend_artifacts_by_key), 1)
for key in PrecompileContext._backend_artifacts_by_key.keys():
result = PrecompileContext.serialize_artifact_by_key(key)
assert isinstance(result, PrecompileCacheArtifact)
assert isinstance(result, BackendCacheArtifact)
self.assertEqual(result.key, key)
self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
result = PrecompileContext.serialize()
assert result is not None
_, cache_info = result
self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1)
# This should still work
result, _ = PrecompileContext.create_cache_entries()
assert len(result) == 1
@requires_triton()
def test_editable(self):
@ -114,11 +85,7 @@ class PrecompileContextTests(InductorTestCase):
self.assertEqual(len(PrecompileContext._dynamo_cache_entries), 1)
self.assertEqual(len(PrecompileContext._backend_artifacts_by_key), 1)
# Find the key for the artifact of type "precompile_aot_autograd"
key = next(
k
for k, v in PrecompileContext._backend_artifacts_by_key.items()
if isinstance(v, EditablePrecompileCacheArtifact)
)
key = next(iter(PrecompileContext._backend_artifacts_by_key))
def edit_fn(x):
x._my_private_field = 42
@ -130,24 +97,12 @@ class PrecompileContextTests(InductorTestCase):
assert isinstance(result, BundledAOTAutogradCacheArtifact)
self.assertEqual(result.key, key)
self.assertEqual(len(PrecompileContext._new_cache_artifacts), 0)
result = PrecompileContext.serialize()
assert result is not None
artifacts, cache_info = result
self.assertEqual(len(cache_info.precompile_aot_autograd_artifacts), 1)
deserialized = PrecompileContext.deserialize(artifacts)
assert deserialized is not None
aot_autograd_artifacts = deserialized["precompile_aot_autograd"]
result, _ = PrecompileContext.create_cache_entries()
assert len(result) == 1
aot_autograd_artifacts = next(iter(result.values())).backends
assert len(aot_autograd_artifacts) == 1
entry = aot_autograd_artifacts[0]
assert isinstance(entry, BundledAOTAutogradCacheArtifact)
raw_entry = pickle.loads(entry.content)
self.assertEqual(raw_entry._my_private_field, 42)
# Now that we've serialized, there should be no new cache artifacts
self.assertEqual(
len(PrecompileContext._new_cache_artifacts["precompile_aot_autograd"]), 0
)
entry = next(iter(aot_autograd_artifacts.values())).content
self.assertEqual(entry._my_private_field, 42)
if __name__ == "__main__":

View File

@ -182,7 +182,9 @@ class BundledAOTAutogradSerializableCallable(SerializableCallable):
deserialize_bundled_cache_entry,
)
compiled_fn = deserialize_bundled_cache_entry(data)
entry = pickle.loads(data)
compiled_fn = deserialize_bundled_cache_entry(entry)
return cls(compiled_fn)
def __call__(self, *args: Any, **kwargs: Any) -> Any:

View File

@ -747,12 +747,11 @@ class _TorchDynamoContext:
# Create a fresh CompilePackage
self._package.initialize(fn, None, ignore_inlined_sources=False)
else:
cache_entry, backends = result
try:
self._package.initialize(
fn, cache_entry, ignore_inlined_sources=False
fn, result.dynamo, ignore_inlined_sources=False
)
self._package.install(backends)
self._package.install(result.backends)
except RuntimeError as e:
log.warning("Failed to load entry from dynamo cache: %s", e)
self._package.initialize(fn, None, ignore_inlined_sources=False)

View File

@ -316,6 +316,7 @@ class SystemInfo:
def current(cls) -> "SystemInfo":
"""Create a SystemInfo instance with current system information."""
# Get GPU name if CUDA or XPU is available
gpu_name = None
from torch.utils._triton import get_triton_version
gpu_name, toolkit_version = None, None
@ -406,6 +407,19 @@ class _DynamoCacheEntry:
}
@dataclasses.dataclass
class PrecompileCacheEntry:
"""
A full cache entry for caching precompile, for a toplevel torch.compile.
Consists of a _DynamoCacheEntry, which contains all the dynamo related contents,
and a set of backends content. In general, the backend content here will always
be of type precompile_context.BackendCacheArtifact
"""
dynamo: _DynamoCacheEntry
backends: dict[_BackendId, Any]
def _hash_source(source: str) -> str:
sha256_hash = hashlib.sha256()
sha256_hash.update(source.encode())
@ -817,10 +831,8 @@ class DynamoStore(abc.ABC):
PrecompileContext,
)
pickled_result = pickle.dumps(backend)
PrecompileContext.record_artifact(
EagerCacheArtifact.type(), key=backend_id, content=pickled_result
)
result = EagerCacheArtifact(key=backend_id, content=backend)
PrecompileContext.record_artifact(result)
@abc.abstractmethod
def clear(self) -> None: ...
@ -828,8 +840,7 @@ class DynamoStore(abc.ABC):
@abc.abstractmethod
def write(
self,
dynamo: _DynamoCacheEntry,
backends: _Backends,
cache_entry: PrecompileCacheEntry,
path: str,
) -> None:
"""
@ -847,7 +858,7 @@ class DynamoStore(abc.ABC):
Saves a package to a given path. Grabs backends from PrecompileContext.
"""
from torch._dynamo.precompile_context import (
PrecompileCacheArtifact,
BackendCacheArtifact,
PrecompileContext,
)
@ -858,10 +869,12 @@ class DynamoStore(abc.ABC):
raise RuntimeError(
f"Backend {backend_id} is not found in the given backends"
)
assert isinstance(serialized_backend, PrecompileCacheArtifact)
assert isinstance(serialized_backend, BackendCacheArtifact)
backend_content[backend_id] = serialized_backend
self.write(cache_entry, backend_content, key)
entry = PrecompileCacheEntry(cache_entry, backend_content)
self.write(entry, key)
def save_package(self, package: CompilePackage, key: str) -> None:
"""
@ -872,7 +885,7 @@ class DynamoStore(abc.ABC):
self.save_cache_entry(cache_entry, key)
@abc.abstractmethod
def read(self, path: str) -> tuple[_DynamoCacheEntry, _Backends]:
def read(self, path: str) -> PrecompileCacheEntry:
"""
Abstract method to read dynamo cache entry and backends from storage.
@ -884,19 +897,18 @@ class DynamoStore(abc.ABC):
"""
...
def load_cache_entry(
self, key: str
) -> tuple[_DynamoCacheEntry, dict[_BackendId, Any]]:
from torch._dynamo.precompile_context import PrecompileContext
def load_cache_entry(self, key: str) -> PrecompileCacheEntry:
from torch._dynamo.precompile_context import (
BackendCacheArtifact,
PrecompileContext,
)
cache_entry, backend_content = self.read(key)
for backend_id, backend in backend_content.items():
PrecompileContext.record_artifact(
backend.type(), key=backend.key, content=backend.content
)
backend_content[backend_id] = backend
precompile_entry = self.read(key)
for backend in precompile_entry.backends.values():
assert isinstance(backend, BackendCacheArtifact)
PrecompileContext.record_artifact(backend)
return cache_entry, backend_content
return precompile_entry
def load_package(
self, fn: Any, key: str
@ -904,9 +916,9 @@ class DynamoStore(abc.ABC):
"""
Loads a package from a given path and returns it plus a list of deserialized backends
"""
cache_entry, backend_content = self.load_cache_entry(key)
package = CompilePackage(fn, cache_entry)
return package, backend_content
entry = self.load_cache_entry(key)
package = CompilePackage(fn, entry.dynamo)
return package, entry.backends
class InMemoryDynamoStore(DynamoStore):
@ -915,23 +927,22 @@ class InMemoryDynamoStore(DynamoStore):
"""
def __init__(self) -> None:
self.packages: dict[str, tuple[_DynamoCacheEntry, _Backends]] = {}
self.packages: dict[str, PrecompileCacheEntry] = {}
def clear(self) -> None:
self.packages.clear()
def write(
self,
dynamo: _DynamoCacheEntry,
backends: _Backends,
entry: PrecompileCacheEntry,
path: str,
) -> None:
"""
Store the dynamo cache entry and backends in memory instead of writing to disk.
"""
self.packages[path] = (dynamo, backends)
self.packages[path] = entry
def read(self, path: str) -> tuple[_DynamoCacheEntry, _Backends]:
def read(self, path: str) -> PrecompileCacheEntry:
"""
Read dynamo cache entry and backends from memory.
"""
@ -964,34 +975,32 @@ class DiskDynamoStore(DynamoStore):
def write(
self,
dynamo: _DynamoCacheEntry,
backends: _Backends,
entry: PrecompileCacheEntry,
path: str,
) -> None:
"""
Write dynamo cache entry and backends to disk.
"""
from torch._inductor.codecache import write_atomic
path = os.path.join(self.path_prefix, path) if self.path_prefix else path
try:
os.makedirs(path, exist_ok=True)
with open(os.path.join(path, "dynamo"), "wb") as dynamo_path:
pickle.dump(dynamo, dynamo_path)
with open(os.path.join(path, "backends"), "wb") as backend_path:
pickle.dump(backends, backend_path)
pickled_content: bytes = pickle.dumps(entry)
write_atomic(os.path.join(path, "entry"), pickled_content)
except Exception as e:
raise RuntimeError(f"Failed to save package to {path}: {e}") from e
def read(self, path: str) -> tuple[_DynamoCacheEntry, _Backends]:
def read(self, path: str) -> PrecompileCacheEntry:
"""
Read dynamo cache entry and backends from disk.
"""
path = os.path.join(self.path_prefix, path) if self.path_prefix else path
try:
with open(os.path.join(path, "dynamo"), "rb") as dynamo_path:
cache_entry = pickle.load(dynamo_path)
with open(os.path.join(path, "backends"), "rb") as backend_path:
backend_content = pickle.load(backend_path)
return cache_entry, backend_content
with open(os.path.join(path, "entry"), "rb") as f:
pickled_content = f.read()
entry = pickle.loads(pickled_content)
return entry
except Exception as e:
raise RuntimeError(f"Failed to load package from path {path}: {e}") from e
@ -1010,9 +1019,7 @@ class DiskDynamoCache(DiskDynamoStore):
logger.info("Saving CompilePackage for %s", package.source_id)
super().save_package(package, key)
def load(
self, fn: Callable[..., Any]
) -> Optional[tuple[_DynamoCacheEntry, dict[_BackendId, Any]]]:
def load(self, fn: Callable[..., Any]) -> Optional[PrecompileCacheEntry]:
"""
Loads a package from a given path and returns it plus a list of deserialized backends
"""
@ -1039,9 +1046,8 @@ class DiskDynamoCache(DiskDynamoStore):
if results is None:
return None
else:
(entry, backends) = results
package = CompilePackage(fn, entry)
package.install(backends)
package = CompilePackage(fn, results.dynamo)
package.install(results.backends)
return package

View File

@ -1,25 +1,18 @@
import copy
import json
import logging
import pickle
from abc import abstractmethod
from collections import defaultdict
from itertools import chain
from typing import Any, Callable, Generic, Optional, TypeVar, Union
from typing_extensions import override
from dataclasses import dataclass
from typing import Any, Callable, Generic, Optional, TypeVar
import torch
from torch._dynamo.package import _DynamoCacheEntry
from torch.compiler._cache import (
_serialize_single_cache,
CacheArtifact,
CacheArtifactFactory,
CacheArtifactManager,
CacheArtifactsResult,
CacheInfo,
from torch._dynamo.package import (
_BackendId,
_DynamoCacheEntry,
DynamoCache,
PrecompileCacheEntry,
)
from torch.utils._appending_byte_serializer import AppendingByteSerializer
from torch.utils._ordered_set import OrderedSet
"""
@ -30,14 +23,12 @@ T = TypeVar("T")
logger = logging.getLogger(__name__)
class PrecompileCacheArtifact(CacheArtifact, Generic[T]):
@dataclass
class BackendCacheArtifact(Generic[T]):
"""
Data for each cache artifact that will be serialized and deserialized by
PrecompileContext, rather than CacheArtifactManager.
T represents the deserialized type of the artifact, i.e. the return type of after_deserialization
PrecompileCacheArtifact is a frozen dataclass - you can add new serializable fields and metadata specific to your own artifacts
as needed, and use them in after_deserialization.
Represents a single serializable backend artifact from a dynamo backend.
Each BackendCacheArtifact has a key associated with it along with some
serializable content.
Example implementation:
@ -51,13 +42,8 @@ class PrecompileCacheArtifact(CacheArtifact, Generic[T]):
return result
"""
@override
def populate_cache(self) -> None:
raise RuntimeError("Precompile cache artifacts do not populate caches")
@override
def precompile_compatible(self) -> bool:
return True
key: str
content: Any
@abstractmethod
def after_deserialization(self) -> T:
@ -67,63 +53,23 @@ class PrecompileCacheArtifact(CacheArtifact, Generic[T]):
"""
...
@CacheArtifactFactory.register
class EagerCacheArtifact(PrecompileCacheArtifact[Any]):
@staticmethod
def type() -> str:
return "precompile_eager"
def after_deserialization(self) -> Any:
return pickle.loads(self.content)
class EditablePrecompileCacheArtifact(Generic[T]):
"""
A PrecompileCacheArtifact whose content isn't encoded until we call PrecompileContext.serialize()
"""
def __init__(self, artifact_type: str, content: Any, key: str) -> None:
# Deepcopy the content for now, but don't pickle it yet.
# This allows us to make changes to self.content before true serialization
self.content = copy.deepcopy(content)
self.key = key
self.artifact_type = artifact_type
def real_encode(self) -> PrecompileCacheArtifact[T]:
"""
Actually encode the object
"""
content = pickle.dumps(self.content)
artifact = CacheArtifactFactory.encode_create(
self.artifact_type, self.key, content
)
assert isinstance(artifact, PrecompileCacheArtifact)
return artifact
def edit_contents(self, edit_fn: Callable[..., Any]) -> None:
"""
Edit the content of an existing artifact
Edit the contents of the artifact.
"""
self.content = edit_fn(self.content)
@CacheArtifactFactory.register
class _DynamoCacheArtifact(PrecompileCacheArtifact[_DynamoCacheEntry]):
@staticmethod
def type() -> str:
return "precompile_dynamo"
def after_deserialization(self) -> _DynamoCacheEntry:
result = pickle.loads(self.content)
return result
class EagerCacheArtifact(BackendCacheArtifact[Any]):
def after_deserialization(self) -> Any:
return self.content
class BypassDynamoCacheEntry(Exception):
pass
class PrecompileContext(CacheArtifactManager):
class PrecompileContext:
"""
PrecompileContext is a special CacheArtifactManager for handling precompilation
It uses the same interface as CacheArtifactManager, but handles deserialization differently: instead
@ -136,69 +82,32 @@ class PrecompileContext(CacheArtifactManager):
The following artifact types are supported by PrecompileContext:
- BundledAOTAutogradCacheArtifact
- AutotuneCacheArtifact (regular autotune results, same as Megacache)
"""
# Protected by the compile_lock
# _backend_artifacts_by_key organizes results by the key of each artifact.
# This allows us to implement serialize_by_key easily.
# On call to `serialize()`, all cache artifacts in _backend_artifacts_by_key
# are transferred to _new_cache_artifacts before serialization.
_backend_artifacts_by_key: dict[
str, Union[EditablePrecompileCacheArtifact[object], CacheArtifact]
] = {}
# Each object here must be serializable
_backend_artifacts_by_key: dict[str, BackendCacheArtifact[Any]] = {}
# On call to `serialize()`, all cache artifacts in _dynamo_cache_entries are converted
# into DynamoCacheArtifacts and added to _new_cache_artifacts for serialization
_dynamo_cache_entries: dict[str, _DynamoCacheEntry] = {}
_new_cache_artifacts: CacheArtifactsResult = defaultdict(list)
# Keep a separate seen artifacts list to make avoid unnecessary duplicates
# This list will not be cleared between serialize() calls
_seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet()
# When serialize() is called, artifacts are transferred from _cache_artifacts to
# internal data structure of the _serializer
# This allows us to only pay the cost of serialization if serialize() is called
_serializer: AppendingByteSerializer[tuple[str, list[CacheArtifact]]] = (
AppendingByteSerializer(serialize_fn=_serialize_single_cache)
)
_cache_info: CacheInfo = CacheInfo()
@classmethod
def clear(cls) -> None:
cls._backend_artifacts_by_key.clear()
cls._dynamo_cache_entries.clear()
super().clear()
@override
@classmethod
def record_artifact(
cls,
artifact_type: str,
key: str,
content: Any,
editable: bool = False,
artifact: BackendCacheArtifact[Any],
) -> None:
"""
Called from each caching operation to record the artifact in this
"mega" list
Records a backend artifact to be used with dynamo cache entries
"""
artifact: Union[EditablePrecompileCacheArtifact[object], CacheArtifact]
if editable:
artifact = EditablePrecompileCacheArtifact(artifact_type, content, key)
else:
artifact = CacheArtifactFactory.encode_create(artifact_type, key, content)
# TODO: although this covers completely same artifacts, it's possible
# with AOTAutogradCacheEntries to have multiple artifacts whose keys
# (i.e. backend_ids) are different, but whose contents are equal.
# In those cases, it would be much better if we only serialize once instead
# of N times.
if artifact in cls._seen_artifacts:
return
cls._seen_artifacts.add(artifact)
cls._backend_artifacts_by_key[key] = artifact
cls._backend_artifacts_by_key[artifact.key] = copy.deepcopy(artifact)
@classmethod
def record_dynamo_cache_entry(
@ -206,36 +115,6 @@ class PrecompileContext(CacheArtifactManager):
) -> None:
cls._dynamo_cache_entries[key] = cache_entry
@classmethod
def _save_artifacts_by_type(cls) -> None:
"""
We normally record artifacts by key, but serialization expects them to be organized
by artifact type. This function transfers artifacts from _backend_artifacts_by_key to _new_cache_artifacts
"""
for key, cache_entry in cls._dynamo_cache_entries.items():
backends = cache_entry.backend_ids
try:
for id_ in backends:
if id_ not in cls._backend_artifacts_by_key:
logger.warning(
"Bypassing %s because backend %s not found in artifacts"
)
raise BypassDynamoCacheEntry
except BypassDynamoCacheEntry:
continue
pickled_result = pickle.dumps(cache_entry)
dynamo_artifact = _DynamoCacheArtifact(key, pickled_result)
cls._new_cache_artifacts[_DynamoCacheArtifact.type()].append(
dynamo_artifact
)
# Save all the backend artifacts
for artifact in cls._backend_artifacts_by_key.values():
if isinstance(artifact, EditablePrecompileCacheArtifact):
artifact = artifact.real_encode()
cls._new_cache_artifacts[artifact.__class__.type()].append(artifact)
cls._backend_artifacts_by_key.clear()
@classmethod
def edit_artifact(cls, key: str, edit_fn: Callable[..., Any]) -> None:
"""
@ -243,53 +122,19 @@ class PrecompileContext(CacheArtifactManager):
"""
assert key in cls._backend_artifacts_by_key, f"Key {key} not found in artifacts"
artifact = cls._backend_artifacts_by_key[key]
assert isinstance(artifact, EditablePrecompileCacheArtifact), (
"Artifact is not editable"
)
artifact.edit_contents(edit_fn)
@classmethod
def serialize_artifact_by_key(cls, key: str) -> Optional[CacheArtifact]:
def serialize_artifact_by_key(cls, key: str) -> Optional[BackendCacheArtifact[Any]]:
"""
Serialize all backend artifacts with the given key returned in a list.
Return the backend cache artifact with the associated key
"""
result = cls._backend_artifacts_by_key.get(key, None)
if isinstance(result, EditablePrecompileCacheArtifact):
result = result.real_encode()
return result
@classmethod
def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]:
if not cls._dynamo_cache_entries:
return None
debug_info = cls.dump_debug_info(
cls._dynamo_cache_entries, cls._backend_artifacts_by_key
)
artifacts = json.dumps({"artifacts": debug_info})
torch._logging.trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "dynamo_cache_save_contents",
"encoding": "json",
},
payload_fn=lambda: artifacts,
expect_trace_id=False,
)
cls._save_artifacts_by_type()
result = super().serialize()
assert result is not None
data, info = result
return data, info
return cls._backend_artifacts_by_key.get(key, None)
@staticmethod
def dump_debug_info(
dynamo_entries: dict[str, _DynamoCacheEntry],
backend_artifacts: dict[
str, Union[EditablePrecompileCacheArtifact[object], CacheArtifact]
],
backend_artifacts: dict[str, BackendCacheArtifact[Any]],
) -> dict[str, Any]:
"""
Return a JSON serializable debug dump of all entries in the precompile context
@ -300,36 +145,32 @@ class PrecompileContext(CacheArtifactManager):
for key, cache_entry in dynamo_entries.items():
info = cache_entry.debug_info()
info["key"] = key
debug_info["precompile_dynamo"].append(info)
debug_info["dynamo"].append(info)
for artifact in backend_artifacts.values():
if isinstance(artifact, EditablePrecompileCacheArtifact):
debug_info[artifact.artifact_type].append(artifact.key)
else:
debug_info[artifact.__class__.type()].append(artifact.key)
debug_info["backends"].append(artifact.key)
return debug_info
@staticmethod
def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo:
PrecompileContext._ensure_cache_artifacts_registered()
@classmethod
def save_to_dynamo_cache(cls) -> dict[str, Any]:
precompile_cache_entries, debug_info = cls.create_cache_entries()
for key, entry in precompile_cache_entries.items():
DynamoCache.write(entry, key)
return debug_info
backend_artifacts: dict[str, Any] = {}
dynamo_entries: dict[str, _DynamoCacheEntry] = {}
cache_info = CacheInfo()
for artifact in chain(*artifacts.values()):
if artifact.type() == "autotune":
# Populate autotune cache artifacts
artifact.populate_cache()
elif artifact.type() == "precompile_dynamo":
assert isinstance(artifact, _DynamoCacheArtifact)
cache_entry: _DynamoCacheEntry = artifact.after_deserialization()
dynamo_entries[artifact.key] = cache_entry
else:
backend_artifacts[artifact.key] = artifact
cache_info.add(artifact)
@classmethod
def create_cache_entries(
cls,
) -> tuple[dict[str, PrecompileCacheEntry], dict[str, Any]]:
"""
Grabs all the cache entries in the precompile context and
stitches them together into full PrecompileCacheEntries.
"""
dynamo_entries = cls._dynamo_cache_entries
backend_artifacts = cls._backend_artifacts_by_key
num_artifacts = len(artifacts["precompile_dynamo"])
num_artifacts = len(dynamo_entries)
debug_info = PrecompileContext.dump_debug_info(
dynamo_entries, backend_artifacts
@ -349,12 +190,13 @@ class PrecompileContext(CacheArtifactManager):
payload_fn=lambda: debug_str,
expect_trace_id=False,
)
from torch._dynamo.package import _BackendId, DynamoCache
precompile_cache_entries = {}
for key, cache_entry in dynamo_entries.items():
try:
backends = cache_entry.backend_ids
backend_content: dict[_BackendId, PrecompileCacheArtifact[Any]] = {}
backend_content: dict[_BackendId, BackendCacheArtifact[Any]] = {}
for id_ in backends:
if id_ not in backend_artifacts:
debug_str = json.dumps(
@ -375,11 +217,13 @@ class PrecompileContext(CacheArtifactManager):
)
continue
artifact = backend_artifacts[id_]
assert isinstance(artifact, PrecompileCacheArtifact)
assert isinstance(artifact, BackendCacheArtifact)
backend_content[id_] = artifact
DynamoCache.write(cache_entry, backend_content, key)
precompile_cache_entries[key] = PrecompileCacheEntry(
dynamo=cache_entry, backends=backend_content
)
except Exception as e:
logger.warning("Failed to deserialize cache entry %s: %s", key, str(e))
logger.warning("Failed to create cache entry %s: %s", key, str(e))
error = e
data = json.dumps(
@ -397,11 +241,4 @@ class PrecompileContext(CacheArtifactManager):
payload_fn=lambda: data,
)
continue
return cache_info
@classmethod
def _ensure_cache_artifacts_registered(cls) -> None:
from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401
BundledAOTAutogradCacheArtifact,
)
return precompile_cache_entries, debug_info

View File

@ -22,7 +22,7 @@ from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar, Uni
from typing_extensions import override
import torch
from torch._dynamo.precompile_context import PrecompileCacheArtifact, PrecompileContext
from torch._dynamo.precompile_context import BackendCacheArtifact, PrecompileContext
from torch._dynamo.trace_rules import torch_non_c_binding_in_graph_functions
from torch._dynamo.utils import (
chromium_event_log_active,
@ -51,7 +51,7 @@ from torch._inductor.output_code import (
OutputCode,
)
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.utils import should_use_remote_fx_graph_cache
from torch._inductor.utils import BoxedBool, should_use_remote_fx_graph_cache
from torch._logging import LazyString
from torch._utils_internal import log_cache_bypass
from torch.compiler._cache import (
@ -81,7 +81,6 @@ from .utils import simple_wraps
if TYPE_CHECKING:
from torch._inductor.compile_fx import _CompileFxKwargs
from torch._inductor.remote_cache import JsonDataTy, RemoteCache
from torch._inductor.utils import BoxedBool
from torch.fx.node import Node
log = logging.getLogger(__name__)
@ -1059,15 +1058,14 @@ class AOTAutogradCacheArtifact(CacheArtifact):
return "aot_autograd"
def deserialize_bundled_cache_entry(data: bytes) -> Callable:
entry = pickle.loads(data)
def deserialize_bundled_cache_entry(entry: BundledAOTAutogradCacheEntry) -> Callable:
# In the precompile use case, guards are already serialized
# by dynamo, so we don't need to add them to the environment
entry.guards_expr = None
# TODO: this isn't exactly right, because cudagraphs needs to be a shared config
# which is set by compile_fx. But in precompile, we never actually call compile_fx
# so we don't have a place to track cudagraphs here.
cudagraphs = torch._inductor.config.triton.cudagraphs
cudagraphs = BoxedBool(torch._inductor.config.triton.cudagraphs)
boxed_forward_device_index = BoxedDeviceIndex(None)
compiled_fn = entry.wrap_post_compile(
[],
@ -1090,14 +1088,8 @@ def deserialize_bundled_cache_entry(data: bytes) -> Callable:
return forward
@CacheArtifactFactory.register
class BundledAOTAutogradCacheArtifact(PrecompileCacheArtifact[Callable]):
@override
@staticmethod
def type():
return "precompile_aot_autograd"
@override
@dataclass
class BundledAOTAutogradCacheArtifact(BackendCacheArtifact[Callable]):
def after_deserialization(self) -> Callable:
return deserialize_bundled_cache_entry(self.content)
@ -1375,9 +1367,9 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
# 1. because we set it to None on save 2. even if we didn't, this new run
# that cache hit has a *new* backend id associated with it.
PrecompileContext.record_artifact(
BundledAOTAutogradCacheArtifact.type(),
aot_config.precompile_backend_id,
pickled_content,
BundledAOTAutogradCacheArtifact(
aot_config.precompile_backend_id, entry
),
)
except Exception as e:
log.info("AOTAutograd cache unable to load compiled graph: %s", e)
@ -1413,15 +1405,11 @@ class AOTAutogradCache(GuardedCache[GenericAOTAutogradCacheEntry]):
and entry.sanitized_aot_config.precompile_backend_id is not None
):
precompile_key = entry.sanitized_aot_config.precompile_backend_id
artifact = BundledAOTAutogradCacheArtifact(precompile_key, entry)
# Now that we're saving it, the precompile_backend_id field is no longer
# useful, remove it from the entry.
entry.sanitized_aot_config.precompile_backend_id = None
PrecompileContext.record_artifact(
BundledAOTAutogradCacheArtifact.type(),
precompile_key,
entry,
editable=True,
)
PrecompileContext.record_artifact(artifact)
AOTAutogradCache._write_to_local_cache(key, content)
counters["aot_autograd"]["autograd_cache_saved"] += 1
except BypassAOTAutogradCache as e:

View File

@ -35,7 +35,6 @@ from typing import Any, Optional, TYPE_CHECKING
from typing_extensions import override
import torch
from torch._dynamo.precompile_context import PrecompileContext
from torch._inductor.runtime.runtime_utils import cache_dir
from torch.compiler._cache import (
CacheArtifact,
@ -302,10 +301,6 @@ class AutotuneCache:
CacheArtifactManager.record_artifact(
AutotuneCacheArtifact.type(), autotune_artifact_key, data
)
if torch._dynamo.config.caching_precompile:
PrecompileContext.record_artifact(
AutotuneCacheArtifact.type(), autotune_artifact_key, data
)
if log.isEnabledFor(logging.DEBUG):
type_str = "coordesc" if found_by_coordesc else "heuristic"
@ -631,10 +626,6 @@ class LocalAutotuneCache(RemoteCache[JsonDataTy]):
CacheArtifactManager.record_artifact(
AutotuneCacheArtifact.type(), autotune_artifact_key, result
)
if torch._dynamo.config.caching_precompile:
PrecompileContext.record_artifact(
AutotuneCacheArtifact.type(), autotune_artifact_key, result
)
return result
@override

View File

@ -48,9 +48,6 @@ class CacheArtifact(ABC):
def populate_cache(self) -> None:
pass
def precompile_compatible(self) -> bool:
return False
@staticmethod
def type() -> str:
"""
@ -131,14 +128,6 @@ class CacheInfo:
def pgo_artifacts(self) -> list[str]: # type: ignore[empty-body]
...
@property
def precompile_aot_autograd_artifacts(self) -> list[str]: # type: ignore[empty-body]
...
@property
def precompile_dynamo_artifacts(self) -> list[str]: # type: ignore[empty-body]
...
def add(self, artifact: CacheArtifact) -> None:
self.artifacts[artifact.type()].append(artifact.key)