[Core] Add generic typing to LRUCache (#3511)

This commit is contained in:
Nick Hill
2024-03-20 00:36:09 -07:00
committed by GitHub
parent 9474e89ba4
commit 4ad521d8b5
4 changed files with 29 additions and 22 deletions

View File

@ -4,7 +4,7 @@ import logging
import math
import os
import re
from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type)
from typing import (Callable, Dict, Hashable, List, Optional, Tuple, Type)
import safetensors.torch
import torch
@ -535,14 +535,14 @@ class LoRAModelManager:
replacement_loras)
class LoRALRUCache(LRUCache):
class LoRALRUCache(LRUCache[LoRAModel]):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
None]):
super().__init__(capacity)
self.deactivate_lora_fn = deactivate_lora_fn
def _on_remove(self, key: Hashable, value: Any):
def _on_remove(self, key: Hashable, value: LoRAModel):
logger.debug(f"Removing LoRA. int id: {key}")
self.deactivate_lora_fn(key)
return super()._on_remove(key, value)

View File

@ -22,27 +22,34 @@ class BaseTokenizerGroup(ABC):
pass
@abstractmethod
def encode(self, prompt: str, request_id: Optional[str],
lora_request: Optional[LoRARequest]) -> List[int]:
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
"""Encode a prompt using the tokenizer group."""
pass
@abstractmethod
async def encode_async(self, prompt: str, request_id: Optional[str],
lora_request: Optional[LoRARequest]) -> List[int]:
async def encode_async(
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
"""Encode a prompt using the tokenizer group."""
pass
@abstractmethod
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
"""Get a tokenizer for a LoRA request."""
pass
@abstractmethod
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
"""Get a tokenizer for a LoRA request."""
pass

View File

@ -21,10 +21,8 @@ class TokenizerGroup(BaseTokenizerGroup):
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
if enable_lora:
self.lora_tokenizers = LRUCache(capacity=max_num_seqs)
else:
self.lora_tokenizers = None
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](
capacity=max_num_seqs) if enable_lora else None
def ping(self) -> bool:
"""Check if the tokenizer group is alive."""

View File

@ -5,7 +5,7 @@ import subprocess
import uuid
import gc
from platform import uname
from typing import List, Tuple, Union
from typing import List, Tuple, Union, Generic
from packaging.version import parse, Version
import psutil
@ -53,10 +53,10 @@ class Counter:
self.counter = 0
class LRUCache:
class LRUCache(Generic[T]):
def __init__(self, capacity: int):
self.cache = OrderedDict()
self.cache = OrderedDict[Hashable, T]()
self.capacity = capacity
def __contains__(self, key: Hashable) -> bool:
@ -65,10 +65,10 @@ class LRUCache:
def __len__(self) -> int:
return len(self.cache)
def __getitem__(self, key: Hashable) -> Any:
def __getitem__(self, key: Hashable) -> T:
return self.get(key)
def __setitem__(self, key: Hashable, value: Any) -> None:
def __setitem__(self, key: Hashable, value: T) -> None:
self.put(key, value)
def __delitem__(self, key: Hashable) -> None:
@ -77,7 +77,9 @@ class LRUCache:
def touch(self, key: Hashable) -> None:
self.cache.move_to_end(key)
def get(self, key: Hashable, default_value: Optional[Any] = None) -> int:
def get(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
if key in self.cache:
value = self.cache[key]
self.cache.move_to_end(key)
@ -85,12 +87,12 @@ class LRUCache:
value = default_value
return value
def put(self, key: Hashable, value: Any) -> None:
def put(self, key: Hashable, value: T) -> None:
self.cache[key] = value
self.cache.move_to_end(key)
self._remove_old_if_needed()
def _on_remove(self, key: Hashable, value: Any):
def _on_remove(self, key: Hashable, value: T):
pass
def remove_oldest(self):
@ -103,7 +105,7 @@ class LRUCache:
while len(self.cache) > self.capacity:
self.remove_oldest()
def pop(self, key: int, default_value: Optional[Any] = None) -> Any:
def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T:
run_on_remove = key in self.cache
value = self.cache.pop(key, default_value)
if run_on_remove: