cleanup: remove adapter commons (#25045)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Simon Mo
2025-09-17 09:46:29 -07:00
committed by GitHub
parent 4b946d693e
commit 4aa8c7b047
11 changed files with 89 additions and 330 deletions

View File

@ -115,7 +115,6 @@ follow_imports = "silent"
# move the directory here and remove it from tools/mypy.sh
files = [
"vllm/*.py",
"vllm/adapter_commons",
"vllm/assets",
"vllm/entrypoints",
"vllm/core",

View File

@ -1,16 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
@dataclass
class AdapterMapping:
# Per every token in input_ids:
index_mapping: tuple[int, ...]
# Per sampled token:
prompt_mapping: tuple[int, ...]
def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)
self.prompt_mapping = tuple(self.prompt_mapping)

View File

@ -1,106 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, TypeVar
from torch import nn
from vllm.logger import init_logger
from vllm.utils import LRUCache
logger = init_logger(__name__)
class AdapterModel(ABC):
def __init__(self, model_id=None):
self.id = model_id
@abstractmethod
def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs):
# Common initialization code
# Load weights or embeddings from local checkpoint
raise NotImplementedError("Subclasses must implement this method.")
T = TypeVar('T')
class AdapterLRUCache(LRUCache[int, T]):
def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
super().__init__(capacity)
self.deactivate_fn = deactivate_fn
def _on_remove(self, key: int, value: Optional[T]):
logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key)
return super()._on_remove(key, value)
class AdapterModelManager(ABC):
def __init__(
self,
model: nn.Module,
):
"""Create a AdapterModelManager and adapter for a given model.
Args:
model: the model to be adapted.
"""
self.model: nn.Module = model
self._registered_adapters: dict[int, Any] = {}
# Dict instead of a Set for compatibility with LRUCache.
self._active_adapters: dict[int, None] = {}
self.adapter_type = 'Adapter'
self._last_mapping = None
def __len__(self) -> int:
return len(self._registered_adapters)
@property
@abstractmethod
def adapter_slots(self) -> int:
raise NotImplementedError
@property
@abstractmethod
def capacity(self) -> int:
raise NotImplementedError
@abstractmethod
def activate_adapter(self, adapter_id: int) -> bool:
raise NotImplementedError
@abstractmethod
def deactivate_adapter(self, adapter_id: int) -> bool:
raise NotImplementedError
@abstractmethod
def add_adapter(self, adapter: Any) -> bool:
raise NotImplementedError
@abstractmethod
def set_adapter_mapping(self, mapping: Any) -> None:
raise NotImplementedError
@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
raise NotImplementedError
@abstractmethod
def remove_all_adapters(self) -> None:
raise NotImplementedError
@abstractmethod
def get_adapter(self, adapter_id: int) -> Optional[Any]:
raise NotImplementedError
@abstractmethod
def list_adapters(self) -> dict[int, Any]:
raise NotImplementedError
@abstractmethod
def pin_adapter(self, adapter_id: int) -> bool:
raise NotImplementedError

View File

@ -1,26 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
class AdapterRequest(ABC):
"""
Base class for adapter requests.
"""
@property
@abstractmethod
def adapter_id(self) -> int:
raise NotImplementedError
def __post_init__(self) -> None:
if self.adapter_id < 1:
raise ValueError(f"id must be > 0, got {self.adapter_id}")
def __eq__(self, value: object) -> bool:
return isinstance(
value, self.__class__) and self.adapter_id == value.adapter_id
def __hash__(self) -> int:
return hash(self.adapter_id)

View File

