[Model] Add base class for LoRA-supported models (#5018)

This commit is contained in:
Cyrus Leung
2024-06-27 16:03:04 +08:00
committed by GitHub
parent d12af207d2
commit 96354d6a29
20 changed files with 270 additions and 75 deletions

View File

@ -4,6 +4,9 @@ Using LoRA adapters
===================
This document shows you how to use `LoRA adapters <https://arxiv.org/abs/2106.09685>`_ with vLLM on top of a base model.
LoRA adapters can be used with any vLLM model that implements :class:`~vllm.model_executor.models.interfaces.SupportsLoRA`.
Adapters can be efficiently served on a per request basis with minimal overhead. First we download the adapter(s) and save
them locally with

View File

@ -2,6 +2,7 @@ from typing import List, Optional
from typing import Sequence as GenericSequence
import torch
import torch.types
from vllm.utils import is_pin_memory_available
@ -64,7 +65,7 @@ class LoRALayerWeights:
output_dim: int,
rank: int,
dtype: torch.dtype,
device: torch.device,
device: torch.types.Device,
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros([input_dim, rank],

View File

@ -18,6 +18,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models.interfaces import SupportsLoRA
from vllm.utils import LRUCache, is_pin_memory_available
logger = init_logger(__name__)
@ -363,7 +364,7 @@ class LoRAModelManager:
def __init__(
self,
model: nn.Module,
model: SupportsLoRA,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
@ -411,7 +412,7 @@ class LoRAModelManager:
# embeddings_indices
self.indices_len: List[Optional[int]] = [None] * 4
self.model: nn.Module = model
self.model = model
if hasattr(self.model, "supported_lora_modules"):
self.supported_lora_modules = copy.deepcopy(
self.model.supported_lora_modules)
@ -428,7 +429,6 @@ class LoRAModelManager:
self._active_loras: Dict[int, None] = {}
self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules()
self.model.lora_manager = self
@property
def capacity(self) -> int:

View File

@ -32,7 +32,8 @@ from vllm.model_executor.model_loader.weight_utils import (
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_tpu
@ -64,12 +65,15 @@ def _get_quantization_config(
def _get_model_initialization_kwargs(
model_class: Type[nn.Module], lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig]
model_class: Type[nn.Module],
lora_config: Optional[LoRAConfig],
vlm_config: Optional[VisionLanguageConfig],
) -> Dict[str, Any]:
"""Get extra kwargs for model initialization."""
extra_kwargs: Dict[str, Any] = {}
if hasattr(model_class, "supported_lora_modules"):
if supports_lora(model_class):
# lora_config=None is used to disable LoRA
extra_kwargs["lora_config"] = lora_config
elif lora_config:
raise ValueError(
@ -77,13 +81,15 @@ def _get_model_initialization_kwargs(
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github.")
elif issubclass(model_class, VisionLanguageModelBase):
if vision_language_config is None:
if supports_vision(model_class):
if vlm_config is None:
raise ValueError("Provide `image_input_type` and other vision "
"related configurations through LLM entrypoint "
"or engine arguments.")
extra_kwargs["vision_language_config"] = vision_language_config
extra_kwargs["vlm_config"] = vlm_config
return extra_kwargs

View File

@ -45,6 +45,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from .interfaces import SupportsLoRA
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
@ -292,7 +294,9 @@ class BaiChuanModel(nn.Module):
return hidden_states
class BaiChuanBaseForCausalLM(nn.Module):
class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = {
"W_pack": ["W_pack"],
"gate_up_proj": [
@ -312,14 +316,17 @@ class BaiChuanBaseForCausalLM(nn.Module):
def __init__(
self,
config,
config: PretrainedConfig,
position_embedding: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = BaiChuanModel(config, position_embedding, cache_config,
quant_config)

View File

@ -28,6 +28,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import ChatGLMConfig
from .interfaces import SupportsLoRA
class GLMAttention(nn.Module):
@ -322,7 +324,9 @@ class ChatGLMModel(nn.Module):
return hidden_states
class ChatGLMForCausalLM(nn.Module):
class ChatGLMForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"]
@ -345,7 +349,10 @@ class ChatGLMForCausalLM(nn.Module):
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config: ChatGLMConfig = config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length",
8192)

View File

@ -26,7 +26,7 @@
from typing import Iterable, Optional, Tuple
import torch
from transformers import PretrainedConfig
from transformers import LlamaConfig
from vllm.config import CacheConfig, LoRAConfig
from vllm.model_executor.layers.quantization.base_config import (
@ -55,7 +55,7 @@ class DeciLMForCausalLM(LlamaForCausalLM):
def __init__(
self,
config: Optional[PretrainedConfig] = None,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,

View File

@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from .interfaces import SupportsLoRA
logger = init_logger(__name__)
@ -288,7 +290,9 @@ class GemmaModel(nn.Module):
return hidden_states
class GemmaForCausalLM(nn.Module):
class GemmaForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -319,9 +323,11 @@ class GemmaForCausalLM(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
del lora_config # Unused.
super().__init__()
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = GemmaModel(config, cache_config, quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)

View File

@ -41,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from .interfaces import SupportsLoRA
class GPTBigCodeAttention(nn.Module):
@ -230,7 +232,9 @@ class GPTBigCodeModel(nn.Module):
return hidden_states
class GPTBigCodeForCausalLM(nn.Module):
class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = {"c_attn": ["c_attn"]}
supported_lora_modules = ["c_fc", "c_proj", "wte", "lm_head", "c_attn"]
@ -250,7 +254,10 @@ class GPTBigCodeForCausalLM(nn.Module):
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.transformer = GPTBigCodeModel(config, cache_config, quant_config,
lora_config)

View File

@ -0,0 +1,130 @@
from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type,
Union, overload, runtime_checkable)
from typing_extensions import TypeGuard
from vllm.config import LoRAConfig, VisionLanguageConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
@runtime_checkable
class SupportsVision(Protocol):
"""The interface required for all vision language models (VLMs)."""
supports_vision: ClassVar[Literal[True]]
def __init__(self, *, vlm_config: VisionLanguageConfig) -> None:
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsVisionType(Protocol):
supports_vision: Literal[True]
def __call__(self, *, vlm_config: VisionLanguageConfig) -> None:
...
@overload
def supports_vision(model: Type[object]) -> TypeGuard[Type[SupportsVision]]:
...
@overload
def supports_vision(model: object) -> TypeGuard[SupportsVision]:
...
def supports_vision(
model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsVision]], TypeGuard[SupportsVision]]:
if isinstance(model, type):
return isinstance(model, _SupportsVisionType)
return isinstance(model, SupportsVision)
@runtime_checkable
class SupportsLoRA(Protocol):
"""The interface required for all models that support LoRA."""
supports_lora: ClassVar[Literal[True]]
packed_modules_mapping: ClassVar[Dict[str, List[str]]]
supported_lora_modules: ClassVar[List[str]]
embedding_modules: ClassVar[Dict[str, str]]
embedding_padding_modules: ClassVar[List[str]]
# lora_config is None when LoRA is not enabled
def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
...
# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsLoRAType(Protocol):
supports_lora: Literal[True]
packed_modules_mapping: Dict[str, List[str]]
supported_lora_modules: List[str]
embedding_modules: Dict[str, str]
embedding_padding_modules: List[str]
def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None:
...
@overload
def supports_lora(model: Type[object]) -> TypeGuard[Type[SupportsLoRA]]:
...
@overload
def supports_lora(model: object) -> TypeGuard[SupportsLoRA]:
...
def supports_lora(
model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
result = _supports_lora(model)
if not result:
lora_attrs = (
"packed_modules_mapping",
"supported_lora_modules",
"embedding_modules",
"embedding_padding_modules",
)
missing_attrs = tuple(attr for attr in lora_attrs
if not hasattr(model, attr))
if getattr(model, "supports_lora", False):
if missing_attrs:
logger.warning(
"The model (%s) sets `supports_lora=True`, "
"but is missing LoRA-specific attributes: %s",
model,
missing_attrs,
)
else:
if not missing_attrs:
logger.warning(
"The model (%s) contains all LoRA-specific attributes, "
"but does not set `supports_lora=True`.", model)
return result
def _supports_lora(
model: Union[Type[object], object],
) -> Union[TypeGuard[Type[SupportsLoRA]], TypeGuard[SupportsLoRA]]:
if isinstance(model, type):
return isinstance(model, _SupportsLoRAType)
return isinstance(model, SupportsLoRA)

View File

@ -49,6 +49,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.utils import is_hip, print_warning_once
from .interfaces import SupportsLoRA
class LlamaMLP(nn.Module):
@ -296,7 +298,9 @@ class LlamaModel(nn.Module):
return hidden_states
class LlamaForCausalLM(nn.Module):
class LlamaForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -336,7 +340,10 @@ class LlamaForCausalLM(nn.Module):
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.lora_config = lora_config
self.model = LlamaModel(config,
cache_config,
quant_config,

View File

@ -20,7 +20,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import get_dummy_image_data
from vllm.sequence import SamplerOutput
from .vlm_base import VisionLanguageModelBase
from .interfaces import SupportsVision
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
@ -86,18 +86,21 @@ LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageFeatureInputs]
@MULTIMODAL_REGISTRY.register_image_feature_input()
@MULTIMODAL_REGISTRY.register_image_pixel_input()
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
class LlavaForConditionalGeneration(VisionLanguageModelBase):
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
supports_vision = True
def __init__(self,
config: LlavaConfig,
vision_language_config: VisionLanguageConfig,
vlm_config: VisionLanguageConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__(vision_language_config)
super().__init__()
self.config = config
self.vlm_config = vlm_config
if self.vision_language_config.image_input_type == (
if self.vlm_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
self.vision_tower = CLIPVisionModel(config.vision_config)
else:
@ -122,11 +125,10 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
self.sampler = Sampler()
def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != list(
self.vision_language_config.image_input_shape[1:]):
if list(data.shape[1:]) != list(self.vlm_config.image_input_shape[1:]):
raise ValueError(
f"The expected image tensor shape is batch dimension plus "
f"{self.vision_language_config.image_input_shape[1:]}. "
f"{self.vlm_config.image_input_shape[1:]}. "
f"You supplied {data.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
@ -139,7 +141,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
pixel_values = kwargs.pop("pixel_values", None)
image_features = kwargs.pop("image_features", None)
expected_input_type = self.vision_language_config.image_input_type
expected_input_type = self.vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
if expected_input_type == ImageInputType.PIXEL_VALUES:
@ -273,7 +275,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
inputs_embeds = merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vision_language_config.image_token_id)
self.vlm_config.image_token_id)
input_ids = None
else:

View File

@ -25,8 +25,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
from vllm.multimodal.image import ImagePixelData, get_dummy_image_data
from vllm.sequence import SamplerOutput, SequenceData
from .interfaces import SupportsVision
from .llava import LlavaMultiModalProjector, merge_vision_embeddings
from .vlm_base import VisionLanguageModelBase
logger = init_logger(__name__)
@ -106,19 +106,21 @@ def _image_pixel_processor(
@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_pixel_processor)
@MULTIMODAL_REGISTRY.register_dummy_data(_get_dummy_image_data)
class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
supports_vision = True
def __init__(self,
config: LlavaNextConfig,
vision_language_config: VisionLanguageConfig,
vlm_config: VisionLanguageConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__(vision_language_config)
super().__init__()
# Update the type annotation from that of its superclass
self.config = config
self.vlm_config = vlm_config
if self.vision_language_config.image_input_type == (
if self.vlm_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
self.vision_tower = CLIPVisionModel(config=config.vision_config)
else:
@ -146,7 +148,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
torch.empty(config.text_config.hidden_size))
def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor:
_, num_channels, _, _ = self.vision_language_config.image_input_shape
_, num_channels, _, _ = self.vlm_config.image_input_shape
# Note that this is different from that of vLLM vision_language_config
# since the image is resized by the HuggingFace preprocessor
@ -177,7 +179,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
image_sizes = kwargs.pop("image_sizes", None)
image_features = kwargs.pop("image_features", None)
expected_input_type = self.vision_language_config.image_input_type
expected_input_type = self.vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
if expected_input_type == ImageInputType.PIXEL_VALUES:
@ -386,7 +388,7 @@ class LlavaNextForConditionalGeneration(VisionLanguageModelBase):
inputs_embeds = merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vision_language_config.image_token_id)
self.vlm_config.image_token_id)
input_ids = None
else:

View File

@ -26,6 +26,7 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
@ -51,6 +52,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput
from .interfaces import SupportsLoRA
class MiniCPMMoE(nn.Module):
"""A tensor-parallel MoE implementation that shards each expert
@ -388,7 +391,9 @@ class MiniCPMModel(nn.Module):
return hidden_states
class MiniCPMForCausalLM(nn.Module):
class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -418,13 +423,16 @@ class MiniCPMForCausalLM(nn.Module):
def __init__(
self,
config,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.lora_config = lora_config
self.num_experts = getattr(self.config, "num_experts", 0)
self.quant_config = quant_config
self.model = MiniCPMModel(config,

View File

@ -54,6 +54,8 @@ from vllm.model_executor.utils import set_weight_attrs
from vllm.sequence import SamplerOutput
from vllm.utils import print_warning_once
from .interfaces import SupportsLoRA
class MixtralMoE(nn.Module):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
@ -472,7 +474,9 @@ class MixtralModel(nn.Module):
return hidden_states
class MixtralForCausalLM(nn.Module):
class MixtralForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
fall_back_to_pt_during_load = False
packed_modules_mapping = {
@ -504,7 +508,10 @@ class MixtralForCausalLM(nn.Module):
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.lora_config = lora_config
self.model = MixtralModel(config,
cache_config,
quant_config,

View File

@ -39,7 +39,7 @@ from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from transformers import PhiConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
@ -59,11 +59,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from .interfaces import SupportsLoRA
class PhiAttention(nn.Module):
def __init__(self,
config: PretrainedConfig,
config: PhiConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
@ -131,7 +133,7 @@ class PhiAttention(nn.Module):
class PhiMLP(nn.Module):
def __init__(self,
config: PretrainedConfig,
config: PhiConfig,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
@ -160,7 +162,7 @@ class PhiMLP(nn.Module):
class PhiLayer(nn.Module):
def __init__(self,
config: PretrainedConfig,
config: PhiConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
@ -192,7 +194,7 @@ class PhiLayer(nn.Module):
class PhiModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
config: PhiConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
@ -229,7 +231,9 @@ class PhiModel(nn.Module):
return hidden_states
class PhiForCausalLM(nn.Module):
class PhiForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -250,14 +254,16 @@ class PhiForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
config: PhiConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
del lora_config # Unused.
super().__init__()
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = PhiModel(config, cache_config, quant_config)

View File

@ -48,6 +48,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.utils import print_warning_once
from .interfaces import SupportsLoRA
class Qwen2MLP(nn.Module):
@ -263,7 +265,9 @@ class Qwen2Model(nn.Module):
return hidden_states
class Qwen2ForCausalLM(nn.Module):
class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -293,7 +297,6 @@ class Qwen2ForCausalLM(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
del lora_config
# TODO (@robertgshaw2): see if this can be moved out
if (cache_config.sliding_window is not None
and hasattr(config, "max_window_layers")):
@ -307,7 +310,10 @@ class Qwen2ForCausalLM(nn.Module):
))
super().__init__()
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = Qwen2Model(config, cache_config, quant_config)

View File

@ -1,12 +0,0 @@
from torch import nn
from vllm.config import VisionLanguageConfig
class VisionLanguageModelBase(nn.Module):
"""Base class for all vision language models (VLMs)."""
def __init__(self, vision_language_config: VisionLanguageConfig) -> None:
super().__init__()
self.vision_language_config = vision_language_config

View File

@ -45,6 +45,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from .interfaces import SupportsLoRA
class XverseMLP(nn.Module):
@ -266,7 +268,9 @@ class XverseModel(nn.Module):
return hidden_states
class XverseForCausalLM(nn.Module):
class XverseForCausalLM(nn.Module, SupportsLoRA):
supports_lora = True
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -299,10 +303,13 @@ class XverseForCausalLM(nn.Module):
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config=None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = XverseModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)

View File

@ -22,6 +22,7 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import supports_lora
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
@ -225,14 +226,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.model_memory_usage / float(2**30))
if self.lora_config:
assert hasattr(self.model, "supported_lora_modules"
) and self.model.supported_lora_modules, (
"Model does not support LoRA")
assert hasattr(
self.model,
"embedding_modules"), "Model does not have embedding_modules"
assert hasattr(self.model, "embedding_padding_modules"
), "Model does not have embedding_padding_modules"
assert supports_lora(self.model), "Model does not support LoRA"
self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens,