mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Split this directory into two PRs to keep them from being too large. Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (6,884 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165062 Approved by: https://github.com/oulgen, https://github.com/mlazos
420 lines
14 KiB
Python
420 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
import pickle
|
|
from abc import ABC, abstractmethod
|
|
from ast import literal_eval
|
|
from functools import cached_property
|
|
from hashlib import sha256
|
|
from os import getenv
|
|
from pathlib import Path
|
|
from tempfile import gettempdir
|
|
from threading import Lock
|
|
from typing import Any, Generic, TYPE_CHECKING, TypeVar
|
|
from typing_extensions import assert_never, override, Self
|
|
|
|
from torch.utils._filelock import FileLock
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
|
|
|
|
# TypeVars can't be recursive, so generic types that fall within
|
|
# Key or Value can't be bound properly; for example, Key should
|
|
# only take tuples of other Key types: tuple[Key, ...]. this is
|
|
# a known shortcoming of torch's typing
|
|
Key = TypeVar("Key", str, int, tuple[Any, ...])
|
|
Value = TypeVar("Value", str, int, tuple[Any, ...], bytes, dict[Any, Any], list[Any])
|
|
|
|
|
|
class CacheError(ValueError):
|
|
"""
|
|
Exception raised for errors encountered during cache operations.
|
|
"""
|
|
|
|
|
|
class Cache(ABC, Generic[Key, Value]):
|
|
"""
|
|
Abstract base class for cache implementations.
|
|
Provides the interface for cache operations.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def get(self: Self, key: Key) -> Value | None:
|
|
"""
|
|
Retrieve a value from the cache.
|
|
Args:
|
|
key (Key): The key to look up.
|
|
Returns:
|
|
Value | None: The cached value if present, else None.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def insert(self: Self, key: Key, value: Value) -> bool:
|
|
"""
|
|
Insert a value into the cache.
|
|
Args:
|
|
key (Key): The key to insert.
|
|
value (Value): The value to associate with the key.
|
|
Returns:
|
|
bool: True if the value was inserted, False if the key already exists.
|
|
"""
|
|
|
|
|
|
class InMemoryCache(Cache[Key, Value]):
|
|
"""
|
|
In-memory cache implementation using a dictionary and thread lock.
|
|
"""
|
|
|
|
def __init__(self: Self) -> None:
|
|
"""
|
|
Initialize an empty in-memory cache.
|
|
"""
|
|
self._cache: dict[Key, Value] = {}
|
|
self._lock: Lock = Lock()
|
|
|
|
def get(self: Self, key: Key) -> Value | None:
|
|
"""
|
|
Retrieve a value from the cache.
|
|
Args:
|
|
key (Key): The key to look up.
|
|
Returns:
|
|
Value | None: The cached value if present, else None.
|
|
"""
|
|
with self._lock:
|
|
if (value := self._cache.get(key)) is not None:
|
|
return value
|
|
return None
|
|
|
|
def insert(self: Self, key: Key, value: Value) -> bool:
|
|
"""
|
|
Insert a value into the cache.
|
|
Args:
|
|
key (Key): The key to insert.
|
|
value (Value): The value to associate with the key.
|
|
Returns:
|
|
bool: True if the value was inserted, False if the key already exists.
|
|
"""
|
|
with self._lock:
|
|
if key in self._cache:
|
|
# no overwrites for insert!
|
|
return False
|
|
self._cache[key] = value
|
|
return True
|
|
|
|
@classmethod
|
|
def from_env_var(cls, env_var: str) -> Self:
|
|
"""
|
|
Create an in-memory cache from an environment variable.
|
|
Args:
|
|
env_var (str): Name of the environment variable containing cache data.
|
|
Returns:
|
|
InMemoryCache: An instance populated from the environment variable.
|
|
Raises:
|
|
CacheError: If the environment variable is malformed or contains invalid data.
|
|
"""
|
|
cache = cls()
|
|
|
|
if (env_val := getenv(env_var)) is None:
|
|
# env_var doesn't exist = empty cache
|
|
return cache
|
|
|
|
for kv_pair in env_val.split(";"):
|
|
# ignore whitespace prefix/suffix
|
|
kv_pair = kv_pair.strip()
|
|
|
|
if not kv_pair:
|
|
# kv_pair could be '' if env_val is '' or has ; suffix
|
|
continue
|
|
|
|
try:
|
|
# keys and values should be comma separated
|
|
key_bytes_repr, value_bytes_repr = kv_pair.split(",", 1)
|
|
except ValueError as err:
|
|
raise CacheError(
|
|
f"Malformed kv_pair {kv_pair!r} from env_var {env_var!r}, likely missing comma separator."
|
|
) from err
|
|
|
|
# ignore whitespace prefix/suffix, again
|
|
key_bytes_repr, value_bytes_repr = (
|
|
key_bytes_repr.strip(),
|
|
value_bytes_repr.strip(),
|
|
)
|
|
|
|
try:
|
|
# check that key_bytes_str is an actual, legitimate encoding
|
|
key_bytes = literal_eval(key_bytes_repr)
|
|
except (ValueError, SyntaxError) as err:
|
|
raise CacheError(
|
|
f"Malformed key_bytes_repr {key_bytes_repr!r} in kv_pair {kv_pair!r}, encoding is invalid."
|
|
) from err
|
|
try:
|
|
# check that value_bytes_str is an actual, legitimate encoding
|
|
value_bytes = literal_eval(value_bytes_repr)
|
|
except (ValueError, SyntaxError) as err:
|
|
raise CacheError(
|
|
f"Malformed value_bytes_repr {value_bytes_repr!r} in kv_pair {kv_pair!r}, encoding is invalid."
|
|
) from err
|
|
|
|
try:
|
|
key = pickle.loads(key_bytes)
|
|
except pickle.UnpicklingError as err:
|
|
raise CacheError(
|
|
f"Malformed key_bytes_repr {key_bytes_repr!r} in kv_pair {kv_pair!r}, not un-pickle-able."
|
|
) from err
|
|
try:
|
|
value = pickle.loads(value_bytes)
|
|
except pickle.UnpicklingError as err:
|
|
raise CacheError(
|
|
f"Malformed value_bytes_repr {value_bytes_repr!r} in kv_pair {kv_pair!r}, not un-pickle-able."
|
|
) from err
|
|
|
|
# true duplicates, i.e. multiple occurrences of the same key => value
|
|
# mapping are ok and treated as a no-op; key duplicates with differing
|
|
# values, i.e. key => value_1 and key => value_2 where value_1 != value_2,
|
|
# are not okay since we don't allow overwriting cached values (it's bad regardless)
|
|
if (not cache.insert(key, value)) and (cache.get(key) != value):
|
|
raise CacheError(
|
|
f"Multiple values for key {key!r} found, got {cache.get(key)!r} and {value!r}."
|
|
)
|
|
|
|
return cache
|
|
|
|
@classmethod
|
|
def from_file_path(cls, fpath: Path) -> Self:
|
|
"""
|
|
Create an in-memory cache from a file path.
|
|
Args:
|
|
fpath (Path): Path to the file containing pickled cache data.
|
|
Returns:
|
|
InMemoryCache: An instance populated from the file.
|
|
Raises:
|
|
CacheError: If the file is not a valid pickled dictionary.
|
|
"""
|
|
cache = cls()
|
|
|
|
if not fpath.is_file():
|
|
# fpath doesn't exit = empty cache
|
|
return cache
|
|
|
|
try:
|
|
with open(fpath, "rb") as fp:
|
|
cache._cache = pickle.load(fp)
|
|
except pickle.UnpicklingError as err:
|
|
raise CacheError(
|
|
f"Failed to create cache from file path {fpath}, file contents are un-pickle-able."
|
|
) from err
|
|
|
|
if not isinstance(cache._cache, dict):
|
|
raise CacheError(
|
|
f"Failed to create cache from file path {fpath}, file contents not pickled dict[Key, Value]."
|
|
)
|
|
|
|
return cache
|
|
|
|
|
|
class AsyncCache(Cache[Key, Value]):
|
|
"""
|
|
Asynchronous cache implementation using ThreadPoolExecutor.
|
|
"""
|
|
|
|
def get_async(
|
|
self: Self, key: Key, executor: ThreadPoolExecutor
|
|
) -> Future[Value | None]:
|
|
"""
|
|
Retrieve a value from the cache asynchronously.
|
|
Args:
|
|
key (Key): The key to look up.
|
|
executor (ThreadPoolExecutor): Executor for async execution.
|
|
Returns:
|
|
Future[Value | None]: Future for the cached value or None.
|
|
"""
|
|
return executor.submit(self.get, key)
|
|
|
|
def insert_async(
|
|
self: Self, key: Key, value: Value, executor: ThreadPoolExecutor
|
|
) -> Future[bool]:
|
|
"""
|
|
Insert a value into the cache asynchronously.
|
|
Args:
|
|
key (Key): The key to insert.
|
|
value (Value): The value to associate with the key.
|
|
executor (ThreadPoolExecutor): Executor for async execution.
|
|
Returns:
|
|
Future[bool]: Future for the result of insertion.
|
|
"""
|
|
return executor.submit(self.insert, key, value)
|
|
|
|
|
|
class OnDiskCache(AsyncCache[Key, Value]):
|
|
"""
|
|
On-disk cache implementation using files and file locks.
|
|
Stores cache data in files on disk, with atomic operations and versioning.
|
|
Supports custom cache directory names.
|
|
Attributes:
|
|
version (int): The version used for cache versioning.
|
|
name (str): The name of the cache directory.
|
|
"""
|
|
|
|
version: int = 0
|
|
|
|
def __init__(self: Self, name: str | None = None) -> None:
|
|
"""
|
|
Initialize an on-disk cache instance.
|
|
Args:
|
|
name (str | None, optional): The name of the cache directory. If None,
|
|
defaults to "on_disk_cache".
|
|
"""
|
|
self.name = name or "on_disk_cache"
|
|
|
|
@cached_property
|
|
def base_dir(self: Self) -> Path:
|
|
"""
|
|
Get the base directory for the cache.
|
|
Returns:
|
|
Path: The base directory path for storing cache files.
|
|
"""
|
|
return Path(gettempdir()) / "cache" / self.name
|
|
|
|
def _fpath_from_key(self: Self, key: Key) -> Path:
|
|
"""
|
|
Get the file path for a given key.
|
|
Args:
|
|
key (Key): The key to convert to a file path.
|
|
Returns:
|
|
Path: The file path for the key.
|
|
Raises:
|
|
CacheError: If the key is not pickle-able.
|
|
"""
|
|
try:
|
|
return self.base_dir / sha256(pickle.dumps(key)).hexdigest()[:32]
|
|
except (AttributeError, pickle.PicklingError) as err:
|
|
raise CacheError(
|
|
f"Failed to get fpath for key {key!r}, key is not pickle-able."
|
|
) from err
|
|
# pyrefly: ignore # bad-argument-type
|
|
assert_never(key)
|
|
|
|
def _flock_from_fpath(self: Self, fpath: Path) -> FileLock:
|
|
"""
|
|
Get a file lock for a given file path.
|
|
Args:
|
|
fpath (Path): The file path.
|
|
Returns:
|
|
FileLock: The file lock for the path.
|
|
"""
|
|
# fpath.name is a hex digest, meaning there are 16^4 potential values
|
|
# for fpath.name[:4]; this is more than enough unique locks to not
|
|
# cause additional overhead from shared locks and it also saves our
|
|
# cache dir from becoming 50 percent locks
|
|
# pyrefly: ignore # bad-return
|
|
return FileLock(str(fpath.parent / "locks" / fpath.name[:4]) + ".lock")
|
|
|
|
@property
|
|
def version_prefix(self: Self) -> bytes:
|
|
"""
|
|
Get the version prefix for the cache.
|
|
Returns:
|
|
bytes: The version prefix as bytes, derived from the cache version string.
|
|
"""
|
|
return sha256(str(OnDiskCache.version).encode()).digest()[:4]
|
|
|
|
@override
|
|
def get(self: Self, key: Key) -> Value | None:
|
|
"""
|
|
Retrieve a value from the cache.
|
|
Args:
|
|
key (Key): The key to look up.
|
|
Returns:
|
|
Value | None: The cached value if present and version matches, else None.
|
|
Raises:
|
|
CacheError: If the value is corrupted or cannot be unpickled.
|
|
Side Effects:
|
|
Removes stale cache files if the version prefix does not match.
|
|
"""
|
|
fpath = self._fpath_from_key(key)
|
|
flock = self._flock_from_fpath(fpath)
|
|
|
|
with flock:
|
|
if not fpath.is_file():
|
|
return None
|
|
|
|
value_bytes = None
|
|
prefix_length = len(self.version_prefix)
|
|
with open(fpath, "rb") as fp:
|
|
if fp.read(prefix_length) == self.version_prefix:
|
|
value_bytes = fp.read()
|
|
|
|
if value_bytes is None:
|
|
# version_prefix did not match, so we can't read the stale
|
|
# cached value; we should also remove the stale cached value,
|
|
# so that key can be re-cached by the newer version
|
|
fpath.unlink()
|
|
return None
|
|
|
|
try:
|
|
value = pickle.loads(value_bytes)
|
|
except pickle.UnpicklingError as err:
|
|
raise CacheError(
|
|
f"Failed to get key {key!r}, value is potentially corrupted (value is not un-pickle-able)."
|
|
) from err
|
|
|
|
return value
|
|
|
|
@override
|
|
def insert(self: Self, key: Key, value: Value) -> bool:
|
|
"""
|
|
Insert a value into the cache.
|
|
Args:
|
|
key (Key): The key to insert.
|
|
value (Value): The value to associate with the key.
|
|
Returns:
|
|
bool: True if the value was inserted, False if the key already exists.
|
|
Raises:
|
|
CacheError: If the value is not pickle-able.
|
|
Side Effects:
|
|
Creates the cache directory if it does not exist.
|
|
"""
|
|
fpath = self._fpath_from_key(key)
|
|
flock = self._flock_from_fpath(fpath)
|
|
fpath.parent.mkdir(parents=True, exist_ok=True)
|
|
try:
|
|
# "x" mode is exclusive creation, meaning the file will be created
|
|
# iff the file does not already exist (atomic w/o overwrite); use
|
|
# flock for added atomicity guarantee and to prevent partial writes
|
|
with flock as _, open(fpath, "xb") as fp:
|
|
fp.write(self.version_prefix)
|
|
pickle.dump(value, fp)
|
|
except pickle.PicklingError as err:
|
|
raise CacheError(
|
|
f"Failed to insert key {key!r} with value {value!r}, value is not pickle-able."
|
|
) from err
|
|
except FileExistsError:
|
|
return False
|
|
return True
|
|
|
|
|
|
class InductorOnDiskCache(OnDiskCache[Key, Value]):
|
|
"""
|
|
Inductor-specific on-disk cache implementation.
|
|
Uses a custom base directory for Inductor cache files.
|
|
"""
|
|
|
|
def __init__(self: Self) -> None:
|
|
"""
|
|
Initialize an inductor on-disk cache instance.
|
|
Sets the cache directory name to "inductor_on_disk_cache".
|
|
"""
|
|
super().__init__("inductor_on_disk_cache")
|
|
|
|
@cached_property
|
|
def base_dir(self: Self) -> Path:
|
|
"""
|
|
Get the base directory for the Inductor cache.
|
|
Returns:
|
|
Path: The base directory path for Inductor cache files.
|
|
"""
|
|
from torch._inductor.runtime.runtime_utils import default_cache_dir
|
|
|
|
return Path(default_cache_dir(), "cache", self.name)
|