[Model] Use merge_by_field_config for MM models (G) (#26117)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-03 13:38:29 +08:00
committed by GitHub
parent 711f485643
commit 39b643dc1a
5 changed files with 56 additions and 108 deletions

View File

@ -36,7 +36,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix)
logger = init_logger(__name__)
@ -289,7 +289,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
processor=hf_processor)
for size in image_sizes
]
processed_outputs["num_crops"] = torch.tensor(num_crops)
processed_outputs["num_patches"] = torch.tensor(num_crops) + 1
return processed_outputs
@ -298,12 +298,12 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
num_crops = hf_inputs.get("num_crops", torch.empty(0))
num_patches = hf_inputs.get("num_patches", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops + 1),
num_crops=MultiModalFieldConfig.batched("image"),
"image", num_patches),
num_patches=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
@ -460,6 +460,8 @@ class Gemma3MultiModalProjector(nn.Module):
dummy_inputs=Gemma3DummyInputsBuilder)
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -526,29 +528,20 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
num_crops = kwargs.pop("num_crops", None)
num_patches = kwargs.pop("num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None:
return None
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")
image_size = self.config.vision_config.image_size
return Gemma3ImagePixelInputs(
pixel_values=flatten_bn(pixel_values, concat=True),
num_patches=flatten_bn(num_crops, concat=True) + 1,
resolve_bindings={
"h": image_size,
"w": image_size
})
return Gemma3ImagePixelInputs(pixel_values=pixel_values,
num_patches=num_patches,
resolve_bindings={
"h": image_size,
"w": image_size
})
def _image_pixels_to_features(
self,

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, TypedDict, Union, cast
from typing import Annotated, Any, Literal, Optional, Union, cast
import numpy as np
import torch
@ -41,6 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
SupportsTranscription)
@ -54,17 +55,28 @@ TOKENS_PER_IMAGE = 256
TOKENS_PER_AUDIO = 188
class Gemma3nImagePixelInputs(TypedDict):
pixel_values: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class Gemma3nImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height of each patch
- w: Width of each patch
"""
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
class Gemma3nAudioInputs(TypedDict):
input_features: Union[torch.Tensor, list[torch.Tensor]]
input_features_padded: torch.Tensor
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
input_features_mask: torch.Tensor
"""Shape: `(batch_size * num_audio, seq_length)`"""
class Gemma3nAudioInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of audios
- s: seq_length
- f: num_features
"""
type: Literal["audio"] = "audio"
input_features_padded: Annotated[torch.Tensor, TensorShape("bn", "s", "f")]
input_features_mask: Annotated[torch.Tensor, TensorShape("bn", "s")]
Gemma3nImageInputs = Gemma3nImagePixelInputs
@ -212,9 +224,9 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
input_features=MultiModalFieldConfig.batched("audio"),
input_features_padded=MultiModalFieldConfig.batched("audio"),
input_features_mask=MultiModalFieldConfig.batched("audio"))
input_features_mask=MultiModalFieldConfig.batched("audio"),
)
def _get_prompt_updates(
self,
@ -422,6 +434,7 @@ class Gemma3nMultimodalEmbedder(nn.Module):
dummy_inputs=Gemma3nDummyInputsBuilder)
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsTranscription):
merge_by_field_config = True
supported_languages = ISO639_1_SUPPORTED_LANGS
packed_modules_mapping = {
@ -482,14 +495,6 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
device=self.language_model.model.embed_tokens.weight.device,
dtype=self.language_model.model.embed_tokens.weight.dtype)
@property
def dtype(self):
return next(self.parameters()).dtype
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
# TODO check if there are any
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Gemma3nImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@ -499,34 +504,22 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values is None:
return None
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
pixel_values = flatten_bn(pixel_values, concat=True)
pixel_values = pixel_values.contiguous()
return Gemma3nImagePixelInputs(
pixel_values=self._validate_pixel_values(pixel_values), )
return Gemma3nImagePixelInputs(pixel_values=pixel_values)
def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[Gemma3nAudioInputs]:
input_features = kwargs.pop("input_features", None)
if input_features is None:
input_features_padded = kwargs.pop("input_features_padded", None)
if input_features_padded is None:
return None
input_features_mask = kwargs.pop("input_features_mask", None)
if input_features_mask is None:
return None
input_features_padded = kwargs.pop("input_features_padded", None)
if input_features_padded is None:
return None
return Gemma3nAudioInputs(
input_features=input_features,
input_features_mask=input_features_mask,
input_features_padded=input_features_padded,
input_features_mask=input_features_mask,
)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
@ -539,7 +532,7 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
) and "image" not in mm_input_by_modality:
mm_input_by_modality[
"image"] = self._parse_and_validate_image_input(**kwargs)
if input_key == "input_features" \
if input_key == "input_features_padded" \
and "audio" not in mm_input_by_modality:
mm_input_by_modality[
"audio"] = self._parse_and_validate_audio_input(**kwargs)

