[Model] Use merge_by_field_config for MM models (H-L) (#26230)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-05 11:54:17 +08:00
committed by GitHub
parent 119f00630b
commit 59a85c366e
6 changed files with 29 additions and 161 deletions

View File

@ -548,7 +548,7 @@ def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData:
engine_args = EngineArgs(
model=model_name,
trust_remote_code=True,
max_model_len=8192,
max_model_len=32768,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)

View File

@ -53,7 +53,7 @@ from .idefics2_vision_model import (
# yapf: enable
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .llama import LlamaModel
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
from .utils import AutoWeightsLoader, maybe_prefix
class Idefics3ImagePixelInputs(TensorSchema):
@ -67,7 +67,7 @@ class Idefics3ImagePixelInputs(TensorSchema):
"""
type: Literal["pixel_values"]
pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
pixel_attention_mask: torch.Tensor
pixel_attention_mask: Annotated[torch.Tensor, TensorShape("bnp", "h", "w")]
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
@ -569,6 +569,8 @@ class Idefics3Model(nn.Module):
dummy_inputs=Idefics3DummyInputsBuilder)
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -621,37 +623,21 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
return None
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Idefics3ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
data=image_embeds,
)
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
pixel_attention_mask = kwargs.pop("pixel_attention_mask")
if not isinstance(pixel_attention_mask, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel_attention_mask. "
f"Got type: {type(pixel_attention_mask)}")
num_patches = kwargs.pop("num_patches")
if not isinstance(num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}")
expected_h = expected_w = self.config.vision_config.image_size
return Idefics3ImagePixelInputs(
type="pixel_values",
pixel_values=flatten_bn(pixel_values, concat=True),
pixel_attention_mask=flatten_bn(pixel_attention_mask,
concat=True),
num_patches=flatten_bn(num_patches, concat=True),
pixel_values=pixel_values,
pixel_attention_mask=pixel_attention_mask,
num_patches=num_patches,
resolve_bindings={
"h": expected_h,
"w": expected_w

View File

@ -30,7 +30,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, VideoItem)
@ -42,7 +42,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
@ -100,8 +99,7 @@ def smart_resize(
class KeyeImagePixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- np: Number of patches
- bnp: Batch size * Number of patches
- c: Number of channels
- ps: Patch size
- ni: Number of images
@ -110,7 +108,7 @@ class KeyeImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
pixel_values: Annotated[
torch.Tensor,
TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})]
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
@ -134,8 +132,7 @@ KeyeImageInputs = Union[KeyeImagePixelInputs, KeyeImageEmbeddingInputs]
class KeyeVideoPixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- np: Number of patches
- bnp: Batch size * Number of patches
- c: Number of channels
- ps: Patch size
- ni: Number of images
@ -144,7 +141,7 @@ class KeyeVideoPixelInputs(TensorSchema):
type: Literal["pixel_values_videos"]
pixel_values_videos: Annotated[
torch.Tensor,
TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})]
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
@ -1258,6 +1255,8 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
class BaseKeyeModule(nn.Module):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -1524,28 +1523,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
prefix: str = "") -> nn.Module:
return Projector(text_config, vision_config, quant_config, prefix)
def _validate_and_reshape_mm_tensor(
self, mm_input: NestedTensors,
name: str) -> Union[torch.Tensor, list[torch.Tensor]]:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == 2:
return mm_input
if mm_input.ndim == 5:
return mm_input
if mm_input.ndim != 3:
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
f"Got ndim: {mm_input.ndim} "
f"(shape={mm_input.shape})")
return mm_input.reshape(-1, mm_input.shape[-1])
elif is_list_of(mm_input, torch.Tensor):
if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2
for p in mm_input):
return mm_input
return torch.concat(mm_input)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[KeyeImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@ -1556,11 +1533,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
return None
if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
return KeyeImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
@ -1568,11 +1540,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
)
if image_embeds is not None:
image_embeds = self._validate_and_reshape_mm_tensor(
image_embeds, "image embeds")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
return KeyeImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds,
@ -1589,13 +1556,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
return None
if pixel_values_videos is not None:
pixel_values_videos = self._validate_and_reshape_mm_tensor(
pixel_values_videos,
"video pixel values",
)
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
return KeyeVideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=pixel_values_videos,
@ -1603,11 +1563,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
)
if video_embeds is not None:
video_embeds = self._validate_and_reshape_mm_tensor(
video_embeds, "video embeds")
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
return KeyeVideoEmbeddingInputs(
type="video_embeds",
video_embeds=video_embeds,

View File

@ -18,7 +18,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalFieldConfig,
MultiModalKwargsItems, VideoItem)
@ -100,8 +100,7 @@ def get_num_patches(grid_thw: torch.Tensor,
class KeyeVL1_5ImagePixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- np: Number of patches
- bnp: Batch size * Number of patches
- c: Number of channels
- ps: Patch size
- ni: Number of images
@ -111,7 +110,7 @@ class KeyeVL1_5ImagePixelInputs(TensorSchema):
pixel_values: Annotated[
torch.Tensor,
TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})]
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
@ -137,8 +136,7 @@ KeyeVL1_5ImageInputs = Union[KeyeVL1_5ImagePixelInputs,
class KeyeVL1_5VideoPixelInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- np: Number of patches
- bnp: Batch size * Number of patches
- c: Number of channels
- ps: Patch size
- ni: Number of images
@ -147,7 +145,7 @@ class KeyeVL1_5VideoPixelInputs(TensorSchema):
type: Literal["pixel_values_videos"]
pixel_values_videos: Annotated[
torch.Tensor,
TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})]
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
num_frames: torch.Tensor
@ -483,24 +481,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
self.merge_size = config.vision_config.spatial_merge_size
super().__init__(vllm_config=vllm_config, prefix=prefix)
def _validate_and_reshape_mm_tensor(self, mm_input: NestedTensors,
expected_dim: int, name: str):
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == expected_dim:
return mm_input
elif mm_input.ndim == expected_dim + 1:
return mm_input.reshape(-1, *mm_input.shape[2:])
else:
raise ValueError(
f"{name} should be {expected_dim}D or "
f"batched {expected_dim}D tensor."
f"Got ndim: {mm_input.ndim} (shape={mm_input.shape})")
else:
return torch.concat(mm_input)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[KeyeVL1_5ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@ -511,11 +491,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
return None
if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, expected_dim=4, name="image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, expected_dim=2, name="image grid_thw")
return KeyeVL1_5ImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
@ -523,11 +498,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
)
if image_embeds is not None:
image_embeds = self._validate_and_reshape_mm_tensor(
image_embeds, expected_dim=2, name="image embeds")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, expected_dim=2, name="image grid_thw")
return KeyeVL1_5ImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds,
@ -545,17 +515,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
return None
if pixel_values_videos is not None:
pixel_values_videos = self._validate_and_reshape_mm_tensor(
pixel_values_videos,
expected_dim=4,
name="video pixel values",
)
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, expected_dim=2, name="video grid_thw")
num_frames = self._validate_and_reshape_mm_tensor(
num_frames, expected_dim=1, name="video num frames")
return KeyeVL1_5VideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=pixel_values_videos,
@ -563,11 +522,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
num_frames=num_frames)
if video_embeds is not None:
video_embeds = self._validate_and_reshape_mm_tensor(
video_embeds, expected_dim=2, name="video embeds")
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, expected_dim=2, name="video grid_thw")
return KeyeVL1_5VideoEmbeddingInputs(type="video_embeds",
video_embeds=video_embeds,
video_grid_thw=video_grid_thw,

View File

@ -283,6 +283,7 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]):
dummy_inputs=KimiVLDummyInputsBuilder)
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
merge_by_field_config = True
supports_encoder_tp_data = True
@ -342,23 +343,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
config.vocab_size, logit_scale)
self.media_placeholder: int = self.config.media_placeholder_token_id
# ref: qwen2_vl.py
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == 2:
return mm_input
if mm_input.ndim != 3:
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
f"Got ndim: {mm_input.ndim} "
f"(shape={mm_input.shape})")
return mm_input.reshape(-1, mm_input.shape[-1])
else:
return torch.concat(mm_input)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[KimiVLImageInputs]:
# image input type must be pixel values now
@ -368,21 +352,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values is None:
return None
image_grid_hws = self._validate_and_reshape_mm_tensor(
image_grid_hws, "image grid hws")
# pixel_values may have complex shapes
num_channels = 3
patch_size = self.config.vision_config.patch_size
if isinstance(pixel_values, list):
pixel_values = torch.cat([
x.reshape(-1, num_channels, patch_size, patch_size)
for x in pixel_values
])
else:
pixel_values = pixel_values.reshape(-1, num_channels, patch_size,
patch_size)
pixel_values = pixel_values.to(self.vision_tower.dtype)
return KimiVLImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,

View File

@ -164,7 +164,9 @@ class TensorSchema:
if len(actual_shape) != len(expected_shape):
raise ValueError(f"{field_name} has rank {len(actual_shape)} "
f"but expected {len(expected_shape)}")
f"but expected {len(expected_shape)}. "
f"Expected shape: {expected_shape}, "
f"but got {actual_shape}")
for i, dim in enumerate(expected_shape):
if dim in dynamic_dims:
@ -172,7 +174,9 @@ class TensorSchema:
elif isinstance(dim, int):
if actual_shape[i] != dim:
raise ValueError(f"{field_name} dim[{i}] expected "
f"{dim}, got {actual_shape[i]}")
f"{dim}, got {actual_shape[i]}. "
f"Expected shape: {expected_shape}, "
f"but got {actual_shape}")
elif isinstance(dim, str):
if dim in shape_env:
if actual_shape[i] != shape_env[dim]: