[Bugfix] Check that number of images matches number of <|image|> tokens with mllama (#13911)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
Travis Johnson
2025-02-27 21:00:45 -07:00
committed by GitHub
parent 6c85da3a18
commit 73e0225ee9
2 changed files with 26 additions and 3 deletions

View File

@ -479,8 +479,9 @@ def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
# Regression tests for https://github.com/vllm-project/vllm/issues/10648
# Number of image tags is greater than the number of images provided
prompt = "<|begin_of_text|><|image|><|image|> Compare the two images" # noqa: E501
# Number of groups of image tokens is greater than the number of images
# provided (the whitespace between the tags is necessary)
prompt = "<|begin_of_text|><|image|> <|image|> Compare the two images" # noqa: E501
image = stop_sign
with pytest.raises(ValueError):
vllm_model.generate_greedy_logprobs([prompt],

View File

@ -54,7 +54,8 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.inputs import (MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalKwargs)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataDict, MultiModalDataItems)
from vllm.multimodal.processing import (BaseProcessingInfo,
@ -169,6 +170,27 @@ class MllamaDummyInputsBuilder(BaseDummyInputsBuilder[MllamaProcessingInfo]):
class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
):
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalEncDecInputs:
mm_inputs = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Check that the number of image tokens in the decoder prompt matches
# the number of images provided in mm_data
num_image_tokens = mm_inputs['prompt_token_ids'].count(
self.info.get_hf_config().image_token_index)
image_data = mm_data.get("image", [])
num_images = 1 if isinstance(image_data, Image) else len(image_data)
if num_image_tokens != num_images:
raise ValueError(
f"The number of image tokens ({num_image_tokens}) must be"
f" the same as the number of images ({num_images})")
return mm_inputs
def _call_hf_processor(
self,
prompt: str,