View File

@ -1319,6 +1319,8 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
)
class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -1381,22 +1383,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of {name}. Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == 2:
return mm_input
if mm_input.ndim != 3:
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
f"Got ndim: {mm_input.ndim} "
f"(shape={mm_input.shape})")
return mm_input.reshape(-1, mm_input.shape[-1])
else:
return torch.concat(mm_input)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Glm4vImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@ -1407,11 +1393,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
return None
if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
return Glm4vImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
@ -1419,11 +1400,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
)
if image_embeds is not None:
image_embeds = self._validate_and_reshape_mm_tensor(
image_embeds, "image embeds")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
return Glm4vImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds,
@ -1440,11 +1416,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
return None
if pixel_values_videos is not None:
pixel_values_videos = self._validate_and_reshape_mm_tensor(
pixel_values_videos, "video pixel values")
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
return Glm4vVideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=pixel_values_videos,
@ -1452,11 +1423,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
)
if video_embeds is not None:
video_embeds = self._validate_and_reshape_mm_tensor(
video_embeds, "video embeds")
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
return Glm4vVideoEmbeddingInputs(
type="video_embeds",
video_embeds=video_embeds,

View File

@ -43,7 +43,6 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .chatglm import ChatGLMBaseModel, ChatGLMModel
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import flatten_bn
class GLMVImagePixelInputs(TensorSchema):
@ -529,8 +528,9 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
@MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor,
info=GLM4VProcessingInfo,
dummy_inputs=GLM4VDummyInputsBuilder)
class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
SupportsMultiModal):
class GLM4VForCausalLM(ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA,
SupportsPP):
merge_by_field_config = True
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
@ -574,14 +574,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
expected_h = expected_w = self.config.vision_config["image_size"]
return GLMVImagePixelInputs(type="pixel_values",
data=flatten_bn(pixel_values,
concat=True),
data=pixel_values,
resolve_bindings={
"h": expected_h,
"w": expected_w
@ -598,6 +593,8 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
def get_language_model(self) -> torch.nn.Module:
return self.transformer
get_input_embeddings = SupportsMultiModal.get_input_embeddings
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)

View File

@ -168,10 +168,8 @@ class GraniteSpeechMultiModalProcessor(
# Calculate the number of audio tokens per entry in the batch;
# This is used to split the batch back out after padding.
audio_token_index = self.info.get_hf_config().audio_token_index
processed_outputs["audio_embed_sizes"] = [
torch.sum(indices == audio_token_index).item()
for indices in processed_outputs["input_ids"]
]
processed_outputs["audio_embed_sizes"] = (
processed_outputs["input_ids"] == audio_token_index).sum(-1)
return processed_outputs
@ -527,6 +525,7 @@ class GraniteSpeechForConditionalGeneration(
SupportsPP,
SupportsLoRA,
):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": [