non-fb impls + unit tests (#164722)

Test Plan:
```
buck test fbcode//mode/opt caffe2/test/inductor:caching
```

Differential Revision: D83714692

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164722
Approved by: https://github.com/NikhilAPatel, https://github.com/adamomainz
This commit is contained in:
Nicolas Macchioni
2025-10-09 05:10:57 +00:00
committed by PyTorch MergeBot
parent d40a9bfb8d
commit ed6156e3ea
2 changed files with 656 additions and 13 deletions

View File

@ -6,16 +6,23 @@ import pickle
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError, wait
from contextlib import contextmanager
from itertools import combinations
from pathlib import Path
from random import Random
from shutil import rmtree
from threading import Lock
from typing import Generator, Sequence, Union
from typing_extensions import Self, TypeVar
from typing import Any, Generator, Sequence, TYPE_CHECKING, Union
from typing_extensions import TypeVar
from unittest.mock import patch
from filelock import FileLock
from torch._inductor.runtime.caching import config, context, exceptions, locks, utils
from torch._inductor.runtime.caching import (
config,
context,
exceptions,
implementations as impls,
locks,
utils,
)
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
@ -23,6 +30,10 @@ from torch.testing._internal.common_utils import (
)
if TYPE_CHECKING:
from pathlib import Path
class TestMixin:
@property
def random_string(self) -> str:
@ -344,25 +355,236 @@ class ExceptionsTest(TestCase):
)
@instantiate_parametrized_tests
class ImplementationsTest(TestMixin, TestCase):
impl_typenames: list[str] = [
"_InMemoryCacheImpl",
"_OnDiskCacheImpl",
]
cls_id: int = Random().randint(0, 2**32)
@classmethod
def sub_dir(cls) -> str:
return f"testing-impls-instance-{cls.cls_id}"
@classmethod
def setUpClass(cls) -> None:
rmtree(
impls._OnDiskCacheImpl(sub_dir=cls.sub_dir())._cache_dir, ignore_errors=True
)
@classmethod
def tearDownClass(cls) -> None:
rmtree(
impls._OnDiskCacheImpl(sub_dir=cls.sub_dir())._cache_dir, ignore_errors=True
)
def impl_from_typename(self, impl_typename: str) -> impls._CacheImpl:
if impl_typename == "_OnDiskCacheImpl":
return impls._OnDiskCacheImpl(
sub_dir=f"{self.sub_dir()}/rng-{self.random_string[4:]}",
)
else:
return getattr(impls, impl_typename)()
def assert_key_in(self, key: Any, impl: impls._CacheImpl) -> None:
self.assertTrue(impl.get(key) is not None)
def assert_key_not_in(self, key: Any, impl: impls._CacheImpl) -> None:
self.assertTrue(impl.get(key) is None)
def assert_key_value_inserted_in(
self, key: Any, value: Any, impl: impls._CacheImpl
) -> None:
self.assertTrue(impl.insert(key, value))
def assert_key_value_not_inserted_in(
self, key: Any, value: Any, impl: impls._CacheImpl
) -> None:
self.assertFalse(impl.insert(key, value))
def assert_key_has_value_in(
self, key: Any, value: Any, impl: impls._CacheImpl
) -> None:
self.assertTrue(((get := impl.get(key)) is not None) and (get.value == value))
@parametrize("impl_typename", impl_typenames)
def test_get(self, impl_typename: str) -> None:
"""Test cache get operation returns cache miss for non-existent keys.
Verifies that both in-memory and on-disk cache implementations correctly
handle get operations for keys that do not exist in the cache. This test
ensures that the cache properly returns a cache miss (hit=False) when
attempting to retrieve a non-existent key.
Args:
impl_typename: The cache implementation type to test ("_InMemoryCacheImpl" or "_OnDiskCacheImpl")
"""
impl: impls._CacheImpl = self.impl_from_typename(impl_typename)
with impl.lock():
self.assert_key_not_in(self.random_string, impl)
@parametrize("impl_typename", impl_typenames)
def test_insert(self, impl_typename: str) -> None:
"""Test cache insert operation successfully stores and retrieves key-value pairs.
Verifies that both in-memory and on-disk cache implementations correctly
handle insert operations for new key-value pairs. This test ensures that:
1. Keys initially don't exist in the cache (cache miss)
2. Insert operations succeed for new keys
3. The stored value can be retrieved correctly after insertion
Args:
impl_typename: The cache implementation type to test ("_InMemoryCacheImpl" or "_OnDiskCacheImpl")
"""
impl: impls._CacheImpl = self.impl_from_typename(impl_typename)
with impl.lock():
key: str = self.random_string
self.assert_key_not_in(key, impl)
value: str = self.random_string
self.assert_key_value_inserted_in(key, value, impl)
self.assert_key_has_value_in(key, value, impl)
@parametrize("impl_typename", impl_typenames)
def test_insert_will_not_overwrite(self, impl_typename: str) -> None:
"""Test cache insert operation does not overwrite existing keys.
Verifies that both in-memory and on-disk cache implementations correctly
handle insert operations for keys that already exist in the cache. This test
ensures that:
1. Keys initially don't exist in the cache (cache miss)
2. First insert operation succeeds for new keys
3. Subsequent insert operations with the same key fail (inserted=False)
4. The original value is preserved and not overwritten
Args:
impl_typename: The cache implementation type to test ("_InMemoryCacheImpl" or "_OnDiskCacheImpl")
"""
impl: impls._CacheImpl = self.impl_from_typename(impl_typename)
with impl.lock():
key: str = self.random_string
self.assert_key_not_in(key, impl)
value: str = self.random_string
self.assert_key_value_inserted_in(key, value, impl)
self.assert_key_value_not_inserted_in(key, self.random_string, impl)
self.assert_key_has_value_in(key, value, impl)
@parametrize("impl_typename", impl_typenames)
def test_key_encoding(self, impl_typename: str) -> None:
"""Test that cache implementations properly handle non-serializable keys.
Verifies that both in-memory and on-disk cache implementations correctly
raise KeyPicklingError when attempting to insert keys that cannot be
pickled (such as lambda functions). This ensures proper error handling
for invalid key types that would break the caching system.
Args:
impl_typename: The cache implementation type to test ("_InMemoryCacheImpl" or "_OnDiskCacheImpl")
"""
impl: impls._CacheImpl = self.impl_from_typename(impl_typename)
with impl.lock():
with self.assertRaises(exceptions.KeyPicklingError):
impl.insert(lambda: None, None)
@parametrize("impl_typename", impl_typenames)
def test_value_encoding(self, impl_typename: str) -> None:
"""Test that on-disk cache implementations properly handle non-serializable values.
Verifies that on-disk cache implementations correctly raise ValuePicklingError
when attempting to insert values that cannot be pickled (such as lambda functions).
This test only applies to on-disk implementations since in-memory caches don't
require serialization. Ensures proper error handling for invalid value types.
Args:
impl_typename: The cache implementation type to test ("_InMemoryCacheImpl" or "_OnDiskCacheImpl")
"""
impl: impls._CacheImpl = self.impl_from_typename(impl_typename)
with impl.lock():
if isinstance(impl, impls._OnDiskCacheImpl):
with self.assertRaises(exceptions.ValuePicklingError):
impl.insert(None, lambda: None)
@parametrize("impl_typename", impl_typenames)
def test_value_decoding(self, impl_typename: str) -> None:
"""Test that on-disk cache implementations properly handle corrupted cached values.
Verifies that on-disk cache implementations correctly raise ValueUnPicklingError
when attempting to retrieve values from cache files that contain corrupted or
invalid pickled data. This test ensures proper error handling when cached data
becomes corrupted on disk. Only applies to on-disk implementations since
in-memory caches don't involve serialization/deserialization.
Args:
impl_typename: The cache implementation type to test ("_InMemoryCacheImpl" or "_OnDiskCacheImpl")
"""
impl: impls._CacheImpl = self.impl_from_typename(impl_typename)
with impl.lock():
if isinstance(impl, impls._OnDiskCacheImpl):
key: str = self.random_string
self.assert_key_not_in(key, impl)
fpath: Path = impl._fpath_from_key(key)
with open(fpath, "xb") as fp:
impl._write_version_header(fp)
fp.write(b"foo")
with self.assertRaises(exceptions.ValueUnPicklingError):
impl.get(key)
@parametrize("impl_typename", impl_typenames)
def test_version_mismatch(self, impl_typename: str) -> None:
"""Test that on-disk cache implementations properly handle version mismatches.
Verifies that on-disk cache implementations correctly handle cached data when
the cache version changes. This test ensures that:
1. Data can be stored and retrieved with the current version
2. When version changes, previously cached data becomes inaccessible (cache miss)
3. New data can be stored with the new version
4. After version change, old cached data remains inaccessible
This version checking mechanism prevents corruption and compatibility issues
when cache formats change between software versions. Only applies to on-disk
implementations since in-memory caches don't persist across version changes.
Args:
impl_typename: The cache implementation type to test ("_InMemoryCacheImpl" or "_OnDiskCacheImpl")
"""
impl: impls._CacheImpl = self.impl_from_typename(impl_typename)
with impl.lock():
if isinstance(impl, impls._OnDiskCacheImpl):
key: str = self.random_string
self.assert_key_not_in(key, impl)
value: str = self.random_string
self.assert_key_value_inserted_in(key, value, impl)
self.assert_key_has_value_in(key, value, impl)
with patch.object(
impls._OnDiskCacheImpl, "_version", impl._version + 1
):
self.assert_key_not_in(key, impl)
self.assert_key_value_inserted_in(key, value, impl)
self.assert_key_has_value_in(key, value, impl)
self.assert_key_not_in(key, impl)
self.assert_key_value_inserted_in(key, value, impl)
self.assert_key_has_value_in(key, value, impl)
@instantiate_parametrized_tests
class LocksTest(TestMixin, TestCase):
T = TypeVar("T")
@contextmanager
def executor(self: Self) -> Generator[ThreadPoolExecutor, None, None]:
def executor(self) -> Generator[ThreadPoolExecutor, None, None]:
executor: ThreadPoolExecutor = ThreadPoolExecutor()
try:
yield executor
finally:
executor.shutdown()
def is_lock(self: Self, lock_or_flock: Union[Lock, FileLock]) -> bool:
def is_lock(self, lock_or_flock: Union[Lock, FileLock]) -> bool:
return hasattr(lock_or_flock, "locked")
def is_flock(self: Self, lock_or_flock: Union[Lock, FileLock]) -> bool:
def is_flock(self, lock_or_flock: Union[Lock, FileLock]) -> bool:
return hasattr(lock_or_flock, "is_locked")
def lock_or_flock_locked(self: Self, lock_or_flock: Union[Lock, FileLock]) -> bool:
def lock_or_flock_locked(self, lock_or_flock: Union[Lock, FileLock]) -> bool:
if self.is_lock(lock_or_flock):
return lock_or_flock.locked()
elif self.is_flock(lock_or_flock):
@ -370,13 +592,13 @@ class LocksTest(TestMixin, TestCase):
else:
raise NotImplementedError
def test_BLOCKING(self: Self) -> None:
def test_BLOCKING(self) -> None:
self.assertEqual(locks._BLOCKING, -1.0)
def test_NON_BLOCKING(self: Self) -> None:
def test_NON_BLOCKING(self) -> None:
self.assertEqual(locks._NON_BLOCKING, 0.0)
def test_BLOCKING_WITH_TIMEOUT(self: Self) -> None:
def test_BLOCKING_WITH_TIMEOUT(self) -> None:
self.assertGreater(locks._BLOCKING_WITH_TIMEOUT, 0.0)
@patch.object(locks, "_BLOCKING_WITH_TIMEOUT", 1.0)
@ -386,7 +608,7 @@ class LocksTest(TestMixin, TestCase):
@parametrize("acquisition_mode", ["safe", "unsafe"])
@parametrize("release", ["unlocked", "never", "before_timeout", "after_timeout"])
def test_acquire_with_timeout(
self: Self,
self,
lock_typename: str,
lock_timeout: str,
acquisition_mode: str,
@ -440,7 +662,10 @@ class LocksTest(TestMixin, TestCase):
self.assertFalse(self.lock_or_flock_locked(lock_or_flock))
assert lock_typename in ["Lock", "FileLock"]
flock_fpath: Path = Path("/tmp/LocksTest") / self.random_string
flock_fpath: Path = (
impls._OnDiskCacheImpl()._cache_dir
/ f"testing-locks-instance-{self.random_string}.lock"
)
lock_or_flock: Union[Lock, FileLock] = (
Lock() if lock_typename == "Lock" else FileLock(str(flock_fpath))
)

View File

@ -0,0 +1,418 @@
"""Cache implementation classes for PyTorch Inductor runtime caching.
This module provides concrete implementations of caching backends including
in-memory, on-disk, and remote caching strategies. Each implementation follows
the abstract _CacheImpl interface and provides thread-safe operations with
appropriate locking mechanisms.
"""
from abc import ABC, abstractmethod
from contextlib import _GeneratorContextManager, contextmanager
from dataclasses import dataclass
from hashlib import sha256
from io import BufferedReader, BufferedWriter
from os import PathLike
from pathlib import Path
from threading import Lock
from typing import Any, Callable, Generator, Optional
from typing_extensions import override, TypeAlias
from filelock import FileLock
from . import locks, utils
_LockContextManager: TypeAlias = (
Generator[None, None, None] | _GeneratorContextManager[None, None, None]
)
@dataclass
class Hit:
"""Result wrapper for hits on cache get operations.
Allows distinguishing between a cache miss and a cached None value.
Attributes:
value: The cached value.
"""
value: Any
class Miss:
"""Sentinel class representing a cache miss.
Used to distinguish between a cached None value and a cache miss
when None is a valid cached value.
"""
# Singleton instance for cache miss sentinel
miss = Miss()
class _CacheImpl(ABC):
"""Abstract base class for cache implementations.
This class defines the interface that all cache implementations must follow.
It provides thread-safe operations through a locking mechanism and supports
both get and insert operations.
Note: We don't use generics here as doing so would require that the interfaces
know which k/v types the implementation can work with. Instead, we leave that
determination up to the implementation itself and require that the interfaces
handle any potential errors from invalid k/v types being passed to the
implementation.
"""
def __init__(self) -> None:
"""Initialize the cache implementation with a threading lock."""
self._lock: Lock = Lock()
@property
def lock(self) -> Callable[[Optional[float]], _LockContextManager]:
"""Get a context manager for acquiring the cache lock.
Locking of the cache is not done by the implementation itself, but by the
interface that uses it. The interface may want to hold the lock for longer
than a single cache operation, for example when dealing with multiple
cache implementations at once, so we leave that decision up to the interface.
Args:
timeout: Optional timeout in seconds (float) for acquiring the lock.
Returns:
A callable that returns a context manager for the lock.
"""
def _lock_with_timeout(
timeout: Optional[float] = None,
) -> _LockContextManager:
return locks._acquire_lock_with_timeout(self._lock, timeout)
return _lock_with_timeout
@abstractmethod
def get(self, key: Any) -> Optional[Hit]:
"""Retrieve a value from the cache.
Args:
key: The key to look up in the cache.
Returns:
A Hit object on cache hit where Hit.value is the cached value,
or None on cache miss.
"""
@abstractmethod
def insert(self, key: Any, value: Any) -> bool:
"""Insert a key-value pair into the cache.
Args:
key: The key to insert.
value: The value to associate with the key.
Returns:
True if the insertion was successful, False if not inserted.
"""
class _InMemoryCacheImpl(_CacheImpl):
"""In-memory cache implementation using a dictionary.
This implementation stores key-value pairs in a Python dictionary,
with keys being pickled for consistent hashing. It provides fast
access but is limited by available memory and process lifetime.
"""
def __init__(self) -> None:
"""Initialize the in-memory cache with an empty dictionary."""
super().__init__()
self._memory: dict[bytes, Any] = {}
@override
def get(self, key: Any) -> Optional[Hit]:
"""Retrieve a value from the in-memory cache.
Args:
key: The key to look up. Will be pickled for storage.
Returns:
A Hit object on cache hit where Hit.value is the cached value,
or None on cache miss.
"""
pickled_key: bytes = utils._try_pickle_key(key)
if (value := self._memory.get(pickled_key, miss)) is not miss:
return Hit(value=value)
return None
@override
def insert(self, key: Any, value: Any) -> bool:
"""Insert a key-value pair into the in-memory cache.
Args:
key: The key to insert. Will be pickled for storage.
value: The value to associate with the key.
Returns:
True if the insertion was successful (key was new),
False if not inserted (key already existed).
"""
pickled_key: bytes = utils._try_pickle_key(key)
if pickled_key not in self._memory:
self._memory[pickled_key] = value
return True
return False
class _OnDiskCacheImpl(_CacheImpl):
"""On-disk cache implementation using file system storage.
This implementation stores cached data as files on disk, with version
headers to handle cache invalidation. It uses file locking to ensure
thread safety across processes and provides persistent storage that
survives process restarts.
Attributes:
_version: Version number for cache format compatibility.
_version_header_length: Length of the version header in bytes.
"""
_version: int = 0
_version_header_length: int = 4
def __init__(self, sub_dir: Optional[PathLike[str]] = None) -> None:
"""Initialize the on-disk cache with a specified subdirectory.
Args:
sub_dir: Subdirectory name within the cache directory.
Defaults to empty string if not specified.
"""
self._cache_dir: Path = self._base_dir / (sub_dir or "")
self._flock: FileLock = FileLock(str(self._cache_dir / "dir.lock"))
@property
def _base_dir(self) -> Path:
"""Get the base directory for cache storage.
Returns:
Path to the cache directory based on the default cache dir
and the specified subdirectory.
"""
from torch._inductor.runtime.runtime_utils import default_cache_dir
return Path(default_cache_dir(), "cache")
def _fpath_from_key(self, key: Any) -> Path:
"""Generate a file path from a cache key.
Args:
key: The cache key to convert to a file path.
Returns:
A Path object representing the file location for this key.
"""
pickled_key: bytes = utils._try_pickle_key(key)
return self._cache_dir / sha256(pickled_key).hexdigest()[:32]
@classmethod
def _version_header(cls) -> bytes:
"""Generate the version header bytes.
Returns:
A byte string representing the current cache version header.
"""
return sha256(str(cls._version).encode()).digest()[: cls._version_header_length]
def _version_header_matches(self, fp: BufferedReader) -> bool:
"""Check if the file's version header matches the current version.
Args:
fp: File pointer positioned at the start of the file.
Returns:
True if the version header matches, False otherwise.
"""
return fp.read(self._version_header_length) == self._version_header()
def _write_version_header(self, fp: BufferedWriter) -> None:
"""Write the version header to a file.
Args:
fp: File pointer where the version header should be written.
"""
fp.write(self._version_header())
@override
@property
def lock(self) -> Callable[[Optional[float]], _LockContextManager]:
"""Get a context manager for acquiring the file lock.
Uses file locking to ensure thread safety across processes.
Args:
timeout: Optional timeout in seconds (float) for acquiring the file lock.
Returns:
A callable that returns a context manager for the file lock.
"""
def _lock_with_timeout(
timeout: Optional[float] = None,
) -> _LockContextManager:
return locks._acquire_flock_with_timeout(self._flock, timeout)
return _lock_with_timeout
@override
def get(self, key: Any) -> Optional[Hit]:
"""Retrieve a value from the on-disk cache.
Args:
key: The key to look up in the cache.
Returns:
A Hit object on cache hit where Hit.value is the cached value,
or None on cache miss or version mismatch.
"""
fpath: Path = self._fpath_from_key(key)
if not fpath.is_file():
return None
pickled_value: Optional[bytes] = None
with open(fpath, "rb") as fp:
if self._version_header_matches(fp):
pickled_value = fp.read()
if not pickled_value:
# if pickled_value is still None, even though the file exists, then
# we know that the version header did not match. in this case implementation
# is up to preference, we choose to remove entries that do not match
# the version header so that the key can be re-cached later with the correct
# version header
fpath.unlink()
return None
return Hit(value=utils._try_unpickle_value(pickled_value))
@override
def insert(self, key: Any, value: Any) -> bool:
"""Insert a key-value pair into the on-disk cache.
Args:
key: The key to insert.
value: The value to associate with the key.
Returns:
True if successfully inserted, False if the key already exists
with a valid version.
"""
fpath: Path = self._fpath_from_key(key)
fpath.parent.mkdir(parents=True, exist_ok=True)
r_fp, w_fp, inserted = None, None, False
try:
w_fp = open(fpath, "xb")
except FileExistsError:
is_stale: bool = False
with open(fpath, "rb") as r_fp:
is_stale = not self._version_header_matches(r_fp)
if is_stale:
# same story as above, in this case the version header doesn't
# match so we choose to remove the old entry so that the new
# k/v pair can be cached
fpath.unlink()
w_fp = open(fpath, "xb")
else:
w_fp = None
finally:
if w_fp:
try:
pickled_value: bytes = utils._try_pickle_value(value)
self._write_version_header(w_fp)
w_fp.write(pickled_value)
inserted = True
finally:
w_fp.close()
return inserted
try:
from .fb.implementations import _RemoteCacheImpl
except ModuleNotFoundError:
class _RemoteCacheImpl(_CacheImpl): # type: ignore[no-redef]
"""Fallback remote cache implementation for non-Facebook environments.
This is a no-op implementation that always raises NotImplementedError.
The actual remote cache implementation is provided in the `.fb` module
for Facebook-specific environments.
Attributes:
_version: Version number for cache format compatibility.
has_strong_consistency: Whether the remote cache provides strong
consistency guarantees.
"""
_version: int = 0
has_strong_consistency: bool = False
def __init__(self) -> None:
"""Initialize the fallback remote cache implementation.
Note: We don't need to initialize any form of lock since this
implementation provides a pseudo-lock context manager.
"""
@override
@property
def lock(self) -> Callable[[Optional[float]], _LockContextManager]:
"""Get a pseudo lock that does nothing.
Most remote cache implementations don't have an ability to implement
any form of locking, so we provide a no-op pseudo-lock for consistency
with the interface.
Args:
timeout: Optional timeout in seconds (float). Ignored in this
Returns:
A callable that returns a no-op context manager.
"""
@contextmanager
def pseudo_lock(
timeout: Optional[float] = None,
) -> Generator[None, None, None]:
yield
return pseudo_lock
@override
def get(self, key: Any) -> Optional[Hit]:
"""Raise NotImplementedError for remote cache get operations.
Args:
key: The key to look up (ignored).
Raises:
NotImplementedError: Always raised as this is a fallback implementation.
"""
raise NotImplementedError
@override
def insert(self, key: Any, value: Any) -> bool:
"""Raise NotImplementedError for remote cache insert operations.
Args:
key: The key to insert (ignored).
value: The value to insert (ignored).
Raises:
NotImplementedError: Always raised as this is a fallback implementation.
"""
raise NotImplementedError