[Core] Use key-only cache for BaseMultiModalProcessor (#23018)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-08-27 14:19:13 +08:00
committed by GitHub
parent 8dbf6ed7be
commit 69244e67e6
29 changed files with 954 additions and 394 deletions

View File

@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct",
If you run out of CPU RAM, try the following options:
- (Multi-modal models only) you can set the size of multi-modal processor cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB per API process + 4 GiB per engine core process)
- (Multi-modal models only) you can set the size of multi-modal cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB).
- (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB).
## Multi-modal input limits

View File

@ -204,20 +204,33 @@ vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2
to avoid CPU resource exhaustion.
!!! note
[Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled
because it requires a one-to-one correspondence between API and engine core processes.
API server scale-out disables [multi-modal IPC caching](#ipc-caching)
because it requires a one-to-one correspondance between API and engine core processes.
This does not impact [multi-modal processor caching](#processor-caching).
## Multi-Modal Caching
### Processor Cache
By default, the multi-modal processor cache is enabled to avoid repeatedly processing
the same multi-modal inputs via Hugging Face `AutoProcessor`,
Multi-modal caching avoids repeated transfer or processing of the same multi-modal data,
which commonly occurs in multi-turn conversations.
You can adjust the size of the cache by setting the value of `mm_processor_cache_gb`
(default 4 GiB per API process + 4 GiB per engine core process).
If you do not benefit much from the cache, you can disable it completely via `mm_processor_cache_gb=0`.
### Processor Caching
Multi-modal processor caching is automatically enabled
to avoid repeatedly processing the same multi-modal inputs in `BaseMultiModalProcessor`.
### IPC Caching
Multi-modal IPC caching is automatically enabled when
there is a one-to-one correspondance between API (`P0`) and engine core (`P1`) processes,
to avoid repeatedly transferring the same multi-modal inputs between them.
### Configuration
You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` (default 4 GiB).
If you do not benefit much from the cache, you can disable both IPC
and processor caching completely via `mm_processor_cache_gb=0`.
Examples:
@ -230,3 +243,16 @@ llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
mm_processor_cache_gb=0)
```
### Cache Placement
Based on the configuration, the content of the multi-modal caches on `P0` and `P1` are as follows:
| Processor Caching | IPC Caching | `P0` Cache | `P1` Cache | Max. Memory |
|-------------------|-------------|------------|------------|-------------|
| ✅ | ✅ | K | K + V | `mm_processor_cache_gb * data_parallel_size` |
| ✅ | ❌ | K + V | N/A | `mm_processor_cache_gb * api_server_count` |
| ❌ | ❌ | N/A | N/A | `0` |
K: Stores the hashes of multi-modal items
V: Stores the processed tensor data of multi-modal items

View File

@ -14,8 +14,9 @@ from PIL import Image
from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import MultiModalInputs
from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
cached_tokenizer_from_config,
encode_tokens)
@ -63,6 +64,8 @@ def _test_processing_correctness(
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
# Ensure that the cache can fit all of the data
mm_processor_cache_gb=2048,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
@ -71,8 +74,7 @@ def _test_processing_correctness(
model_config,
tokenizer=cached_tokenizer_from_config(model_config),
)
# Ensure that it can fit all of the data
cache = ProcessingCache(capacity_gb=2048)
cache = MultiModalProcessorOnlyCache(model_config)
processing_info = factories.info(ctx)
supported_mm_limits = processing_info.get_supported_mm_limits()

View File

@ -1,32 +1,64 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import numpy as np
import pytest
import torch
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
from vllm.multimodal.cache import (MultiModalCache,
MultiModalProcessorCacheItem,
MultiModalProcessorCacheItemMetadata,
processor_cache_from_config,
receiver_cache_from_config)
from vllm.multimodal.hasher import MultiModalHasher
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem,
MultiModalKwargsItems,
MultiModalSharedField)
from vllm.multimodal.processing import PromptInsertion
from vllm.multimodal.registry import MultiModalRegistry
def _dummy_elem(modality: str, key: str, size: int):
def _dummy_elem(
modality: str,
key: str,
size: int,
*,
rng: Optional[np.random.RandomState] = None,
):
if rng is None:
data = torch.empty((size, ), dtype=torch.int8)
else:
data = torch.from_numpy(rng.randint(4, size=(size, ), dtype=np.int8))
return MultiModalFieldElem(
modality=modality,
key=key,
data=torch.empty((size, ), dtype=torch.int8),
data=data,
field=MultiModalSharedField(1),
)
def _dummy_item(modality: str, size_by_key: dict[str, int]):
def _dummy_item(
modality: str,
size_by_key: dict[str, int],
*,
rng: Optional[np.random.RandomState] = None,
):
return MultiModalKwargsItem.from_elems([
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
_dummy_elem(modality, key, size, rng=rng)
for key, size in size_by_key.items()
])
def _dummy_items(size_by_key_modality: dict[str, dict[str, int]]):
def _dummy_items(
size_by_key_modality: dict[str, dict[str, int]],
*,
rng: Optional[np.random.RandomState] = None,
):
return MultiModalKwargsItems.from_seq([
_dummy_item(modality, size_by_key)
_dummy_item(modality, size_by_key, rng=rng)
for modality, size_by_key in size_by_key_modality.items()
])
@ -48,5 +80,139 @@ def test_cache_item_size(item, expected_size):
cache[""] = item
assert cache.currsize == expected_size
cache[""] = MultiModalCacheItemMetadata.wraps(item)
prompt_update = PromptInsertion("dummy", "target", "insertion") \
.resolve(0)
cache[""] = MultiModalProcessorCacheItem(item, [prompt_update])
assert cache.currsize == expected_size
cache[""] = MultiModalProcessorCacheItemMetadata(item, [prompt_update])
assert cache.currsize == expected_size
def _create_vllm_config(
*,
mm_processor_cache_gb: float,
enable_ipc: bool,
):
return VllmConfig(
model_config=ModelConfig(mm_processor_cache_gb=mm_processor_cache_gb),
parallel_config=ParallelConfig(
data_parallel_size=1 if enable_ipc else 2),
)
def _compare_caches(
config_0: VllmConfig,
config_1: VllmConfig,
*,
item_capacity: int = 8,
hit_rate: float = 0.5,
max_items_per_iter: int = 3,
is_cached_calls_per_iter: int,
n_iter: int = 100,
seed: int = 0,
):
mm_registry = MultiModalRegistry()
cache_0_p0 = processor_cache_from_config(config_0, mm_registry)
cache_0_p1 = receiver_cache_from_config(config_0, mm_registry)
cache_1_p0 = processor_cache_from_config(config_1, mm_registry)
cache_1_p1 = receiver_cache_from_config(config_1, mm_registry)
cache_size_gb = max(
config_0.model_config.mm_processor_cache_gb,
config_1.model_config.mm_processor_cache_gb,
)
item_size_gb = int(cache_size_gb / item_capacity)
rng = np.random.RandomState(seed)
all_items = [
_dummy_item("item", {"key": item_size_gb}, rng=rng)
for _ in range(int(item_capacity / hit_rate))
]
all_hashes = [
MultiModalHasher.hash_kwargs(item=item.get_data())
for item in all_items
]
# Should not be used since there is nothing to convert to text
prompt_update = PromptInsertion("dummy", "target", "insertion")
for it in range(n_iter):
num_items_to_select = rng.randint(0, max_items_per_iter)
item_idxs_to_select = rng.choice(len(all_items), num_items_to_select)
selected_items = [all_items[idx] for idx in item_idxs_to_select]
selected_hashes = [all_hashes[idx] for idx in item_idxs_to_select]
if cache_0_p0 is None:
cache_0_p0_out = selected_items
else:
for _ in range(is_cached_calls_per_iter):
cache_0_p0.is_cached(selected_hashes)
cache_0_p0_out = [
item for item, _ in cache_0_p0.get_and_update(
[(item, prompt_update.content) for item in selected_items],
selected_hashes,
)
]
if cache_1_p0 is None:
cache_1_p0_out = selected_items
else:
for _ in range(is_cached_calls_per_iter):
cache_1_p0.is_cached(selected_hashes)
cache_1_p0_out = [
item for item, _ in cache_1_p0.get_and_update(
[(item, prompt_update.content) for item in selected_items],
selected_hashes,
)
]
if cache_0_p1 is None:
cache_0_p1_out = cache_0_p0_out
else:
cache_0_p1_out = cache_0_p1.get_and_update(cache_0_p0_out,
selected_hashes)
if cache_1_p1 is None:
cache_1_p1_out = cache_1_p0_out
else:
cache_1_p1_out = cache_1_p1.get_and_update(cache_1_p0_out,
selected_hashes)
assert cache_0_p1_out == cache_1_p1_out, f"Failed at {it=}"
@pytest.mark.parametrize("is_cached_calls_per_iter", [1, 2, 3])
def test_ipc_enable_disable_consistency(is_cached_calls_per_iter):
cache_size_gb = 1 / (1 << 20)
vllm_config_ipc_enabled = _create_vllm_config(
mm_processor_cache_gb=cache_size_gb,
enable_ipc=True,
)
vllm_config_ipc_disabled = _create_vllm_config(
mm_processor_cache_gb=0,
enable_ipc=False,
)
vllm_config_cache_disabled = _create_vllm_config(
mm_processor_cache_gb=cache_size_gb,
enable_ipc=True,
)
_compare_caches(
vllm_config_ipc_enabled,
vllm_config_ipc_disabled,
is_cached_calls_per_iter=is_cached_calls_per_iter,
)
_compare_caches(
vllm_config_ipc_disabled,
vllm_config_cache_disabled,
is_cached_calls_per_iter=is_cached_calls_per_iter,
)
_compare_caches(
vllm_config_cache_disabled,
vllm_config_ipc_enabled,
is_cached_calls_per_iter=is_cached_calls_per_iter,
)

View File

@ -437,7 +437,7 @@ class ModelConfig:
from `AutoProcessor.from_pretrained`. The available overrides depend on the
model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`.
"""
mm_processor_cache_gb: int = 4
mm_processor_cache_gb: float = 4
"""The size (in GiB) of the multi-modal processor cache, which is used to
avoid re-processing past multi-modal inputs.
@ -884,12 +884,6 @@ class ModelConfig:
return None
def set_mm_processor_cache_gb(self, value: int) -> None:
mm_config = self.get_multimodal_config()
self.mm_processor_cache_gb = value
mm_config.mm_processor_cache_gb = value
def _get_encoder_config(self):
return get_sentence_transformer_tokenizer_config(
self.model, self.revision)
@ -1697,22 +1691,6 @@ class ModelConfig:
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None
@property
def enable_mm_processor_cache(self) -> bool:
"""Whether the multi-modal processor cache should be enabled."""
mm_config = self.multimodal_config
if mm_config is None:
return False
return mm_config.mm_processor_cache_gb > 0
def get_mm_input_cache_gb(self) -> int:
mm_config = self.multimodal_config
if mm_config is None:
return 0
return envs.VLLM_MM_INPUT_CACHE_GIB
@property
def is_cross_encoder(self) -> bool:
return (self._model_info.supports_cross_encoding
@ -2561,7 +2539,7 @@ class MultiModalConfig:
`{"num_crops": 4}`.
"""
mm_processor_cache_gb: int = 4
mm_processor_cache_gb: float = 4
"""
The size (in GiB) of the multi-modal processor cache, which is used to

View File

@ -351,7 +351,7 @@ class EngineArgs:
mm_processor_kwargs: Optional[Dict[str, Any]] = \
MultiModalConfig.mm_processor_kwargs
disable_mm_preprocessor_cache: bool = False # DEPRECATED
mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb
mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
# LoRA fields
@ -1293,18 +1293,6 @@ class EngineArgs:
worker_extension_cls=self.worker_extension_cls,
)
if model_config.is_multimodal_model:
dp_supports_mm_processor_cache = (self.data_parallel_size == 1
or data_parallel_external_lb)
if (not dp_supports_mm_processor_cache
and model_config.mm_processor_cache_gb > 0):
logger.warning(
"Multi-modal processor cache is disabled because "
"it is not compatible with data parallelism when "
"there does not exist a one-to-one correspondance "
"between API and engine core processes.")
model_config.set_mm_processor_cache_gb(0)
speculative_config = self.create_speculative_config(
target_model_config=model_config,
target_parallel_config=parallel_config,

View File

@ -36,6 +36,7 @@ from vllm.logits_process import get_bad_words_logits_processors
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory)
@ -250,9 +251,13 @@ class LLMEngine:
self.generation_config_fields = (
self.model_config.try_get_generation_config())
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer,
mm_registry)
self.input_preprocessor = InputPreprocessor(
self.model_config,
self.tokenizer,
mm_registry,
mm_processor_cache=processor_only_cache_from_config(
self.model_config, mm_registry),
)
self.model_executor = executor_class(vllm_config=vllm_config)
@ -840,8 +845,8 @@ class LLMEngine:
def reset_mm_cache(self) -> bool:
"""Reset the multi-modal cache."""
return self.input_preprocessor.mm_registry.reset_processor_cache(
self.model_config)
self.input_preprocessor.clear_cache()
return True
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for all devices."""

View File

@ -11,6 +11,7 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs)
from vllm.transformers_utils.tokenizer import AnyTokenizer
@ -32,12 +33,14 @@ class InputPreprocessor:
model_config: ModelConfig,
tokenizer: Optional[TokenizerGroup],
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None,
) -> None:
super().__init__()
self.model_config = model_config
self.tokenizer = tokenizer
self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache
def get_tokenizer_group(self) -> TokenizerGroup:
if self.tokenizer is None:
@ -261,8 +264,11 @@ class InputPreprocessor:
"""
tokenizer = self._get_mm_tokenizer(lora_request)
mm_processor = self.mm_registry.create_processor(self.model_config,
tokenizer=tokenizer)
mm_processor = self.mm_registry.create_processor(
self.model_config,
tokenizer=tokenizer,
cache=self.mm_processor_cache,
)
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
@ -286,8 +292,12 @@ class InputPreprocessor:
"""
tokenizer = await self._get_mm_tokenizer_async(lora_request)
mm_processor = self.mm_registry.create_processor(self.model_config,
tokenizer=tokenizer)
mm_processor = self.mm_registry.create_processor(
self.model_config,
tokenizer=tokenizer,
cache=self.mm_processor_cache,
)
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
@ -860,3 +870,7 @@ class InputPreprocessor:
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
)
def clear_cache(self) -> None:
if self.mm_processor_cache is not None:
self.mm_processor_cache.clear_cache()

View File

@ -223,20 +223,26 @@ class InputRegistry:
The model is identified by ``model_config``.
"""
# Avoid circular import
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.sequence import SequenceData
if not model_config.is_multimodal_model:
seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
return DummyData(seq_data=seq_data)
cache = processor_only_cache_from_config(model_config, mm_registry)
# Encoder dummy data does not contain multi-modal data
if is_encoder_data:
enc_data = mm_registry.get_encoder_dummy_data(
model_config, seq_len)
enc_data = mm_registry.get_encoder_dummy_data(model_config,
seq_len,
cache=cache)
seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids)
return DummyData(seq_data=seq_data)
dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len)
dec_data = mm_registry.get_decoder_dummy_data(model_config,
seq_len,
cache=cache)
return DummyData(
seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),

View File

@ -33,12 +33,13 @@ from vllm.inputs import InputProcessingContext
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache,
PromptReplacement, PromptUpdate)
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
@ -367,7 +368,7 @@ def _build_hcxvision_hf_processor(
info: HCXVisionProcessingInfo,
dummy_inputs: BaseDummyInputsBuilder[HCXVisionProcessingInfo],
*,
cache: Optional[ProcessingCache] = None,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> BaseMultiModalProcessor:
if isinstance(info, HCXVisionProcessingInfo):
return HCXVisionMultiModalProcessor(

View File

@ -22,14 +22,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves
@ -394,7 +394,7 @@ def _build_llava_or_pixtral_hf_processor(
info: _I,
dummy_inputs: BaseDummyInputsBuilder[_I],
*,
cache: Optional[ProcessingCache] = None,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> BaseMultiModalProcessor:
if isinstance(info, PixtralHFProcessingInfo):
return PixtralHFMultiModalProcessor(

View File

@ -58,7 +58,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
VideoItem, VideoProcessorItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
PromptUpdate, PromptUpdateDetails,
ResolvedPromptUpdate, _seq2text)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
@ -744,6 +745,43 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
for modality, pattern in placeholders
]
def _recompute_cached_prompt_update(
self,
cached_update: ResolvedPromptUpdate,
new_item_idx: int,
) -> ResolvedPromptUpdate:
new_update = super()._recompute_cached_prompt_update(
cached_update,
new_item_idx,
)
if cached_update.modality == "image":
tokenizer = self.info.get_tokenizer()
image_processor = self.info.get_image_processor()
version = self.info.get_model_version()
text = _seq2text(tokenizer, cached_update.content.full)
prev_item_idx = cached_update.item_idx
if version == (2, 0) or version == (2, 5):
im_start = image_processor.im_start_token
im_end = image_processor.im_end_token
else:
im_start = image_processor.im_id_start
im_end = image_processor.im_id_end
new_update = new_update.with_content(
PromptUpdateDetails.select_text(
text.replace(
f"{im_start}{prev_item_idx}{im_end}",
f"{im_start}{new_item_idx}{im_end}",
1,
),
"<unk>",
))
return new_update
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,

View File

@ -22,14 +22,14 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@ -322,7 +322,7 @@ def _build_mistral3_processor(
info: _I,
dummy_inputs: BaseDummyInputsBuilder[_I],
*,
cache: Optional[ProcessingCache] = None,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> BaseMultiModalProcessor:
assert isinstance(info, Mistral3ProcessingInfo)
return Mistral3MultiModalProcessor(

View File

@ -41,7 +41,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalPromptUpdates,
PlaceholderFeaturesInfo,
PromptReplacement, PromptUpdate)
PromptReplacement, PromptUpdate,
ResolvedPromptUpdate)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
@ -440,6 +441,23 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
)
]
def _recompute_cached_prompt_update(
self,
cached_update: ResolvedPromptUpdate,
new_item_idx: int,
) -> ResolvedPromptUpdate:
new_update = super()._recompute_cached_prompt_update(
cached_update,
new_item_idx,
)
if cached_update.modality == "image":
hf_processor = self.info.get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
new_update = new_update.with_target(image_tokens[new_item_idx])
return new_update
def _apply_prompt_updates(
self,
token_ids: list[int],

View File

@ -27,7 +27,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems,
MultiModalDataItems, MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
PromptUpdate, ResolvedPromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
@ -850,6 +850,25 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
),
]
def _recompute_cached_prompt_update(
self,
cached_update: ResolvedPromptUpdate,
new_item_idx: int,
) -> ResolvedPromptUpdate:
new_update = super()._recompute_cached_prompt_update(
cached_update,
new_item_idx,
)
if cached_update.modality == "image":
image_tokens: list[str] = self.info.image_tokens # type: ignore
new_update = new_update.with_target(image_tokens[new_item_idx])
elif cached_update.modality == "audio":
audio_tokens: list[str] = self.info.audio_tokens # type: ignore
new_update = new_update.with_target(audio_tokens[new_item_idx])
return new_update
@MULTIMODAL_REGISTRY.register_processor(
Phi4MMMultiModalProcessor,

View File

@ -25,12 +25,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.llava import LlavaDummyInputsBuilder
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache,
PromptReplacement, PromptUpdate)
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves
@ -332,7 +333,7 @@ def _build_tarsier_hf_processor(
info: _I_Tarsier,
dummy_inputs: BaseDummyInputsBuilder[_I_Tarsier],
*,
cache: Optional[ProcessingCache] = None,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> BaseMultiModalProcessor:
if isinstance(info, TarsierProcessingInfo):
return TarsierMultiModalProcessor(

View File

@ -1,11 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys
from collections.abc import Mapping
from dataclasses import dataclass
from typing import TypeVar, Union
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union
import torch
from typing_extensions import TypeAlias, override
from vllm.logger import init_logger
from vllm.utils import GiB_bytes, LRUCache
@ -15,24 +16,67 @@ from .inputs import (MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem, MultiModalKwargsItems,
NestedTensors)
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
from .processing import ResolvedPromptUpdate
from .registry import MultiModalRegistry
logger = init_logger(__name__)
@dataclass
class MultiModalCacheItemMetadata:
size: int
class MultiModalProcessorCacheItem:
"""
The data to store inside `MultiModalProcessorOnlyCache`.
@classmethod
def wraps(cls, value: "MultiModalCacheValue"):
return cls(size=MultiModalCache.get_item_size(value))
Args:
item: The processed tensor data corresponding to a multi-modal item.
prompt_updates: The prompt updates corresponding to `item`.
"""
def __init__(
self,
item: MultiModalKwargsItem,
prompt_updates: Sequence["ResolvedPromptUpdate"],
) -> None:
super().__init__()
self.item = item
self.prompt_updates = prompt_updates
class MultiModalProcessorCacheItemMetadata:
"""
The metadata to store inside `MultiModalProcessorSenderCache`.
Args:
item: The processed tensor data corresponding to a multi-modal item.
Since P1 already stores the tensor data, we only store its size
metadata in P0 to reduce memory usage. The size metadata is still
needed to keep the same cache eviction policy as P0.
prompt_updates: The prompt updates corresponding to `item`.
This needs to stay on P0 because for some models, they are
dependent on the processed tensor data (cached on P1).
"""
def __init__(
self,
item: MultiModalKwargsItem,
prompt_updates: Sequence["ResolvedPromptUpdate"],
) -> None:
super().__init__()
self.item_size = MultiModalCache.get_item_size(item)
self.prompt_updates = prompt_updates
MultiModalCacheValue = Union[
MultiModalProcessorCacheItem,
MultiModalProcessorCacheItemMetadata,
MultiModalKwargsItems,
MultiModalKwargsItem,
MultiModalKwargs,
Mapping[str, NestedTensors],
MultiModalCacheItemMetadata,
]
_V = TypeVar("_V", bound=MultiModalCacheValue)
@ -47,8 +91,10 @@ class MultiModalCache:
*,
debug: bool = False,
) -> int:
if isinstance(leaf, MultiModalFieldElem):
return cls.get_item_size(leaf.data) # type: ignore
if isinstance(leaf, MultiModalProcessorCacheItem):
return cls.get_leaf_size(leaf.item)
if isinstance(leaf, MultiModalProcessorCacheItemMetadata):
return leaf.item_size
# These are not subclasses of dict
if isinstance(leaf, MultiModalKwargsItems):
@ -58,13 +104,13 @@ class MultiModalCache:
if isinstance(leaf, MultiModalKwargs):
return cls.get_item_size(leaf.data) # type: ignore
if isinstance(leaf, MultiModalFieldElem):
return cls.get_item_size(leaf.data) # type: ignore
# sys.getsizeof doesn't work for tensors
if isinstance(leaf, torch.Tensor):
return leaf.nbytes
if isinstance(leaf, MultiModalCacheItemMetadata):
return leaf.size
return sys.getsizeof(leaf)
@classmethod
@ -98,3 +144,332 @@ class MultiModalCache:
GiB_bytes * capacity_gb,
getsizeof=lambda x: cls.get_item_size(x, debug=debug),
)
_I = TypeVar("_I", contravariant=True)
_O = TypeVar("_O", covariant=True)
class BaseMultiModalCache(ABC, Generic[_I, _O]):
"""
Abstract base class to read/write multi-modal items from cache.
The idea of multi-modal caching is based on having a client and server
where the client executes in the frontend process (=P0) and
the server in the core process (=P1). The data flow is as follows:
```
is_cached() x N get_and_update()
P0: From API -----------------> -----------------> To P1
get_and_update()
P1: From P0 -----------------> To model
```
`is_cached()` can be called any number of times in P0. However,
`get_and_update()` must be called in P0 and P1 one after another
so that their cache eviction order remains the same.
This ensures that the keys in P0 and P1 caches are mirrored,
allowing us to determine whether a key is cached in P1 by looking
up the P0 cache, without having to communicate with P1.
"""
@abstractmethod
def get_and_update_item(
self,
mm_item: _I,
mm_hash: str,
) -> _O:
"""
Possibly update a multi-modal item based on whether it is
in the underlying cache.
This update is done out-of-place and updates the cache eviction order.
Args:
mm_item: The multi-modal item to update.
mm_hash: The hash of `mm_item`.
Returns:
The update multi-modal item.
"""
raise NotImplementedError
def get_and_update(
self,
mm_items: Sequence[_I],
mm_hashes: list[str],
) -> list[_O]:
"""
Possibly update a sequence of multi-modal items based on whether they
are in the underlying cache.
This update is done out-of-place and updates the cache eviction order.
Args:
mm_items: The multi-modal items to update.
mm_hashes: The hash of each item in `mm_items`.
Returns:
A new list of updated multi-modal items.
"""
assert len(mm_items) == len(mm_hashes)
return [
self.get_and_update_item(mm_item, mm_hash)
for mm_item, mm_hash in zip(mm_items, mm_hashes)
]
@abstractmethod
def clear_cache(self) -> None:
"""Clear the underlying cache."""
raise NotImplementedError
MultiModalProcessorCacheInItem: TypeAlias = \
Optional[tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]]]
MultiModalProcessorCacheOutItem: TypeAlias = \
tuple[Optional[MultiModalKwargsItem], Sequence["ResolvedPromptUpdate"]]
class BaseMultiModalProcessorCache(
BaseMultiModalCache[MultiModalProcessorCacheInItem,
MultiModalProcessorCacheOutItem]):
"""The required interface for caches on P0."""
@abstractmethod
def is_cached_item(self, mm_hash: str) -> bool:
"""
Check whether a multi-modal item is
in the underlying cache.
This **DOES NOT** update the cache eviction order.
Args:
mm_hash: The hash of the item to check.
Returns:
`True` if the item is cached, otherwise `False`.
"""
raise NotImplementedError
def is_cached(self, mm_hashes: list[str]) -> list[bool]:
"""
Check whether a sequence of multi-modal items are
in the underlying cache.
This **DOES NOT** update the cache eviction order.
Args:
mm_hashes: The hash of each item to check.
Returns:
For each item, `True` if the item is cached, otherwise `False`.
"""
return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]
class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
"""
The cache which is used on P0 when IPC caching is disabled.
How to update each item:
- If the item is in the cache, replace the input with the cached item.
- If the item is not in the cache, store that item (which includes
tensor data and metadata) into the cache, and return the input.
"""
def __init__(self, model_config: "ModelConfig") -> None:
super().__init__()
mm_config = model_config.get_multimodal_config()
self._cache = MultiModalCache.get_lru_cache(
mm_config.mm_processor_cache_gb,
MultiModalProcessorCacheItem,
)
@override
def is_cached_item(self, mm_hash: str) -> bool:
return mm_hash in self._cache
@override
def get_and_update_item(
self,
mm_item: MultiModalProcessorCacheInItem,
mm_hash: str,
) -> MultiModalProcessorCacheOutItem:
if (cached_item := self._cache.get(mm_hash)) is not None:
return cached_item.item, cached_item.prompt_updates
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
self._cache[mm_hash] = MultiModalProcessorCacheItem(*mm_item)
return mm_item
@override
def clear_cache(self) -> None:
self._cache.clear()
class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
"""
The cache which is used on P0 when IPC caching is enabled.
How to update each item:
- If the item is already in the cache, clear the input to avoid
unnecessary IPC.
- If the item is not in the cache, store the metadata of that item so
that the eviction policy remains the same as the cache on P1,
and return the input.
By only storing the metadata, we avoid keeping the data itself in
memory inside P0.
"""
def __init__(self, model_config: "ModelConfig") -> None:
super().__init__()
mm_config = model_config.get_multimodal_config()
self._cache = MultiModalCache.get_lru_cache(
mm_config.mm_processor_cache_gb,
MultiModalProcessorCacheItemMetadata,
)
@override
def is_cached_item(self, mm_hash: str) -> bool:
return mm_hash in self._cache
@override
def get_and_update_item(
self,
mm_item: MultiModalProcessorCacheInItem,
mm_hash: str,
) -> MultiModalProcessorCacheOutItem:
if (cached_item := self._cache.get(mm_hash)) is not None:
return None, cached_item.prompt_updates
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
self._cache[mm_hash] = MultiModalProcessorCacheItemMetadata(*mm_item)
return mm_item
@override
def clear_cache(self) -> None:
self._cache.clear()
def _enable_processor_cache(
model_config: "ModelConfig",
mm_registry: "MultiModalRegistry",
) -> bool:
if not mm_registry.supports_multimodal_inputs(model_config):
return False
mm_config = model_config.get_multimodal_config()
return mm_config.mm_processor_cache_gb > 0
def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool:
parallel_config = vllm_config.parallel_config
supports_ipc_cache = (parallel_config.data_parallel_size == 1
or parallel_config.data_parallel_external_lb)
return supports_ipc_cache
def processor_cache_from_config(
vllm_config: "VllmConfig",
mm_registry: "MultiModalRegistry",
) -> Optional[BaseMultiModalProcessorCache]:
"""Return a `BaseMultiModalProcessorCache`, if enabled."""
model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry):
return None
if not _enable_ipc_cache(vllm_config):
return MultiModalProcessorOnlyCache(model_config)
return MultiModalProcessorSenderCache(model_config)
def processor_only_cache_from_config(
model_config: "ModelConfig",
mm_registry: "MultiModalRegistry",
):
"""Return a `MultiModalProcessorOnlyCache`, if enabled."""
if not _enable_processor_cache(model_config, mm_registry):
return None
return MultiModalProcessorOnlyCache(model_config)
class BaseMultiModalReceiverCache(
BaseMultiModalCache[Optional[MultiModalKwargsItem],
MultiModalKwargsItem]):
"""The required interface for caches on P1."""
class MultiModalReceiverCache(BaseMultiModalReceiverCache):
"""
The cache which is used on P1 when IPC caching is enabled.
How to update each item:
- If the item is in the cache, replace the input with the cached item.
- If the item is not in the cache, store that item (which includes tensor
data) into the cache, and return the input.
"""
def __init__(self, model_config: "ModelConfig") -> None:
super().__init__()
mm_config = model_config.get_multimodal_config()
self._cache = MultiModalCache.get_lru_cache(
mm_config.mm_processor_cache_gb,
MultiModalKwargsItem,
)
@override
def get_and_update_item(
self,
mm_item: Optional[MultiModalKwargsItem],
mm_hash: str,
) -> MultiModalKwargsItem:
if (cached_item := self._cache.get(mm_hash)) is not None:
return cached_item
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
self._cache[mm_hash] = mm_item
return mm_item
@override
def clear_cache(self) -> None:
self._cache.clear()
def receiver_cache_from_config(
vllm_config: "VllmConfig",
mm_registry: "MultiModalRegistry",
) -> Optional[BaseMultiModalReceiverCache]:
"""Return a `BaseMultiModalReceiverCache`, if enabled."""
model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry):
return None
if not _enable_ipc_cache(vllm_config):
return None
return MultiModalReceiverCache(model_config)

View File

@ -7,11 +7,11 @@ from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from functools import partial
from itertools import accumulate
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
Union, cast, final)
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union,
cast, final)
import numpy as np
from typing_extensions import NotRequired, TypeAlias, deprecated
from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated
from vllm.utils import LazyLoader, full_groupby, is_list_of
from vllm.utils.jsontree import JSONTree, json_map_leaves
@ -668,7 +668,15 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
return {key: elem.data for key, elem in self.items()}
class MultiModalKwargsItems(UserDict[str, Sequence[MultiModalKwargsItem]]):
_I = TypeVar(
"_I",
MultiModalKwargsItem,
Optional[MultiModalKwargsItem],
default=MultiModalKwargsItem,
)
class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
"""
A dictionary of
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
@ -714,27 +722,37 @@ class MultiModalKwargsItems(UserDict[str, Sequence[MultiModalKwargsItem]]):
items_by_modality = full_groupby(items, key=lambda x: x.modality)
return MultiModalKwargsItems(items_by_modality)
def __getitem__(self, modality: str):
def __getitem__(self, modality: str) -> Sequence[_I]:
if modality not in self:
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {set(self.keys())}")
return super().__getitem__(modality)
return super().__getitem__(modality) # type: ignore[return-value]
def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
for items in self.values():
for item in items:
for modality, items in self.items():
for i, item in enumerate(items):
if item is None:
raise RuntimeError("Cannot build data from empty "
f"mm_items[{modality}][{i}]")
for key, elem in item.items():
elems_by_key[key].append(elem)
return MultiModalKwargs({
key:
elems[0].field.reduce_data(elems, pin_memory=pin_memory)
for key, elems in elems_by_key.items() if len(elems) > 0
for key, elems in elems_by_key.items()
})
MultiModalKwargsOptionalItems: TypeAlias = Union[
MultiModalKwargsItems[MultiModalKwargsItem],
MultiModalKwargsItems[Optional[MultiModalKwargsItem]],
]
class MultiModalKwargs(UserDict[str, NestedTensors]):
"""
A dictionary that represents the keyword arguments to
@ -898,7 +916,7 @@ class MultiModalInputs(TypedDict):
token_type_ids: NotRequired[list[int]]
"""The token type IDs of the prompt."""
mm_kwargs: MultiModalKwargsItems
mm_kwargs: MultiModalKwargsOptionalItems
"""Keyword arguments to be directly passed to the model after batching."""
mm_hashes: "MultiModalHashDict"

View File

@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
Sequence)
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from enum import Enum
from functools import lru_cache
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
@ -20,12 +20,11 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens)
from vllm.utils import flatten_2d_lists, full_groupby
from .cache import MultiModalCache
from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalInputs,
MultiModalKwargsItem, MultiModalKwargsItems,
PlaceholderRange)
MultiModalKwargsOptionalItems, PlaceholderRange)
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
MultiModalDataParser)
@ -34,6 +33,7 @@ if TYPE_CHECKING:
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
from .cache import BaseMultiModalProcessorCache
from .profiling import BaseDummyInputsBuilder
logger = init_logger(__name__)
@ -557,6 +557,15 @@ class ResolvedPromptUpdate:
return self.iter_token_matches(prompt, tokenizer, start_idx=start_idx)
def with_target(self, target: UpdateTarget):
return replace(self, target=target)
def with_content(self, content: PromptUpdateInfo):
if not isinstance(content, PromptUpdateDetails):
content = PromptUpdateDetails.from_seq(content)
return replace(self, content=content)
class _TokenMatch(NamedTuple):
start_idx: int
@ -865,21 +874,6 @@ def find_mm_placeholders(
return dict(full_groupby_modality(it))
class ProcessingCache(MultiModalCache):
def __init__(self, capacity_gb: float) -> None:
super().__init__()
self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem)
self.get = self._cache.get
self.put = self._cache.put
self.reset = self._cache.clear
_CacheItemOrHash = Union[MultiModalKwargsItem, str]
class BaseProcessingInfo:
"""Base class to provide the information necessary for data processing."""
@ -982,7 +976,7 @@ For an item `MultiModalPromptUpdates[k][i]`,
class MultiModalProcessingInfo(NamedTuple):
kwargs: MultiModalKwargsItems
kwargs: MultiModalKwargsOptionalItems
hashes: MultiModalHashes
prompt_updates: MultiModalPromptUpdates
@ -994,11 +988,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Not to be confused with `transformers.ProcessorMixin`.
"""
def __init__(self,
info: _I,
dummy_inputs: "BaseDummyInputsBuilder[_I]",
*,
cache: Optional[ProcessingCache] = None) -> None:
def __init__(
self,
info: _I,
dummy_inputs: "BaseDummyInputsBuilder[_I]",
*,
cache: Optional["BaseMultiModalProcessorCache"] = None,
) -> None:
super().__init__()
self.info = info
@ -1355,32 +1351,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return prompt_ids, mm_processed_data, False
def _get_cache_missing_items(
self,
cache: ProcessingCache,
mm_data_items: MultiModalDataItems,
mm_hashes: MultiModalHashes,
) -> tuple[dict[str, list[_CacheItemOrHash]], MultiModalDataItems]:
mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]] = {
modality: [(h if (v := cache.get(h)) is None else v)
for h in hashes]
for modality, hashes in mm_hashes.items()
}
mm_missing_idxs = {
modality: [
idx for idx, item_or_hash in enumerate(items_or_hashes)
if isinstance(item_or_hash, str)
]
for modality, items_or_hashes in mm_cache_items_or_hashes.items()
}
mm_missing_data = {
modality: [mm_data_items[modality][idx] for idx in idxs]
for modality, idxs in mm_missing_idxs.items()
}
return mm_cache_items_or_hashes, self._to_mm_items(mm_missing_data)
def _hash_mm_items(
self,
mm_items: MultiModalDataItems,
@ -1401,28 +1371,92 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
for modality, items in mm_items.items()
}
def _get_cache_missing_items(
self,
cache: "BaseMultiModalProcessorCache",
mm_data_items: MultiModalDataItems,
mm_hashes: MultiModalHashes,
) -> MultiModalDataItems:
mm_is_cached = {
modality: cache.is_cached(hashes)
for modality, hashes in mm_hashes.items()
}
mm_missing_idxs = {
modality: [
idx for idx, item_is_cached in enumerate(items_is_cached)
if not item_is_cached
]
for modality, items_is_cached in mm_is_cached.items()
}
mm_missing_data = {
modality: [mm_data_items[modality][idx] for idx in idxs]
for modality, idxs in mm_missing_idxs.items()
}
return self._to_mm_items(mm_missing_data)
def _recompute_cached_prompt_update(
self,
cached_update: ResolvedPromptUpdate,
new_item_idx: int,
) -> ResolvedPromptUpdate:
"""
Override this if other attributes of `ResolvedPromptUpdate`
also need to be recomputed after retrieving from the cache.
"""
return replace(cached_update, item_idx=new_item_idx)
def _merge_mm_kwargs(
self,
cache: ProcessingCache,
mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]],
cache: "BaseMultiModalProcessorCache",
mm_hashes: MultiModalHashes,
mm_missing_kwargs: MultiModalKwargsItems,
) -> MultiModalKwargsItems:
mm_missing_prompt_updates: MultiModalPromptUpdates,
) -> tuple[MultiModalKwargsOptionalItems, MultiModalPromptUpdates]:
# Need to calculate this at the beginning to avoid skipping cache logic
# for subsequently repeated items in the same modality
mm_is_cached = {
modality: cache.is_cached(hashes)
for modality, hashes in mm_hashes.items()
}
mm_missing_next_idx = defaultdict[str, int](lambda: 0)
merged_items = defaultdict[str, list[MultiModalKwargsItem]](list)
for modality, items_or_hashes in mm_cache_items_or_hashes.items():
for item_or_hash in items_or_hashes:
if isinstance(item_or_hash, str):
kw_item = mm_missing_kwargs[modality][
mm_missing_next_idx[modality]]
cache.put(item_or_hash, kw_item)
merged_kwargs = defaultdict[str,
list[Optional[MultiModalKwargsItem]]](list)
merged_prompt_updates = defaultdict[
str, list[Sequence[ResolvedPromptUpdate]]](list)
for modality, hashes in mm_hashes.items():
missing_kwargs = mm_missing_kwargs.get(modality, [])
missing_prompt_updates = mm_missing_prompt_updates.get(
modality, [])
for item_idx, item_hash in enumerate(hashes):
kwargs: Optional[MultiModalKwargsItem]
if not mm_is_cached[modality][item_idx]:
missing_next_idx = mm_missing_next_idx[modality]
kwargs = missing_kwargs[missing_next_idx]
updates = missing_prompt_updates[missing_next_idx]
mm_missing_next_idx[modality] += 1
item = kwargs, updates
else:
kw_item = item_or_hash
item = None
merged_items[modality].append(kw_item)
kwargs, updates = cache.get_and_update_item(item, item_hash)
return MultiModalKwargsItems(merged_items)
merged_kwargs[modality].append(kwargs)
merged_prompt_updates[modality].append([
self._recompute_cached_prompt_update(update, item_idx)
for update in updates
])
mm_kwargs = MultiModalKwargsItems(merged_kwargs)
mm_prompt_updates = dict(merged_prompt_updates)
return mm_kwargs, mm_prompt_updates
def _apply_hf_processor(
self,
@ -1490,10 +1524,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs)
(
mm_cache_items_or_hashes,
mm_missing_data_items,
) = self._get_cache_missing_items(
mm_missing_data_items = self._get_cache_missing_items(
cache=cache,
mm_data_items=mm_data_items,
mm_hashes=mm_hashes,
@ -1520,16 +1552,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs),
)
mm_kwargs = self._merge_mm_kwargs(
cache,
mm_cache_items_or_hashes=mm_cache_items_or_hashes,
mm_missing_kwargs=mm_missing_kwargs,
mm_missing_prompt_updates = self._get_mm_prompt_updates(
mm_missing_data_items,
hf_processor_mm_kwargs,
mm_missing_kwargs,
)
mm_prompt_updates = self._get_mm_prompt_updates(
mm_data_items,
hf_processor_mm_kwargs,
mm_kwargs,
mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
cache,
mm_hashes=mm_hashes,
mm_missing_kwargs=mm_missing_kwargs,
mm_missing_prompt_updates=mm_missing_prompt_updates,
)
mm_info = MultiModalProcessingInfo(
@ -1614,7 +1647,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _validate_mm_kwargs(
self,
mm_kwargs: MultiModalKwargsItems,
mm_kwargs: MultiModalKwargsOptionalItems,
mm_item_counts: Mapping[str, int],
) -> None:
for modality, item_count in mm_item_counts.items():
@ -1655,7 +1688,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self,
mm_items: MultiModalDataItems,
prompt_ids: list[int],
mm_kwargs: MultiModalKwargsItems,
mm_kwargs: MultiModalKwargsOptionalItems,
mm_prompt_updates: MultiModalPromptUpdates,
is_update_applied: bool,
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:

View File

@ -13,7 +13,7 @@ import vllm.envs as envs
from vllm.logger import init_logger
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs, MultiModalKwargsItems,
MultiModalInputs, MultiModalKwargsOptionalItems,
MultiModalPlaceholderDict)
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
EncDecMultiModalProcessor)
@ -43,7 +43,7 @@ class DummyDecoderData(NamedTuple):
"""Dummy data used for profiling."""
prompt_token_ids: list[int]
multi_modal_data: MultiModalKwargsItems
multi_modal_data: MultiModalKwargsOptionalItems
multi_modal_placeholders: MultiModalPlaceholderDict

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from dataclasses import dataclass
from functools import lru_cache
from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
import torch.nn as nn
@ -13,8 +12,9 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer,
cached_tokenizer_from_config)
from vllm.utils import ClassRegistry
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
ProcessingCache)
from .cache import (BaseMultiModalProcessorCache,
processor_only_cache_from_config)
from .processing import BaseMultiModalProcessor, BaseProcessingInfo
from .profiling import (BaseDummyInputsBuilder, DummyDecoderData,
DummyEncoderData, MultiModalProfiler)
@ -65,7 +65,7 @@ class MultiModalProcessorFactory(Protocol[_I]):
info: _I,
dummy_inputs: BaseDummyInputsBuilder[_I],
*,
cache: Optional[ProcessingCache] = None,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> BaseMultiModalProcessor[_I]:
...
@ -80,20 +80,13 @@ class _ProcessorFactories(Generic[_I]):
self,
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
cache: Optional[BaseMultiModalProcessorCache] = None,
):
info = self.info(ctx)
dummy_inputs_builder = self.dummy_inputs(info)
return self.processor(info, dummy_inputs_builder, cache=cache)
# Make sure a different cache is used for each model config
# NOTE: ModelConfig is not hashable so it cannot be passed directly
@lru_cache(maxsize=1)
def _get_processor_cache(model_id: str, capacity_gb: int):
return ProcessingCache(capacity_gb) if capacity_gb > 0 else None
class MultiModalRegistry:
"""
A registry that dispatches data processing according to the model.
@ -103,31 +96,6 @@ class MultiModalRegistry:
self._processor_factories = ClassRegistry[nn.Module,
_ProcessorFactories]()
def _get_processor_cache(self, model_config: "ModelConfig"):
model_id = model_config.model
capacity_gb = model_config.mm_processor_cache_gb
return _get_processor_cache(model_id, capacity_gb)
def reset_processor_cache(self, model_config: "ModelConfig") -> bool:
"""Reset the multi-modal processing cache."""
if processor_cache := self._get_processor_cache(model_config):
processor_cache.reset()
return True # Success
def enable_mm_input_cache(self, model_config: "ModelConfig") -> bool:
"""Whether the multi-modal input cache should be enabled.
NOTE: This is put under MultiModalRegistry on purpose to respect
text-only mode for multimodal models.
"""
if not self.supports_multimodal_inputs(model_config):
return False
mm_config = model_config.get_multimodal_config()
return mm_config.mm_processor_cache_gb > 0
def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
"""
Checks if the model supports multimodal inputs.
@ -157,6 +125,8 @@ class MultiModalRegistry:
def get_max_tokens_per_item_by_modality(
self,
model_config: "ModelConfig",
*,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
@ -165,11 +135,11 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model:
return {}
processor = self.create_processor(model_config, disable_cache=False)
processor = self.create_processor(model_config, cache=cache)
profiler = MultiModalProfiler(processor)
seq_len = model_config.max_model_len
mm_limits = self.get_mm_limits_per_prompt(model_config)
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
return profiler.get_mm_max_contiguous_tokens(
seq_len,
@ -182,6 +152,8 @@ class MultiModalRegistry:
def get_max_tokens_per_item_by_nonzero_modality(
self,
model_config: "ModelConfig",
*,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
@ -192,15 +164,19 @@ class MultiModalRegistry:
This is currently directly used only in V1 for profiling the memory
usage of a model.
"""
mm_limits = self.get_mm_limits_per_prompt(model_config)
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
max_tokens_per_item = self.get_max_tokens_per_item_by_modality(
model_config,
cache=cache,
)
return {
key: max_tokens_per_mm_item
for key, max_tokens_per_mm_item in
self.get_max_tokens_per_item_by_modality(model_config).items()
for key, max_tokens_per_mm_item in max_tokens_per_item.items()
if mm_limits[key] > 0
}
# TODO: Remove once V0 is gone
def get_max_tokens_by_modality(
self,
model_config: "ModelConfig",
@ -209,14 +185,19 @@ class MultiModalRegistry:
Get the maximum number of tokens from each modality
for profiling the memory usage of a model.
"""
mm_limits = self.get_mm_limits_per_prompt(model_config)
cache = processor_only_cache_from_config(model_config, self)
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
max_tokens_per_item = self.get_max_tokens_per_item_by_modality(
model_config,
cache=cache,
)
return {
key: mm_limits[key] * max_tokens_per_mm_item
for key, max_tokens_per_mm_item in
self.get_max_tokens_per_item_by_modality(model_config).items()
for key, max_tokens_per_mm_item in max_tokens_per_item.items()
}
# TODO: Remove once V0 is gone
def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
"""
Get the maximum number of multi-modal tokens
@ -227,6 +208,8 @@ class MultiModalRegistry:
def get_mm_limits_per_prompt(
self,
model_config: "ModelConfig",
*,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> Mapping[str, int]:
"""
Get the maximum number of multi-modal input instances for each modality
@ -235,7 +218,7 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model:
return {}
processor = self.create_processor(model_config, disable_cache=False)
processor = self.create_processor(model_config, cache=cache)
profiler = MultiModalProfiler(processor)
return profiler.get_mm_limits()
@ -303,7 +286,7 @@ class MultiModalRegistry:
model_config: "ModelConfig",
*,
tokenizer: Optional[AnyTokenizer] = None,
disable_cache: Optional[bool] = None,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
"""
Create a multi-modal processor for a specific model and tokenizer.
@ -311,15 +294,10 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model:
raise ValueError(f"{model_config.model} is not a multimodal model")
if disable_cache is None:
disable_cache = not model_config.enable_mm_processor_cache
model_cls = self._get_model_cls(model_config)
factories = self._processor_factories[model_cls]
ctx = self._create_processing_ctx(model_config, tokenizer)
cache = None if disable_cache else self._get_processor_cache(
model_config)
return factories.build_processor(ctx, cache=cache)
@ -328,13 +306,15 @@ class MultiModalRegistry:
model_config: "ModelConfig",
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
*,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> DummyDecoderData:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by ``model_config``.
"""
processor = self.create_processor(model_config, disable_cache=False)
processor = self.create_processor(model_config, cache=cache)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)
@ -352,13 +332,15 @@ class MultiModalRegistry:
model_config: "ModelConfig",
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
*,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> DummyEncoderData:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by ``model_config``.
"""
processor = self.create_processor(model_config, disable_cache=False)
processor = self.create_processor(model_config, cache=cache)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)

View File

@ -597,8 +597,7 @@ class AsyncLLM(EngineClient):
await asyncio.gather(*coros)
async def reset_mm_cache(self) -> None:
self.processor.mm_registry.reset_processor_cache(self.model_config)
self.processor.mm_input_cache_client.reset()
self.processor.clear_cache()
await self.engine_core.reset_mm_cache_async()
async def reset_prefix_cache(self,

View File

@ -22,6 +22,7 @@ from vllm.logger import init_logger
from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import receiver_cache_from_config
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
@ -38,7 +39,6 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType,
ReconfigureDistributedRequest, ReconfigureRankType,
UtilityOutput, UtilityResult)
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheServer
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig
@ -128,8 +128,9 @@ class EngineCore:
)
self.use_spec_decode = vllm_config.speculative_config is not None
self.mm_input_cache_server = MultiModalInputCacheServer(
vllm_config.model_config, MULTIMODAL_REGISTRY)
self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
self.mm_receiver_cache = receiver_cache_from_config(
vllm_config, mm_registry)
# Setup batch queue for pipeline parallelism.
# Batch queue for scheduled batches. This enables us to asynchronously
@ -370,7 +371,8 @@ class EngineCore:
logger.warning("Resetting the multi-modal cache when requests are "
"in progress may lead to desynced internal caches.")
self.mm_input_cache_server.reset()
if self.mm_receiver_cache is not None:
self.mm_receiver_cache.clear_cache()
def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()
@ -435,10 +437,11 @@ class EngineCore:
assert request.mm_kwargs is not None
# Note on thread safety: no race condition.
# `mm_input_cache_server` is reset at the end of LLMEngine init,
# `mm_receiver_cache` is reset at the end of LLMEngine init,
# and will only accessed in the input processing thread afterwards.
request.mm_kwargs = self.mm_input_cache_server.get_and_update(
request.mm_kwargs, request.mm_hashes)
if self.mm_receiver_cache is not None:
request.mm_kwargs = self.mm_receiver_cache.get_and_update(
request.mm_kwargs, request.mm_hashes)
req = Request.from_engine_core_request(request,
self.request_block_hasher)

View File

@ -271,8 +271,7 @@ class LLMEngine:
self.engine_core.profile(False)
def reset_mm_cache(self):
self.processor.mm_registry.reset_processor_cache(self.model_config)
self.processor.mm_input_cache_client.reset()
self.processor.clear_cache()
self.engine_core.reset_mm_cache()
def reset_prefix_cache(self, device: Optional[Device] = None):

View File

@ -1,121 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional
from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
from vllm.multimodal.inputs import MultiModalKwargsItem
from vllm.utils import is_list_of
if TYPE_CHECKING:
from vllm.config import ModelConfig
# The idea of multimodal input caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the
# server in the core process (=P1).
#
# -- P0:
# - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of
# each input multi-modal item (e.g. image),
# - BaseMultiModalProcessor processes the input items into `mm_kwargs`,
# which are MultiModalKwargsItem instances that each correspond to an
# input multi-modal item.
# - MultiModalInputCacheClient accepts the `mm_kwargs` and corresponding
# `mm_hash` for each item. It stores the `mm_hash` as keys and the size
# of `mm_kwargs`, but not the `mm_kwargs` themselves, to avoid taking
# up additional memory in P0.
# - The `mm_hash` is always sent to P1.
# - The corresponding `mm_kwargs` are only sent to P1 if they are not cached
# in MultiModalInputCacheServer.
#
# -- P1:
# - If the `mm_hash` is cached (i.e. `mm_kwargs` are not sent from P0),
# MultiModalInputCacheServer retrieves the corresponding `mm_kwargs`.
# - If the `mm_hash` is not cached (i.e. `mm_kwargs` are sent from P0),
# MultiModalInputCacheServer stores `mm_kwargs` under the key `mm_hash`.
# - Either way, the `mm_hash` and corresponding `mm_kwargs` are sent to
# the engine for model execution.
#
# Both Client and Server must perform cache update and eviction based on the
# same item size. This ensures that the keys of MultiModalInputCacheClient
# and MultiModalInputCacheServer are mirrored, allowing us to determine in P0
# whether a key is cached in MultiModalInputCacheServer by querying
# MultiModalInputCacheClient without having to communicate with P1.
class MultiModalInputCacheClient:
"""Used by P0 to check whether multi-modal kwargs are cached in P1."""
def __init__(self, model_config: "ModelConfig",
mm_registry: MultiModalRegistry) -> None:
super().__init__()
self.enabled = mm_registry.enable_mm_input_cache(model_config)
self.mm_cache = MultiModalCache.get_lru_cache(
model_config.get_mm_input_cache_gb(),
MultiModalCacheItemMetadata,
)
def get_and_update(
self,
mm_kwargs: Sequence[MultiModalKwargsItem],
mm_hashes: list[str],
) -> list[Optional[MultiModalKwargsItem]]:
if not self.enabled:
return list(mm_kwargs)
assert len(mm_kwargs) == len(mm_hashes)
out_mm_items = list[Optional[MultiModalKwargsItem]]()
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
if self.mm_cache.get(mm_hash) is not None:
out_mm_items.append(None)
else:
self.mm_cache[mm_hash] = \
MultiModalCacheItemMetadata.wraps(mm_item)
out_mm_items.append(mm_item)
return out_mm_items
def reset(self) -> None:
self.mm_cache.clear()
class MultiModalInputCacheServer:
"""Used by P1 to avoid requiring past multi-modal kwargs from P0."""
def __init__(self, model_config: "ModelConfig",
mm_registry: MultiModalRegistry) -> None:
super().__init__()
self.enabled = mm_registry.enable_mm_input_cache(model_config)
self.mm_cache = MultiModalCache.get_lru_cache(
model_config.get_mm_input_cache_gb(),
MultiModalKwargsItem,
)
def get_and_update(
self,
mm_kwargs: Sequence[Optional[MultiModalKwargsItem]],
mm_hashes: list[str],
) -> list[MultiModalKwargsItem]:
if not self.enabled:
mm_kwargs_lst = list(mm_kwargs)
assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem)
return mm_kwargs_lst
assert len(mm_kwargs) == len(mm_hashes)
out_mm_items = list[MultiModalKwargsItem]()
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
if mm_item is None:
out_mm_items.append(self.mm_cache[mm_hash])
else:
self.mm_cache[mm_hash] = mm_item
out_mm_items.append(mm_item)
return out_mm_items
def reset(self) -> None:
self.mm_cache.clear()

View File

@ -11,6 +11,7 @@ from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import argsort_mm_positions
@ -18,7 +19,6 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.backend_lm_format_enforcer import (
@ -47,16 +47,17 @@ class Processor:
self.generation_config_fields = (
self.model_config.try_get_generation_config())
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer,
mm_registry)
self.mm_input_cache_client = MultiModalInputCacheClient(
self.model_config, mm_registry)
self.mm_registry = mm_registry
self.mm_processor_cache = processor_cache_from_config(
vllm_config, mm_registry)
@property
def mm_registry(self):
return self.input_preprocessor.mm_registry
self.input_preprocessor = InputPreprocessor(
self.model_config,
self.tokenizer,
mm_registry,
mm_processor_cache=self.mm_processor_cache,
)
def _validate_logprobs(
self,
@ -310,7 +311,7 @@ class Processor:
# in the input sequence.
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
orig_sorted_mm_inputs = [
sorted_mm_inputs = [
decoder_mm_inputs[modality][idx]
for modality, idx in sorted_mm_idxs
]
@ -323,11 +324,6 @@ class Processor:
for modality, idx in sorted_mm_idxs
]
sorted_mm_inputs = self.mm_input_cache_client.get_and_update(
orig_sorted_mm_inputs,
sorted_mm_hashes,
)
return decoder_inputs.get("prompt"), EngineCoreRequest(
request_id=request_id,
prompt_token_ids=decoder_inputs["prompt_token_ids"],
@ -415,3 +411,6 @@ class Processor:
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
def clear_cache(self) -> None:
self.input_preprocessor.clear_cache()

View File

@ -2186,10 +2186,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
max_items_per_batch: int,
) -> BatchedTensorInputs:
"""Dummy data for profiling and precompiling multimodal models."""
assert self.mm_budget is not None
dummy_decoder_data = self.mm_registry.get_decoder_dummy_data(
model_config=self.model_config,
seq_len=self.max_num_tokens,
mm_counts={modality: 1},
cache=self.mm_budget.cache,
)
dummy_mm_data = dummy_decoder_data.multi_modal_data

View File

@ -1813,10 +1813,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
max_items_per_batch: int,
) -> BatchedTensorInputs:
"""Dummy data for profiling and precompiling multimodal models."""
assert self.mm_budget is not None
dummy_decoder_data = self.mm_registry.get_decoder_dummy_data(
model_config=self.model_config,
seq_len=self.max_num_tokens,
mm_counts={modality: 1},
cache=self.mm_budget.cache,
)
dummy_mm_data = dummy_decoder_data.multi_modal_data

View File

@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import ModelConfig, SchedulerConfig
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.registry import MultiModalRegistry
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
@ -33,14 +34,18 @@ class MultiModalBudget:
self.model_config = model_config
self.scheduler_config = scheduler_config
self.mm_registry = mm_registry
self.cache = cache = processor_only_cache_from_config(
model_config, mm_registry)
self.max_model_len = model_config.max_model_len
self.max_num_reqs = scheduler_config.max_num_seqs
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config)
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config,
cache=cache)
max_tokens_by_modality = mm_registry \
.get_max_tokens_per_item_by_nonzero_modality(model_config)
.get_max_tokens_per_item_by_nonzero_modality(model_config,
cache=cache)
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
scheduler_config,