@ -1,93 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Callable, Optional
## model functions
def deactivate_adapter(adapter_id: int, active_adapters: dict[int, None],
deactivate_func: Callable) -> bool:
if adapter_id in active_adapters:
deactivate_func(adapter_id)
active_adapters.pop(adapter_id)
return True
return False
def add_adapter(adapter: Any, registered_adapters: dict[int, Any],
capacity: int, add_func: Callable) -> bool:
if adapter.id not in registered_adapters:
if len(registered_adapters) >= capacity:
raise RuntimeError('No free adapter slots.')
add_func(adapter)
registered_adapters[adapter.id] = adapter
return True
return False
def set_adapter_mapping(mapping: Any, last_mapping: Any,
set_mapping_func: Callable) -> Any:
if last_mapping != mapping:
set_mapping_func(mapping)
return mapping
return last_mapping
def remove_adapter(adapter_id: int, registered_adapters: dict[int, Any],
deactivate_func: Callable) -> bool:
deactivate_func(adapter_id)
return bool(registered_adapters.pop(adapter_id, None))
def list_adapters(registered_adapters: dict[int, Any]) -> dict[int, Any]:
return dict(registered_adapters)
def get_adapter(adapter_id: int,
registered_adapters: dict[int, Any]) -> Optional[Any]:
return registered_adapters.get(adapter_id)
## worker functions
def set_active_adapters_worker(requests: set[Any], mapping: Optional[Any],
apply_adapters_func,
set_adapter_mapping_func) -> None:
apply_adapters_func(requests)
set_adapter_mapping_func(mapping)
def add_adapter_worker(adapter_request: Any, list_adapters_func,
load_adapter_func, add_adapter_func,
activate_adapter_func) -> bool:
if adapter_request.adapter_id in list_adapters_func():
return False
loaded_adapter = load_adapter_func(adapter_request)
loaded = add_adapter_func(loaded_adapter)
activate_adapter_func(loaded_adapter.id)
return loaded
def apply_adapters_worker(adapter_requests: set[Any], list_adapters_func,
adapter_slots: int, remove_adapter_func,
add_adapter_func) -> None:
models_that_exist = list_adapters_func()
models_map = {
adapter_request.adapter_id: adapter_request
for adapter_request in adapter_requests if adapter_request
}
if len(models_map) > adapter_slots:
raise RuntimeError(
f"Number of requested models ({len(models_map)}) is greater "
f"than the number of GPU model slots "
f"({adapter_slots}).")
new_models = set(models_map)
models_to_add = new_models - models_that_exist
models_to_remove = models_that_exist - new_models
for adapter_id in models_to_remove:
remove_adapter_func(adapter_id)
for adapter_id in models_to_add:
add_adapter_func(models_map[adapter_id])
def list_adapters_worker(adapter_manager_list_adapters_func) -> set[int]:
return set(adapter_manager_list_adapters_func())

View File

@ -1,39 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Any, Optional
import torch
class AbstractWorkerManager(ABC):
def __init__(self, device: torch.device):
self.device = device
@property
@abstractmethod
def is_enabled(self) -> bool:
raise NotImplementedError
@abstractmethod
def set_active_adapters(self, requests: set[Any],
mapping: Optional[Any]) -> None:
raise NotImplementedError
@abstractmethod
def add_adapter(self, adapter_request: Any) -> bool:
raise NotImplementedError
@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
raise NotImplementedError
@abstractmethod
def remove_all_adapters(self) -> None:
raise NotImplementedError
@abstractmethod
def list_adapters(self) -> set[int]:
raise NotImplementedError

View File

@ -1,17 +1,22 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
import torch.nn as nn
from vllm.adapter_commons.layers import AdapterMapping
@dataclass
class LoRAMapping(AdapterMapping):
class LoRAMapping:
index_mapping: tuple[int, ...]
prompt_mapping: tuple[int, ...]
is_prefill: bool = False
def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)
self.prompt_mapping = tuple(self.prompt_mapping)
def _get_lora_device(base_layer: nn.Module) -> torch.device:
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34

View File

