mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MegaCache] Make MegaCache generic to allow external plugins registration (#152977)
Implements #152976 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152977 Approved by: https://github.com/oulgen
This commit is contained in:
committed by
PyTorch MergeBot
parent
c31e239910
commit
bb7e30c165
@ -1,14 +1,13 @@
|
||||
import copy
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
from itertools import chain
|
||||
from typing import Any, Optional
|
||||
|
||||
from torch._inductor.remote_cache import JsonDataTy, RemoteCacheJsonSerde
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch.utils._appending_byte_serializer import (
|
||||
AppendingByteSerializer,
|
||||
BytesReader,
|
||||
@ -20,39 +19,85 @@ from torch.utils._ordered_set import OrderedSet
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CacheArtifactType(Enum):
|
||||
"""
|
||||
Type of cache
|
||||
"""
|
||||
|
||||
INDUCTOR = 0
|
||||
AUTOTUNE = 1
|
||||
AOT_AUTOGRAD = 2
|
||||
PGO = 3
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class CacheArtifact:
|
||||
class CacheArtifact(ABC):
|
||||
"""
|
||||
Data for each cache artifact that will be serialized and deserialized
|
||||
"""
|
||||
|
||||
type: CacheArtifactType
|
||||
key: str
|
||||
content: bytes = dataclasses.field(repr=False) # Do not display potential binary
|
||||
|
||||
@staticmethod
|
||||
def serialize(writer: BytesWriter, cls: "CacheArtifact") -> None:
|
||||
writer.write_uint64(cls.type.value)
|
||||
writer.write_str(cls.key)
|
||||
writer.write_bytes(cls.content)
|
||||
|
||||
@staticmethod
|
||||
def deserialize(reader: BytesReader) -> "CacheArtifact":
|
||||
type = reader.read_uint64()
|
||||
def deserialize(artifact_type: str, reader: BytesReader) -> "CacheArtifact":
|
||||
key = reader.read_str()
|
||||
content = reader.read_bytes()
|
||||
return CacheArtifact(CacheArtifactType(type), key, content)
|
||||
return CacheArtifactFactory.create(artifact_type, key, content)
|
||||
|
||||
@staticmethod
|
||||
def encode(content: Any) -> bytes:
|
||||
assert isinstance(content, bytes), f"Expected bytes, got {type(content)}"
|
||||
return content
|
||||
|
||||
@abstractmethod
|
||||
def populate_cache(self) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def type() -> str:
|
||||
"""
|
||||
Returns the type of the artifact. Must be unique across all CacheArtifact classes.
|
||||
|
||||
CacheArtifactFactory.register will add property method to CacheInfo based on this (def {type}_artifacts)
|
||||
that returns all artifacts for specific cache.
|
||||
"""
|
||||
raise RuntimeError("CacheArtifact is an abstract class, please use a subclass")
|
||||
|
||||
|
||||
class CacheArtifactFactory:
|
||||
"""
|
||||
Factory for creating CacheArtifact objects based on their type
|
||||
"""
|
||||
|
||||
_artifact_types: dict[str, type[CacheArtifact]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, artifact_cls: type[CacheArtifact]) -> type[CacheArtifact]:
|
||||
artifact_type_key = artifact_cls.type()
|
||||
assert (
|
||||
artifact_cls.type() not in cls._artifact_types
|
||||
), f"Artifact of type={artifact_type_key} already registered in mega-cache artifact factory"
|
||||
cls._artifact_types[artifact_type_key] = artifact_cls
|
||||
setattr(
|
||||
CacheInfo,
|
||||
f"{artifact_type_key}_artifacts",
|
||||
property(lambda self: self.artifacts[artifact_type_key]),
|
||||
)
|
||||
return artifact_cls
|
||||
|
||||
@classmethod
|
||||
def _get_artifact_type(cls, artifact_type_key: str) -> type[CacheArtifact]:
|
||||
assert (
|
||||
artifact_type_key in cls._artifact_types
|
||||
), f"Artifact of type={artifact_type_key} not registered in mega-cache artifact factory"
|
||||
return cls._artifact_types[artifact_type_key]
|
||||
|
||||
@classmethod
|
||||
def create(cls, artifact_type_key: str, key: str, content: bytes) -> CacheArtifact:
|
||||
artifact_cls = cls._get_artifact_type(artifact_type_key)
|
||||
return artifact_cls(key, content)
|
||||
|
||||
@classmethod
|
||||
def encode_create(
|
||||
cls, artifact_type_key: str, key: str, content: Any
|
||||
) -> CacheArtifact:
|
||||
artifact_cls = cls._get_artifact_type(artifact_type_key)
|
||||
return artifact_cls(key, artifact_cls.encode(content))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -62,36 +107,56 @@ class CacheInfo:
|
||||
instrumentation
|
||||
"""
|
||||
|
||||
inductor_artifacts: list[str] = dataclasses.field(default_factory=list)
|
||||
autotune_artifacts: list[str] = dataclasses.field(default_factory=list)
|
||||
aot_autograd_artifacts: list[str] = dataclasses.field(default_factory=list)
|
||||
pgo_artifacts: list[str] = dataclasses.field(default_factory=list)
|
||||
artifacts: defaultdict[str, list[str]] = dataclasses.field(
|
||||
default_factory=lambda: defaultdict(list)
|
||||
)
|
||||
|
||||
# Methods set by CacheArtifactFactory.register based on CacheArtifact.type()
|
||||
@property
|
||||
def inductor_artifacts(self) -> list[str]: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
@property
|
||||
def autotune_artifacts(self) -> list[str]: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
@property
|
||||
def aot_autograd_artifacts(self) -> list[str]: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
@property
|
||||
def pgo_artifacts(self) -> list[str]: # type: ignore[empty-body]
|
||||
...
|
||||
|
||||
def add(self, artifact: CacheArtifact) -> None:
|
||||
if artifact.type == CacheArtifactType.INDUCTOR:
|
||||
self.inductor_artifacts.append(artifact.key)
|
||||
elif artifact.type == CacheArtifactType.AUTOTUNE:
|
||||
self.autotune_artifacts.append(artifact.key)
|
||||
elif artifact.type == CacheArtifactType.AOT_AUTOGRAD:
|
||||
self.aot_autograd_artifacts.append(artifact.key)
|
||||
elif artifact.type == CacheArtifactType.PGO:
|
||||
self.pgo_artifacts.append(artifact.key)
|
||||
else:
|
||||
log.warning(f"Unsupported artifact type {artifact.type}") # noqa: G004
|
||||
self.artifacts[artifact.type()].append(artifact.key)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.inductor_artifacts.clear()
|
||||
self.autotune_artifacts.clear()
|
||||
self.aot_autograd_artifacts.clear()
|
||||
self.pgo_artifacts.clear()
|
||||
self.artifacts.clear()
|
||||
|
||||
def empty(self) -> bool:
|
||||
return not (
|
||||
self.inductor_artifacts
|
||||
or self.autotune_artifacts
|
||||
or self.aot_autograd_artifacts
|
||||
or self.pgo_artifacts
|
||||
)
|
||||
return not self.artifacts
|
||||
|
||||
|
||||
def _serialize_single_cache(
|
||||
writer: BytesWriter, cls: "tuple[str, list[CacheArtifact]]"
|
||||
) -> None:
|
||||
writer.write_str(cls[0])
|
||||
writer.write_uint64(len(cls[1]))
|
||||
for artifact in cls[1]:
|
||||
CacheArtifact.serialize(writer, artifact)
|
||||
|
||||
|
||||
def _deserialize_single_cache(
|
||||
reader: BytesReader,
|
||||
) -> "tuple[str, list[CacheArtifact]]":
|
||||
artifacts = []
|
||||
artifact_type_key = reader.read_str()
|
||||
num_artifacts = reader.read_uint64()
|
||||
for _ in range(num_artifacts):
|
||||
artifacts.append(CacheArtifact.deserialize(artifact_type_key, reader))
|
||||
|
||||
return artifact_type_key, artifacts
|
||||
|
||||
|
||||
class CacheArtifactManager:
|
||||
@ -112,16 +177,16 @@ class CacheArtifactManager:
|
||||
"""
|
||||
|
||||
# Protected by the compile_lock
|
||||
_new_cache_artifacts: list[CacheArtifact] = []
|
||||
_new_cache_artifacts: defaultdict[str, list[CacheArtifact]] = defaultdict(list)
|
||||
# Keep a seperate seen artifacts list to make avoid unnecessary duplicates
|
||||
# This list will not be cleared between serialize() calls
|
||||
_seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet()
|
||||
# When serialize() is called, artifacts are transferred from _cache_artifacts to
|
||||
# internal data structure of the _serializer
|
||||
# This allows us to only pay the cost of serialization if serialize() is called
|
||||
_serializer: AppendingByteSerializer[CacheArtifact] = AppendingByteSerializer(
|
||||
serialize_fn=CacheArtifact.serialize
|
||||
)
|
||||
_serializer: AppendingByteSerializer[
|
||||
tuple[str, list[CacheArtifact]]
|
||||
] = AppendingByteSerializer(serialize_fn=_serialize_single_cache)
|
||||
_cache_info: CacheInfo = CacheInfo()
|
||||
|
||||
@classmethod
|
||||
@ -139,9 +204,9 @@ class CacheArtifactManager:
|
||||
original_serializer = cls._serializer
|
||||
original_cache_info = cls._cache_info
|
||||
|
||||
cls._new_cache_artifacts = []
|
||||
cls._new_cache_artifacts = defaultdict(list)
|
||||
cls._seen_artifacts = OrderedSet()
|
||||
cls._serializer = AppendingByteSerializer(serialize_fn=CacheArtifact.serialize)
|
||||
cls._serializer = AppendingByteSerializer(serialize_fn=_serialize_single_cache)
|
||||
cls._cache_info = CacheInfo()
|
||||
try:
|
||||
yield
|
||||
@ -154,24 +219,19 @@ class CacheArtifactManager:
|
||||
@classmethod
|
||||
def record_artifact(
|
||||
cls,
|
||||
artifact_type: CacheArtifactType,
|
||||
artifact_type: str,
|
||||
key: str,
|
||||
content: Union[bytes, JsonDataTy],
|
||||
content: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Called from each caching operation to record the artifact in this
|
||||
"mega" list
|
||||
"""
|
||||
if artifact_type == CacheArtifactType.AUTOTUNE:
|
||||
assert not isinstance(content, bytes)
|
||||
serde = RemoteCacheJsonSerde()
|
||||
content = serde.encode(content)
|
||||
assert isinstance(content, bytes)
|
||||
artifact = CacheArtifact(artifact_type, key, content)
|
||||
artifact = CacheArtifactFactory.encode_create(artifact_type, key, content)
|
||||
if artifact in cls._seen_artifacts:
|
||||
return
|
||||
log.debug("Recording %s", str(artifact))
|
||||
cls._new_cache_artifacts.append(artifact)
|
||||
cls._new_cache_artifacts[artifact_type].append(artifact)
|
||||
cls._seen_artifacts.add(artifact)
|
||||
|
||||
@classmethod
|
||||
@ -186,7 +246,7 @@ class CacheArtifactManager:
|
||||
"""
|
||||
Converts the "mega" list into portable format
|
||||
"""
|
||||
for artifact in cls._new_cache_artifacts:
|
||||
for artifact in chain(*cls._new_cache_artifacts.values()):
|
||||
log.debug("saving: %s", artifact)
|
||||
cls._cache_info.add(artifact)
|
||||
|
||||
@ -199,7 +259,7 @@ class CacheArtifactManager:
|
||||
# We deep copy cls._cache_info since later compilations
|
||||
# can keep adding to cache_info
|
||||
info = copy.deepcopy(cls._cache_info)
|
||||
cls._serializer.extend(cls._new_cache_artifacts)
|
||||
cls._serializer.extend(cls._new_cache_artifacts.items())
|
||||
artifact_bytes = cls._serializer.to_bytes()
|
||||
cls._new_cache_artifacts.clear()
|
||||
return artifact_bytes, info
|
||||
@ -213,37 +273,36 @@ class CacheArtifactManager:
|
||||
Converts the portable format back into various filesystem caches
|
||||
"""
|
||||
try:
|
||||
artifacts = AppendingByteSerializer.to_list(
|
||||
serialized_artifacts, deserialize_fn=CacheArtifact.deserialize
|
||||
CacheArtifactManager._ensure_cache_artifacts_registered()
|
||||
artifacts = dict(
|
||||
AppendingByteSerializer.to_list(
|
||||
serialized_artifacts,
|
||||
deserialize_fn=_deserialize_single_cache,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
log.warning("Failed to un-pickle cache artifacts", exc_info=True)
|
||||
return None
|
||||
|
||||
from torch._dynamo.pgo import rewrite_cache_key_for_mega_cache, write_local_impl
|
||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||
from torch._inductor.codecache import FxGraphCache
|
||||
from torch._inductor.runtime.autotune_cache import _LocalAutotuneCacheBackend
|
||||
|
||||
autotune_cache = _LocalAutotuneCacheBackend()
|
||||
|
||||
info = CacheInfo()
|
||||
for artifact in artifacts:
|
||||
for artifact in chain(*artifacts.values()):
|
||||
log.debug("writing: %s", artifact)
|
||||
info.add(artifact)
|
||||
artifact.populate_cache()
|
||||
|
||||
if artifact.type == CacheArtifactType.INDUCTOR:
|
||||
FxGraphCache._write_to_local_cache(artifact.key, artifact.content)
|
||||
elif artifact.type == CacheArtifactType.AUTOTUNE:
|
||||
key = os.path.join(cache_dir(), artifact.key)
|
||||
autotune_cache._put(key, artifact.content)
|
||||
elif artifact.type == CacheArtifactType.AOT_AUTOGRAD:
|
||||
AOTAutogradCache._write_to_local_cache(artifact.key, artifact.content)
|
||||
elif artifact.type == CacheArtifactType.PGO:
|
||||
meta = write_local_impl(
|
||||
rewrite_cache_key_for_mega_cache(artifact.key), artifact.content
|
||||
)
|
||||
assert meta is not None
|
||||
else:
|
||||
log.warning(f"Unsupported artifact type {artifact.type}") # noqa: G004
|
||||
return info
|
||||
|
||||
@staticmethod
|
||||
def _ensure_cache_artifacts_registered() -> None:
|
||||
"""When deserializing caches in fresh process, we need to ensure that all
|
||||
cache artifacts are registered in the cache registry. This is done by
|
||||
simply importing all the cache artifacts already wrapped with register call.
|
||||
"""
|
||||
from torch._dynamo.pgo import PGOCacheArtifact # noqa: F401
|
||||
from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401
|
||||
AOTAutogradCacheArtifact,
|
||||
)
|
||||
from torch._inductor.codecache import InductorCacheArtifact # noqa: F401
|
||||
from torch._inductor.runtime.autotune_cache import ( # noqa: F401
|
||||
AutotuneCacheArtifact,
|
||||
)
|
||||
|
Reference in New Issue
Block a user