[Model] Enable BNB support for qwen2_5_omni_thinker (#24420)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-09-09 00:37:08 +08:00
committed by GitHub
parent c44797a4d6
commit 6f4a82f8b5

View File

@ -41,6 +41,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.qwen2_5_vl import (
Qwen2_5_VisionTransformer, Qwen2_5_VLImageEmbeddingInputs,
Qwen2_5_VLImageInputs, Qwen2_5_VLImagePixelInputs,
@ -66,7 +67,8 @@ from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import decode_tokens, encode_tokens
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
@ -726,7 +728,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
dummy_inputs=Qwen2_5OmniThinkerDummyInputsBuilder,
)
class Qwen2_5OmniThinkerForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP,
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
Qwen2_5OmniConditionalGenerationMixin):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
@ -734,6 +736,22 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
"thinker.model.": "language_model.model.",
"thinker.": "",
})
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"attn.qkv": [
"attn.q",
"attn.k",
"attn.v",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@ -956,3 +974,12 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
mapper=self.hf_to_vllm_mapper)
return loaded_weights
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="merger.",
tower_model=["visual.", "audio_tower."])