mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model][LoRA]LoRA support added for glm-4v (#10418)
Signed-off-by: B-201 <Joy25810@foxmail.com>
This commit is contained in:
@ -30,6 +30,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalData, MultiModalKwargs
|
||||
@ -574,25 +575,8 @@ class ChatGLMModel(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
|
||||
class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||
SupportsMultiModal):
|
||||
packed_modules_mapping = {
|
||||
"query_key_value": ["query_key_value"],
|
||||
"dense_h_to_4h": ["dense_h_to_4h"]
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"query_key_value",
|
||||
"dense",
|
||||
"dense_h_to_4h",
|
||||
"dense_4h_to_h",
|
||||
]
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP,
|
||||
SupportsMultiModal):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
@ -692,3 +676,79 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP,
|
||||
weight_loader(param, combined_weight)
|
||||
loaded_params.add(combined_name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class ChatGLM(ChatGLMBaseModel):
|
||||
packed_modules_mapping = {
|
||||
"query_key_value": ["query_key_value"],
|
||||
"dense_h_to_4h": ["dense_h_to_4h"]
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"query_key_value",
|
||||
"dense",
|
||||
"dense_h_to_4h",
|
||||
"dense_4h_to_h",
|
||||
]
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
|
||||
class ChatGLMV(ChatGLMBaseModel):
|
||||
packed_modules_mapping = {
|
||||
"query_key_value": ["query_key_value"],
|
||||
"dense_h_to_4h": ["dense_h_to_4h"],
|
||||
"merged_proj": ["gate_proj", "dense_h_to_4h"]
|
||||
}
|
||||
# LoRA specific attributes
|
||||
supported_lora_modules = [
|
||||
"query_key_value",
|
||||
"dense",
|
||||
"dense_h_to_4h",
|
||||
"dense_4h_to_h",
|
||||
# vision
|
||||
"fc1",
|
||||
"fc2",
|
||||
"merged_proj",
|
||||
"linear_proj"
|
||||
]
|
||||
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def get_mm_mapping(self) -> MultiModelKeys:
|
||||
"""
|
||||
Get the module prefix in multimodal models
|
||||
"""
|
||||
return MultiModelKeys.from_string_field(
|
||||
language_model="transformer.encoder",
|
||||
connector="transformer.vision.linear_proj",
|
||||
tower_model="transformer.vision.transformer")
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv)
|
||||
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
SupportsMultiModal):
|
||||
# Ensure that the LoRA support check passes when the class is not
|
||||
# initialized, but set all these attributes to empty.
|
||||
packed_modules_mapping = {}
|
||||
supported_lora_modules = []
|
||||
embedding_modules = {}
|
||||
embedding_padding_modules = []
|
||||
|
||||
def __new__(
|
||||
cls,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
# Initialize VL
|
||||
if hasattr(config, "visual"):
|
||||
return ChatGLM(vllm_config=vllm_config, prefix=prefix)
|
||||
# Initialize LLM
|
||||
else:
|
||||
return ChatGLMV(vllm_config=vllm_config, prefix=prefix)
|
||||
|
Reference in New Issue
Block a user