Files
pytorch/torch/_dynamo/precompile_context.py
James Wu be56a8d7ac Automatically load and save dynamo entries via caching_precompile (#155913)
This PR adds a new config option, `caching_precompile`, and a `DynamoCache`, which loads and saves Dynamo Cache entries automatically. It also hooks up DynamoCache to PrecompileContext, so that we can save multiple cache entries.

When this configuration is turned on, we:
- Automatically create and initialize a CompilePackage on every torch.compile
- Automatically use BundledAutogradcache
- Automatically save the CompilePackage entry to DynamoCache after every compile

You can also use PrecompileContext.serialize() to manually serialize a full object.

I've added unit tests to exhibit this behavior.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155913
Approved by: https://github.com/zhxchen17
2025-07-07 23:57:17 +00:00

178 lines
6.8 KiB
Python

from abc import abstractmethod
from collections import defaultdict
from itertools import chain
from typing import Any, Generic, Optional, TypeVar
from typing_extensions import override
from torch.compiler._cache import (
_serialize_single_cache,
CacheArtifact,
CacheArtifactFactory,
CacheArtifactManager,
CacheArtifactsResult,
CacheInfo,
)
from torch.utils._appending_byte_serializer import AppendingByteSerializer
from torch.utils._ordered_set import OrderedSet
"""
Classes and implementations related to precompile
"""
T = TypeVar("T")
class PrecompileCacheArtifact(CacheArtifact, Generic[T]):
"""
Data for each cache artifact that will be serialized and deserialized by
PrecompileContext, rather than CacheArtifactManager.
T represents the deserialized type of the artifact, i.e. the return type of after_deserialization
PrecompileCacheArtifact is a frozen dataclass - you can add new serializable fields and metadata specific to your own artifacts
as needed, and use them in after_deserialization.
Example implementation:
class MyPrecompileCacheArtifact(PrecompileCacheArtifact[MySerializableType]):
my_field: int
def after_deserialization(self) -> MySerializableType:
result = pickle.loads(self.content)
# Do some extra work post deserialization
result.my_post_deserialization_function(self.my_field)
return result
"""
@override
def populate_cache(self) -> None:
raise RuntimeError("Precompile cache artifacts do not populate caches")
@override
def precompile_compatible(self) -> bool:
return True
@abstractmethod
def after_deserialization(self) -> T:
"""
Code to be run after reading raw byte contents from disk.
Generally converts self.content from raw bytes back into its original form.
"""
...
class PrecompileContext(CacheArtifactManager):
"""
PrecompileContext is a special CacheArtifactManager for handling precompilation
It uses the same interface as CacheArtifactManager, but handles deserialization differently: instead
of placing each artifact into respective caches, it will stitch all the cache artifacts for a single key
together and place it into a global Precompile Cache.
The following artifact types are supported by PrecompileContext:
- BundledAOTAutogradCacheArtifact
- CodeStateArtifact (from torch._dynamo.package once available)
"""
# Protected by the compile_lock
# _new_cache_artifacts_by_key organizes results by the key of each artifact.
# This allows us to implement serialize_by_key easily.
# On call to `serialize()`, all cache artifacts in _new_cache_artifacts_by_key
# are transferred to _new_cache_artifacts before serialization.
_new_cache_artifacts_by_key: dict[str, CacheArtifact] = {}
_new_cache_artifacts: CacheArtifactsResult = defaultdict(list)
# Keep a separate seen artifacts list to make avoid unnecessary duplicates
# This list will not be cleared between serialize() calls
_seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet()
# When serialize() is called, artifacts are transferred from _cache_artifacts to
# internal data structure of the _serializer
# This allows us to only pay the cost of serialization if serialize() is called
_serializer: AppendingByteSerializer[tuple[str, list[CacheArtifact]]] = (
AppendingByteSerializer(serialize_fn=_serialize_single_cache)
)
_cache_info: CacheInfo = CacheInfo()
@classmethod
def clear(cls) -> None:
cls._new_cache_artifacts_by_key.clear()
super().clear()
@override
@classmethod
def record_artifact(
cls,
artifact_type: str,
key: str,
content: Any,
) -> None:
"""
Called from each caching operation to record the artifact in this
"mega" list
"""
artifact = CacheArtifactFactory.encode_create(artifact_type, key, content)
# TODO: although this covers completely same artifacts, it's possible
# with AOTAutogradCacheEntries to have multiple artifacts whose keys
# (i.e. backend_ids) are different, but whose contents are equal.
# In those cases, it would be much better if we only serialize once instead
# of N times.
if artifact in cls._seen_artifacts:
return
cls._new_cache_artifacts_by_key[key] = artifact
cls._seen_artifacts.add(artifact)
@classmethod
def _save_artifacts_by_type(cls) -> None:
"""
We normally record artifacts by key, but serialization expects them to be organized
by artifact type. This function transfers artifacts from _new_cache_artifacts_by_key to _new_cache_artifacts
"""
for artifact in cls._new_cache_artifacts_by_key.values():
cls._new_cache_artifacts[artifact.__class__.type()].append(artifact)
cls._new_cache_artifacts_by_key.clear()
@classmethod
def serialize_artifact_by_key(cls, key: str) -> Optional[CacheArtifact]:
"""
Serialize all artifacts with the given key returned in a list.
"""
return cls._new_cache_artifacts_by_key.get(key, None)
@classmethod
def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]:
cls._save_artifacts_by_type()
return super().serialize()
@staticmethod
def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo:
PrecompileContext._ensure_cache_artifacts_registered()
artifacts_by_key = {}
cache_info = CacheInfo()
for artifact in chain(*artifacts.values()):
cache_info.add(artifact)
artifacts_by_key[artifact.key] = artifact
from torch._dynamo.package import _BackendId, DynamoCache
for dynamo_entry in artifacts["precompile_dynamo"]:
assert isinstance(dynamo_entry, PrecompileCacheArtifact)
cache_entry = dynamo_entry.after_deserialization()
# Grab backends from the dynamo cache entry
backends = cache_entry.backend_ids
backend_content: dict[_BackendId, PrecompileCacheArtifact[Any]] = {}
for id_ in backends:
assert id_ in artifacts_by_key, f"Backend {id_} not found in artifacts"
artifact = artifacts_by_key[id_]
assert isinstance(artifact, PrecompileCacheArtifact)
backend_content[id_] = artifact
DynamoCache.write(cache_entry, backend_content, dynamo_entry.key)
return cache_info
@classmethod
def _ensure_cache_artifacts_registered(cls) -> None:
from torch._dynamo.package import _DynamoCacheArtifact # noqa: F401
from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401
BundledAOTAutogradCacheArtifact,
)