mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This PR adds a new config option, `caching_precompile`, and a `DynamoCache`, which loads and saves Dynamo Cache entries automatically. It also hooks up DynamoCache to PrecompileContext, so that we can save multiple cache entries. When this configuration is turned on, we: - Automatically create and initialize a CompilePackage on every torch.compile - Automatically use BundledAutogradcache - Automatically save the CompilePackage entry to DynamoCache after every compile You can also use PrecompileContext.serialize() to manually serialize a full object. I've added unit tests to exhibit this behavior. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155913 Approved by: https://github.com/zhxchen17
178 lines
6.8 KiB
Python
178 lines
6.8 KiB
Python
from abc import abstractmethod
|
|
from collections import defaultdict
|
|
from itertools import chain
|
|
from typing import Any, Generic, Optional, TypeVar
|
|
from typing_extensions import override
|
|
|
|
from torch.compiler._cache import (
|
|
_serialize_single_cache,
|
|
CacheArtifact,
|
|
CacheArtifactFactory,
|
|
CacheArtifactManager,
|
|
CacheArtifactsResult,
|
|
CacheInfo,
|
|
)
|
|
from torch.utils._appending_byte_serializer import AppendingByteSerializer
|
|
from torch.utils._ordered_set import OrderedSet
|
|
|
|
|
|
"""
|
|
Classes and implementations related to precompile
|
|
"""
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class PrecompileCacheArtifact(CacheArtifact, Generic[T]):
|
|
"""
|
|
Data for each cache artifact that will be serialized and deserialized by
|
|
PrecompileContext, rather than CacheArtifactManager.
|
|
T represents the deserialized type of the artifact, i.e. the return type of after_deserialization
|
|
|
|
PrecompileCacheArtifact is a frozen dataclass - you can add new serializable fields and metadata specific to your own artifacts
|
|
as needed, and use them in after_deserialization.
|
|
|
|
Example implementation:
|
|
|
|
class MyPrecompileCacheArtifact(PrecompileCacheArtifact[MySerializableType]):
|
|
my_field: int
|
|
|
|
def after_deserialization(self) -> MySerializableType:
|
|
result = pickle.loads(self.content)
|
|
# Do some extra work post deserialization
|
|
result.my_post_deserialization_function(self.my_field)
|
|
return result
|
|
"""
|
|
|
|
@override
|
|
def populate_cache(self) -> None:
|
|
raise RuntimeError("Precompile cache artifacts do not populate caches")
|
|
|
|
@override
|
|
def precompile_compatible(self) -> bool:
|
|
return True
|
|
|
|
@abstractmethod
|
|
def after_deserialization(self) -> T:
|
|
"""
|
|
Code to be run after reading raw byte contents from disk.
|
|
Generally converts self.content from raw bytes back into its original form.
|
|
"""
|
|
...
|
|
|
|
|
|
class PrecompileContext(CacheArtifactManager):
|
|
"""
|
|
PrecompileContext is a special CacheArtifactManager for handling precompilation
|
|
It uses the same interface as CacheArtifactManager, but handles deserialization differently: instead
|
|
of placing each artifact into respective caches, it will stitch all the cache artifacts for a single key
|
|
together and place it into a global Precompile Cache.
|
|
|
|
The following artifact types are supported by PrecompileContext:
|
|
- BundledAOTAutogradCacheArtifact
|
|
- CodeStateArtifact (from torch._dynamo.package once available)
|
|
"""
|
|
|
|
# Protected by the compile_lock
|
|
# _new_cache_artifacts_by_key organizes results by the key of each artifact.
|
|
# This allows us to implement serialize_by_key easily.
|
|
# On call to `serialize()`, all cache artifacts in _new_cache_artifacts_by_key
|
|
# are transferred to _new_cache_artifacts before serialization.
|
|
_new_cache_artifacts_by_key: dict[str, CacheArtifact] = {}
|
|
_new_cache_artifacts: CacheArtifactsResult = defaultdict(list)
|
|
# Keep a separate seen artifacts list to make avoid unnecessary duplicates
|
|
# This list will not be cleared between serialize() calls
|
|
_seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet()
|
|
# When serialize() is called, artifacts are transferred from _cache_artifacts to
|
|
# internal data structure of the _serializer
|
|
# This allows us to only pay the cost of serialization if serialize() is called
|
|
_serializer: AppendingByteSerializer[tuple[str, list[CacheArtifact]]] = (
|
|
AppendingByteSerializer(serialize_fn=_serialize_single_cache)
|
|
)
|
|
_cache_info: CacheInfo = CacheInfo()
|
|
|
|
@classmethod
|
|
def clear(cls) -> None:
|
|
cls._new_cache_artifacts_by_key.clear()
|
|
super().clear()
|
|
|
|
@override
|
|
@classmethod
|
|
def record_artifact(
|
|
cls,
|
|
artifact_type: str,
|
|
key: str,
|
|
content: Any,
|
|
) -> None:
|
|
"""
|
|
Called from each caching operation to record the artifact in this
|
|
"mega" list
|
|
"""
|
|
artifact = CacheArtifactFactory.encode_create(artifact_type, key, content)
|
|
# TODO: although this covers completely same artifacts, it's possible
|
|
# with AOTAutogradCacheEntries to have multiple artifacts whose keys
|
|
# (i.e. backend_ids) are different, but whose contents are equal.
|
|
# In those cases, it would be much better if we only serialize once instead
|
|
# of N times.
|
|
if artifact in cls._seen_artifacts:
|
|
return
|
|
|
|
cls._new_cache_artifacts_by_key[key] = artifact
|
|
cls._seen_artifacts.add(artifact)
|
|
|
|
@classmethod
|
|
def _save_artifacts_by_type(cls) -> None:
|
|
"""
|
|
We normally record artifacts by key, but serialization expects them to be organized
|
|
by artifact type. This function transfers artifacts from _new_cache_artifacts_by_key to _new_cache_artifacts
|
|
"""
|
|
for artifact in cls._new_cache_artifacts_by_key.values():
|
|
cls._new_cache_artifacts[artifact.__class__.type()].append(artifact)
|
|
cls._new_cache_artifacts_by_key.clear()
|
|
|
|
@classmethod
|
|
def serialize_artifact_by_key(cls, key: str) -> Optional[CacheArtifact]:
|
|
"""
|
|
Serialize all artifacts with the given key returned in a list.
|
|
"""
|
|
return cls._new_cache_artifacts_by_key.get(key, None)
|
|
|
|
@classmethod
|
|
def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]:
|
|
cls._save_artifacts_by_type()
|
|
return super().serialize()
|
|
|
|
@staticmethod
|
|
def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo:
|
|
PrecompileContext._ensure_cache_artifacts_registered()
|
|
|
|
artifacts_by_key = {}
|
|
cache_info = CacheInfo()
|
|
for artifact in chain(*artifacts.values()):
|
|
cache_info.add(artifact)
|
|
artifacts_by_key[artifact.key] = artifact
|
|
|
|
from torch._dynamo.package import _BackendId, DynamoCache
|
|
|
|
for dynamo_entry in artifacts["precompile_dynamo"]:
|
|
assert isinstance(dynamo_entry, PrecompileCacheArtifact)
|
|
cache_entry = dynamo_entry.after_deserialization()
|
|
# Grab backends from the dynamo cache entry
|
|
backends = cache_entry.backend_ids
|
|
backend_content: dict[_BackendId, PrecompileCacheArtifact[Any]] = {}
|
|
for id_ in backends:
|
|
assert id_ in artifacts_by_key, f"Backend {id_} not found in artifacts"
|
|
artifact = artifacts_by_key[id_]
|
|
assert isinstance(artifact, PrecompileCacheArtifact)
|
|
backend_content[id_] = artifact
|
|
DynamoCache.write(cache_entry, backend_content, dynamo_entry.key)
|
|
|
|
return cache_info
|
|
|
|
@classmethod
|
|
def _ensure_cache_artifacts_registered(cls) -> None:
|
|
from torch._dynamo.package import _DynamoCacheArtifact # noqa: F401
|
|
from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401
|
|
BundledAOTAutogradCacheArtifact,
|
|
)
|