[Multimodal] Update Tensor schema test to cover arbitrary shape mm inputs (#22867)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-08-16 15:44:50 +08:00
committed by GitHub
parent 6d3da472bc
commit cc826a202b
2 changed files with 138 additions and 27 deletions

View File

@ -1,17 +1,26 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from functools import partial
from typing import Any, Union
from unittest.mock import patch
import numpy as np
import pytest
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image
from vllm.config import ModelConfig
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs)
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils import GiB_bytes, set_default_torch_num_threads
from vllm.utils import GiB_bytes, is_list_of, set_default_torch_num_threads
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.engine.core import EngineCore as V1EngineCore
@ -23,12 +32,64 @@ ARCH_TO_SKIP = {
"MolmoForCausalLM": "incompatible requirements",
"MiniMaxVL01ForConditionalGeneration": "broken model",
}
ARCH_NEEDS_EXTRAS = [
"InternVLChatModel",
"Idefics3ForConditionalGeneration",
"LlavaForConditionalGeneration",
"MiniCPMV",
"PaliGemmaForConditionalGeneration",
]
REPO_ID_TO_SKIP = {"nm-testing/pixtral-12b-FP8-dynamic": "duplicated test"}
ImageInput = list[Image.Image]
VideoInput = Union[list[Image.Image], list[np.ndarray],
list[tuple[np.ndarray, dict[str, Any]]]]
AudioInput = list[tuple[np.ndarray, int]]
def _resize_data(_data: Union[Image.Image, np.ndarray],
size_factor: float) -> Union[Image.Image, np.ndarray]:
assert size_factor <= 1, "Size factor must be less than 1"
# Image input
if isinstance(_data, Image.Image):
W, H = _data.width, _data.height
W, H = map(lambda x: int(x * size_factor), (W, H))
return _data.resize((W, H))
# Video input with PIL Images
elif is_list_of(_data, Image.Image):
W, H = next(iter(_data)).width, next(iter(_data)).height
T = len(_data)
T, W, H = map(lambda x: max(int(x * size_factor), 1), (T, W, H))
return [d.resize((W, H)) for d in _data[:T]]
# Video input with numpy arrays
elif isinstance(_data, np.ndarray) and _data.ndim >= 4:
T, H, W, C = _data.shape[-4:]
T, H, W = map(lambda x: max(int(x * size_factor), 1), (T, H, W))
return _data[..., :T, :H, :W, :C]
# Audio input
elif isinstance(_data, np.ndarray) and _data.ndim == 1:
return _data[:int(len(_data) * size_factor)]
raise AssertionError("This line should be unreachable.")
def resize_mm_data(
data: Union[ImageInput, VideoInput, AudioInput],
size_factors: tuple[float,
...]) -> Union[ImageInput, VideoInput, AudioInput]:
size_factors = size_factors[:len(data)]
if is_list_of(data, (Image.Image, np.ndarray, list)):
return [_resize_data(d, s) for d, s in zip(data, size_factors)]
elif is_list_of(data, tuple):
return [(_resize_data(d, s), meta)
for (d, meta), s in zip(data, size_factors)]
raise ValueError("Unsupported multimodal data type.")
def create_batched_mm_kwargs(
model_config: ModelConfig,
processor: BaseMultiModalProcessor,
) -> MultiModalKwargs:
size_factors: tuple[float, ...] = (1.0, 0.5, 0.25),
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
processing_info = processor.info
dummy_inputs = processor.dummy_inputs
supported_mm_limits = processing_info.get_supported_mm_limits()
@ -40,30 +101,69 @@ def create_batched_mm_kwargs(
seq_len=model_config.max_model_len,
mm_counts=mm_counts,
)
mm_data = processor_inputs.mm_data
resized_mm_data = {
modality: resize_mm_data(data, size_factors)
for modality, data in mm_data.items()
}
# Mistral chat outputs tokens directly, rather than text prompts
if model_config.tokenizer_mode == "mistral":
images = resized_mm_data.get("image", [])
request = ChatCompletionRequest(messages=[
UserMessage(content=[
TextChunk(text=""),
*(ImageChunk(image=image) for image in images),
]),
])
tokenizer = processing_info.get_tokenizer()
res = tokenizer.mistral.encode_chat_completion(request)
prompt = res.tokens
else:
prompt = processor_inputs.prompt
mm_kwargs = processor.apply(
prompt=processor_inputs.prompt,
mm_data=processor_inputs.mm_data,
prompt=prompt,
mm_data=resized_mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
tokenization_kwargs=processor_inputs.tokenization_kwargs,
)["mm_kwargs"]
mm_kwargs = MultiModalKwargs.batch([mm_kwargs])
return mm_kwargs
items = [
item for modality in supported_mm_limits
for item in mm_kwargs.get_items(modality)
]
return group_mm_kwargs_by_modality(items)
def get_model_id_to_test(
model_arch_list: Iterable[str]) -> list[tuple[str, str]]:
filtered_results = []
for model_arch in model_arch_list:
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
if model_info.extras and model_arch in ARCH_NEEDS_EXTRAS:
available_repos = list(
map(lambda model_id: (model_arch, model_id),
[model_info.default, *model_info.extras.values()]))
filtered_results.extend(available_repos)
else:
filtered_results.append((model_arch, model_info.default))
return filtered_results
@pytest.mark.core_model
@pytest.mark.parametrize("model_arch", list(_MULTIMODAL_EXAMPLE_MODELS.keys()))
def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner],
monkeypatch):
@pytest.mark.parametrize(
"model_arch, model_id",
get_model_id_to_test(_MULTIMODAL_EXAMPLE_MODELS.keys()))
def test_model_tensor_schema(model_arch: str, model_id: str,
vllm_runner: type[VllmRunner], monkeypatch):
if model_arch in ARCH_TO_SKIP:
pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}")
if model_id in REPO_ID_TO_SKIP:
pytest.skip(f"Skipping {model_id} due to {REPO_ID_TO_SKIP[model_id]}")
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip",
check_max_version=False)
model_id = model_info.default
hf_overrides_fn = partial(dummy_hf_overrides,
model_arch=model_arch,
exist_overrides=model_info.hf_overrides)
@ -119,6 +219,7 @@ def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner],
if model_info.v0_only:
m.setenv("VLLM_USE_V1", "0")
# TODO(Isotr0py): Can we avoid initializing engine?
with (
set_default_torch_num_threads(1),
vllm_runner(
@ -145,12 +246,16 @@ def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner],
mm_registry = llm_engine.input_preprocessor.mm_registry
processor = mm_registry.create_processor(model_config)
mm_kwargs = create_batched_mm_kwargs(model_config, processor)
def validate_model_input(model):
for modality in ("audio", "image", "video"):
method_name = f"_parse_and_validate_{modality}_input"
if hasattr(model, method_name):
getattr(model, method_name)(**mm_kwargs)
def validate_model_input(model, modality: str,
mm_kwargs: MultiModalKwargs):
method_name = f"_parse_and_validate_{modality}_input"
if hasattr(model, method_name):
getattr(model, method_name)(**mm_kwargs)
vllm_model.apply_model(validate_model_input)
for modality, _, mm_kwargs in create_batched_mm_kwargs(
model_config, processor):
valid_func = partial(validate_model_input,
modality=modality,
mm_kwargs=mm_kwargs)
vllm_model.apply_model(valid_func)

View File

@ -30,7 +30,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.inputs import (ImageItem, ModalityData,
MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs, VideoItem)
@ -44,6 +44,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.config import uses_mrope
from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
@ -112,8 +113,9 @@ class KeyeImagePixelInputs(TensorSchema):
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["pixel_values"]
pixel_values: Annotated[torch.Tensor,
TensorShape("b", "np", 3, "ps", "ps")]
pixel_values: Annotated[
torch.Tensor,
TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})]
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
@ -145,8 +147,9 @@ class KeyeVideoPixelInputs(TensorSchema):
- g: Grid dimensions (3 for t, h, w)
"""
type: Literal["pixel_values_videos"]
pixel_values_videos: Annotated[torch.Tensor,
TensorShape("b", "np", 3, "ps", "ps")]
pixel_values_videos: Annotated[
torch.Tensor,
TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})]
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
@ -1295,7 +1298,7 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
return None
return quant_config
def _validate_and_reshape_mm_tensor(self, mm_input: object,
def _validate_and_reshape_mm_tensor(self, mm_input: NestedTensors,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. "
@ -1310,8 +1313,11 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
f"Got ndim: {mm_input.ndim} "
f"(shape={mm_input.shape})")
return torch.concat(list(mm_input))
else:
return torch.concat(mm_input)
elif is_list_of(mm_input, torch.Tensor):
if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2
for p in mm_input):
return mm_input
return torch.concat(list(mm_input))
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[KeyeImageInputs]: