mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] Add base class for LoRA-supported models (#5018)
This commit is contained in:
@ -4,6 +4,9 @@ Using LoRA adapters
|
||||
===================
|
||||
|
||||
This document shows you how to use `LoRA adapters <https://arxiv.org/abs/2106.09685>`_ with vLLM on top of a base model.
|
||||
|
||||
LoRA adapters can be used with any vLLM model that implements :class:`~vllm.model_executor.models.interfaces.SupportsLoRA`.
|
||||
|
||||
Adapters can be efficiently served on a per request basis with minimal overhead. First we download the adapter(s) and save
|
||||
them locally with
|
||||
|
||||
|
@ -2,6 +2,7 @@ from typing import List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
|
||||
import torch
|
||||
import torch.types
|
||||
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
@ -64,7 +65,7 @@ class LoRALayerWeights:
|
||||
output_dim: int,
|
||||
rank: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
device: torch.types.Device,
|
||||
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
|
||||
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
||||
lora_a = torch.zeros([input_dim, rank],
|
||||
|
@ -18,6 +18,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
||||
parse_fine_tuned_lora_name, replace_submodule)
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA
|
||||
from vllm.utils import LRUCache, is_pin_memory_available
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -363,7 +364,7 @@ class LoRAModelManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model: SupportsLoRA,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
@ -411,7 +412,7 @@ class LoRAModelManager:
|
||||
# embeddings_indices
|
||||
self.indices_len: List[Optional[int]] = [None] * 4
|
||||
|
||||
self.model: nn.Module = model
|
||||
self.model = model
|
||||
if hasattr(self.model, "supported_lora_modules"):
|
||||
self.supported_lora_modules = copy.deepcopy(
|
||||
self.model.supported_lora_modules)
|
||||
@ -428,7 +429,6 @@ class LoRAModelManager:
|
||||
self._active_loras: Dict[int, None] = {}
|
||||
self._last_mapping: Optional[LoRAMapping] = None
|
||||
self._create_lora_modules()
|
||||
self.model.lora_manager = self
|
||||
|
||||
@property
|
||||
def capacity(self) -> int:
|
||||
|
@ -32,7 +32,8 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
||||
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
|
||||
pt_weights_iterator, safetensors_weights_iterator)
|
||||
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
|
||||
from vllm.model_executor.models.interfaces import (supports_lora,
|
||||
supports_vision)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import is_tpu
|
||||
|
||||
@ -64,12 +65,15 @@ def _get_quantization_config(
|
||||
|
||||
|
||||
def _get_model_initialization_kwargs(
|
||||
model_class: Type[nn.Module], lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig]
|
||||
model_class: Type[nn.Module],
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vlm_config: Optional[VisionLanguageConfig],
|
||||
) -> Dict[str, Any]:
|
||||
"""Get extra kwargs for model initialization."""
|
||||
extra_kwargs: Dict[str, Any] = {}
|
||||
if hasattr(model_class, "supported_lora_modules"):
|
||||
|
||||
if supports_lora(model_class):
|
||||
# lora_config=None is used to disable LoRA
|
||||
extra_kwargs["lora_config"] = lora_config
|
||||
elif lora_config:
|
||||
raise ValueError(
|
||||
@ -77,13 +81,15 @@ def _get_model_initialization_kwargs(
|
||||
"but LoRA is enabled. Support for this model may "
|
||||
"be added in the future. If this is important to you, "
|
||||
"please open an issue on github.")
|
||||
elif issubclass(model_class, VisionLanguageModelBase):
|
||||
if vision_language_config is None:
|
||||
|
||||
if supports_vision(model_class):
|
||||
if vlm_config is None:
|
||||
raise ValueError("Provide `image_input_type` and other vision "
|
||||
"related configurations through LLM entrypoint "
|
||||
"or engine arguments.")
|
||||
|
||||
extra_kwargs["vision_language_config"] = vision_language_config
|
||||
extra_kwargs["vlm_config"] = vlm_config
|
||||
|
||||
return extra_kwargs
|
||||
|
||||
|
||||
|
@ -45,6 +45,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
||||
@ -292,7 +294,9 @@ class BaiChuanModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BaiChuanBaseForCausalLM(nn.Module):
|
||||
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
|
||||
supports_lora = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"W_pack": ["W_pack"],
|
||||
"gate_up_proj": [
|
||||
@ -312,14 +316,17 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: PretrainedConfig,
|
||||
position_embedding: str,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = BaiChuanModel(config, position_embedding, cache_config,
|
||||
quant_config)
|
||||
|
@ -28,6 +28,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
|
||||
|
||||
class GLMAttention(nn.Module):
|
||||
|
||||
@ -322,7 +324,9 @@ class ChatGLMModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ChatGLMForCausalLM(nn.Module):
|
||||
class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
|
||||
supports_lora = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"query_key_value": ["query_key_value"],
|
||||
"dense_h_to_4h": ["dense_h_to_4h"]
|
||||
@ -345,7 +349,10 @@ class ChatGLMForCausalLM(nn.Module):
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config: ChatGLMConfig = config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.max_position_embeddings = getattr(config, "max_sequence_length",
|
||||
8192)
|
||||
|
@ -26,7 +26,7 @@
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -55,7 +55,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[PretrainedConfig] = None,
|
||||
config: LlamaConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
|
@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -288,7 +290,9 @@ class GemmaModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GemmaForCausalLM(nn.Module):
|
||||
class GemmaForCausalLM(nn.Module, SupportsLoRA):
|
||||
supports_lora = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -319,9 +323,11 @@ class GemmaForCausalLM(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
del lora_config # Unused.
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = GemmaModel(config, cache_config, quant_config)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
|
||||
|
||||
class GPTBigCodeAttention(nn.Module):
|
||||
|
||||
@ -230,7 +232,9 @@ class GPTBigCodeModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GPTBigCodeForCausalLM(nn.Module):
|
||||
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
|
||||
supports_lora = True
|
||||
|
||||
packed_modules_mapping = {"c_attn": ["c_attn"]}
|
||||
|
||||
supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"]
|
||||
@ -250,7 +254,10 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
|
||||
lora_config)
|
||||
|
130
vllm/model_executor/models/interfaces.py
Normal file
130
vllm/model_executor/models/interfaces.py
Normal file
@ -0,0 +1,130 @@
|
||||
from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
|
||||
Union, overload, runtime_checkable)
|
||||
|
||||
from typing_extensions import TypeGuard
|
||||
|
||||
from vllm.config import LoRAConfig, VisionLanguageConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsVision(Protocol):
|
||||
"""The interface required for all vision language models (VLMs)."""
|
||||
|
||||
supports_vision: ClassVar[Literal[True]]
|
||||
|
||||
def __init__(self, *, vlm_config: VisionLanguageConfig) -> None:
|
||||
...
|
||||
|
||||
|
||||
# We can't use runtime_checkable with ClassVar for issubclass checks
|
||||
# so we need to treat the class as an instance and use isinstance instead
|
||||
@runtime_checkable
|
||||
class _SupportsVisionType(Protocol):
|
||||
supports_vision: Literal[True]
|
||||
|
||||
def __call__(self, *, vlm_config: VisionLanguageConfig) -> None:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_vision(model: object) -> TypeGuard[SupportsVision]:
|
||||
...
|
||||
|
||||
|
||||
def supports_vision(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]:
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, _SupportsVisionType)
|
||||
|
||||
return isinstance(model, SupportsVision)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsLoRA(Protocol):
|
||||
"""The interface required for all models that support LoRA."""
|
||||
|
||||
supports_lora: ClassVar[Literal[True]]
|
||||
|
||||
packed_modules_mapping: ClassVar[Dict[str, List[str]]]
|
||||
supported_lora_modules: ClassVar[List[str]]
|
||||
embedding_modules: ClassVar[Dict[str, str]]
|
||||
embedding_padding_modules: ClassVar[List[str]]
|
||||
|
||||
# lora_config is None when LoRA is not enabled
|
||||
def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
|
||||
...
|
||||
|
||||
|
||||
# We can't use runtime_checkable with ClassVar for issubclass checks
|
||||
# so we need to treat the class as an instance and use isinstance instead
|
||||
@runtime_checkable
|
||||
class _SupportsLoRAType(Protocol):
|
||||
supports_lora: Literal[True]
|
||||
|
||||
packed_modules_mapping: Dict[str, List[str]]
|
||||
supported_lora_modules: List[str]
|
||||
embedding_modules: Dict[str, str]
|
||||
embedding_padding_modules: List[str]
|
||||
|
||||
def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def supports_lora(model: object) -> TypeGuard[SupportsLoRA]:
|
||||
...
|
||||
|
||||
|
||||
def supports_lora(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
|
||||
result = _supports_lora(model)
|
||||
|
||||
if not result:
|
||||
lora_attrs = (
|
||||
"packed_modules_mapping",
|
||||
"supported_lora_modules",
|
||||
"embedding_modules",
|
||||
"embedding_padding_modules",
|
||||
)
|
||||
missing_attrs = tuple(attr for attr in lora_attrs
|
||||
if not hasattr(model, attr))
|
||||
|
||||
if getattr(model, "supports_lora", False):
|
||||
if missing_attrs:
|
||||
logger.warning(
|
||||
"The model (%s) sets `supports_lora=True`, "
|
||||
"but is missing LoRA-specific attributes: %s",
|
||||
model,
|
||||
missing_attrs,
|
||||
)
|
||||
else:
|
||||
if not missing_attrs:
|
||||
logger.warning(
|
||||
"The model (%s) contains all LoRA-specific attributes, "
|
||||
"but does not set `supports_lora=True`.", model)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _supports_lora(
|
||||
model: Union[Type[object], object],
|
||||
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
|
||||
if isinstance(model, type):
|
||||
return isinstance(model, _SupportsLoRAType)
|
||||
|
||||
return isinstance(model, SupportsLoRA)
|
@ -49,6 +49,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.utils import is_hip, print_warning_once
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
|
||||
@ -296,7 +298,9 @@ class LlamaModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LlamaForCausalLM(nn.Module):
|
||||
class LlamaForCausalLM(nn.Module, SupportsLoRA):
|
||||
supports_lora = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -336,7 +340,10 @@ class LlamaForCausalLM(nn.Module):
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = LlamaModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
|
@ -20,7 +20,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import get_dummy_image_data
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
from .vlm_base import VisionLanguageModelBase
|
||||
from .interfaces import SupportsVision
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
"language_model.lm_head": "lm_head",
|
||||
@ -86,18 +86,21 @@ LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
|
||||
@MULTIMODAL_REGISTRY.register_image_feature_input()
|
||||
@MULTIMODAL_REGISTRY.register_image_pixel_input()
|
||||
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
|
||||
class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
supports_vision = True
|
||||
|
||||
def __init__(self,
|
||||
config: LlavaConfig,
|
||||
vision_language_config: VisionLanguageConfig,
|
||||
vlm_config: VisionLanguageConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__(vision_language_config)
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.vlm_config = vlm_config
|
||||
|
||||
if self.vision_language_config.image_input_type == (
|
||||
if self.vlm_config.image_input_type == (
|
||||
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
|
||||
self.vision_tower = CLIPVisionModel(config.vision_config)
|
||||
else:
|
||||
@ -122,11 +125,10 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
self.sampler = Sampler()
|
||||
|
||||
def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
|
||||
if list(data.shape[1:]) != list(
|
||||
self.vision_language_config.image_input_shape[1:]):
|
||||
if list(data.shape[1:]) != list(self.vlm_config.image_input_shape[1:]):
|
||||
raise ValueError(
|
||||
f"The expected image tensor shape is batch dimension plus "
|
||||
f"{self.vision_language_config.image_input_shape[1:]}. "
|
||||
f"{self.vlm_config.image_input_shape[1:]}. "
|
||||
f"You supplied {data.shape}. "
|
||||
f"If you are using vLLM's entrypoint, make sure your "
|
||||
f"supplied image input is consistent with "
|
||||
@ -139,7 +141,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_features = kwargs.pop("image_features", None)
|
||||
|
||||
expected_input_type = self.vision_language_config.image_input_type
|
||||
expected_input_type = self.vlm_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
|
||||
if expected_input_type == ImageInputType.PIXEL_VALUES:
|
||||
@ -273,7 +275,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
|
||||
|
||||
inputs_embeds = merge_vision_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.vision_language_config.image_token_id)
|
||||
self.vlm_config.image_token_id)
|
||||
|
||||
input_ids = None
|
||||
else:
|
||||
|
@ -25,8 +25,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
|
||||
from vllm.multimodal.image import ImagePixelData, get_dummy_image_data
|
||||
from vllm.sequence import SamplerOutput, SequenceData
|
||||
|
||||
from .interfaces import SupportsVision
|
||||
from .llava import LlavaMultiModalProjector, merge_vision_embeddings
|
||||
from .vlm_base import VisionLanguageModelBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -106,19 +106,21 @@ def _image_pixel_processor(
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_pixel_processor)
|
||||
@MULTIMODAL_REGISTRY.register_dummy_data(_get_dummy_image_data)
|
||||
class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
|
||||
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
supports_vision = True
|
||||
|
||||
def __init__(self,
|
||||
config: LlavaNextConfig,
|
||||
vision_language_config: VisionLanguageConfig,
|
||||
vlm_config: VisionLanguageConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__(vision_language_config)
|
||||
super().__init__()
|
||||
|
||||
# Update the type annotation from that of its superclass
|
||||
self.config = config
|
||||
self.vlm_config = vlm_config
|
||||
|
||||
if self.vision_language_config.image_input_type == (
|
||||
if self.vlm_config.image_input_type == (
|
||||
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
|
||||
self.vision_tower = CLIPVisionModel(config=config.vision_config)
|
||||
else:
|
||||
@ -146,7 +148,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
|
||||
torch.empty(config.text_config.hidden_size))
|
||||
|
||||
def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor:
|
||||
_, num_channels, _, _ = self.vision_language_config.image_input_shape
|
||||
_, num_channels, _, _ = self.vlm_config.image_input_shape
|
||||
|
||||
# Note that this is different from that of vLLM vision_language_config
|
||||
# since the image is resized by the HuggingFace preprocessor
|
||||
@ -177,7 +179,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
|
||||
image_sizes = kwargs.pop("image_sizes", None)
|
||||
image_features = kwargs.pop("image_features", None)
|
||||
|
||||
expected_input_type = self.vision_language_config.image_input_type
|
||||
expected_input_type = self.vlm_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
|
||||
if expected_input_type == ImageInputType.PIXEL_VALUES:
|
||||
@ -386,7 +388,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
|
||||
|
||||
inputs_embeds = merge_vision_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.vision_language_config.image_token_id)
|
||||
self.vlm_config.image_token_id)
|
||||
|
||||
input_ids = None
|
||||
else:
|
||||
|
@ -26,6 +26,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
@ -51,6 +52,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
|
||||
|
||||
class MiniCPMMoE(nn.Module):
|
||||
"""A tensor-parallel MoE implementation that shards each expert
|
||||
@ -388,7 +391,9 @@ class MiniCPMModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MiniCPMForCausalLM(nn.Module):
|
||||
class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
|
||||
supports_lora = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -418,13 +423,16 @@ class MiniCPMForCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.num_experts = getattr(self.config, "num_experts", 0)
|
||||
self.quant_config = quant_config
|
||||
self.model = MiniCPMModel(config,
|
||||
|
@ -54,6 +54,8 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
|
||||
|
||||
class MixtralMoE(nn.Module):
|
||||
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
|
||||
@ -472,7 +474,9 @@ class MixtralModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MixtralForCausalLM(nn.Module):
|
||||
class MixtralForCausalLM(nn.Module, SupportsLoRA):
|
||||
supports_lora = True
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
packed_modules_mapping = {
|
||||
@ -504,7 +508,10 @@ class MixtralForCausalLM(nn.Module):
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = MixtralModel(config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
|
@ -39,7 +39,7 @@ from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from transformers import PhiConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
@ -59,11 +59,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
|
||||
|
||||
class PhiAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
config: PhiConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
@ -131,7 +133,7 @@ class PhiAttention(nn.Module):
|
||||
class PhiMLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
config: PhiConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
|
||||
@ -160,7 +162,7 @@ class PhiMLP(nn.Module):
|
||||
class PhiLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
config: PhiConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
@ -192,7 +194,7 @@ class PhiLayer(nn.Module):
|
||||
class PhiModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
config: PhiConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
@ -229,7 +231,9 @@ class PhiModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PhiForCausalLM(nn.Module):
|
||||
class PhiForCausalLM(nn.Module, SupportsLoRA):
|
||||
supports_lora = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -250,14 +254,16 @@ class PhiForCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
config: PhiConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
del lora_config # Unused.
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.model = PhiModel(config, cache_config, quant_config)
|
||||
|
@ -48,6 +48,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
|
||||
|
||||
class Qwen2MLP(nn.Module):
|
||||
|
||||
@ -263,7 +265,9 @@ class Qwen2Model(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen2ForCausalLM(nn.Module):
|
||||
class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
|
||||
supports_lora = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -293,7 +297,6 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
del lora_config
|
||||
# TODO (@robertgshaw2): see if this can be moved out
|
||||
if (cache_config.sliding_window is not None
|
||||
and hasattr(config, "max_window_layers")):
|
||||
@ -307,7 +310,10 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
))
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2Model(config, cache_config, quant_config)
|
||||
|
||||
|
@ -1,12 +0,0 @@
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import VisionLanguageConfig
|
||||
|
||||
|
||||
class VisionLanguageModelBase(nn.Module):
|
||||
"""Base class for all vision language models (VLMs)."""
|
||||
|
||||
def __init__(self, vision_language_config: VisionLanguageConfig) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.vision_language_config = vision_language_config
|
@ -45,6 +45,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA
|
||||
|
||||
|
||||
class XverseMLP(nn.Module):
|
||||
|
||||
@ -266,7 +268,9 @@ class XverseModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class XverseForCausalLM(nn.Module):
|
||||
class XverseForCausalLM(nn.Module, SupportsLoRA):
|
||||
supports_lora = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -299,10 +303,13 @@ class XverseForCausalLM(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config=None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = XverseModel(config, cache_config, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
|
@ -22,6 +22,7 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.model_executor.models.interfaces import supports_lora
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
@ -225,14 +226,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.model_memory_usage / float(2**30))
|
||||
|
||||
if self.lora_config:
|
||||
assert hasattr(self.model, "supported_lora_modules"
|
||||
) and self.model.supported_lora_modules, (
|
||||
"Model does not support LoRA")
|
||||
assert hasattr(
|
||||
self.model,
|
||||
"embedding_modules"), "Model does not have embedding_modules"
|
||||
assert hasattr(self.model, "embedding_padding_modules"
|
||||
), "Model does not have embedding_padding_modules"
|
||||
assert supports_lora(self.model), "Model does not support LoRA"
|
||||
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens,
|
||||
|
Reference in New Issue
Block a user