mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Multimodal] Update Tensor schema test to cover arbitrary shape mm inputs (#22867)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@ -1,17 +1,26 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
from typing import Any, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
|
||||
UserMessage)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from PIL import Image
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalKwargs)
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
from vllm.utils import GiB_bytes, set_default_torch_num_threads
|
||||
from vllm.utils import GiB_bytes, is_list_of, set_default_torch_num_threads
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
|
||||
from vllm.v1.engine.core import EngineCore as V1EngineCore
|
||||
|
||||
@ -23,12 +32,64 @@ ARCH_TO_SKIP = {
|
||||
"MolmoForCausalLM": "incompatible requirements",
|
||||
"MiniMaxVL01ForConditionalGeneration": "broken model",
|
||||
}
|
||||
ARCH_NEEDS_EXTRAS = [
|
||||
"InternVLChatModel",
|
||||
"Idefics3ForConditionalGeneration",
|
||||
"LlavaForConditionalGeneration",
|
||||
"MiniCPMV",
|
||||
"PaliGemmaForConditionalGeneration",
|
||||
]
|
||||
REPO_ID_TO_SKIP = {"nm-testing/pixtral-12b-FP8-dynamic": "duplicated test"}
|
||||
|
||||
ImageInput = list[Image.Image]
|
||||
VideoInput = Union[list[Image.Image], list[np.ndarray],
|
||||
list[tuple[np.ndarray, dict[str, Any]]]]
|
||||
AudioInput = list[tuple[np.ndarray, int]]
|
||||
|
||||
|
||||
def _resize_data(_data: Union[Image.Image, np.ndarray],
|
||||
size_factor: float) -> Union[Image.Image, np.ndarray]:
|
||||
assert size_factor <= 1, "Size factor must be less than 1"
|
||||
# Image input
|
||||
if isinstance(_data, Image.Image):
|
||||
W, H = _data.width, _data.height
|
||||
W, H = map(lambda x: int(x * size_factor), (W, H))
|
||||
return _data.resize((W, H))
|
||||
# Video input with PIL Images
|
||||
elif is_list_of(_data, Image.Image):
|
||||
W, H = next(iter(_data)).width, next(iter(_data)).height
|
||||
T = len(_data)
|
||||
T, W, H = map(lambda x: max(int(x * size_factor), 1), (T, W, H))
|
||||
return [d.resize((W, H)) for d in _data[:T]]
|
||||
# Video input with numpy arrays
|
||||
elif isinstance(_data, np.ndarray) and _data.ndim >= 4:
|
||||
T, H, W, C = _data.shape[-4:]
|
||||
T, H, W = map(lambda x: max(int(x * size_factor), 1), (T, H, W))
|
||||
return _data[..., :T, :H, :W, :C]
|
||||
# Audio input
|
||||
elif isinstance(_data, np.ndarray) and _data.ndim == 1:
|
||||
return _data[:int(len(_data) * size_factor)]
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
|
||||
def resize_mm_data(
|
||||
data: Union[ImageInput, VideoInput, AudioInput],
|
||||
size_factors: tuple[float,
|
||||
...]) -> Union[ImageInput, VideoInput, AudioInput]:
|
||||
size_factors = size_factors[:len(data)]
|
||||
if is_list_of(data, (Image.Image, np.ndarray, list)):
|
||||
return [_resize_data(d, s) for d, s in zip(data, size_factors)]
|
||||
elif is_list_of(data, tuple):
|
||||
return [(_resize_data(d, s), meta)
|
||||
for (d, meta), s in zip(data, size_factors)]
|
||||
raise ValueError("Unsupported multimodal data type.")
|
||||
|
||||
|
||||
def create_batched_mm_kwargs(
|
||||
model_config: ModelConfig,
|
||||
processor: BaseMultiModalProcessor,
|
||||
) -> MultiModalKwargs:
|
||||
size_factors: tuple[float, ...] = (1.0, 0.5, 0.25),
|
||||
) -> Iterable[tuple[str, int, BatchedTensorInputs]]:
|
||||
processing_info = processor.info
|
||||
dummy_inputs = processor.dummy_inputs
|
||||
supported_mm_limits = processing_info.get_supported_mm_limits()
|
||||
@ -40,30 +101,69 @@ def create_batched_mm_kwargs(
|
||||
seq_len=model_config.max_model_len,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
mm_data = processor_inputs.mm_data
|
||||
resized_mm_data = {
|
||||
modality: resize_mm_data(data, size_factors)
|
||||
for modality, data in mm_data.items()
|
||||
}
|
||||
# Mistral chat outputs tokens directly, rather than text prompts
|
||||
if model_config.tokenizer_mode == "mistral":
|
||||
images = resized_mm_data.get("image", [])
|
||||
request = ChatCompletionRequest(messages=[
|
||||
UserMessage(content=[
|
||||
TextChunk(text=""),
|
||||
*(ImageChunk(image=image) for image in images),
|
||||
]),
|
||||
])
|
||||
tokenizer = processing_info.get_tokenizer()
|
||||
res = tokenizer.mistral.encode_chat_completion(request)
|
||||
prompt = res.tokens
|
||||
else:
|
||||
prompt = processor_inputs.prompt
|
||||
mm_kwargs = processor.apply(
|
||||
prompt=processor_inputs.prompt,
|
||||
mm_data=processor_inputs.mm_data,
|
||||
prompt=prompt,
|
||||
mm_data=resized_mm_data,
|
||||
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=processor_inputs.tokenization_kwargs,
|
||||
)["mm_kwargs"]
|
||||
mm_kwargs = MultiModalKwargs.batch([mm_kwargs])
|
||||
return mm_kwargs
|
||||
items = [
|
||||
item for modality in supported_mm_limits
|
||||
for item in mm_kwargs.get_items(modality)
|
||||
]
|
||||
return group_mm_kwargs_by_modality(items)
|
||||
|
||||
|
||||
def get_model_id_to_test(
|
||||
model_arch_list: Iterable[str]) -> list[tuple[str, str]]:
|
||||
filtered_results = []
|
||||
for model_arch in model_arch_list:
|
||||
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
|
||||
if model_info.extras and model_arch in ARCH_NEEDS_EXTRAS:
|
||||
available_repos = list(
|
||||
map(lambda model_id: (model_arch, model_id),
|
||||
[model_info.default, *model_info.extras.values()]))
|
||||
filtered_results.extend(available_repos)
|
||||
else:
|
||||
filtered_results.append((model_arch, model_info.default))
|
||||
return filtered_results
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model_arch", list(_MULTIMODAL_EXAMPLE_MODELS.keys()))
|
||||
def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner],
|
||||
monkeypatch):
|
||||
@pytest.mark.parametrize(
|
||||
"model_arch, model_id",
|
||||
get_model_id_to_test(_MULTIMODAL_EXAMPLE_MODELS.keys()))
|
||||
def test_model_tensor_schema(model_arch: str, model_id: str,
|
||||
vllm_runner: type[VllmRunner], monkeypatch):
|
||||
if model_arch in ARCH_TO_SKIP:
|
||||
pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}")
|
||||
if model_id in REPO_ID_TO_SKIP:
|
||||
pytest.skip(f"Skipping {model_id} due to {REPO_ID_TO_SKIP[model_id]}")
|
||||
|
||||
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip",
|
||||
check_max_version=False)
|
||||
|
||||
model_id = model_info.default
|
||||
|
||||
hf_overrides_fn = partial(dummy_hf_overrides,
|
||||
model_arch=model_arch,
|
||||
exist_overrides=model_info.hf_overrides)
|
||||
@ -119,6 +219,7 @@ def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner],
|
||||
if model_info.v0_only:
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
# TODO(Isotr0py): Can we avoid initializing engine?
|
||||
with (
|
||||
set_default_torch_num_threads(1),
|
||||
vllm_runner(
|
||||
@ -145,12 +246,16 @@ def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner],
|
||||
mm_registry = llm_engine.input_preprocessor.mm_registry
|
||||
|
||||
processor = mm_registry.create_processor(model_config)
|
||||
mm_kwargs = create_batched_mm_kwargs(model_config, processor)
|
||||
|
||||
def validate_model_input(model):
|
||||
for modality in ("audio", "image", "video"):
|
||||
method_name = f"_parse_and_validate_{modality}_input"
|
||||
if hasattr(model, method_name):
|
||||
getattr(model, method_name)(**mm_kwargs)
|
||||
def validate_model_input(model, modality: str,
|
||||
mm_kwargs: MultiModalKwargs):
|
||||
method_name = f"_parse_and_validate_{modality}_input"
|
||||
if hasattr(model, method_name):
|
||||
getattr(model, method_name)(**mm_kwargs)
|
||||
|
||||
vllm_model.apply_model(validate_model_input)
|
||||
for modality, _, mm_kwargs in create_batched_mm_kwargs(
|
||||
model_config, processor):
|
||||
valid_func = partial(validate_model_input,
|
||||
modality=modality,
|
||||
mm_kwargs=mm_kwargs)
|
||||
vllm_model.apply_model(valid_func)
|
||||
|
@ -30,7 +30,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
||||
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
||||
MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargs, VideoItem)
|
||||
@ -44,6 +44,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
@ -112,8 +113,9 @@ class KeyeImagePixelInputs(TensorSchema):
|
||||
- g: Grid dimensions (3 for t, h, w)
|
||||
"""
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: Annotated[torch.Tensor,
|
||||
TensorShape("b", "np", 3, "ps", "ps")]
|
||||
pixel_values: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})]
|
||||
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
||||
|
||||
|
||||
@ -145,8 +147,9 @@ class KeyeVideoPixelInputs(TensorSchema):
|
||||
- g: Grid dimensions (3 for t, h, w)
|
||||
"""
|
||||
type: Literal["pixel_values_videos"]
|
||||
pixel_values_videos: Annotated[torch.Tensor,
|
||||
TensorShape("b", "np", 3, "ps", "ps")]
|
||||
pixel_values_videos: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})]
|
||||
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
|
||||
|
||||
|
||||
@ -1295,7 +1298,7 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
|
||||
return None
|
||||
return quant_config
|
||||
|
||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||
def _validate_and_reshape_mm_tensor(self, mm_input: NestedTensors,
|
||||
name: str) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. "
|
||||
@ -1310,8 +1313,11 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
|
||||
f"Got ndim: {mm_input.ndim} "
|
||||
f"(shape={mm_input.shape})")
|
||||
return torch.concat(list(mm_input))
|
||||
else:
|
||||
return torch.concat(mm_input)
|
||||
elif is_list_of(mm_input, torch.Tensor):
|
||||
if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2
|
||||
for p in mm_input):
|
||||
return mm_input
|
||||
return torch.concat(list(mm_input))
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[KeyeImageInputs]:
|
||||
|
Reference in New Issue
Block a user