[Core][Multimodal] Allow passing multi_modal_uuids as multimodal identifiers. (#23394)

Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Roger Wang
2025-08-30 18:01:22 -07:00
committed by GitHub
parent 5b8077b8ac
commit 749be00a98
10 changed files with 455 additions and 54 deletions

View File

@ -13,6 +13,41 @@ To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]:
- `prompt`: The prompt should follow the format that is documented on HuggingFace.
- `multi_modal_data`: This is a dictionary that follows the schema defined in [vllm.multimodal.inputs.MultiModalDataDict][].
### Stable UUIDs for Caching (multi_modal_uuids)
When using multi-modal inputs, vLLM normally hashes each media item by content to enable caching across requests. You can optionally pass `multi_modal_uuids` to provide your own stable IDs for each item so caching can reuse work across requests without rehashing the raw content.
??? code
```python
from vllm import LLM
from PIL import Image
# Qwen2.5-VL example with two images
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct")
prompt = "USER: <image><image>\nDescribe the differences.\nASSISTANT:"
img_a = Image.open("/path/to/a.jpg")
img_b = Image.open("/path/to/b.jpg")
outputs = llm.generate({
"prompt": prompt,
"multi_modal_data": {"image": [img_a, img_b]},
# Provide stable IDs for caching.
# Requirements (matched by this example):
# - Include every modality present in multi_modal_data.
# - For lists, provide the same number of entries.
# - Use None to fall back to content hashing for that item.
"multi_modal_uuids": {"image": ["sku-1234-a", None]},
})
for o in outputs:
print(o.outputs[0].text)
```
!!! warning
If both multimodal processor caching and prefix caching are disabled, user-provided `multi_modal_uuids` are ignored.
### Image Inputs
You can pass a single image to the `'image'` field of the multi-modal dictionary, as shown in the following examples:

View File

