mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Following #162438, this PR generalized the origin CUDA only check, and add XPU check. Fixes #162939, Fixes #162938, Fixes #163032,Fixes #163045 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162951 Approved by: https://github.com/EikanWang, https://github.com/jansel
328 lines
12 KiB
Python
328 lines
12 KiB
Python
import copy
|
|
import dataclasses
|
|
import logging
|
|
import pickle
|
|
import platform
|
|
from abc import abstractmethod
|
|
from collections import defaultdict
|
|
from itertools import chain
|
|
from typing import Any, Callable, Generic, Optional, TypeVar, Union
|
|
from typing_extensions import override
|
|
|
|
import torch
|
|
from torch.compiler._cache import (
|
|
_serialize_single_cache,
|
|
CacheArtifact,
|
|
CacheArtifactFactory,
|
|
CacheArtifactManager,
|
|
CacheArtifactsResult,
|
|
CacheInfo,
|
|
)
|
|
from torch.utils._appending_byte_serializer import AppendingByteSerializer
|
|
from torch.utils._ordered_set import OrderedSet
|
|
from torch.utils._triton import get_triton_version
|
|
|
|
|
|
"""
|
|
Classes and implementations related to precompile
|
|
"""
|
|
|
|
T = TypeVar("T")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PrecompileCacheArtifact(CacheArtifact, Generic[T]):
|
|
"""
|
|
Data for each cache artifact that will be serialized and deserialized by
|
|
PrecompileContext, rather than CacheArtifactManager.
|
|
T represents the deserialized type of the artifact, i.e. the return type of after_deserialization
|
|
|
|
PrecompileCacheArtifact is a frozen dataclass - you can add new serializable fields and metadata specific to your own artifacts
|
|
as needed, and use them in after_deserialization.
|
|
|
|
Example implementation:
|
|
|
|
class MyPrecompileCacheArtifact(PrecompileCacheArtifact[MySerializableType]):
|
|
my_field: int
|
|
|
|
def after_deserialization(self) -> MySerializableType:
|
|
result = pickle.loads(self.content)
|
|
# Do some extra work post deserialization
|
|
result.my_post_deserialization_function(self.my_field)
|
|
return result
|
|
"""
|
|
|
|
@override
|
|
def populate_cache(self) -> None:
|
|
raise RuntimeError("Precompile cache artifacts do not populate caches")
|
|
|
|
@override
|
|
def precompile_compatible(self) -> bool:
|
|
return True
|
|
|
|
@abstractmethod
|
|
def after_deserialization(self) -> T:
|
|
"""
|
|
Code to be run after reading raw byte contents from disk.
|
|
Generally converts self.content from raw bytes back into its original form.
|
|
"""
|
|
...
|
|
|
|
|
|
class EditablePrecompileCacheArtifact(Generic[T]):
|
|
"""
|
|
A PrecompileCacheArtifact whose content isn't encoded until we call PrecompileContext.serialize()
|
|
"""
|
|
|
|
def __init__(self, artifact_type: str, content: Any, key: str) -> None:
|
|
# Deepcopy the content for now, but don't pickle it yet.
|
|
# This allows us to make changes to self.content before true serialization
|
|
self.content = copy.deepcopy(content)
|
|
self.key = key
|
|
self.artifact_type = artifact_type
|
|
|
|
def real_encode(self) -> PrecompileCacheArtifact[T]:
|
|
"""
|
|
Actually encode the object
|
|
"""
|
|
content = pickle.dumps(self.content)
|
|
artifact = CacheArtifactFactory.encode_create(
|
|
self.artifact_type, self.key, content
|
|
)
|
|
assert isinstance(artifact, PrecompileCacheArtifact)
|
|
return artifact
|
|
|
|
def edit_contents(self, edit_fn: Callable[..., Any]) -> None:
|
|
"""
|
|
Edit the content of an existing artifact
|
|
"""
|
|
self.content = edit_fn(self.content)
|
|
|
|
|
|
class PrecompileContext(CacheArtifactManager):
|
|
"""
|
|
PrecompileContext is a special CacheArtifactManager for handling precompilation
|
|
It uses the same interface as CacheArtifactManager, but handles deserialization differently: instead
|
|
of placing each artifact into respective caches, it will stitch all the cache artifacts for a single key
|
|
together and place it into a global Precompile Cache.
|
|
|
|
The following artifact types are supported by PrecompileContext:
|
|
- BundledAOTAutogradCacheArtifact
|
|
- DynamoCodeStateArtifact
|
|
- AutotuneCacheArtifact (regular autotune results, same as Megacache)
|
|
"""
|
|
|
|
# Protected by the compile_lock
|
|
# _new_cache_artifacts_by_key organizes results by the key of each artifact.
|
|
# This allows us to implement serialize_by_key easily.
|
|
# On call to `serialize()`, all cache artifacts in _new_cache_artifacts_by_key
|
|
# are transferred to _new_cache_artifacts before serialization.
|
|
_new_cache_artifacts_by_key: dict[
|
|
str, Union[EditablePrecompileCacheArtifact[object], CacheArtifact]
|
|
] = {}
|
|
_new_cache_artifacts: CacheArtifactsResult = defaultdict(list)
|
|
# Keep a separate seen artifacts list to make avoid unnecessary duplicates
|
|
# This list will not be cleared between serialize() calls
|
|
_seen_artifacts: OrderedSet[CacheArtifact] = OrderedSet()
|
|
# When serialize() is called, artifacts are transferred from _cache_artifacts to
|
|
# internal data structure of the _serializer
|
|
# This allows us to only pay the cost of serialization if serialize() is called
|
|
_serializer: AppendingByteSerializer[tuple[str, list[CacheArtifact]]] = (
|
|
AppendingByteSerializer(serialize_fn=_serialize_single_cache)
|
|
)
|
|
_cache_info: CacheInfo = CacheInfo()
|
|
|
|
@classmethod
|
|
def clear(cls) -> None:
|
|
cls._new_cache_artifacts_by_key.clear()
|
|
super().clear()
|
|
|
|
@override
|
|
@classmethod
|
|
def record_artifact(
|
|
cls,
|
|
artifact_type: str,
|
|
key: str,
|
|
content: Any,
|
|
editable: bool = False,
|
|
) -> None:
|
|
"""
|
|
Called from each caching operation to record the artifact in this
|
|
"mega" list
|
|
"""
|
|
artifact: Union[EditablePrecompileCacheArtifact[object], CacheArtifact]
|
|
if editable:
|
|
artifact = EditablePrecompileCacheArtifact(artifact_type, content, key)
|
|
else:
|
|
artifact = CacheArtifactFactory.encode_create(artifact_type, key, content)
|
|
# TODO: although this covers completely same artifacts, it's possible
|
|
# with AOTAutogradCacheEntries to have multiple artifacts whose keys
|
|
# (i.e. backend_ids) are different, but whose contents are equal.
|
|
# In those cases, it would be much better if we only serialize once instead
|
|
# of N times.
|
|
if artifact in cls._seen_artifacts:
|
|
return
|
|
cls._seen_artifacts.add(artifact)
|
|
|
|
cls._new_cache_artifacts_by_key[key] = artifact
|
|
|
|
@classmethod
|
|
def _save_artifacts_by_type(cls) -> None:
|
|
"""
|
|
We normally record artifacts by key, but serialization expects them to be organized
|
|
by artifact type. This function transfers artifacts from _new_cache_artifacts_by_key to _new_cache_artifacts
|
|
"""
|
|
for artifact in cls._new_cache_artifacts_by_key.values():
|
|
if isinstance(artifact, EditablePrecompileCacheArtifact):
|
|
artifact = artifact.real_encode()
|
|
cls._new_cache_artifacts[artifact.__class__.type()].append(artifact)
|
|
cls._new_cache_artifacts_by_key.clear()
|
|
|
|
@classmethod
|
|
def edit_artifact(cls, key: str, edit_fn: Callable[..., Any]) -> None:
|
|
"""
|
|
Edit the content of an existing artifact
|
|
"""
|
|
assert key in cls._new_cache_artifacts_by_key, (
|
|
f"Key {key} not found in artifacts"
|
|
)
|
|
artifact = cls._new_cache_artifacts_by_key[key]
|
|
assert isinstance(artifact, EditablePrecompileCacheArtifact), (
|
|
"Artifact is not editable"
|
|
)
|
|
artifact.edit_contents(edit_fn)
|
|
|
|
@classmethod
|
|
def serialize_artifact_by_key(cls, key: str) -> Optional[CacheArtifact]:
|
|
"""
|
|
Serialize all artifacts with the given key returned in a list.
|
|
"""
|
|
result = cls._new_cache_artifacts_by_key.get(key, None)
|
|
if isinstance(result, EditablePrecompileCacheArtifact):
|
|
result = result.real_encode()
|
|
return result
|
|
|
|
@classmethod
|
|
def serialize(cls) -> Optional[tuple[bytes, CacheInfo]]:
|
|
cls._save_artifacts_by_type()
|
|
# No need to serialize if there are no new dynamo compiles
|
|
if "precompile_dynamo" not in cls._new_cache_artifacts:
|
|
return None
|
|
return super().serialize()
|
|
|
|
@staticmethod
|
|
def populate_caches(artifacts: CacheArtifactsResult) -> CacheInfo:
|
|
PrecompileContext._ensure_cache_artifacts_registered()
|
|
|
|
artifacts_by_key = {}
|
|
cache_info = CacheInfo()
|
|
for artifact in chain(*artifacts.values()):
|
|
if artifact.type() == "autotune":
|
|
# Populate autotune cache artifacts
|
|
artifact.populate_cache()
|
|
else:
|
|
artifacts_by_key[artifact.key] = artifact
|
|
cache_info.add(artifact)
|
|
|
|
from torch._dynamo.package import _BackendId, DynamoCache
|
|
|
|
for dynamo_entry in artifacts["precompile_dynamo"]:
|
|
assert isinstance(dynamo_entry, PrecompileCacheArtifact)
|
|
cache_entry = dynamo_entry.after_deserialization()
|
|
# Grab backends from the dynamo cache entry
|
|
backends = cache_entry.backend_ids
|
|
backend_content: dict[_BackendId, PrecompileCacheArtifact[Any]] = {}
|
|
for id_ in backends:
|
|
assert id_ in artifacts_by_key, f"Backend {id_} not found in artifacts"
|
|
artifact = artifacts_by_key[id_]
|
|
assert isinstance(artifact, PrecompileCacheArtifact)
|
|
backend_content[id_] = artifact
|
|
DynamoCache.write(cache_entry, backend_content, dynamo_entry.key)
|
|
|
|
return cache_info
|
|
|
|
@classmethod
|
|
def _ensure_cache_artifacts_registered(cls) -> None:
|
|
from torch._dynamo.package import _DynamoCacheArtifact # noqa: F401
|
|
from torch._functorch._aot_autograd.autograd_cache import ( # noqa: F401
|
|
BundledAOTAutogradCacheArtifact,
|
|
)
|
|
|
|
|
|
@dataclasses.dataclass(frozen=True)
|
|
class SystemInfo:
|
|
"""
|
|
System information including Python, PyTorch, and GPU details.
|
|
This information is used to ensure compiled artifacts can only be loaded
|
|
with compatible system configurations.
|
|
"""
|
|
|
|
python_version: str
|
|
torch_version: str
|
|
toolkit_version: Optional[str]
|
|
triton_version: Optional[tuple[int, int]]
|
|
gpu_name: Optional[str]
|
|
CHECK_GPUS = ("cuda", "xpu")
|
|
|
|
@classmethod
|
|
def current(cls) -> "SystemInfo":
|
|
"""Create a SystemInfo instance with current system information."""
|
|
# Get GPU name if CUDA or XPU is available
|
|
gpu_name, toolkit_version = None, None
|
|
for device_type in cls.CHECK_GPUS:
|
|
if getattr(torch, device_type).is_available():
|
|
try:
|
|
gpu_name = getattr(torch, device_type).get_device_name()
|
|
toolkit_version = getattr(torch.version, device_type)
|
|
break
|
|
except Exception:
|
|
pass
|
|
|
|
return cls(
|
|
python_version=platform.python_version(),
|
|
torch_version=torch.__version__,
|
|
toolkit_version=toolkit_version,
|
|
triton_version=get_triton_version((0, 0)),
|
|
gpu_name=gpu_name,
|
|
)
|
|
|
|
def check_compatibility(
|
|
self, other: "SystemInfo", device_type: str = "cpu"
|
|
) -> None:
|
|
"""
|
|
Check if this SystemInfo is compatible with another SystemInfo.
|
|
Raises RuntimeError if incompatible.
|
|
"""
|
|
if self.python_version != other.python_version:
|
|
raise RuntimeError(
|
|
f"Compile package was created with a different Python version: {self.python_version}"
|
|
)
|
|
|
|
if self.torch_version != other.torch_version:
|
|
raise RuntimeError(
|
|
f"Compile package was created with a different PyTorch version: {self.torch_version}"
|
|
)
|
|
if device_type in self.CHECK_GPUS:
|
|
if not getattr(torch, device_type).is_available():
|
|
raise RuntimeError(f"{device_type} is not available")
|
|
|
|
if self.toolkit_version != other.toolkit_version:
|
|
raise RuntimeError(
|
|
f"Compile package was created with a different toolkit version: {self.toolkit_version}"
|
|
)
|
|
|
|
if (
|
|
other.triton_version != (0, 0)
|
|
and self.triton_version != other.triton_version
|
|
):
|
|
raise RuntimeError(
|
|
f"Compile package was created with a different Triton version: {self.triton_version}"
|
|
)
|
|
|
|
# Check GPU name if CUDA/XPU was used
|
|
if other.gpu_name is not None and self.gpu_name != other.gpu_name:
|
|
raise RuntimeError(
|
|
f"Compile package was created with different GPU: "
|
|
f"cached={self.gpu_name}, current={other.gpu_name}"
|
|
)
|