[Model] Use merge_by_field_config for MM models (D-F) (#26076)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-02 23:17:35 +08:00
committed by GitHub
parent 7d6fb905d9
commit cc253b73d3
4 changed files with 102 additions and 180 deletions

View File

@ -20,8 +20,7 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.transformers import replace_linear_class
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, MultiModalUUIDDict,
NestedTensors)
MultiModalKwargsItems, MultiModalUUIDDict)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@ -40,7 +39,7 @@ from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix)
# The image token id may be various
@ -50,15 +49,15 @@ _IMAGE_TOKEN = "<image>"
class DeepseekVL2ImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- bnp: Batch size * number of images * number of patches
- p: Number of patches
- c: Number of channels (3)
- h: Height of each image
- w: Width of each image
"""
type: Literal["pixel_values"]
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"})]
data: Annotated[torch.Tensor,
TensorShape("bnp", 3, "h", "w", dynamic_dims={"bnp"})]
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
@ -228,12 +227,8 @@ class DeepseekVL2MultiModalProcessor(
tok_kwargs=tok_kwargs,
)
pixel_values = processed_outputs["pixel_values"]
# split pixel values into patches corresponding to each image
images_spatial_crop = processed_outputs["images_spatial_crop"]
patches_per_image = [x.prod().item() + 1 for x in images_spatial_crop]
pixel_values = pixel_values.split(patches_per_image)
processed_outputs["pixel_values"] = pixel_values
processed_outputs["num_patches"] = (
processed_outputs["images_spatial_crop"].prod(-1) + 1)
return processed_outputs
@ -242,8 +237,11 @@ class DeepseekVL2MultiModalProcessor(
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
num_patches = hf_inputs.get("num_patches", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_patches),
images_spatial_crop=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
@ -318,6 +316,7 @@ class DeepseekVL2MultiModalProcessor(
info=DeepseekVL2ProcessingInfo,
dummy_inputs=DeepseekVL2DummyInputsBuilder)
class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
"language.": "language_model.",
@ -460,37 +459,30 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
if pixel_values is not None:
expected_h = expected_w = self.vision_config.image_size
return DeepseekVL2ImagePixelInputs(type="pixel_values",
data=flatten_bn(pixel_values),
images_spatial_crop=flatten_bn(
images_spatial_crop,
concat=True),
resolve_bindings={
"h": expected_h,
"w": expected_w,
})
return DeepseekVL2ImagePixelInputs(
type="pixel_values",
data=pixel_values,
images_spatial_crop=images_spatial_crop,
resolve_bindings={
"h": expected_h,
"w": expected_w,
})
if image_embeds is not None:
return DeepseekVL2VImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _pixel_values_to_embedding(
self,
pixel_values: NestedTensors,
pixel_values: torch.Tensor,
images_spatial_crop: torch.Tensor,
) -> NestedTensors:
# Pixel_values: n_image * batch_size * [patch_per_img, 3, height, width]
total_tiles = [x for x in pixel_values]
# [batch_all_tiles, 3, height, width]
total_tiles = torch.cat(total_tiles, dim=0)
) -> list[torch.Tensor]:
# [batch_all_tiles, vit_seq_len, c]
images_feature = self.vision.forward_features(total_tiles)
images_feature = self.vision.forward_features(pixel_values)
# [batch_all_tiles, hw, D]
images_embeds = self.projector(images_feature)
@ -573,7 +565,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return vision_embeddings
def _process_image_input(
self, image_input: DeepseekVL2ImageInputs) -> torch.Tensor:
self, image_input: DeepseekVL2ImageInputs) -> list[torch.Tensor]:
if image_input["type"] == "image_embeds":
image_data = image_input["data"]
if is_list_of(image_data, torch.Tensor):

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping
from typing import Literal, Optional, TypedDict, Union
from typing import Annotated, Literal, Optional, Union
import torch
import torch.nn as nn
@ -42,34 +42,38 @@ from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.dotsocr import (DotsOCRConfig,
DotsVisionConfig)
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .vision import run_dp_sharded_mrope_vision_model
IMAGE_TOKEN = "<|imgpad|>"
class DotsOCRImagePixelInputs(TypedDict):
type: Literal["pixel_values", "image_grid_thw"]
pixel_values: torch.Tensor
image_grid_thw: torch.Tensor
class DotsOCRImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds", "image_grid_thw"]
image_embeds: torch.Tensor
"""Supported types:
- List[`torch.Tensor`]: A list of tensors holding all images' features.
Each tensor holds an image's features.
- `torch.Tensor`: A tensor holding all images' features
(concatenation of all images' feature tensors).
Tensor shape: `(num_image_features, hidden_size)`
- `num_image_features` varies based on
the number and resolution of the images.
- `hidden_size` must match the hidden size of language model backbone.
class DotsOCRImagePixelInputs(TensorSchema):
"""
Dimensions:
- np: The total number of patches over each image over each prompt in
the batch
- ni: Number of images
- cps: Number of channels * patch_size * patch_size
"""
type: Literal["pixel_values"]
image_grid_thw: torch.Tensor
pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
class DotsOCRImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- nf: Number of image features
- hs: Hidden size
- ni: Number of images
"""
type: Literal["image_embeds"]
image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
DotsOCRImageInputs = Union[DotsOCRImagePixelInputs,
@ -654,6 +658,8 @@ class DotsVisionTransformer(nn.Module):
)
class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={
".attn.qkv_proj.": ".attn.qkv.",
@ -709,22 +715,6 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
architectures=["Qwen2ForCausalLM"],
)
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == 2:
return mm_input
if mm_input.ndim != 3:
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
f"Got ndim: {mm_input.ndim} "
f"(shape={mm_input.shape})")
return torch.concat(list(mm_input))
else:
return torch.concat(mm_input)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[DotsOCRImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@ -735,28 +725,11 @@ class DotsOCRForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
return None
if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}")
return DotsOCRImagePixelInputs(type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw)
if image_embeds is not None:
image_embeds = self._validate_and_reshape_mm_tensor(
image_embeds, "image embeds")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return DotsOCRImageEmbeddingInputs(type="image_embeds",
image_embeds=image_embeds,
image_grid_thw=image_grid_thw)

View File

@ -25,7 +25,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from typing import Any, Callable, Literal, Optional, TypedDict, Union
from typing import Annotated, Any, Callable, Literal, Optional, Union
import numpy as np
import torch
@ -56,6 +56,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .ernie45_vl_moe import Ernie4_5_VLMoeForCausalLM
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
@ -579,38 +580,38 @@ class Ernie4_5_VisionTransformer(nn.Module):
# === Vision Inputs === #
class Ernie4_5_VLImagePixelInputs(TypedDict):
class Ernie4_5_VLImagePixelInputs(TensorSchema):
"""
Dimensions:
- np: The total number of patches over each image over each prompt in
the batch
- ni: Number of images
- cps: Number of channels * patch_size * patch_size
"""
type: Literal["pixel_values"]
pixel_values: torch.Tensor
"""Shape:
`(num_patches, num_channels * patch_size * patch_size)`
"""
grid_thw: torch.Tensor
"""Shape: `(num_images, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
"""
pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
Ernie4_5_VLImageInputs = Ernie4_5_VLImagePixelInputs
class Ernie4_5_VLVideoPixelInputs(TypedDict):
class Ernie4_5_VLVideoPixelInputs(TensorSchema):
"""
Dimensions:
- np: The total number of patches over each image over each prompt in
the batch
- ni: Number of images
- cps: Number of channels * temporal_patch_size * patch_size *
patch_size
"""
type: Literal["pixel_values_videos"]
pixel_values_videos: torch.Tensor
"""Shape:
`(num_patches,
num_channels * temporal_patch_size * patch_size * patch_size)`
"""
video_grid_thw: torch.Tensor
"""Shape: `(num_videos, 3)`
This should be in `(grid_t, grid_h, grid_w)` format.
"""
pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "cps")]
video_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
Ernie4_5_VLVideoInputs = Ernie4_5_VLImagePixelInputs
Ernie4_5_VLVideoInputs = Ernie4_5_VLVideoPixelInputs
# === Vision Processor === #
@ -1213,6 +1214,7 @@ class Ernie4_5_VLDummyInputsBuilder(
dummy_inputs=Ernie4_5_VLDummyInputsBuilder)
class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsLoRA, SupportsPP):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": [
@ -1325,22 +1327,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
f"Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == 2:
return mm_input
if mm_input.ndim != 3:
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
f"Got ndim: {mm_input.ndim} "
f"(shape={mm_input.shape})")
return mm_input.reshape(-1, mm_input.shape[-1])
else:
return torch.concat(mm_input)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Ernie4_5_VLImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@ -1350,15 +1336,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
return None
if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values")
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw")
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}")
return Ernie4_5_VLImagePixelInputs(type="pixel_values",
pixel_values=pixel_values,
image_grid_thw=image_grid_thw)
@ -1372,11 +1349,6 @@ class Ernie4_5_VLMoeForConditionalGeneration(nn.Module, SupportsMultiModal,
return None
if pixel_values_videos is not None:
pixel_values_videos = self._validate_and_reshape_mm_tensor(
pixel_values_videos, "video pixel values")
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw")
return Ernie4_5_VLVideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=pixel_values_videos,

View File

@ -59,17 +59,14 @@ class FuyuImagePatchInputs(TensorSchema):
type: Literal["image_patches"] = "image_patches"
flat_data: Annotated[
torch.Tensor,
TensorShape("bnp", "fn"),
]
image_patches_flat: Annotated[torch.Tensor, TensorShape("bnp", "fn")]
patches_per_image: Annotated[list[int], TensorShape("bn")]
"""
The number of total patches for each image in the batch.
This is used to split the embeddings which has the first two dimensions
flattened just like `flat_data`.
flattened just like `image_patches_flat`.
"""
@ -174,28 +171,10 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
tok_kwargs=tok_kwargs,
)
image_patches = processed_outputs.get("image_patches")
if image_patches is not None:
images = mm_data["images"]
assert isinstance(images, list)
# Original output: (1, num_images, Pn, Px * Py * C)
# New output: (num_images, Pn, Px * Py * C)
# image_patches is a list with shape:
# (1, num_images, Pn, Px * Py * C)
# before Transformers 4.53
if isinstance(image_patches, list):
assert len(image_patches) == 1
assert (isinstance(image_patches[0], torch.Tensor)
and len(image_patches[0]) == len(images))
processed_outputs["image_patches"] = image_patches[0]
# image_patches is a tensor with shape:
# (num_images, Pn, Px * Py * C)
# after Transformers 4.53
elif isinstance(image_patches, torch.Tensor):
assert len(image_patches) == len(images)
else:
raise AssertionError("This line should be unreachable.")
image_patches = processed_outputs["image_patches"]
processed_outputs["image_patches"] = flatten_bn(image_patches)
processed_outputs["patches_per_image"] = torch.tensor(
[len(p) for p in image_patches])
return processed_outputs
@ -218,7 +197,13 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(image_patches=MultiModalFieldConfig.batched("image"))
patches_per_image = hf_inputs.get("patches_per_image", torch.empty(0))
return dict(
image_patches=MultiModalFieldConfig.flat_from_sizes(
"image", patches_per_image),
patches_per_image=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
self,
@ -263,6 +248,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
info=FuyuProcessingInfo,
dummy_inputs=FuyuDummyInputsBuilder)
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
@ -306,29 +292,28 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
image_patches = kwargs.pop("image_patches", None)
if image_patches is not None:
image_patches_flat = flatten_bn(image_patches)
flat_data = flatten_bn(image_patches_flat, concat=True)
patches_per_image = kwargs.pop("patches_per_image", None)
return FuyuImagePatchInputs(
type="image_patches",
flat_data=flat_data,
patches_per_image=[x.size(0) for x in image_patches_flat],
resolve_bindings={"fn": self.image_feature_size},
)
if image_patches is None:
return None
return None
return FuyuImagePatchInputs(
type="image_patches",
image_patches_flat=image_patches,
patches_per_image=patches_per_image,
resolve_bindings={"fn": self.image_feature_size},
)
def _process_image_input(
self, image_input: FuyuImagePatchInputs) -> MultiModalEmbeddings:
image_patches_flat = image_input["flat_data"]
image_patches_flat = image_input["image_patches_flat"]
patches_per_image = image_input["patches_per_image"]
assert self.vision_embed_tokens is not None
vision_embeddings_flat, _ = self.vision_embed_tokens(
image_patches_flat)
return vision_embeddings_flat.split(patches_per_image, dim=0)
return vision_embeddings_flat.split(patches_per_image.tolist(), dim=0)
def get_language_model(self) -> torch.nn.Module:
return self.language_model