@ -0,0 +1,229 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig
from vllm.platforms.interface import UnspecifiedPlatform
from vllm.sampling_params import SamplingParams
from vllm.v1.engine import processor as processor_mod
from vllm.v1.engine.processor import Processor
cherry_pil_image = ImageAsset("cherry_blossom").pil_image
stop_pil_image = ImageAsset("stop_sign").pil_image
baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays
# Mock processor for testing
def _mk_processor(monkeypatch,
*,
mm_cache_gb: float = 4.0,
enable_prefix_caching: bool = True) -> Processor:
"""
Create a Processor instance with minimal configuration suitable for unit
tests without accessing external resources.
"""
monkeypatch.setattr(ModelConfig,
"try_get_generation_config",
lambda self: {},
raising=True)
monkeypatch.setattr(ModelConfig,
"__post_init__",
lambda self: None,
raising=True)
monkeypatch.setattr(UnspecifiedPlatform,
"is_async_output_supported",
classmethod(lambda cls, enforce_eager: True),
raising=True)
monkeypatch.setattr(
ModelConfig,
"verify_async_output_proc",
lambda self, parallel_config, speculative_config, device_config: None,
raising=True)
monkeypatch.setattr(ModelConfig,
"verify_with_parallel_config",
lambda self, parallel_config: None,
raising=True)
monkeypatch.setattr(processor_mod,
"processor_cache_from_config",
lambda vllm_config, mm_registry: None,
raising=True)
monkeypatch.setattr(VllmConfig,
"__post_init__",
lambda self: None,
raising=True)
model_config = ModelConfig(
skip_tokenizer_init=True,
max_model_len=128,
mm_processor_cache_gb=mm_cache_gb,
generation_config="vllm",
tokenizer="dummy",
)
# Minimal multimodal_config to satisfy references in
# Processor.process_inputs.
class _MockMMConfig:
def __init__(self, gb: float):
self.mm_processor_cache_gb = gb
model_config.multimodal_config = _MockMMConfig(
mm_cache_gb) # type: ignore[attr-defined]
vllm_config = VllmConfig(
model_config=model_config,
cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching),
device_config=DeviceConfig(device="cpu"),
)
# Pass tokenizer=None; InputPreprocessor handles None when
# skip_tokenizer_init is True.
return Processor(vllm_config, tokenizer=None) # type: ignore[arg-type]
def test_multi_modal_uuids_length_mismatch_raises(monkeypatch):
processor = _mk_processor(monkeypatch)
prompt = {
"prompt": "USER: <image>\nDescribe\nASSISTANT:",
"multi_modal_data": {
"image": [cherry_pil_image, stop_pil_image]
},
# Mismatch: 2 items but only 1 uuid provided
"multi_modal_uuids": {
"image": ["hash_cherry"]
},
}
with pytest.raises(ValueError, match="must have same length as data"):
processor.process_inputs(
request_id="req-1",
prompt=prompt, # type: ignore[arg-type]
params=SamplingParams(),
)
def test_multi_modal_uuids_missing_modality_raises(monkeypatch):
processor = _mk_processor(monkeypatch)
prompt = {
"prompt": "USER: <image><video>\nDescribe\nASSISTANT:",
# Two modalities provided in data
"multi_modal_data": {
"image": [cherry_pil_image],
"video": [baby_reading_np_ndarrays]
},
# Only image uuids provided; video missing should raise
"multi_modal_uuids": {
"image": ["hash_cherry"]
},
}
with pytest.raises(ValueError,
match="must be provided if multi_modal_data"):
processor.process_inputs(
request_id="req-2",
prompt=prompt, # type: ignore[arg-type]
params=SamplingParams(),
)
@pytest.mark.parametrize(
"mm_cache_gb, enable_prefix_caching",
[
(4.0, True), # default behavior
(4.0, False), # prefix caching disabled
(0.0, True), # processor cache disabled
],
)
def test_multi_modal_uuids_accepts_none_and_passes_through(
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool):
processor = _mk_processor(monkeypatch,
mm_cache_gb=mm_cache_gb,
enable_prefix_caching=enable_prefix_caching)
# Capture the overrides passed to InputPreprocessor.preprocess
captured: dict[str, object] = {}
def fake_preprocess(prompt,
*,
tokenization_kwargs=None,
lora_request=None,
mm_hash_overrides=None):
captured["mm_hash_overrides"] = mm_hash_overrides
# Minimal processed inputs for decoder-only flow
return {"type": "token", "prompt_token_ids": [1]}
# Monkeypatch only the bound preprocess method on this instance
monkeypatch.setattr(processor.input_preprocessor,
"preprocess",
fake_preprocess,
raising=True)
# Use a consistent two-image scenario across all configurations
mm_uuids = {"image": [None, "hash_stop"], "video": None}
prompt = {
"prompt": "USER: <image><image>\nTwo images\nASSISTANT:",
"multi_modal_data": {
"image": [cherry_pil_image, stop_pil_image],
"video": baby_reading_np_ndarrays,
},
"multi_modal_uuids": mm_uuids,
}
processor.process_inputs(
request_id="req-3",
prompt=prompt, # type: ignore[arg-type]
params=SamplingParams(),
)
assert captured["mm_hash_overrides"] == mm_uuids
def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
# When both processor cache is 0 and prefix caching disabled, the
# processor builds overrides from request id instead of using user UUIDs.
processor = _mk_processor(monkeypatch,
mm_cache_gb=0.0,
enable_prefix_caching=False)
captured: dict[str, object] = {}
def fake_preprocess(prompt,
*,
tokenization_kwargs=None,
lora_request=None,
mm_hash_overrides=None):
captured["mm_hash_overrides"] = mm_hash_overrides
return {"type": "token", "prompt_token_ids": [1]}
monkeypatch.setattr(processor.input_preprocessor,
"preprocess",
fake_preprocess,
raising=True)
request_id = "req-42"
mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": "hash_video"}
prompt = {
"prompt": "USER: <image><image><video>\nDescribe\nASSISTANT:",
"multi_modal_data": {
"image": [cherry_pil_image, stop_pil_image],
"video": baby_reading_np_ndarrays,
},
"multi_modal_uuids": mm_uuids,
}
processor.process_inputs(
request_id=request_id,
prompt=prompt, # type: ignore[arg-type]
params=SamplingParams(),
)
# Expect request-id-based overrides are passed through
assert captured["mm_hash_overrides"] == {
"image": [f"{request_id}-image-0", f"{request_id}-image-1"],
"video": [f"{request_id}-video-0"],
}