@ -4,18 +4,13 @@
import math
import os
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
from typing import Callable, Optional, TypeVar, Union
import regex as re
import safetensors.torch
import torch
from torch import nn
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
AdapterModelManager)
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
get_adapter, list_adapters,
remove_adapter, set_adapter_mapping)
from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
@ -33,10 +28,25 @@ from vllm.model_executor.models.interfaces import is_pooling_model
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
from vllm.model_executor.utils import get_packed_modules_mapping
from vllm.utils import is_pin_memory_available
from vllm.utils import LRUCache, is_pin_memory_available
logger = init_logger(__name__)
T = TypeVar("T")
class AdapterLRUCache(LRUCache[int, T]):
def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
super().__init__(capacity)
self.deactivate_fn = deactivate_fn
def _on_remove(self, key: int, value: Optional[T]):
logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key)
return super()._on_remove(key, value)
_GLOBAL_LORA_ID = 0
@ -57,7 +67,7 @@ def is_moe_model(model: nn.Module) -> bool:
return False
class LoRAModel(AdapterModel):
class LoRAModel:
"""A LoRA fine-tuned model."""
def __init__(
@ -313,7 +323,7 @@ class LoRAModel(AdapterModel):
weights_mapper=weights_mapper)
class LoRAModelManager(AdapterModelManager):
class LoRAModelManager:
"""A manager that manages multiple LoRA-fine-tuned models."""
def __init__(
@ -336,6 +346,11 @@ class LoRAModelManager(AdapterModelManager):
vocab_size: the vocab size of the model.
lora_config: the LoRA configuration.
"""
self.model: SupportsLoRA = model
self._registered_adapters: dict[int, LoRAModel] = {}
# Dict instead of a set for compatibility with LRUCache.
self._active_adapters: dict[int, None] = {}
self.adapter_type = "LoRA"
self.lora_config = lora_config
self.device = device
self.max_num_seqs = max_num_seqs
@ -347,9 +362,8 @@ class LoRAModelManager(AdapterModelManager):
max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device,
max_loras=self.lora_config.max_loras)
super().__init__(model)
max_loras=self.lora_config.max_loras,
)
self.supported_lora_modules = get_supported_lora_modules(self.model)
assert self.supported_lora_modules, "No supported LoRA modules found in"
@ -370,7 +384,9 @@ class LoRAModelManager(AdapterModelManager):
self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules()
self.model.lora_manager = self
self.adapter_type = 'LoRA'
def __len__(self) -> int:
return len(self._registered_adapters)
@property
def capacity(self) -> int:
@ -669,28 +685,39 @@ class LoRAModelManager(AdapterModelManager):
return lora_model.get_lora(org_module_name)
def deactivate_adapter(self, adapter_id: int) -> bool:
return deactivate_adapter(adapter_id, self._active_adapters,
self._deactivate_adapter)
if adapter_id not in self._active_adapters:
return False
self._deactivate_adapter(adapter_id)
self._active_adapters.pop(adapter_id, None)
return True
def add_adapter(self, adapter: LoRAModel) -> bool:
logger.debug("Adding lora. Model id: %d, "
"int id: %d", adapter.id, adapter.id)
return add_adapter(adapter, self._registered_adapters, self.capacity,
self._add_adapter)
if adapter.id in self._registered_adapters:
return False
if len(self._registered_adapters) >= self.capacity:
raise RuntimeError("No free adapter slots.")
self._add_adapter(adapter)
return True
def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
self._set_adapter_mapping)
if self._last_mapping != mapping:
self._set_adapter_mapping(mapping)
self._last_mapping = mapping
def remove_adapter(self, adapter_id: int) -> bool:
return remove_adapter(adapter_id, self._registered_adapters,
self.deactivate_adapter)
self.deactivate_adapter(adapter_id)
if adapter_id not in self._registered_adapters:
return False
self._registered_adapters.pop(adapter_id, None)
return True
def list_adapters(self) -> dict[int, Any]:
return list_adapters(self._registered_adapters)
def list_adapters(self) -> dict[int, LoRAModel]:
return dict(self._registered_adapters)
def get_adapter(self, adapter_id: int) -> Optional[Any]:
return get_adapter(adapter_id, self._registered_adapters)
def get_adapter(self, adapter_id: int) -> Optional[LoRAModel]:
return self._registered_adapters.get(adapter_id)
class LoRALRUCache(AdapterLRUCache[LoRAModel]):

View File

@ -6,8 +6,6 @@ from typing import Optional
import msgspec
from vllm.adapter_commons.request import AdapterRequest
class LoRARequest(
msgspec.Struct,
@ -24,8 +22,6 @@ class LoRARequest(
lora_int_id must be globally unique for a given adapter.
This is currently not enforced in vLLM.
"""
__metaclass__ = AdapterRequest
lora_name: str
lora_int_id: int
lora_path: str = ""
@ -35,6 +31,8 @@ class LoRARequest(
tensorizer_config_dict: Optional[dict] = None
def __post_init__(self):
if self.lora_int_id < 1:
raise ValueError(f"id must be > 0, got {self.lora_int_id}")
if self.lora_local_path:
warnings.warn(
"The 'lora_local_path' attribute is deprecated "

View File

@ -6,11 +6,6 @@ from typing import Any, Literal, Optional, Union
import torch
from vllm.adapter_commons.utils import (add_adapter_worker,
apply_adapters_worker,
list_adapters_worker,
set_active_adapters_worker)
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.models import (LoRAModel, LoRAModelManager,
@ -22,7 +17,7 @@ from vllm.lora.utils import get_adapter_absolute_path
logger = init_logger(__name__)
class WorkerLoRAManager(AbstractWorkerManager):
class WorkerLoRAManager:
"""WorkerLoRAManager that manages LoRA models on the worker side.
Every request, the requested LoRAs will be loaded (unless they are already
@ -51,7 +46,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
self.vocab_size = vocab_size
self.lora_config = lora_config
self.max_position_embeddings = max_position_embeddings
super().__init__(device)
self.device = device
# Lazily initialized by create_lora_manager.
self._adapter_manager: LoRAModelManager
@ -164,19 +159,34 @@ class WorkerLoRAManager(AbstractWorkerManager):
def set_active_adapters(self, requests: set[Any],
mapping: Optional[Any]) -> None:
set_active_adapters_worker(requests, mapping, self._apply_adapters,
self._adapter_manager.set_adapter_mapping)
self._apply_adapters(requests)
if mapping is not None:
self._adapter_manager.set_adapter_mapping(mapping)
def _apply_adapters(self, adapter_requests: set[Any]) -> None:
apply_adapters_worker(adapter_requests, self.list_adapters,
self._adapter_manager.adapter_slots,
self.remove_adapter, self.add_adapter)
existing_adapters = self.list_adapters()
models_map = {
adapter_request.adapter_id: adapter_request
for adapter_request in adapter_requests if adapter_request
}
if len(models_map) > self._adapter_manager.adapter_slots:
raise RuntimeError(
f"Number of requested models ({len(models_map)}) is greater "
"than the number of GPU model slots "
f"({self._adapter_manager.adapter_slots}).")
requested_ids = set(models_map)
for adapter_id in existing_adapters - requested_ids:
self.remove_adapter(adapter_id)
for adapter_id in requested_ids - existing_adapters:
self.add_adapter(models_map[adapter_id])
def add_adapter(self, adapter_request: Any) -> bool:
return add_adapter_worker(adapter_request, self.list_adapters,
self._load_adapter,
self._adapter_manager.add_adapter,
self._adapter_manager.activate_adapter)
if adapter_request.adapter_id in self.list_adapters():
return False
loaded_adapter = self._load_adapter(adapter_request)
loaded = self._adapter_manager.add_adapter(loaded_adapter)
self._adapter_manager.activate_adapter(loaded_adapter.id)
return loaded
def remove_adapter(self, adapter_id: int) -> bool:
return self._adapter_manager.remove_adapter(adapter_id)
@ -185,7 +195,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
self._adapter_manager.remove_all_adapters()
def list_adapters(self) -> set[int]:
return list_adapters_worker(self._adapter_manager.list_adapters)
return set(self._adapter_manager.list_adapters())
class LRUCacheWorkerLoRAManager(WorkerLoRAManager):