mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
non-fb impls + unit tests (#164722)
Test Plan: ``` buck test fbcode//mode/opt caffe2/test/inductor:caching ``` Differential Revision: D83714692 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164722 Approved by: https://github.com/NikhilAPatel, https://github.com/adamomainz
This commit is contained in:
committed by
PyTorch MergeBot
parent
d40a9bfb8d
commit
ed6156e3ea
@ -6,16 +6,23 @@ import pickle
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError, wait
|
||||
from contextlib import contextmanager
|
||||
from itertools import combinations
|
||||
from pathlib import Path
|
||||
from random import Random
|
||||
from shutil import rmtree
|
||||
from threading import Lock
|
||||
from typing import Generator, Sequence, Union
|
||||
from typing_extensions import Self, TypeVar
|
||||
from typing import Any, Generator, Sequence, TYPE_CHECKING, Union
|
||||
from typing_extensions import TypeVar
|
||||
from unittest.mock import patch
|
||||
|
||||
from filelock import FileLock
|
||||
|
||||
from torch._inductor.runtime.caching import config, context, exceptions, locks, utils
|
||||
from torch._inductor.runtime.caching import (
|
||||
config,
|
||||
context,
|
||||
exceptions,
|
||||
implementations as impls,
|
||||
locks,
|
||||
utils,
|
||||
)
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
@ -23,6 +30,10 @@ from torch.testing._internal.common_utils import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestMixin:
|
||||
@property
|
||||
def random_string(self) -> str:
|
||||
@ -344,25 +355,236 @@ class ExceptionsTest(TestCase):
|
||||
)
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class ImplementationsTest(TestMixin, TestCase):
|
||||
impl_typenames: list[str] = [
|
||||
"_InMemoryCacheImpl",
|
||||
"_OnDiskCacheImpl",
|
||||
]
|
||||
cls_id: int = Random().randint(0, 2**32)
|
||||
|
||||
@classmethod
|
||||
def sub_dir(cls) -> str:
|
||||
return f"testing-impls-instance-{cls.cls_id}"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
rmtree(
|
||||
impls._OnDiskCacheImpl(sub_dir=cls.sub_dir())._cache_dir, ignore_errors=True
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls) -> None:
|
||||
rmtree(
|
||||
impls._OnDiskCacheImpl(sub_dir=cls.sub_dir())._cache_dir, ignore_errors=True
|
||||
)
|
||||
|
||||
def impl_from_typename(self, impl_typename: str) -> impls._CacheImpl:
|
||||
if impl_typename == "_OnDiskCacheImpl":
|
||||
return impls._OnDiskCacheImpl(
|
||||
sub_dir=f"{self.sub_dir()}/rng-{self.random_string[4:]}",
|
||||
)
|
||||
else:
|
||||
return getattr(impls, impl_typename)()
|
||||
|
||||
def assert_key_in(self, key: Any, impl: impls._CacheImpl) -> None:
|
||||
self.assertTrue(impl.get(key) is not None)
|
||||
|
||||
def assert_key_not_in(self, key: Any, impl: impls._CacheImpl) -> None:
|
||||
self.assertTrue(impl.get(key) is None)
|
||||
|
||||
def assert_key_value_inserted_in(
|
||||
self, key: Any, value: Any, impl: impls._CacheImpl
|
||||
) -> None:
|
||||
self.assertTrue(impl.insert(key, value))
|
||||
|
||||
def assert_key_value_not_inserted_in(
|
||||
self, key: Any, value: Any, impl: impls._CacheImpl
|
||||
) -> None:
|
||||
self.assertFalse(impl.insert(key, value))
|
||||
|
||||
def assert_key_has_value_in(
|
||||
self, key: Any, value: Any, impl: impls._CacheImpl
|
||||
) -> None:
|
||||
self.assertTrue(((get := impl.get(key)) is not None) and (get.value == value))
|
||||
|
||||
@parametrize("impl_typename", impl_typenames)
|
||||
def test_get(self, impl_typename: str) -> None:
|
||||
"""Test cache get operation returns cache miss for non-existent keys.
|
||||
|
||||
Verifies that both in-memory and on-disk cache implementations correctly
|
||||
handle get operations for keys that do not exist in the cache. This test
|
||||
ensures that the cache properly returns a cache miss (hit=False) when
|
||||
attempting to retrieve a non-existent key.
|
||||
|
||||
Args:
|
||||
impl_typename: The cache implementation type to test ("_InMemoryCacheImpl" or "_OnDiskCacheImpl")
|
||||
"""
|
||||
impl: impls._CacheImpl = self.impl_from_typename(impl_typename)
|
||||
with impl.lock():
|
||||
self.assert_key_not_in(self.random_string, impl)
|
||||
|
||||
@parametrize("impl_typename", impl_typenames)
|
||||
def test_insert(self, impl_typename: str) -> None:
|
||||
"""Test cache insert operation successfully stores and retrieves key-value pairs.
|
||||
|
||||
Verifies that both in-memory and on-disk cache implementations correctly
|
||||
handle insert operations for new key-value pairs. This test ensures that:
|
||||
1. Keys initially don't exist in the cache (cache miss)
|
||||
2. Insert operations succeed for new keys
|
||||
3. The stored value can be retrieved correctly after insertion
|
||||
|
||||
Args:
|
||||
impl_typename: The cache implementation type to test ("_InMemoryCacheImpl" or "_OnDiskCacheImpl")
|
||||
"""
|
||||
impl: impls._CacheImpl = self.impl_from_typename(impl_typename)
|
||||
with impl.lock():
|
||||
key: str = self.random_string
|
||||
self.assert_key_not_in(key, impl)
|
||||
value: str = self.random_string
|
||||
self.assert_key_value_inserted_in(key, value, impl)
|
||||
self.assert_key_has_value_in(key, value, impl)
|
||||
|
||||
@parametrize("impl_typename", impl_typenames)
|
||||
def test_insert_will_not_overwrite(self, impl_typename: str) -> None:
|
||||
"""Test cache insert operation does not overwrite existing keys.
|
||||
|
||||
Verifies that both in-memory and on-disk cache implementations correctly
|
||||
handle insert operations for keys that already exist in the cache. This test
|
||||
ensures that:
|
||||
1. Keys initially don't exist in the cache (cache miss)
|
||||
2. First insert operation succeeds for new keys
|
||||
3. Subsequent insert operations with the same key fail (inserted=False)
|
||||
4. The original value is preserved and not overwritten
|
||||
|
||||
Args:
|
||||
impl_typename: The cache implementation type to test ("_InMemoryCacheImpl" or "_OnDiskCacheImpl")
|
||||
"""
|
||||
impl: impls._CacheImpl = self.impl_from_typename(impl_typename)
|
||||
with impl.lock():
|
||||
key: str = self.random_string
|
||||
self.assert_key_not_in(key, impl)
|
||||
value: str = self.random_string
|
||||
self.assert_key_value_inserted_in(key, value, impl)
|
||||
self.assert_key_value_not_inserted_in(key, self.random_string, impl)
|
||||
self.assert_key_has_value_in(key, value, impl)
|
||||
|
||||
@parametrize("impl_typename", impl_typenames)
|
||||
def test_key_encoding(self, impl_typename: str) -> None:
|
||||
"""Test that cache implementations properly handle non-serializable keys.
|
||||
|
||||
Verifies that both in-memory and on-disk cache implementations correctly
|
||||
raise KeyPicklingError when attempting to insert keys that cannot be
|
||||
pickled (such as lambda functions). This ensures proper error handling
|
||||
for invalid key types that would break the caching system.
|
||||
|
||||
Args:
|
||||
impl_typename: The cache implementation type to test ("_InMemoryCacheImpl" or "_OnDiskCacheImpl")
|
||||
"""
|
||||
impl: impls._CacheImpl = self.impl_from_typename(impl_typename)
|
||||
with impl.lock():
|
||||
with self.assertRaises(exceptions.KeyPicklingError):
|
||||
impl.insert(lambda: None, None)
|
||||
|
||||
@parametrize("impl_typename", impl_typenames)
|
||||
def test_value_encoding(self, impl_typename: str) -> None:
|
||||
"""Test that on-disk cache implementations properly handle non-serializable values.
|
||||
|
||||
Verifies that on-disk cache implementations correctly raise ValuePicklingError
|
||||
when attempting to insert values that cannot be pickled (such as lambda functions).
|
||||
This test only applies to on-disk implementations since in-memory caches don't
|
||||
require serialization. Ensures proper error handling for invalid value types.
|
||||
|
||||
Args:
|
||||
impl_typename: The cache implementation type to test ("_InMemoryCacheImpl" or "_OnDiskCacheImpl")
|
||||
"""
|
||||
impl: impls._CacheImpl = self.impl_from_typename(impl_typename)
|
||||
with impl.lock():
|
||||
if isinstance(impl, impls._OnDiskCacheImpl):
|
||||
with self.assertRaises(exceptions.ValuePicklingError):
|
||||
impl.insert(None, lambda: None)
|
||||
|
||||
@parametrize("impl_typename", impl_typenames)
|
||||
def test_value_decoding(self, impl_typename: str) -> None:
|
||||
"""Test that on-disk cache implementations properly handle corrupted cached values.
|
||||
|
||||
Verifies that on-disk cache implementations correctly raise ValueUnPicklingError
|
||||
when attempting to retrieve values from cache files that contain corrupted or
|
||||
invalid pickled data. This test ensures proper error handling when cached data
|
||||
becomes corrupted on disk. Only applies to on-disk implementations since
|
||||
in-memory caches don't involve serialization/deserialization.
|
||||
|
||||
Args:
|
||||
impl_typename: The cache implementation type to test ("_InMemoryCacheImpl" or "_OnDiskCacheImpl")
|
||||
"""
|
||||
impl: impls._CacheImpl = self.impl_from_typename(impl_typename)
|
||||
with impl.lock():
|
||||
if isinstance(impl, impls._OnDiskCacheImpl):
|
||||
key: str = self.random_string
|
||||
self.assert_key_not_in(key, impl)
|
||||
fpath: Path = impl._fpath_from_key(key)
|
||||
with open(fpath, "xb") as fp:
|
||||
impl._write_version_header(fp)
|
||||
fp.write(b"foo")
|
||||
with self.assertRaises(exceptions.ValueUnPicklingError):
|
||||
impl.get(key)
|
||||
|
||||
@parametrize("impl_typename", impl_typenames)
|
||||
def test_version_mismatch(self, impl_typename: str) -> None:
|
||||
"""Test that on-disk cache implementations properly handle version mismatches.
|
||||
|
||||
Verifies that on-disk cache implementations correctly handle cached data when
|
||||
the cache version changes. This test ensures that:
|
||||
1. Data can be stored and retrieved with the current version
|
||||
2. When version changes, previously cached data becomes inaccessible (cache miss)
|
||||
3. New data can be stored with the new version
|
||||
4. After version change, old cached data remains inaccessible
|
||||
|
||||
This version checking mechanism prevents corruption and compatibility issues
|
||||
when cache formats change between software versions. Only applies to on-disk
|
||||
implementations since in-memory caches don't persist across version changes.
|
||||
|
||||
Args:
|
||||
impl_typename: The cache implementation type to test ("_InMemoryCacheImpl" or "_OnDiskCacheImpl")
|
||||
"""
|
||||
impl: impls._CacheImpl = self.impl_from_typename(impl_typename)
|
||||
with impl.lock():
|
||||
if isinstance(impl, impls._OnDiskCacheImpl):
|
||||
key: str = self.random_string
|
||||
self.assert_key_not_in(key, impl)
|
||||
value: str = self.random_string
|
||||
self.assert_key_value_inserted_in(key, value, impl)
|
||||
self.assert_key_has_value_in(key, value, impl)
|
||||
with patch.object(
|
||||
impls._OnDiskCacheImpl, "_version", impl._version + 1
|
||||
):
|
||||
self.assert_key_not_in(key, impl)
|
||||
self.assert_key_value_inserted_in(key, value, impl)
|
||||
self.assert_key_has_value_in(key, value, impl)
|
||||
self.assert_key_not_in(key, impl)
|
||||
self.assert_key_value_inserted_in(key, value, impl)
|
||||
self.assert_key_has_value_in(key, value, impl)
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class LocksTest(TestMixin, TestCase):
|
||||
T = TypeVar("T")
|
||||
|
||||
@contextmanager
|
||||
def executor(self: Self) -> Generator[ThreadPoolExecutor, None, None]:
|
||||
def executor(self) -> Generator[ThreadPoolExecutor, None, None]:
|
||||
executor: ThreadPoolExecutor = ThreadPoolExecutor()
|
||||
try:
|
||||
yield executor
|
||||
finally:
|
||||
executor.shutdown()
|
||||
|
||||
def is_lock(self: Self, lock_or_flock: Union[Lock, FileLock]) -> bool:
|
||||
def is_lock(self, lock_or_flock: Union[Lock, FileLock]) -> bool:
|
||||
return hasattr(lock_or_flock, "locked")
|
||||
|
||||
def is_flock(self: Self, lock_or_flock: Union[Lock, FileLock]) -> bool:
|
||||
def is_flock(self, lock_or_flock: Union[Lock, FileLock]) -> bool:
|
||||
return hasattr(lock_or_flock, "is_locked")
|
||||
|
||||
def lock_or_flock_locked(self: Self, lock_or_flock: Union[Lock, FileLock]) -> bool:
|
||||
def lock_or_flock_locked(self, lock_or_flock: Union[Lock, FileLock]) -> bool:
|
||||
if self.is_lock(lock_or_flock):
|
||||
return lock_or_flock.locked()
|
||||
elif self.is_flock(lock_or_flock):
|
||||
@ -370,13 +592,13 @@ class LocksTest(TestMixin, TestCase):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def test_BLOCKING(self: Self) -> None:
|
||||
def test_BLOCKING(self) -> None:
|
||||
self.assertEqual(locks._BLOCKING, -1.0)
|
||||
|
||||
def test_NON_BLOCKING(self: Self) -> None:
|
||||
def test_NON_BLOCKING(self) -> None:
|
||||
self.assertEqual(locks._NON_BLOCKING, 0.0)
|
||||
|
||||
def test_BLOCKING_WITH_TIMEOUT(self: Self) -> None:
|
||||
def test_BLOCKING_WITH_TIMEOUT(self) -> None:
|
||||
self.assertGreater(locks._BLOCKING_WITH_TIMEOUT, 0.0)
|
||||
|
||||
@patch.object(locks, "_BLOCKING_WITH_TIMEOUT", 1.0)
|
||||
@ -386,7 +608,7 @@ class LocksTest(TestMixin, TestCase):
|
||||
@parametrize("acquisition_mode", ["safe", "unsafe"])
|
||||
@parametrize("release", ["unlocked", "never", "before_timeout", "after_timeout"])
|
||||
def test_acquire_with_timeout(
|
||||
self: Self,
|
||||
self,
|
||||
lock_typename: str,
|
||||
lock_timeout: str,
|
||||
acquisition_mode: str,
|
||||
@ -440,7 +662,10 @@ class LocksTest(TestMixin, TestCase):
|
||||
self.assertFalse(self.lock_or_flock_locked(lock_or_flock))
|
||||
|
||||
assert lock_typename in ["Lock", "FileLock"]
|
||||
flock_fpath: Path = Path("/tmp/LocksTest") / self.random_string
|
||||
flock_fpath: Path = (
|
||||
impls._OnDiskCacheImpl()._cache_dir
|
||||
/ f"testing-locks-instance-{self.random_string}.lock"
|
||||
)
|
||||
lock_or_flock: Union[Lock, FileLock] = (
|
||||
Lock() if lock_typename == "Lock" else FileLock(str(flock_fpath))
|
||||
)
|
||||
|
418
torch/_inductor/runtime/caching/implementations.py
Normal file
418
torch/_inductor/runtime/caching/implementations.py
Normal file
@ -0,0 +1,418 @@
|
||||
"""Cache implementation classes for PyTorch Inductor runtime caching.
|
||||
|
||||
This module provides concrete implementations of caching backends including
|
||||
in-memory, on-disk, and remote caching strategies. Each implementation follows
|
||||
the abstract _CacheImpl interface and provides thread-safe operations with
|
||||
appropriate locking mechanisms.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import _GeneratorContextManager, contextmanager
|
||||
from dataclasses import dataclass
|
||||
from hashlib import sha256
|
||||
from io import BufferedReader, BufferedWriter
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Generator, Optional
|
||||
from typing_extensions import override, TypeAlias
|
||||
|
||||
from filelock import FileLock
|
||||
|
||||
from . import locks, utils
|
||||
|
||||
|
||||
_LockContextManager: TypeAlias = (
|
||||
Generator[None, None, None] | _GeneratorContextManager[None, None, None]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hit:
|
||||
"""Result wrapper for hits on cache get operations.
|
||||
|
||||
Allows distinguishing between a cache miss and a cached None value.
|
||||
|
||||
Attributes:
|
||||
value: The cached value.
|
||||
"""
|
||||
|
||||
value: Any
|
||||
|
||||
|
||||
class Miss:
|
||||
"""Sentinel class representing a cache miss.
|
||||
|
||||
Used to distinguish between a cached None value and a cache miss
|
||||
when None is a valid cached value.
|
||||
"""
|
||||
|
||||
|
||||
# Singleton instance for cache miss sentinel
|
||||
miss = Miss()
|
||||
|
||||
|
||||
class _CacheImpl(ABC):
|
||||
"""Abstract base class for cache implementations.
|
||||
|
||||
This class defines the interface that all cache implementations must follow.
|
||||
It provides thread-safe operations through a locking mechanism and supports
|
||||
both get and insert operations.
|
||||
|
||||
Note: We don't use generics here as doing so would require that the interfaces
|
||||
know which k/v types the implementation can work with. Instead, we leave that
|
||||
determination up to the implementation itself and require that the interfaces
|
||||
handle any potential errors from invalid k/v types being passed to the
|
||||
implementation.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the cache implementation with a threading lock."""
|
||||
self._lock: Lock = Lock()
|
||||
|
||||
@property
|
||||
def lock(self) -> Callable[[Optional[float]], _LockContextManager]:
|
||||
"""Get a context manager for acquiring the cache lock.
|
||||
|
||||
Locking of the cache is not done by the implementation itself, but by the
|
||||
interface that uses it. The interface may want to hold the lock for longer
|
||||
than a single cache operation, for example when dealing with multiple
|
||||
cache implementations at once, so we leave that decision up to the interface.
|
||||
|
||||
Args:
|
||||
timeout: Optional timeout in seconds (float) for acquiring the lock.
|
||||
|
||||
Returns:
|
||||
A callable that returns a context manager for the lock.
|
||||
"""
|
||||
|
||||
def _lock_with_timeout(
|
||||
timeout: Optional[float] = None,
|
||||
) -> _LockContextManager:
|
||||
return locks._acquire_lock_with_timeout(self._lock, timeout)
|
||||
|
||||
return _lock_with_timeout
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: Any) -> Optional[Hit]:
|
||||
"""Retrieve a value from the cache.
|
||||
|
||||
Args:
|
||||
key: The key to look up in the cache.
|
||||
|
||||
Returns:
|
||||
A Hit object on cache hit where Hit.value is the cached value,
|
||||
or None on cache miss.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, key: Any, value: Any) -> bool:
|
||||
"""Insert a key-value pair into the cache.
|
||||
|
||||
Args:
|
||||
key: The key to insert.
|
||||
value: The value to associate with the key.
|
||||
|
||||
Returns:
|
||||
True if the insertion was successful, False if not inserted.
|
||||
"""
|
||||
|
||||
|
||||
class _InMemoryCacheImpl(_CacheImpl):
|
||||
"""In-memory cache implementation using a dictionary.
|
||||
|
||||
This implementation stores key-value pairs in a Python dictionary,
|
||||
with keys being pickled for consistent hashing. It provides fast
|
||||
access but is limited by available memory and process lifetime.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the in-memory cache with an empty dictionary."""
|
||||
super().__init__()
|
||||
self._memory: dict[bytes, Any] = {}
|
||||
|
||||
@override
|
||||
def get(self, key: Any) -> Optional[Hit]:
|
||||
"""Retrieve a value from the in-memory cache.
|
||||
|
||||
Args:
|
||||
key: The key to look up. Will be pickled for storage.
|
||||
|
||||
Returns:
|
||||
A Hit object on cache hit where Hit.value is the cached value,
|
||||
or None on cache miss.
|
||||
"""
|
||||
pickled_key: bytes = utils._try_pickle_key(key)
|
||||
if (value := self._memory.get(pickled_key, miss)) is not miss:
|
||||
return Hit(value=value)
|
||||
return None
|
||||
|
||||
@override
|
||||
def insert(self, key: Any, value: Any) -> bool:
|
||||
"""Insert a key-value pair into the in-memory cache.
|
||||
|
||||
Args:
|
||||
key: The key to insert. Will be pickled for storage.
|
||||
value: The value to associate with the key.
|
||||
|
||||
Returns:
|
||||
True if the insertion was successful (key was new),
|
||||
False if not inserted (key already existed).
|
||||
"""
|
||||
pickled_key: bytes = utils._try_pickle_key(key)
|
||||
if pickled_key not in self._memory:
|
||||
self._memory[pickled_key] = value
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class _OnDiskCacheImpl(_CacheImpl):
|
||||
"""On-disk cache implementation using file system storage.
|
||||
|
||||
This implementation stores cached data as files on disk, with version
|
||||
headers to handle cache invalidation. It uses file locking to ensure
|
||||
thread safety across processes and provides persistent storage that
|
||||
survives process restarts.
|
||||
|
||||
Attributes:
|
||||
_version: Version number for cache format compatibility.
|
||||
_version_header_length: Length of the version header in bytes.
|
||||
"""
|
||||
|
||||
_version: int = 0
|
||||
_version_header_length: int = 4
|
||||
|
||||
def __init__(self, sub_dir: Optional[PathLike[str]] = None) -> None:
|
||||
"""Initialize the on-disk cache with a specified subdirectory.
|
||||
|
||||
Args:
|
||||
sub_dir: Subdirectory name within the cache directory.
|
||||
Defaults to empty string if not specified.
|
||||
"""
|
||||
self._cache_dir: Path = self._base_dir / (sub_dir or "")
|
||||
self._flock: FileLock = FileLock(str(self._cache_dir / "dir.lock"))
|
||||
|
||||
@property
|
||||
def _base_dir(self) -> Path:
|
||||
"""Get the base directory for cache storage.
|
||||
|
||||
Returns:
|
||||
Path to the cache directory based on the default cache dir
|
||||
and the specified subdirectory.
|
||||
"""
|
||||
from torch._inductor.runtime.runtime_utils import default_cache_dir
|
||||
|
||||
return Path(default_cache_dir(), "cache")
|
||||
|
||||
def _fpath_from_key(self, key: Any) -> Path:
|
||||
"""Generate a file path from a cache key.
|
||||
|
||||
Args:
|
||||
key: The cache key to convert to a file path.
|
||||
|
||||
Returns:
|
||||
A Path object representing the file location for this key.
|
||||
"""
|
||||
pickled_key: bytes = utils._try_pickle_key(key)
|
||||
return self._cache_dir / sha256(pickled_key).hexdigest()[:32]
|
||||
|
||||
@classmethod
|
||||
def _version_header(cls) -> bytes:
|
||||
"""Generate the version header bytes.
|
||||
|
||||
Returns:
|
||||
A byte string representing the current cache version header.
|
||||
"""
|
||||
return sha256(str(cls._version).encode()).digest()[: cls._version_header_length]
|
||||
|
||||
def _version_header_matches(self, fp: BufferedReader) -> bool:
|
||||
"""Check if the file's version header matches the current version.
|
||||
|
||||
Args:
|
||||
fp: File pointer positioned at the start of the file.
|
||||
|
||||
Returns:
|
||||
True if the version header matches, False otherwise.
|
||||
"""
|
||||
return fp.read(self._version_header_length) == self._version_header()
|
||||
|
||||
def _write_version_header(self, fp: BufferedWriter) -> None:
|
||||
"""Write the version header to a file.
|
||||
|
||||
Args:
|
||||
fp: File pointer where the version header should be written.
|
||||
"""
|
||||
fp.write(self._version_header())
|
||||
|
||||
@override
|
||||
@property
|
||||
def lock(self) -> Callable[[Optional[float]], _LockContextManager]:
|
||||
"""Get a context manager for acquiring the file lock.
|
||||
|
||||
Uses file locking to ensure thread safety across processes.
|
||||
|
||||
Args:
|
||||
timeout: Optional timeout in seconds (float) for acquiring the file lock.
|
||||
|
||||
Returns:
|
||||
A callable that returns a context manager for the file lock.
|
||||
"""
|
||||
|
||||
def _lock_with_timeout(
|
||||
timeout: Optional[float] = None,
|
||||
) -> _LockContextManager:
|
||||
return locks._acquire_flock_with_timeout(self._flock, timeout)
|
||||
|
||||
return _lock_with_timeout
|
||||
|
||||
@override
|
||||
def get(self, key: Any) -> Optional[Hit]:
|
||||
"""Retrieve a value from the on-disk cache.
|
||||
|
||||
Args:
|
||||
key: The key to look up in the cache.
|
||||
|
||||
Returns:
|
||||
A Hit object on cache hit where Hit.value is the cached value,
|
||||
or None on cache miss or version mismatch.
|
||||
"""
|
||||
fpath: Path = self._fpath_from_key(key)
|
||||
|
||||
if not fpath.is_file():
|
||||
return None
|
||||
|
||||
pickled_value: Optional[bytes] = None
|
||||
with open(fpath, "rb") as fp:
|
||||
if self._version_header_matches(fp):
|
||||
pickled_value = fp.read()
|
||||
|
||||
if not pickled_value:
|
||||
# if pickled_value is still None, even though the file exists, then
|
||||
# we know that the version header did not match. in this case implementation
|
||||
# is up to preference, we choose to remove entries that do not match
|
||||
# the version header so that the key can be re-cached later with the correct
|
||||
# version header
|
||||
fpath.unlink()
|
||||
return None
|
||||
|
||||
return Hit(value=utils._try_unpickle_value(pickled_value))
|
||||
|
||||
@override
|
||||
def insert(self, key: Any, value: Any) -> bool:
|
||||
"""Insert a key-value pair into the on-disk cache.
|
||||
|
||||
Args:
|
||||
key: The key to insert.
|
||||
value: The value to associate with the key.
|
||||
|
||||
Returns:
|
||||
True if successfully inserted, False if the key already exists
|
||||
with a valid version.
|
||||
"""
|
||||
fpath: Path = self._fpath_from_key(key)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
r_fp, w_fp, inserted = None, None, False
|
||||
try:
|
||||
w_fp = open(fpath, "xb")
|
||||
except FileExistsError:
|
||||
is_stale: bool = False
|
||||
with open(fpath, "rb") as r_fp:
|
||||
is_stale = not self._version_header_matches(r_fp)
|
||||
|
||||
if is_stale:
|
||||
# same story as above, in this case the version header doesn't
|
||||
# match so we choose to remove the old entry so that the new
|
||||
# k/v pair can be cached
|
||||
fpath.unlink()
|
||||
w_fp = open(fpath, "xb")
|
||||
else:
|
||||
w_fp = None
|
||||
finally:
|
||||
if w_fp:
|
||||
try:
|
||||
pickled_value: bytes = utils._try_pickle_value(value)
|
||||
self._write_version_header(w_fp)
|
||||
w_fp.write(pickled_value)
|
||||
inserted = True
|
||||
finally:
|
||||
w_fp.close()
|
||||
|
||||
return inserted
|
||||
|
||||
|
||||
try:
|
||||
from .fb.implementations import _RemoteCacheImpl
|
||||
except ModuleNotFoundError:
|
||||
|
||||
class _RemoteCacheImpl(_CacheImpl): # type: ignore[no-redef]
|
||||
"""Fallback remote cache implementation for non-Facebook environments.
|
||||
|
||||
This is a no-op implementation that always raises NotImplementedError.
|
||||
The actual remote cache implementation is provided in the `.fb` module
|
||||
for Facebook-specific environments.
|
||||
|
||||
Attributes:
|
||||
_version: Version number for cache format compatibility.
|
||||
has_strong_consistency: Whether the remote cache provides strong
|
||||
consistency guarantees.
|
||||
"""
|
||||
|
||||
_version: int = 0
|
||||
has_strong_consistency: bool = False
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the fallback remote cache implementation.
|
||||
|
||||
Note: We don't need to initialize any form of lock since this
|
||||
implementation provides a pseudo-lock context manager.
|
||||
"""
|
||||
|
||||
@override
|
||||
@property
|
||||
def lock(self) -> Callable[[Optional[float]], _LockContextManager]:
|
||||
"""Get a pseudo lock that does nothing.
|
||||
|
||||
Most remote cache implementations don't have an ability to implement
|
||||
any form of locking, so we provide a no-op pseudo-lock for consistency
|
||||
with the interface.
|
||||
|
||||
Args:
|
||||
timeout: Optional timeout in seconds (float). Ignored in this
|
||||
|
||||
Returns:
|
||||
A callable that returns a no-op context manager.
|
||||
"""
|
||||
|
||||
@contextmanager
|
||||
def pseudo_lock(
|
||||
timeout: Optional[float] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
yield
|
||||
|
||||
return pseudo_lock
|
||||
|
||||
@override
|
||||
def get(self, key: Any) -> Optional[Hit]:
|
||||
"""Raise NotImplementedError for remote cache get operations.
|
||||
|
||||
Args:
|
||||
key: The key to look up (ignored).
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Always raised as this is a fallback implementation.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def insert(self, key: Any, value: Any) -> bool:
|
||||
"""Raise NotImplementedError for remote cache insert operations.
|
||||
|
||||
Args:
|
||||
key: The key to insert (ignored).
|
||||
value: The value to insert (ignored).
|
||||
|
||||
Raises:
|
||||
NotImplementedError: Always raised as this is a fallback implementation.
|
||||
"""
|
||||
raise NotImplementedError
|
Reference in New Issue
Block a user