View File

@ -67,7 +67,7 @@ from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
MultiModalDataDict)
MultiModalDataDict, MultiModalUUIDDict)
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import BeamSearchParams, SamplingParams

View File

@ -7,7 +7,8 @@ import torch
from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar
if TYPE_CHECKING:
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalInputs,
MultiModalUUIDDict)
class TextPrompt(TypedDict):
@ -30,6 +31,15 @@ class TextPrompt(TypedDict):
to pass the mm_processor_kwargs to each of them.
"""
multi_modal_uuids: NotRequired["MultiModalUUIDDict"]
"""
Optional user-specified UUIDs for multimodal items, mapped by modality.
Lists must match the number of items per modality and may contain `None`.
For `None` entries, the hasher will compute IDs automatically; non-None
entries override the default hashes for caching, and MUST be unique per
multimodal item.
"""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
@ -59,6 +69,14 @@ class TokensPrompt(TypedDict):
to pass the mm_processor_kwargs to each of them.
"""
multi_modal_uuids: NotRequired["MultiModalUUIDDict"]
"""
Optional user-specified UUIDs for multimodal items, mapped by modality.
Lists must match the number of items per modality and may contain `None`.
For `None` entries, the hasher will compute IDs automatically; non-None
entries override the default hashes for caching.
"""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.

View File

@ -13,7 +13,7 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs)
MultiModalInputs, MultiModalUUIDDict)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
@ -258,7 +258,8 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> MultiModalInputs:
"""
Apply the model's multi-modal processor to a multi-modal prompt,
@ -291,7 +292,8 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> MultiModalInputs:
"""
Async version of
@ -368,7 +370,8 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = self._truncate_inputs(
parsed_content["prompt_token_ids"], tokenization_kwargs)
@ -397,7 +400,8 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = self._truncate_inputs(
parsed_content["prompt_token_ids"], tokenization_kwargs)
@ -426,7 +430,8 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"]
@ -462,7 +467,8 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"]
@ -498,7 +504,8 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> SingletonInputs:
"""
Extract the singleton inputs from a prompt.
@ -545,7 +552,8 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> SingletonInputs:
"""
Async version of
@ -684,7 +692,8 @@ class InputPreprocessor:
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> EncoderDecoderInputs:
"""
For encoder/decoder models only:
@ -759,7 +768,8 @@ class InputPreprocessor:
prompt: PromptType,
tokenization_kwargs: Optional[dict[str, Any]] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> EncoderDecoderInputs:
"""
Async version of
@ -826,7 +836,8 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> DecoderOnlyInputs:
"""
For decoder-only models:
@ -858,7 +869,8 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> DecoderOnlyInputs:
"""
Async version of
@ -879,7 +891,8 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> ProcessorInputs:
"""Preprocess the input prompt."""
if self.model_config.is_encoder_decoder:
@ -909,7 +922,8 @@ class InputPreprocessor:
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> ProcessorInputs:
"""
Async version of

View File

@ -1,11 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .base import MultiModalPlaceholderMap
from .hasher import MultiModalHashDict, MultiModalHasher
from .hasher import MultiModalHasher
from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins,
MultiModalDataDict, MultiModalKwargs,
MultiModalKwargsItems, MultiModalPlaceholderDict,
NestedTensors)
MultiModalUUIDDict, NestedTensors)
from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry()
@ -23,12 +23,12 @@ __all__ = [
"ModalityData",
"MultiModalDataBuiltins",
"MultiModalDataDict",
"MultiModalHashDict",
"MultiModalHasher",
"MultiModalKwargs",
"MultiModalKwargsItems",
"MultiModalPlaceholderDict",
"MultiModalPlaceholderMap",
"MultiModalUUIDDict",
"NestedTensors",
"MULTIMODAL_REGISTRY",
"MultiModalRegistry",

View File

@ -3,7 +3,7 @@
import pickle
import uuid
from collections.abc import Iterable, Mapping
from collections.abc import Iterable
from typing import Union
import numpy as np
@ -16,11 +16,6 @@ from vllm.multimodal.image import convert_image_mode
logger = init_logger(__name__)
MultiModalHashDict = Mapping[str, list[str]]
"""
A dictionary containing hashes for items in each modality.
"""
class MultiModalHasher:

View File

@ -22,7 +22,8 @@ if TYPE_CHECKING:
from PIL.Image import Image
from transformers.feature_extraction_utils import BatchFeature
from .hasher import MultiModalHashDict
from .processing import MultiModalHashes
else:
torch = LazyLoader("torch", globals(), "torch")
@ -115,6 +116,16 @@ The built-in modalities are defined by
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
"""
MultiModalUUIDDict: TypeAlias = Mapping[str, Union[list[Optional[str]], str]]
"""
A dictionary containing user-provided UUIDs for items in each modality.
If a UUID for an item is not provided, its entry will be `None` and
MultiModalHasher will compute a hash for the item.
The UUID will be used to identify the item for all caching purposes
(input processing caching, embedding caching, prefix caching, etc).
"""
@dataclass(frozen=True)
class PlaceholderRange:
@ -939,7 +950,7 @@ class MultiModalInputs(TypedDict):
mm_kwargs: MultiModalKwargsOptionalItems
"""Keyword arguments to be directly passed to the model after batching."""
mm_hashes: "MultiModalHashDict"
mm_hashes: "MultiModalHashes"
"""The hashes of the multi-modal data."""
mm_placeholders: "MultiModalPlaceholderDict"

