[Misc] Consolidate LRUCache implementations (#15481)

Signed-off-by: Bella kira <2374035698@qq.com>
This commit is contained in:
Bella kira
2025-03-27 14:43:43 +08:00
committed by GitHub
parent e1e0fd7543
commit f4c98b4d4c
2 changed files with 105 additions and 57 deletions

View File

@ -12,7 +12,6 @@ from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
TypeVar, Union, cast)
import torch
from cachetools import LRUCache
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import assert_never
@ -21,7 +20,7 @@ from vllm.jsontree import json_map_leaves, json_reduce_leaves
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens)
from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby
from vllm.utils import GiB_bytes, LRUCache, flatten_2d_lists, full_groupby
from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,

View File

@ -33,15 +33,17 @@ import uuid
import warnings
import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import OrderedDict, UserDict, defaultdict
from collections import UserDict, defaultdict
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
Iterable, Iterator, Mapping)
Iterable, Iterator, KeysView, Mapping)
from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps
from types import MappingProxyType
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, Type, TypeVar, Union)
Optional, Type, TypeVar, Union, cast, overload)
from uuid import uuid4
import cachetools
import cloudpickle
import numpy as np
import numpy.typing as npt
@ -173,6 +175,7 @@ U = TypeVar("U")
_K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V")
_T = TypeVar("_T")
class _Sentinel:
@ -206,6 +209,19 @@ class Counter:
self.counter = 0
class _MappingOrderCacheView(UserDict[_K, _V]):
def __init__(self, data: Mapping[_K, _V], ordered_keys: Mapping[_K, None]):
super().__init__(data)
self.ordered_keys = ordered_keys
def __iter__(self) -> Iterator[_K]:
return iter(self.ordered_keys)
def keys(self) -> KeysView[_K]:
return KeysView(self.ordered_keys)
class CacheInfo(NamedTuple):
hits: int
total: int
@ -218,45 +234,62 @@ class CacheInfo(NamedTuple):
return self.hits / self.total
class LRUCache(Generic[_K, _V]):
"""Note: This class is not thread safe!"""
class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
def __init__(self, capacity: int) -> None:
self.cache = OrderedDict[_K, _V]()
def __init__(self,
capacity: float,
getsizeof: Optional[Callable[[_V], float]] = None):
super().__init__(capacity, getsizeof)
self.pinned_items = set[_K]()
self.capacity = capacity
self._hits = 0
self._total = 0
def __contains__(self, key: _K) -> bool:
return key in self.cache
def __len__(self) -> int:
return len(self.cache)
def __getitem__(self, key: _K) -> _V:
value = self.cache[key] # Raise KeyError if not exists
self.cache.move_to_end(key)
return value
def __setitem__(self, key: _K, value: _V) -> None:
self.put(key, value)
def __delitem__(self, key: _K) -> None:
self.pop(key)
run_on_remove = key in self
value = self.__getitem__(key)
super().__delitem__(key)
if key in self.pinned_items:
# Todo: add warning to inform that del pinned item
self._unpin(key)
if run_on_remove:
self._on_remove(key, value)
@property
def cache(self) -> Mapping[_K, _V]:
"""Return the internal cache dictionary in order (read-only)."""
return _MappingOrderCacheView(
self._Cache__data, # type: ignore
self.order)
@property
def order(self) -> Mapping[_K, None]:
"""Return the internal order dictionary (read-only)."""
return MappingProxyType(self._LRUCache__order) # type: ignore
def stat(self) -> CacheInfo:
return CacheInfo(hits=self._hits, total=self._total)
def touch(self, key: _K) -> None:
self.cache.move_to_end(key)
self._LRUCache__update(key) # type: ignore
def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
value: Optional[_V]
if key in self.cache:
value = self.cache[key]
self.cache.move_to_end(key)
@overload
def get(self, key: _K, /) -> Optional[_V]:
...
@overload
def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]:
...
def get(self,
key: _K,
/,
default: Optional[Union[_V,
_T]] = None) -> Optional[Union[_V, _T]]:
value: Optional[Union[_V, _T]]
if key in self:
value = self.__getitem__(key)
self._hits += 1
else:
@ -265,60 +298,76 @@ class LRUCache(Generic[_K, _V]):
self._total += 1
return value
@overload
def pop(self, key: _K) -> _V:
...
@overload
def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]:
...
def pop(self,
key: _K,
default: Optional[Union[_V,
_T]] = None) -> Optional[Union[_V, _T]]:
value: Optional[Union[_V, _T]]
if key not in self:
return default
value = self[key]
del self[key]
return value
def put(self, key: _K, value: _V) -> None:
self.cache[key] = value
self.cache.move_to_end(key)
self._remove_old_if_needed()
self.__setitem__(key, value)
def pin(self, key: _K) -> None:
"""
Pins a key in the cache preventing it from being
evicted in the LRU order.
"""
if key not in self.cache:
if key not in self:
raise ValueError(f"Cannot pin key: {key} not in cache.")
self.pinned_items.add(key)
def _unpin(self, key: _K) -> None:
"""
Unpins a key in the cache allowing it to be
evicted in the LRU order.
"""
self.pinned_items.remove(key)
def _on_remove(self, key: _K, value: Optional[_V]) -> None:
pass
def remove_oldest(self, *, remove_pinned: bool = False) -> None:
if not self.cache:
if len(self) == 0:
return
self.popitem(remove_pinned=remove_pinned)
def _remove_old_if_needed(self) -> None:
while self.currsize > self.capacity:
self.remove_oldest()
def clear(self) -> None:
while len(self) > 0:
self.remove_oldest(remove_pinned=True)
def popitem(self, remove_pinned: bool = False):
"""Remove and return the `(key, value)` pair least recently used."""
if not remove_pinned:
# pop the oldest item in the cache that is not pinned
lru_key = next(
(key for key in self.cache if key not in self.pinned_items),
(key for key in self.order if key not in self.pinned_items),
ALL_PINNED_SENTINEL)
if lru_key is ALL_PINNED_SENTINEL:
raise RuntimeError("All items are pinned, "
"cannot remove oldest from the cache.")
else:
lru_key = next(iter(self.cache))
self.pop(lru_key) # type: ignore
def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity:
self.remove_oldest()
def pop(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
run_on_remove = key in self.cache
value = self.cache.pop(key, default)
# remove from pinned items
if key in self.pinned_items:
self._unpin(key)
if run_on_remove:
self._on_remove(key, value)
return value
def clear(self) -> None:
while len(self.cache) > 0:
self.remove_oldest(remove_pinned=True)
self.cache.clear()
lru_key = next(iter(self.order))
value = self.pop(cast(_K, lru_key))
return (lru_key, value)
class PyObjectCache: