mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core] Add generic typing to LRUCache
(#3511)
This commit is contained in:
@ -4,7 +4,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type)
|
||||
from typing import (Callable, Dict, Hashable, List, Optional, Tuple, Type)
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
@ -535,14 +535,14 @@ class LoRAModelManager:
|
||||
replacement_loras)
|
||||
|
||||
|
||||
class LoRALRUCache(LRUCache):
|
||||
class LoRALRUCache(LRUCache[LoRAModel]):
|
||||
|
||||
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
|
||||
None]):
|
||||
super().__init__(capacity)
|
||||
self.deactivate_lora_fn = deactivate_lora_fn
|
||||
|
||||
def _on_remove(self, key: Hashable, value: Any):
|
||||
def _on_remove(self, key: Hashable, value: LoRAModel):
|
||||
logger.debug(f"Removing LoRA. int id: {key}")
|
||||
self.deactivate_lora_fn(key)
|
||||
return super()._on_remove(key, value)
|
||||
|
@ -22,27 +22,34 @@ class BaseTokenizerGroup(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, prompt: str, request_id: Optional[str],
|
||||
lora_request: Optional[LoRARequest]) -> List[int]:
|
||||
def encode(self,
|
||||
prompt: str,
|
||||
request_id: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
||||
"""Encode a prompt using the tokenizer group."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def encode_async(self, prompt: str, request_id: Optional[str],
|
||||
lora_request: Optional[LoRARequest]) -> List[int]:
|
||||
async def encode_async(
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
||||
"""Encode a prompt using the tokenizer group."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_lora_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> "PreTrainedTokenizer":
|
||||
"""Get a tokenizer for a LoRA request."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_lora_tokenizer_async(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> "PreTrainedTokenizer":
|
||||
"""Get a tokenizer for a LoRA request."""
|
||||
pass
|
||||
|
@ -21,10 +21,8 @@ class TokenizerGroup(BaseTokenizerGroup):
|
||||
self.enable_lora = enable_lora
|
||||
self.max_input_length = max_input_length
|
||||
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
|
||||
if enable_lora:
|
||||
self.lora_tokenizers = LRUCache(capacity=max_num_seqs)
|
||||
else:
|
||||
self.lora_tokenizers = None
|
||||
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](
|
||||
capacity=max_num_seqs) if enable_lora else None
|
||||
|
||||
def ping(self) -> bool:
|
||||
"""Check if the tokenizer group is alive."""
|
||||
|
@ -5,7 +5,7 @@ import subprocess
|
||||
import uuid
|
||||
import gc
|
||||
from platform import uname
|
||||
from typing import List, Tuple, Union
|
||||
from typing import List, Tuple, Union, Generic
|
||||
from packaging.version import parse, Version
|
||||
|
||||
import psutil
|
||||
@ -53,10 +53,10 @@ class Counter:
|
||||
self.counter = 0
|
||||
|
||||
|
||||
class LRUCache:
|
||||
class LRUCache(Generic[T]):
|
||||
|
||||
def __init__(self, capacity: int):
|
||||
self.cache = OrderedDict()
|
||||
self.cache = OrderedDict[Hashable, T]()
|
||||
self.capacity = capacity
|
||||
|
||||
def __contains__(self, key: Hashable) -> bool:
|
||||
@ -65,10 +65,10 @@ class LRUCache:
|
||||
def __len__(self) -> int:
|
||||
return len(self.cache)
|
||||
|
||||
def __getitem__(self, key: Hashable) -> Any:
|
||||
def __getitem__(self, key: Hashable) -> T:
|
||||
return self.get(key)
|
||||
|
||||
def __setitem__(self, key: Hashable, value: Any) -> None:
|
||||
def __setitem__(self, key: Hashable, value: T) -> None:
|
||||
self.put(key, value)
|
||||
|
||||
def __delitem__(self, key: Hashable) -> None:
|
||||
@ -77,7 +77,9 @@ class LRUCache:
|
||||
def touch(self, key: Hashable) -> None:
|
||||
self.cache.move_to_end(key)
|
||||
|
||||
def get(self, key: Hashable, default_value: Optional[Any] = None) -> int:
|
||||
def get(self,
|
||||
key: Hashable,
|
||||
default_value: Optional[T] = None) -> Optional[T]:
|
||||
if key in self.cache:
|
||||
value = self.cache[key]
|
||||
self.cache.move_to_end(key)
|
||||
@ -85,12 +87,12 @@ class LRUCache:
|
||||
value = default_value
|
||||
return value
|
||||
|
||||
def put(self, key: Hashable, value: Any) -> None:
|
||||
def put(self, key: Hashable, value: T) -> None:
|
||||
self.cache[key] = value
|
||||
self.cache.move_to_end(key)
|
||||
self._remove_old_if_needed()
|
||||
|
||||
def _on_remove(self, key: Hashable, value: Any):
|
||||
def _on_remove(self, key: Hashable, value: T):
|
||||
pass
|
||||
|
||||
def remove_oldest(self):
|
||||
@ -103,7 +105,7 @@ class LRUCache:
|
||||
while len(self.cache) > self.capacity:
|
||||
self.remove_oldest()
|
||||
|
||||
def pop(self, key: int, default_value: Optional[Any] = None) -> Any:
|
||||
def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T:
|
||||
run_on_remove = key in self.cache
|
||||
value = self.cache.pop(key, default_value)
|
||||
if run_on_remove:
|
||||
|
Reference in New Issue
Block a user