View File

@ -24,7 +24,8 @@ from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalInputs,
MultiModalKwargsItem, MultiModalKwargsItems,
MultiModalKwargsOptionalItems, PlaceholderRange)
MultiModalKwargsOptionalItems, MultiModalUUIDDict,
PlaceholderRange)
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
MultiModalDataParser)
@ -1021,7 +1022,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
*,
mm_hash_overrides: Optional[MultiModalHashes] = None,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> MultiModalInputs:
return self.apply(prompt,
mm_data,
@ -1361,24 +1363,62 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
*,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> MultiModalHashes:
"""Create MM hashes to be returned (only used in V1).
Note: When overrides are provided via callers of `apply`,
`_hash_mm_items` will be bypassed and the overrides will be used.
"""
model_id = self.info.model_id
return {
modality: [
MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs,
**tokenization_kwargs)
for item in items
]
for modality, items in mm_items.items()
}
hashes: MultiModalHashes = {}
mm_hash_overrides = mm_hash_overrides or {}
for modality, items in mm_items.items():
if modality in mm_hash_overrides:
mm_hashes = mm_hash_overrides[modality]
if isinstance(mm_hashes, str):
mm_hashes = [mm_hashes]
# For None entries, compute a hash; otherwise, use provided ID.
computed: list[str] = []
for i, item in enumerate(items):
mm_hash = mm_hashes[i]
# NOTE: Even if a mm_hash is provided, we still compute a
# hash if `hf_processor_mm_kwargs` or `tokenization_kwargs`
# are provided. This is because the processed multimodal
# inputs can be different depending on the processor kwargs.
if mm_hash is None or \
hf_processor_mm_kwargs or \
tokenization_kwargs:
# NOTE: use provided hash string to hash with kwargs
# if available for better performance.
item = mm_hash if mm_hash is not None else item
computed.append(
MultiModalHasher.hash_kwargs(
model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs,
**tokenization_kwargs))
else:
computed.append(mm_hash)
hashes[modality] = computed
else:
hashes[modality] = [
MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs,
**tokenization_kwargs)
for item in items
]
return hashes
def _get_cache_missing_items(
self,
@ -1474,7 +1514,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
*,
mm_hash_overrides: Optional[MultiModalHashes] = None,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
(
prompt_ids,
@ -1495,9 +1536,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
)
# Use overrides if provided; fallback to data-dependent hashing.
mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else
self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs))
mm_hashes = self._hash_mm_items(mm_data_items,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides)
mm_prompt_updates = self._get_mm_prompt_updates(
mm_data_items,
@ -1520,7 +1562,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
*,
mm_hash_overrides: Optional[MultiModalHashes] = None,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
"""
Apply the HF processor on the full prompt text,
@ -1538,10 +1581,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_hash_overrides=mm_hash_overrides,
)
# Use overrides if provided; fallback to data-dependent hashing.
mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else
self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs))
mm_hashes = self._hash_mm_items(mm_data_items,
hf_processor_mm_kwargs,
tokenization_kwargs,
mm_hash_overrides=mm_hash_overrides)
mm_missing_data_items = self._get_cache_missing_items(
cache=cache,
@ -1742,7 +1785,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> MultiModalInputs:
"""
Process multi-modal inputs to be used in vLLM.
@ -1857,7 +1901,8 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Optional[Mapping[str, object]] = None,
*,
mm_hash_overrides: Optional[MultiModalHashes] = None,
mm_hash_overrides: Optional[Union[dict[str, list[str]],
MultiModalUUIDDict]] = None,
) -> MultiModalEncDecInputs:
"""
Process multi-modal inputs to be used in vLLM.

View File

@ -150,6 +150,49 @@ class Processor:
self._validate_sampling_params(params, lora_request)
self._validate_supported_sampling_params(params)
def _validate_multi_modal_uuids(self, prompt: PromptType) -> None:
"""
Validate that user-provided multi_modal_uuids align with
multi_modal_data in the incoming request prompt(s).
Only checks lengths; `None` entries are allowed and will be
auto-hashed downstream.
"""
def _validate_single_prompt(single_prompt: Union[dict, str]) -> None:
if not isinstance(single_prompt, dict):
return
mm_data = single_prompt.get("multi_modal_data")
mm_uuids = single_prompt.get("multi_modal_uuids")
if not mm_data or not mm_uuids:
return
for modality, items in mm_data.items():
if modality in mm_uuids:
data_len = len(items) if isinstance(items, list) else 1
uuid_len = len(mm_uuids[modality]) if isinstance(
mm_uuids[modality], list) else 1
if uuid_len != data_len:
raise ValueError(
f"multi_modal_uuids for modality '{modality}' "
"must have same length as data: got "
f"{uuid_len} uuids vs "
f"{data_len} items.")
else:
raise ValueError(
f"multi_modal_uuids for modality '{modality}' must "
"be provided if multi_modal_data is provided.")
# Handle explicit encoder/decoder prompts or singleton prompt
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
enc = prompt.get("encoder_prompt")
dec = prompt.get("decoder_prompt")
if enc is not None:
_validate_single_prompt(enc)
if dec is not None:
_validate_single_prompt(dec)
else:
_validate_single_prompt(prompt) # type: ignore[arg-type]
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
@ -289,17 +332,27 @@ class Processor:
if arrival_time is None:
arrival_time = time.time()
# Optionally generate multimodal hash overrides based on request id.
# Optionally generate multimodal hash overrides to avoid hashing
# multimodal data items by their content as their identifiers.
# NOTE: when users explicitly turn off BOTH prefix caching and input
# processing caching, no multimodal features or embeddings will be
# reused across requests, therefore hashing is no longer necessary.
# reused across requests, therefore identifying multimodal data items
# by their content is no longer necessary, and we create uuids with
# request id-modality-index as multimodal hash overrides.
if (self.model_config.multimodal_config and
self.model_config.multimodal_config.mm_processor_cache_gb == 0
and not self.cache_config.enable_prefix_caching):
mm_hash_overrides = self._maybe_build_mm_hash_overrides(
request_id, prompt)
else:
mm_hash_overrides = None
# Otherwise, use user-provided uuids as multimodal hash overrides
# if provided.
self._validate_multi_modal_uuids(prompt)
if isinstance(prompt, dict):
mm_hash_overrides = prompt.get("multi_modal_uuids")
else:
mm_hash_overrides = None
# Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists.
@ -317,6 +370,7 @@ class Processor:
params=params,
processed_inputs=processed_inputs,
)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
self._validate_model_inputs(processed_inputs, lora_request)