mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] Use merge_by_field_config
for MM models (G) (#26117)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -36,7 +36,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -289,7 +289,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
processor=hf_processor)
|
||||
for size in image_sizes
|
||||
]
|
||||
processed_outputs["num_crops"] = torch.tensor(num_crops)
|
||||
processed_outputs["num_patches"] = torch.tensor(num_crops) + 1
|
||||
|
||||
return processed_outputs
|
||||
|
||||
@ -298,12 +298,12 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
num_crops = hf_inputs.get("num_crops", torch.empty(0))
|
||||
num_patches = hf_inputs.get("num_patches", torch.empty(0))
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", num_crops + 1),
|
||||
num_crops=MultiModalFieldConfig.batched("image"),
|
||||
"image", num_patches),
|
||||
num_patches=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@ -460,6 +460,8 @@ class Gemma3MultiModalProjector(nn.Module):
|
||||
dummy_inputs=Gemma3DummyInputsBuilder)
|
||||
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
SupportsLoRA):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -526,29 +528,20 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
num_crops = kwargs.pop("num_crops", None)
|
||||
num_patches = kwargs.pop("num_patches", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
assert image_embeds is None, "Gemma3 does not support image_embeds."
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if not isinstance(num_crops, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_crops. "
|
||||
f"Got type: {type(num_crops)}")
|
||||
|
||||
image_size = self.config.vision_config.image_size
|
||||
|
||||
return Gemma3ImagePixelInputs(
|
||||
pixel_values=flatten_bn(pixel_values, concat=True),
|
||||
num_patches=flatten_bn(num_crops, concat=True) + 1,
|
||||
resolve_bindings={
|
||||
"h": image_size,
|
||||
"w": image_size
|
||||
})
|
||||
return Gemma3ImagePixelInputs(pixel_values=pixel_values,
|
||||
num_patches=num_patches,
|
||||
resolve_bindings={
|
||||
"h": image_size,
|
||||
"w": image_size
|
||||
})
|
||||
|
||||
def _image_pixels_to_features(
|
||||
self,
|
||||
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, TypedDict, Union, cast
|
||||
from typing import Annotated, Any, Literal, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -41,6 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
||||
SupportsTranscription)
|
||||
@ -54,17 +55,28 @@ TOKENS_PER_IMAGE = 256
|
||||
TOKENS_PER_AUDIO = 188
|
||||
|
||||
|
||||
class Gemma3nImagePixelInputs(TypedDict):
|
||||
pixel_values: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
||||
class Gemma3nImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- c: Number of channels (3)
|
||||
- h: Height of each patch
|
||||
- w: Width of each patch
|
||||
"""
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
|
||||
|
||||
|
||||
class Gemma3nAudioInputs(TypedDict):
|
||||
input_features: Union[torch.Tensor, list[torch.Tensor]]
|
||||
input_features_padded: torch.Tensor
|
||||
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
|
||||
input_features_mask: torch.Tensor
|
||||
"""Shape: `(batch_size * num_audio, seq_length)`"""
|
||||
class Gemma3nAudioInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of audios
|
||||
- s: seq_length
|
||||
- f: num_features
|
||||
"""
|
||||
type: Literal["audio"] = "audio"
|
||||
input_features_padded: Annotated[torch.Tensor, TensorShape("bn", "s", "f")]
|
||||
input_features_mask: Annotated[torch.Tensor, TensorShape("bn", "s")]
|
||||
|
||||
|
||||
Gemma3nImageInputs = Gemma3nImagePixelInputs
|
||||
@ -212,9 +224,9 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
input_features=MultiModalFieldConfig.batched("audio"),
|
||||
input_features_padded=MultiModalFieldConfig.batched("audio"),
|
||||
input_features_mask=MultiModalFieldConfig.batched("audio"))
|
||||
input_features_mask=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@ -422,6 +434,7 @@ class Gemma3nMultimodalEmbedder(nn.Module):
|
||||
dummy_inputs=Gemma3nDummyInputsBuilder)
|
||||
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsTranscription):
|
||||
merge_by_field_config = True
|
||||
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||
|
||||
packed_modules_mapping = {
|
||||
@ -482,14 +495,6 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
device=self.language_model.model.embed_tokens.weight.device,
|
||||
dtype=self.language_model.model.embed_tokens.weight.dtype)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
# TODO check if there are any
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Gemma3nImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
@ -499,34 +504,22 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
pixel_values = pixel_values.contiguous()
|
||||
|
||||
return Gemma3nImagePixelInputs(
|
||||
pixel_values=self._validate_pixel_values(pixel_values), )
|
||||
return Gemma3nImagePixelInputs(pixel_values=pixel_values)
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object) -> Optional[Gemma3nAudioInputs]:
|
||||
input_features = kwargs.pop("input_features", None)
|
||||
if input_features is None:
|
||||
|
||||
input_features_padded = kwargs.pop("input_features_padded", None)
|
||||
if input_features_padded is None:
|
||||
return None
|
||||
|
||||
input_features_mask = kwargs.pop("input_features_mask", None)
|
||||
if input_features_mask is None:
|
||||
return None
|
||||
|
||||
input_features_padded = kwargs.pop("input_features_padded", None)
|
||||
if input_features_padded is None:
|
||||
return None
|
||||
|
||||
return Gemma3nAudioInputs(
|
||||
input_features=input_features,
|
||||
input_features_mask=input_features_mask,
|
||||
input_features_padded=input_features_padded,
|
||||
input_features_mask=input_features_mask,
|
||||
)
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
@ -539,7 +532,7 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) and "image" not in mm_input_by_modality:
|
||||
mm_input_by_modality[
|
||||
"image"] = self._parse_and_validate_image_input(**kwargs)
|
||||
if input_key == "input_features" \
|
||||
if input_key == "input_features_padded" \
|
||||
and "audio" not in mm_input_by_modality:
|
||||
mm_input_by_modality[
|
||||
"audio"] = self._parse_and_validate_audio_input(**kwargs)
|
||||
|
@ -1319,6 +1319,8 @@ class Glm4vMultiModalProcessor(BaseMultiModalProcessor[Glm4vProcessingInfo]):
|
||||
)
|
||||
class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsLoRA, SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -1381,22 +1383,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||
name: str) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of {name}. Got type: {type(mm_input)}")
|
||||
if isinstance(mm_input, torch.Tensor):
|
||||
if mm_input.ndim == 2:
|
||||
return mm_input
|
||||
if mm_input.ndim != 3:
|
||||
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
||||
f"Got ndim: {mm_input.ndim} "
|
||||
f"(shape={mm_input.shape})")
|
||||
return mm_input.reshape(-1, mm_input.shape[-1])
|
||||
else:
|
||||
return torch.concat(mm_input)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Glm4vImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
@ -1407,11 +1393,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = self._validate_and_reshape_mm_tensor(
|
||||
pixel_values, "image pixel values")
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw")
|
||||
|
||||
return Glm4vImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=pixel_values,
|
||||
@ -1419,11 +1400,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
image_embeds = self._validate_and_reshape_mm_tensor(
|
||||
image_embeds, "image embeds")
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw")
|
||||
|
||||
return Glm4vImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
image_embeds=image_embeds,
|
||||
@ -1440,11 +1416,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return None
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = self._validate_and_reshape_mm_tensor(
|
||||
pixel_values_videos, "video pixel values")
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, "video grid_thw")
|
||||
|
||||
return Glm4vVideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
@ -1452,11 +1423,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
)
|
||||
|
||||
if video_embeds is not None:
|
||||
video_embeds = self._validate_and_reshape_mm_tensor(
|
||||
video_embeds, "video embeds")
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, "video grid_thw")
|
||||
|
||||
return Glm4vVideoEmbeddingInputs(
|
||||
type="video_embeds",
|
||||
video_embeds=video_embeds,
|
||||
|
@ -43,7 +43,6 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .chatglm import ChatGLMBaseModel, ChatGLMModel
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .utils import flatten_bn
|
||||
|
||||
|
||||
class GLMVImagePixelInputs(TensorSchema):
|
||||
@ -529,8 +528,9 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
|
||||
@MULTIMODAL_REGISTRY.register_processor(GLM4VMultiModalProcessor,
|
||||
info=GLM4VProcessingInfo,
|
||||
dummy_inputs=GLM4VDummyInputsBuilder)
|
||||
class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
SupportsMultiModal):
|
||||
class GLM4VForCausalLM(ChatGLMBaseModel, SupportsMultiModal, SupportsLoRA,
|
||||
SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"query_key_value": ["query_key_value"],
|
||||
@ -574,14 +574,9 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
expected_h = expected_w = self.config.vision_config["image_size"]
|
||||
return GLMVImagePixelInputs(type="pixel_values",
|
||||
data=flatten_bn(pixel_values,
|
||||
concat=True),
|
||||
data=pixel_values,
|
||||
resolve_bindings={
|
||||
"h": expected_h,
|
||||
"w": expected_w
|
||||
@ -598,6 +593,8 @@ class GLM4VForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.transformer
|
||||
|
||||
get_input_embeddings = SupportsMultiModal.get_input_embeddings
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
@ -168,10 +168,8 @@ class GraniteSpeechMultiModalProcessor(
|
||||
# Calculate the number of audio tokens per entry in the batch;
|
||||
# This is used to split the batch back out after padding.
|
||||
audio_token_index = self.info.get_hf_config().audio_token_index
|
||||
processed_outputs["audio_embed_sizes"] = [
|
||||
torch.sum(indices == audio_token_index).item()
|
||||
for indices in processed_outputs["input_ids"]
|
||||
]
|
||||
processed_outputs["audio_embed_sizes"] = (
|
||||
processed_outputs["input_ids"] == audio_token_index).sum(-1)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
@ -527,6 +525,7 @@ class GraniteSpeechForConditionalGeneration(
|
||||
SupportsPP,
|
||||
SupportsLoRA,
|
||||
):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
|
Reference in New Issue
Block a user