[Core][Multimodal] Allow passing multi_modal_uuids
as multimodal identifiers. (#23394)
Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@ -13,6 +13,41 @@ To input multi-modal data, follow this schema in [vllm.inputs.PromptType][]:
|
||||
- `prompt`: The prompt should follow the format that is documented on HuggingFace.
|
||||
- `multi_modal_data`: This is a dictionary that follows the schema defined in [vllm.multimodal.inputs.MultiModalDataDict][].
|
||||
|
||||
### Stable UUIDs for Caching (multi_modal_uuids)
|
||||
|
||||
When using multi-modal inputs, vLLM normally hashes each media item by content to enable caching across requests. You can optionally pass `multi_modal_uuids` to provide your own stable IDs for each item so caching can reuse work across requests without rehashing the raw content.
|
||||
|
||||
??? code
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
from PIL import Image
|
||||
|
||||
# Qwen2.5-VL example with two images
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct")
|
||||
|
||||
prompt = "USER: <image><image>\nDescribe the differences.\nASSISTANT:"
|
||||
img_a = Image.open("/path/to/a.jpg")
|
||||
img_b = Image.open("/path/to/b.jpg")
|
||||
|
||||
outputs = llm.generate({
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {"image": [img_a, img_b]},
|
||||
# Provide stable IDs for caching.
|
||||
# Requirements (matched by this example):
|
||||
# - Include every modality present in multi_modal_data.
|
||||
# - For lists, provide the same number of entries.
|
||||
# - Use None to fall back to content hashing for that item.
|
||||
"multi_modal_uuids": {"image": ["sku-1234-a", None]},
|
||||
})
|
||||
|
||||
for o in outputs:
|
||||
print(o.outputs[0].text)
|
||||
```
|
||||
|
||||
!!! warning
|
||||
If both multimodal processor caching and prefix caching are disabled, user-provided `multi_modal_uuids` are ignored.
|
||||
|
||||
### Image Inputs
|
||||
|
||||
You can pass a single image to the `'image'` field of the multi-modal dictionary, as shown in the following examples:
|
||||
|
229
tests/v1/engine/test_processor_multi_modal_uuids.py
Normal file
229
tests/v1/engine/test_processor_multi_modal_uuids.py
Normal file
@ -0,0 +1,229 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig
|
||||
from vllm.platforms.interface import UnspecifiedPlatform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.engine import processor as processor_mod
|
||||
from vllm.v1.engine.processor import Processor
|
||||
|
||||
cherry_pil_image = ImageAsset("cherry_blossom").pil_image
|
||||
stop_pil_image = ImageAsset("stop_sign").pil_image
|
||||
baby_reading_np_ndarrays = VideoAsset("baby_reading").np_ndarrays
|
||||
|
||||
|
||||
# Mock processor for testing
|
||||
def _mk_processor(monkeypatch,
|
||||
*,
|
||||
mm_cache_gb: float = 4.0,
|
||||
enable_prefix_caching: bool = True) -> Processor:
|
||||
"""
|
||||
Create a Processor instance with minimal configuration suitable for unit
|
||||
tests without accessing external resources.
|
||||
"""
|
||||
monkeypatch.setattr(ModelConfig,
|
||||
"try_get_generation_config",
|
||||
lambda self: {},
|
||||
raising=True)
|
||||
monkeypatch.setattr(ModelConfig,
|
||||
"__post_init__",
|
||||
lambda self: None,
|
||||
raising=True)
|
||||
monkeypatch.setattr(UnspecifiedPlatform,
|
||||
"is_async_output_supported",
|
||||
classmethod(lambda cls, enforce_eager: True),
|
||||
raising=True)
|
||||
monkeypatch.setattr(
|
||||
ModelConfig,
|
||||
"verify_async_output_proc",
|
||||
lambda self, parallel_config, speculative_config, device_config: None,
|
||||
raising=True)
|
||||
monkeypatch.setattr(ModelConfig,
|
||||
"verify_with_parallel_config",
|
||||
lambda self, parallel_config: None,
|
||||
raising=True)
|
||||
monkeypatch.setattr(processor_mod,
|
||||
"processor_cache_from_config",
|
||||
lambda vllm_config, mm_registry: None,
|
||||
raising=True)
|
||||
|
||||
monkeypatch.setattr(VllmConfig,
|
||||
"__post_init__",
|
||||
lambda self: None,
|
||||
raising=True)
|
||||
|
||||
model_config = ModelConfig(
|
||||
skip_tokenizer_init=True,
|
||||
max_model_len=128,
|
||||
mm_processor_cache_gb=mm_cache_gb,
|
||||
generation_config="vllm",
|
||||
tokenizer="dummy",
|
||||
)
|
||||
|
||||
# Minimal multimodal_config to satisfy references in
|
||||
# Processor.process_inputs.
|
||||
class _MockMMConfig:
|
||||
|
||||
def __init__(self, gb: float):
|
||||
self.mm_processor_cache_gb = gb
|
||||
|
||||
model_config.multimodal_config = _MockMMConfig(
|
||||
mm_cache_gb) # type: ignore[attr-defined]
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=CacheConfig(enable_prefix_caching=enable_prefix_caching),
|
||||
device_config=DeviceConfig(device="cpu"),
|
||||
)
|
||||
|
||||
# Pass tokenizer=None; InputPreprocessor handles None when
|
||||
# skip_tokenizer_init is True.
|
||||
return Processor(vllm_config, tokenizer=None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
def test_multi_modal_uuids_length_mismatch_raises(monkeypatch):
|
||||
processor = _mk_processor(monkeypatch)
|
||||
|
||||
prompt = {
|
||||
"prompt": "USER: <image>\nDescribe\nASSISTANT:",
|
||||
"multi_modal_data": {
|
||||
"image": [cherry_pil_image, stop_pil_image]
|
||||
},
|
||||
# Mismatch: 2 items but only 1 uuid provided
|
||||
"multi_modal_uuids": {
|
||||
"image": ["hash_cherry"]
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="must have same length as data"):
|
||||
processor.process_inputs(
|
||||
request_id="req-1",
|
||||
prompt=prompt, # type: ignore[arg-type]
|
||||
params=SamplingParams(),
|
||||
)
|
||||
|
||||
|
||||
def test_multi_modal_uuids_missing_modality_raises(monkeypatch):
|
||||
processor = _mk_processor(monkeypatch)
|
||||
|
||||
prompt = {
|
||||
"prompt": "USER: <image><video>\nDescribe\nASSISTANT:",
|
||||
# Two modalities provided in data
|
||||
"multi_modal_data": {
|
||||
"image": [cherry_pil_image],
|
||||
"video": [baby_reading_np_ndarrays]
|
||||
},
|
||||
# Only image uuids provided; video missing should raise
|
||||
"multi_modal_uuids": {
|
||||
"image": ["hash_cherry"]
|
||||
},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match="must be provided if multi_modal_data"):
|
||||
processor.process_inputs(
|
||||
request_id="req-2",
|
||||
prompt=prompt, # type: ignore[arg-type]
|
||||
params=SamplingParams(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mm_cache_gb, enable_prefix_caching",
|
||||
[
|
||||
(4.0, True), # default behavior
|
||||
(4.0, False), # prefix caching disabled
|
||||
(0.0, True), # processor cache disabled
|
||||
],
|
||||
)
|
||||
def test_multi_modal_uuids_accepts_none_and_passes_through(
|
||||
monkeypatch, mm_cache_gb: float, enable_prefix_caching: bool):
|
||||
processor = _mk_processor(monkeypatch,
|
||||
mm_cache_gb=mm_cache_gb,
|
||||
enable_prefix_caching=enable_prefix_caching)
|
||||
|
||||
# Capture the overrides passed to InputPreprocessor.preprocess
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def fake_preprocess(prompt,
|
||||
*,
|
||||
tokenization_kwargs=None,
|
||||
lora_request=None,
|
||||
mm_hash_overrides=None):
|
||||
captured["mm_hash_overrides"] = mm_hash_overrides
|
||||
# Minimal processed inputs for decoder-only flow
|
||||
return {"type": "token", "prompt_token_ids": [1]}
|
||||
|
||||
# Monkeypatch only the bound preprocess method on this instance
|
||||
monkeypatch.setattr(processor.input_preprocessor,
|
||||
"preprocess",
|
||||
fake_preprocess,
|
||||
raising=True)
|
||||
|
||||
# Use a consistent two-image scenario across all configurations
|
||||
mm_uuids = {"image": [None, "hash_stop"], "video": None}
|
||||
prompt = {
|
||||
"prompt": "USER: <image><image>\nTwo images\nASSISTANT:",
|
||||
"multi_modal_data": {
|
||||
"image": [cherry_pil_image, stop_pil_image],
|
||||
"video": baby_reading_np_ndarrays,
|
||||
},
|
||||
"multi_modal_uuids": mm_uuids,
|
||||
}
|
||||
|
||||
processor.process_inputs(
|
||||
request_id="req-3",
|
||||
prompt=prompt, # type: ignore[arg-type]
|
||||
params=SamplingParams(),
|
||||
)
|
||||
|
||||
assert captured["mm_hash_overrides"] == mm_uuids
|
||||
|
||||
|
||||
def test_multi_modal_uuids_ignored_when_caching_disabled(monkeypatch):
|
||||
# When both processor cache is 0 and prefix caching disabled, the
|
||||
# processor builds overrides from request id instead of using user UUIDs.
|
||||
processor = _mk_processor(monkeypatch,
|
||||
mm_cache_gb=0.0,
|
||||
enable_prefix_caching=False)
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def fake_preprocess(prompt,
|
||||
*,
|
||||
tokenization_kwargs=None,
|
||||
lora_request=None,
|
||||
mm_hash_overrides=None):
|
||||
captured["mm_hash_overrides"] = mm_hash_overrides
|
||||
return {"type": "token", "prompt_token_ids": [1]}
|
||||
|
||||
monkeypatch.setattr(processor.input_preprocessor,
|
||||
"preprocess",
|
||||
fake_preprocess,
|
||||
raising=True)
|
||||
|
||||
request_id = "req-42"
|
||||
mm_uuids = {"image": ["hash_cherry", "hash_stop"], "video": "hash_video"}
|
||||
prompt = {
|
||||
"prompt": "USER: <image><image><video>\nDescribe\nASSISTANT:",
|
||||
"multi_modal_data": {
|
||||
"image": [cherry_pil_image, stop_pil_image],
|
||||
"video": baby_reading_np_ndarrays,
|
||||
},
|
||||
"multi_modal_uuids": mm_uuids,
|
||||
}
|
||||
|
||||
processor.process_inputs(
|
||||
request_id=request_id,
|
||||
prompt=prompt, # type: ignore[arg-type]
|
||||
params=SamplingParams(),
|
||||
)
|
||||
|
||||
# Expect request-id-based overrides are passed through
|
||||
assert captured["mm_hash_overrides"] == {
|
||||
"image": [f"{request_id}-image-0", f"{request_id}-image-1"],
|
||||
"video": [f"{request_id}-video-0"],
|
||||
}
|
@ -67,7 +67,7 @@ from vllm.inputs.parse import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
|
||||
MultiModalDataDict)
|
||||
MultiModalDataDict, MultiModalUUIDDict)
|
||||
from vllm.outputs import PoolingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
|
@ -7,7 +7,8 @@ import torch
|
||||
from typing_extensions import NotRequired, TypedDict, TypeIs, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalInputs,
|
||||
MultiModalUUIDDict)
|
||||
|
||||
|
||||
class TextPrompt(TypedDict):
|
||||
@ -30,6 +31,15 @@ class TextPrompt(TypedDict):
|
||||
to pass the mm_processor_kwargs to each of them.
|
||||
"""
|
||||
|
||||
multi_modal_uuids: NotRequired["MultiModalUUIDDict"]
|
||||
"""
|
||||
Optional user-specified UUIDs for multimodal items, mapped by modality.
|
||||
Lists must match the number of items per modality and may contain `None`.
|
||||
For `None` entries, the hasher will compute IDs automatically; non-None
|
||||
entries override the default hashes for caching, and MUST be unique per
|
||||
multimodal item.
|
||||
"""
|
||||
|
||||
cache_salt: NotRequired[str]
|
||||
"""
|
||||
Optional cache salt to be used for prefix caching.
|
||||
@ -59,6 +69,14 @@ class TokensPrompt(TypedDict):
|
||||
to pass the mm_processor_kwargs to each of them.
|
||||
"""
|
||||
|
||||
multi_modal_uuids: NotRequired["MultiModalUUIDDict"]
|
||||
"""
|
||||
Optional user-specified UUIDs for multimodal items, mapped by modality.
|
||||
Lists must match the number of items per modality and may contain `None`.
|
||||
For `None` entries, the hasher will compute IDs automatically; non-None
|
||||
entries override the default hashes for caching.
|
||||
"""
|
||||
|
||||
cache_salt: NotRequired[str]
|
||||
"""
|
||||
Optional cache salt to be used for prefix caching.
|
||||
|
@ -13,7 +13,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
MultiModalInputs)
|
||||
MultiModalInputs, MultiModalUUIDDict)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
|
||||
@ -258,7 +258,8 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
Apply the model's multi-modal processor to a multi-modal prompt,
|
||||
@ -291,7 +292,8 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
Async version of
|
||||
@ -368,7 +370,8 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_token_ids = self._truncate_inputs(
|
||||
parsed_content["prompt_token_ids"], tokenization_kwargs)
|
||||
@ -397,7 +400,8 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_token_ids = self._truncate_inputs(
|
||||
parsed_content["prompt_token_ids"], tokenization_kwargs)
|
||||
@ -426,7 +430,8 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_text = parsed_content["prompt"]
|
||||
|
||||
@ -462,7 +467,8 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_text = parsed_content["prompt"]
|
||||
|
||||
@ -498,7 +504,8 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
) -> SingletonInputs:
|
||||
"""
|
||||
Extract the singleton inputs from a prompt.
|
||||
@ -545,7 +552,8 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
) -> SingletonInputs:
|
||||
"""
|
||||
Async version of
|
||||
@ -684,7 +692,8 @@ class InputPreprocessor:
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
) -> EncoderDecoderInputs:
|
||||
"""
|
||||
For encoder/decoder models only:
|
||||
@ -759,7 +768,8 @@ class InputPreprocessor:
|
||||
prompt: PromptType,
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
) -> EncoderDecoderInputs:
|
||||
"""
|
||||
Async version of
|
||||
@ -826,7 +836,8 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
) -> DecoderOnlyInputs:
|
||||
"""
|
||||
For decoder-only models:
|
||||
@ -858,7 +869,8 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
) -> DecoderOnlyInputs:
|
||||
"""
|
||||
Async version of
|
||||
@ -879,7 +891,8 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
) -> ProcessorInputs:
|
||||
"""Preprocess the input prompt."""
|
||||
if self.model_config.is_encoder_decoder:
|
||||
@ -909,7 +922,8 @@ class InputPreprocessor:
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
) -> ProcessorInputs:
|
||||
"""
|
||||
Async version of
|
||||
|
@ -1,11 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from .base import MultiModalPlaceholderMap
|
||||
from .hasher import MultiModalHashDict, MultiModalHasher
|
||||
from .hasher import MultiModalHasher
|
||||
from .inputs import (BatchedTensorInputs, ModalityData, MultiModalDataBuiltins,
|
||||
MultiModalDataDict, MultiModalKwargs,
|
||||
MultiModalKwargsItems, MultiModalPlaceholderDict,
|
||||
NestedTensors)
|
||||
MultiModalUUIDDict, NestedTensors)
|
||||
from .registry import MultiModalRegistry
|
||||
|
||||
MULTIMODAL_REGISTRY = MultiModalRegistry()
|
||||
@ -23,12 +23,12 @@ __all__ = [
|
||||
"ModalityData",
|
||||
"MultiModalDataBuiltins",
|
||||
"MultiModalDataDict",
|
||||
"MultiModalHashDict",
|
||||
"MultiModalHasher",
|
||||
"MultiModalKwargs",
|
||||
"MultiModalKwargsItems",
|
||||
"MultiModalPlaceholderDict",
|
||||
"MultiModalPlaceholderMap",
|
||||
"MultiModalUUIDDict",
|
||||
"NestedTensors",
|
||||
"MULTIMODAL_REGISTRY",
|
||||
"MultiModalRegistry",
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
import pickle
|
||||
import uuid
|
||||
from collections.abc import Iterable, Mapping
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
@ -16,11 +16,6 @@ from vllm.multimodal.image import convert_image_mode
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MultiModalHashDict = Mapping[str, list[str]]
|
||||
"""
|
||||
A dictionary containing hashes for items in each modality.
|
||||
"""
|
||||
|
||||
|
||||
class MultiModalHasher:
|
||||
|
||||
|
@ -22,7 +22,8 @@ if TYPE_CHECKING:
|
||||
from PIL.Image import Image
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
|
||||
from .hasher import MultiModalHashDict
|
||||
from .processing import MultiModalHashes
|
||||
|
||||
else:
|
||||
torch = LazyLoader("torch", globals(), "torch")
|
||||
|
||||
@ -115,6 +116,16 @@ The built-in modalities are defined by
|
||||
[`MultiModalDataBuiltins`][vllm.multimodal.inputs.MultiModalDataBuiltins].
|
||||
"""
|
||||
|
||||
MultiModalUUIDDict: TypeAlias = Mapping[str, Union[list[Optional[str]], str]]
|
||||
"""
|
||||
A dictionary containing user-provided UUIDs for items in each modality.
|
||||
If a UUID for an item is not provided, its entry will be `None` and
|
||||
MultiModalHasher will compute a hash for the item.
|
||||
|
||||
The UUID will be used to identify the item for all caching purposes
|
||||
(input processing caching, embedding caching, prefix caching, etc).
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PlaceholderRange:
|
||||
@ -939,7 +950,7 @@ class MultiModalInputs(TypedDict):
|
||||
mm_kwargs: MultiModalKwargsOptionalItems
|
||||
"""Keyword arguments to be directly passed to the model after batching."""
|
||||
|
||||
mm_hashes: "MultiModalHashDict"
|
||||
mm_hashes: "MultiModalHashes"
|
||||
"""The hashes of the multi-modal data."""
|
||||
|
||||
mm_placeholders: "MultiModalPlaceholderDict"
|
||||
|
@ -24,7 +24,8 @@ from .hasher import MultiModalHasher
|
||||
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
MultiModalFieldConfig, MultiModalInputs,
|
||||
MultiModalKwargsItem, MultiModalKwargsItems,
|
||||
MultiModalKwargsOptionalItems, PlaceholderRange)
|
||||
MultiModalKwargsOptionalItems, MultiModalUUIDDict,
|
||||
PlaceholderRange)
|
||||
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
|
||||
@ -1021,7 +1022,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
*,
|
||||
mm_hash_overrides: Optional[MultiModalHashes] = None,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
) -> MultiModalInputs:
|
||||
return self.apply(prompt,
|
||||
mm_data,
|
||||
@ -1361,24 +1363,62 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
*,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
) -> MultiModalHashes:
|
||||
"""Create MM hashes to be returned (only used in V1).
|
||||
|
||||
|
||||
Note: When overrides are provided via callers of `apply`,
|
||||
`_hash_mm_items` will be bypassed and the overrides will be used.
|
||||
"""
|
||||
model_id = self.info.model_id
|
||||
|
||||
return {
|
||||
modality: [
|
||||
MultiModalHasher.hash_kwargs(model_id=model_id,
|
||||
**{modality: item},
|
||||
**hf_processor_mm_kwargs,
|
||||
**tokenization_kwargs)
|
||||
for item in items
|
||||
]
|
||||
for modality, items in mm_items.items()
|
||||
}
|
||||
hashes: MultiModalHashes = {}
|
||||
mm_hash_overrides = mm_hash_overrides or {}
|
||||
|
||||
for modality, items in mm_items.items():
|
||||
if modality in mm_hash_overrides:
|
||||
mm_hashes = mm_hash_overrides[modality]
|
||||
if isinstance(mm_hashes, str):
|
||||
mm_hashes = [mm_hashes]
|
||||
|
||||
# For None entries, compute a hash; otherwise, use provided ID.
|
||||
computed: list[str] = []
|
||||
for i, item in enumerate(items):
|
||||
mm_hash = mm_hashes[i]
|
||||
|
||||
# NOTE: Even if a mm_hash is provided, we still compute a
|
||||
# hash if `hf_processor_mm_kwargs` or `tokenization_kwargs`
|
||||
# are provided. This is because the processed multimodal
|
||||
# inputs can be different depending on the processor kwargs.
|
||||
if mm_hash is None or \
|
||||
hf_processor_mm_kwargs or \
|
||||
tokenization_kwargs:
|
||||
|
||||
# NOTE: use provided hash string to hash with kwargs
|
||||
# if available for better performance.
|
||||
item = mm_hash if mm_hash is not None else item
|
||||
computed.append(
|
||||
MultiModalHasher.hash_kwargs(
|
||||
model_id=model_id,
|
||||
**{modality: item},
|
||||
**hf_processor_mm_kwargs,
|
||||
**tokenization_kwargs))
|
||||
else:
|
||||
computed.append(mm_hash)
|
||||
hashes[modality] = computed
|
||||
else:
|
||||
hashes[modality] = [
|
||||
MultiModalHasher.hash_kwargs(model_id=model_id,
|
||||
**{modality: item},
|
||||
**hf_processor_mm_kwargs,
|
||||
**tokenization_kwargs)
|
||||
for item in items
|
||||
]
|
||||
|
||||
return hashes
|
||||
|
||||
def _get_cache_missing_items(
|
||||
self,
|
||||
@ -1474,7 +1514,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
*,
|
||||
mm_hash_overrides: Optional[MultiModalHashes] = None,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
(
|
||||
prompt_ids,
|
||||
@ -1495,9 +1536,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
)
|
||||
|
||||
# Use overrides if provided; fallback to data-dependent hashing.
|
||||
mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else
|
||||
self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
|
||||
tokenization_kwargs))
|
||||
mm_hashes = self._hash_mm_items(mm_data_items,
|
||||
hf_processor_mm_kwargs,
|
||||
tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides)
|
||||
|
||||
mm_prompt_updates = self._get_mm_prompt_updates(
|
||||
mm_data_items,
|
||||
@ -1520,7 +1562,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
*,
|
||||
mm_hash_overrides: Optional[MultiModalHashes] = None,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
"""
|
||||
Apply the HF processor on the full prompt text,
|
||||
@ -1538,10 +1581,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_hash_overrides=mm_hash_overrides,
|
||||
)
|
||||
|
||||
# Use overrides if provided; fallback to data-dependent hashing.
|
||||
mm_hashes = (mm_hash_overrides if mm_hash_overrides is not None else
|
||||
self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
|
||||
tokenization_kwargs))
|
||||
mm_hashes = self._hash_mm_items(mm_data_items,
|
||||
hf_processor_mm_kwargs,
|
||||
tokenization_kwargs,
|
||||
mm_hash_overrides=mm_hash_overrides)
|
||||
|
||||
mm_missing_data_items = self._get_cache_missing_items(
|
||||
cache=cache,
|
||||
@ -1742,7 +1785,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Optional[Mapping[str, object]] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
Process multi-modal inputs to be used in vLLM.
|
||||
@ -1857,7 +1901,8 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Optional[Mapping[str, object]] = None,
|
||||
*,
|
||||
mm_hash_overrides: Optional[MultiModalHashes] = None,
|
||||
mm_hash_overrides: Optional[Union[dict[str, list[str]],
|
||||
MultiModalUUIDDict]] = None,
|
||||
) -> MultiModalEncDecInputs:
|
||||
"""
|
||||
Process multi-modal inputs to be used in vLLM.
|
||||
|
@ -150,6 +150,49 @@ class Processor:
|
||||
self._validate_sampling_params(params, lora_request)
|
||||
self._validate_supported_sampling_params(params)
|
||||
|
||||
def _validate_multi_modal_uuids(self, prompt: PromptType) -> None:
|
||||
"""
|
||||
Validate that user-provided multi_modal_uuids align with
|
||||
multi_modal_data in the incoming request prompt(s).
|
||||
Only checks lengths; `None` entries are allowed and will be
|
||||
auto-hashed downstream.
|
||||
"""
|
||||
|
||||
def _validate_single_prompt(single_prompt: Union[dict, str]) -> None:
|
||||
if not isinstance(single_prompt, dict):
|
||||
return
|
||||
mm_data = single_prompt.get("multi_modal_data")
|
||||
mm_uuids = single_prompt.get("multi_modal_uuids")
|
||||
if not mm_data or not mm_uuids:
|
||||
return
|
||||
|
||||
for modality, items in mm_data.items():
|
||||
if modality in mm_uuids:
|
||||
data_len = len(items) if isinstance(items, list) else 1
|
||||
uuid_len = len(mm_uuids[modality]) if isinstance(
|
||||
mm_uuids[modality], list) else 1
|
||||
if uuid_len != data_len:
|
||||
raise ValueError(
|
||||
f"multi_modal_uuids for modality '{modality}' "
|
||||
"must have same length as data: got "
|
||||
f"{uuid_len} uuids vs "
|
||||
f"{data_len} items.")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"multi_modal_uuids for modality '{modality}' must "
|
||||
"be provided if multi_modal_data is provided.")
|
||||
|
||||
# Handle explicit encoder/decoder prompts or singleton prompt
|
||||
if isinstance(prompt, dict) and "encoder_prompt" in prompt:
|
||||
enc = prompt.get("encoder_prompt")
|
||||
dec = prompt.get("decoder_prompt")
|
||||
if enc is not None:
|
||||
_validate_single_prompt(enc)
|
||||
if dec is not None:
|
||||
_validate_single_prompt(dec)
|
||||
else:
|
||||
_validate_single_prompt(prompt) # type: ignore[arg-type]
|
||||
|
||||
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
@ -289,17 +332,27 @@ class Processor:
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
# Optionally generate multimodal hash overrides based on request id.
|
||||
# Optionally generate multimodal hash overrides to avoid hashing
|
||||
# multimodal data items by their content as their identifiers.
|
||||
|
||||
# NOTE: when users explicitly turn off BOTH prefix caching and input
|
||||
# processing caching, no multimodal features or embeddings will be
|
||||
# reused across requests, therefore hashing is no longer necessary.
|
||||
# reused across requests, therefore identifying multimodal data items
|
||||
# by their content is no longer necessary, and we create uuids with
|
||||
# request id-modality-index as multimodal hash overrides.
|
||||
if (self.model_config.multimodal_config and
|
||||
self.model_config.multimodal_config.mm_processor_cache_gb == 0
|
||||
and not self.cache_config.enable_prefix_caching):
|
||||
mm_hash_overrides = self._maybe_build_mm_hash_overrides(
|
||||
request_id, prompt)
|
||||
else:
|
||||
mm_hash_overrides = None
|
||||
# Otherwise, use user-provided uuids as multimodal hash overrides
|
||||
# if provided.
|
||||
self._validate_multi_modal_uuids(prompt)
|
||||
if isinstance(prompt, dict):
|
||||
mm_hash_overrides = prompt.get("multi_modal_uuids")
|
||||
else:
|
||||
mm_hash_overrides = None
|
||||
|
||||
# Process inputs, which includes:
|
||||
# 1. Tokenize text prompt, with LoRA request if one exists.
|
||||
@ -317,6 +370,7 @@ class Processor:
|
||||
params=params,
|
||||
processed_inputs=processed_inputs,
|
||||
)
|
||||
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
||||
|
||||
self._validate_model_inputs(processed_inputs, lora_request)
|
||||
|
Reference in New Issue
Block a user