mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
These rules are enabled by removing existing suppressions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164655 Approved by: https://github.com/janeyx99, https://github.com/mlazos
225 lines
7.2 KiB
Python
225 lines
7.2 KiB
Python
import copy
|
|
import json
|
|
import logging
|
|
from abc import abstractmethod
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Generic, Optional, TypeVar
|
|
|
|
import torch
|
|
from torch._dynamo.package import (
|
|
_BackendId,
|
|
_DynamoCacheEntry,
|
|
DynamoCache,
|
|
PrecompileCacheEntry,
|
|
)
|
|
|
|
|
|
"""
|
|
Classes and implementations related to precompile
|
|
"""
|
|
|
|
T = TypeVar("T")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class BackendCacheArtifact(Generic[T]):
|
|
"""
|
|
Represents a single serializable backend artifact from a dynamo backend.
|
|
Each BackendCacheArtifact has a key associated with it along with some
|
|
serializable content.
|
|
|
|
Example implementation:
|
|
|
|
class MyPrecompileCacheArtifact(PrecompileCacheArtifact[MySerializableType]):
|
|
my_field: int
|
|
|
|
def after_deserialization(self) -> MySerializableType:
|
|
result = pickle.loads(self.content)
|
|
# Do some extra work post deserialization
|
|
result.my_post_deserialization_function(self.my_field)
|
|
return result
|
|
"""
|
|
|
|
key: str
|
|
content: Any
|
|
|
|
@abstractmethod
|
|
def after_deserialization(self) -> T:
|
|
"""
|
|
Code to be run after reading raw byte contents from disk.
|
|
Generally converts self.content from raw bytes back into its original form.
|
|
"""
|
|
...
|
|
|
|
def edit_contents(self, edit_fn: Callable[..., Any]) -> None:
|
|
"""
|
|
Edit the contents of the artifact.
|
|
"""
|
|
self.content = edit_fn(self.content)
|
|
|
|
|
|
class EagerCacheArtifact(BackendCacheArtifact[Any]):
|
|
def after_deserialization(self) -> Any:
|
|
return self.content
|
|
|
|
|
|
class BypassDynamoCacheEntry(Exception):
|
|
pass
|
|
|
|
|
|
class PrecompileContext:
|
|
"""
|
|
PrecompileContext is a special CacheArtifactManager for handling precompilation
|
|
It uses the same interface as CacheArtifactManager, but handles deserialization differently: instead
|
|
of placing each artifact into respective caches, it will stitch all the cache artifacts for a single key
|
|
together and place it into a global Precompile Cache.
|
|
|
|
PrecompileContext has two main portions: dynamo_cache_entries and backend_cache_artifacts.
|
|
When saving, PrecompileContext.serialize() will serialize all dynamo cache entries along with any PrecompileCacheArtifacts that
|
|
are needed to save those dynamo cache entries.
|
|
|
|
The following artifact types are supported by PrecompileContext:
|
|
- BundledAOTAutogradCacheArtifact
|
|
|
|
"""
|
|
|
|
# Protected by the compile_lock
|
|
# _backend_artifacts_by_key organizes results by the key of each artifact.
|
|
# Each object here must be serializable
|
|
_backend_artifacts_by_key: dict[_BackendId, BackendCacheArtifact[Any]] = {}
|
|
|
|
# On call to `serialize()`, all cache artifacts in _dynamo_cache_entries are converted
|
|
# into DynamoCacheArtifacts and added to _new_cache_artifacts for serialization
|
|
_dynamo_cache_entries: dict[str, _DynamoCacheEntry] = {}
|
|
|
|
@classmethod
|
|
def clear(cls) -> None:
|
|
cls._backend_artifacts_by_key.clear()
|
|
cls._dynamo_cache_entries.clear()
|
|
|
|
@classmethod
|
|
def record_artifact(
|
|
cls,
|
|
artifact: BackendCacheArtifact[Any],
|
|
) -> None:
|
|
"""
|
|
Records a backend artifact to be used with dynamo cache entries
|
|
"""
|
|
cls._backend_artifacts_by_key[_BackendId(artifact.key)] = copy.deepcopy(
|
|
artifact
|
|
)
|
|
|
|
@classmethod
|
|
def record_dynamo_cache_entry(
|
|
cls, cache_entry: _DynamoCacheEntry, key: str
|
|
) -> None:
|
|
cls._dynamo_cache_entries[key] = cache_entry
|
|
|
|
@classmethod
|
|
def edit_artifact(cls, key: str, edit_fn: Callable[..., Any]) -> None:
|
|
"""
|
|
Edit the content of an existing artifact
|
|
"""
|
|
assert key in cls._backend_artifacts_by_key, f"Key {key} not found in artifacts"
|
|
artifact = cls._backend_artifacts_by_key[_BackendId(key)]
|
|
artifact.edit_contents(edit_fn)
|
|
|
|
@classmethod
|
|
def serialize_artifact_by_key(cls, key: str) -> Optional[BackendCacheArtifact[Any]]:
|
|
"""
|
|
Return the backend cache artifact with the associated key
|
|
"""
|
|
return cls._backend_artifacts_by_key.get(_BackendId(key), None)
|
|
|
|
@staticmethod
|
|
def dump_debug_info(
|
|
dynamo_entries: dict[str, _DynamoCacheEntry],
|
|
backend_artifacts: dict[_BackendId, BackendCacheArtifact[Any]],
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Return a JSON serializable debug dump of all entries in the precompile context
|
|
Called in serialize before serialization, and in populate_caches after deserialization
|
|
"""
|
|
# Print debug information
|
|
debug_info: defaultdict[str, list[Any]] = defaultdict(list)
|
|
for key, cache_entry in dynamo_entries.items():
|
|
info = cache_entry.debug_info()
|
|
info["key"] = key
|
|
debug_info["dynamo"].append(info)
|
|
|
|
for artifact in backend_artifacts.values():
|
|
debug_info["backends"].append(artifact.key)
|
|
|
|
return debug_info
|
|
|
|
@classmethod
|
|
def save_to_dynamo_cache(cls) -> dict[str, Any]:
|
|
precompile_cache_entries, debug_info = cls.create_cache_entries()
|
|
for key, entry in precompile_cache_entries.items():
|
|
DynamoCache.write(entry, key)
|
|
return debug_info
|
|
|
|
@classmethod
|
|
def create_cache_entries(
|
|
cls,
|
|
) -> tuple[dict[str, PrecompileCacheEntry], dict[str, Any]]:
|
|
"""
|
|
Grabs all the cache entries in the precompile context and
|
|
stitches them together into full PrecompileCacheEntries.
|
|
"""
|
|
dynamo_entries = cls._dynamo_cache_entries
|
|
backend_artifacts = cls._backend_artifacts_by_key
|
|
|
|
num_artifacts = len(dynamo_entries)
|
|
|
|
debug_info = PrecompileContext.dump_debug_info(
|
|
dynamo_entries, backend_artifacts
|
|
)
|
|
debug_str = json.dumps(
|
|
{
|
|
"num_entries": num_artifacts,
|
|
"artifacts": debug_info,
|
|
},
|
|
)
|
|
torch._logging.trace_structured(
|
|
"artifact",
|
|
metadata_fn=lambda: {
|
|
"name": "dynamo_cache_entries",
|
|
"encoding": "json",
|
|
},
|
|
payload_fn=lambda: debug_str,
|
|
expect_trace_id=False,
|
|
)
|
|
|
|
precompile_cache_entries = {}
|
|
|
|
for key, cache_entry in dynamo_entries.items():
|
|
try:
|
|
result = PrecompileCacheEntry.from_cache_entry(
|
|
cache_entry, backend_artifacts
|
|
)
|
|
if result is not None:
|
|
precompile_cache_entries[key] = result
|
|
except Exception as e:
|
|
logger.warning("Failed to create cache entry %s", key, exc_info=True)
|
|
|
|
error = e
|
|
data = json.dumps(
|
|
{
|
|
"key": key,
|
|
"error": str(error),
|
|
}
|
|
)
|
|
torch._logging.trace_structured(
|
|
"artifact",
|
|
metadata_fn=lambda: {
|
|
"name": "dynamo_cache_exception",
|
|
"encoding": "json",
|
|
},
|
|
payload_fn=lambda: data,
|
|
)
|
|
continue
|
|
return precompile_cache_entries, debug_info
|