From 69244e67e6822f1c15816f887659e1ccc18c2632 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 27 Aug 2025 14:19:13 +0800 Subject: [PATCH] [Core] Use key-only cache for `BaseMultiModalProcessor` (#23018) Signed-off-by: DarkLight1337 --- docs/configuration/conserving_memory.md | 2 +- docs/configuration/optimization.md | 44 +- .../multimodal/processing/test_common.py | 8 +- tests/multimodal/test_cache.py | 182 +++++++- vllm/config/__init__.py | 26 +- vllm/engine/arg_utils.py | 14 +- vllm/engine/llm_engine.py | 15 +- vllm/inputs/preprocess.py | 22 +- vllm/inputs/registry.py | 12 +- .../models/hyperclovax_vision.py | 7 +- vllm/model_executor/models/llava.py | 8 +- vllm/model_executor/models/minicpmv.py | 40 +- vllm/model_executor/models/mistral3.py | 8 +- vllm/model_executor/models/phi3v.py | 20 +- vllm/model_executor/models/phi4mm.py | 21 +- vllm/model_executor/models/tarsier.py | 7 +- vllm/multimodal/cache.py | 405 +++++++++++++++++- vllm/multimodal/inputs.py | 38 +- vllm/multimodal/processing.py | 187 ++++---- vllm/multimodal/profiling.py | 4 +- vllm/multimodal/registry.py | 90 ++-- vllm/v1/engine/async_llm.py | 3 +- vllm/v1/engine/core.py | 17 +- vllm/v1/engine/llm_engine.py | 3 +- vllm/v1/engine/mm_input_cache.py | 121 ------ vllm/v1/engine/processor.py | 29 +- vllm/v1/worker/gpu_model_runner.py | 3 + vllm/v1/worker/tpu_model_runner.py | 3 + vllm/v1/worker/utils.py | 9 +- 29 files changed, 954 insertions(+), 394 deletions(-) delete mode 100644 vllm/v1/engine/mm_input_cache.py diff --git a/docs/configuration/conserving_memory.md b/docs/configuration/conserving_memory.md index 058eba5fe0..efda9c8e01 100644 --- a/docs/configuration/conserving_memory.md +++ b/docs/configuration/conserving_memory.md @@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", If you run out of CPU RAM, try the following options: -- (Multi-modal models only) you can set the size of multi-modal processor cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB per API process + 4 GiB per engine core process) +- (Multi-modal models only) you can set the size of multi-modal cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB). - (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). ## Multi-modal input limits diff --git a/docs/configuration/optimization.md b/docs/configuration/optimization.md index bb47e1b90f..3eaf2185a5 100644 --- a/docs/configuration/optimization.md +++ b/docs/configuration/optimization.md @@ -204,20 +204,33 @@ vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2 to avoid CPU resource exhaustion. !!! note - [Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled - because it requires a one-to-one correspondence between API and engine core processes. + API server scale-out disables [multi-modal IPC caching](#ipc-caching) + because it requires a one-to-one correspondance between API and engine core processes. + + This does not impact [multi-modal processor caching](#processor-caching). ## Multi-Modal Caching -### Processor Cache - -By default, the multi-modal processor cache is enabled to avoid repeatedly processing -the same multi-modal inputs via Hugging Face `AutoProcessor`, +Multi-modal caching avoids repeated transfer or processing of the same multi-modal data, which commonly occurs in multi-turn conversations. -You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` -(default 4 GiB per API process + 4 GiB per engine core process). -If you do not benefit much from the cache, you can disable it completely via `mm_processor_cache_gb=0`. +### Processor Caching + +Multi-modal processor caching is automatically enabled +to avoid repeatedly processing the same multi-modal inputs in `BaseMultiModalProcessor`. + +### IPC Caching + +Multi-modal IPC caching is automatically enabled when +there is a one-to-one correspondance between API (`P0`) and engine core (`P1`) processes, +to avoid repeatedly transferring the same multi-modal inputs between them. + +### Configuration + +You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` (default 4 GiB). + +If you do not benefit much from the cache, you can disable both IPC +and processor caching completely via `mm_processor_cache_gb=0`. Examples: @@ -230,3 +243,16 @@ llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", mm_processor_cache_gb=0) ``` + +### Cache Placement + +Based on the configuration, the content of the multi-modal caches on `P0` and `P1` are as follows: + +| Processor Caching | IPC Caching | `P0` Cache | `P1` Cache | Max. Memory | +|-------------------|-------------|------------|------------|-------------| +| ✅ | ✅ | K | K + V | `mm_processor_cache_gb * data_parallel_size` | +| ✅ | ❌ | K + V | N/A | `mm_processor_cache_gb * api_server_count` | +| ❌ | ❌ | N/A | N/A | `0` | + +K: Stores the hashes of multi-modal items +V: Stores the processed tensor data of multi-modal items diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 6361cb9b55..3ff4360b83 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -14,8 +14,9 @@ from PIL import Image from vllm.config import ModelConfig from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict +from vllm.multimodal.cache import MultiModalProcessorOnlyCache from vllm.multimodal.inputs import MultiModalInputs -from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache +from vllm.multimodal.processing import BaseMultiModalProcessor from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, cached_tokenizer_from_config, encode_tokens) @@ -63,6 +64,8 @@ def _test_processing_correctness( revision=model_info.revision, trust_remote_code=model_info.trust_remote_code, hf_overrides=model_info.hf_overrides, + # Ensure that the cache can fit all of the data + mm_processor_cache_gb=2048, ) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) @@ -71,8 +74,7 @@ def _test_processing_correctness( model_config, tokenizer=cached_tokenizer_from_config(model_config), ) - # Ensure that it can fit all of the data - cache = ProcessingCache(capacity_gb=2048) + cache = MultiModalProcessorOnlyCache(model_config) processing_info = factories.info(ctx) supported_mm_limits = processing_info.get_supported_mm_limits() diff --git a/tests/multimodal/test_cache.py b/tests/multimodal/test_cache.py index 088cd00db2..44c05db227 100644 --- a/tests/multimodal/test_cache.py +++ b/tests/multimodal/test_cache.py @@ -1,32 +1,64 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import numpy as np import pytest import torch -from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata +from vllm.config import ModelConfig, ParallelConfig, VllmConfig +from vllm.multimodal.cache import (MultiModalCache, + MultiModalProcessorCacheItem, + MultiModalProcessorCacheItemMetadata, + processor_cache_from_config, + receiver_cache_from_config) +from vllm.multimodal.hasher import MultiModalHasher from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem, MultiModalKwargsItems, MultiModalSharedField) +from vllm.multimodal.processing import PromptInsertion +from vllm.multimodal.registry import MultiModalRegistry -def _dummy_elem(modality: str, key: str, size: int): +def _dummy_elem( + modality: str, + key: str, + size: int, + *, + rng: Optional[np.random.RandomState] = None, +): + if rng is None: + data = torch.empty((size, ), dtype=torch.int8) + else: + data = torch.from_numpy(rng.randint(4, size=(size, ), dtype=np.int8)) + return MultiModalFieldElem( modality=modality, key=key, - data=torch.empty((size, ), dtype=torch.int8), + data=data, field=MultiModalSharedField(1), ) -def _dummy_item(modality: str, size_by_key: dict[str, int]): +def _dummy_item( + modality: str, + size_by_key: dict[str, int], + *, + rng: Optional[np.random.RandomState] = None, +): return MultiModalKwargsItem.from_elems([ - _dummy_elem(modality, key, size) for key, size in size_by_key.items() + _dummy_elem(modality, key, size, rng=rng) + for key, size in size_by_key.items() ]) -def _dummy_items(size_by_key_modality: dict[str, dict[str, int]]): +def _dummy_items( + size_by_key_modality: dict[str, dict[str, int]], + *, + rng: Optional[np.random.RandomState] = None, +): return MultiModalKwargsItems.from_seq([ - _dummy_item(modality, size_by_key) + _dummy_item(modality, size_by_key, rng=rng) for modality, size_by_key in size_by_key_modality.items() ]) @@ -48,5 +80,139 @@ def test_cache_item_size(item, expected_size): cache[""] = item assert cache.currsize == expected_size - cache[""] = MultiModalCacheItemMetadata.wraps(item) + prompt_update = PromptInsertion("dummy", "target", "insertion") \ + .resolve(0) + + cache[""] = MultiModalProcessorCacheItem(item, [prompt_update]) assert cache.currsize == expected_size + + cache[""] = MultiModalProcessorCacheItemMetadata(item, [prompt_update]) + assert cache.currsize == expected_size + + +def _create_vllm_config( + *, + mm_processor_cache_gb: float, + enable_ipc: bool, +): + return VllmConfig( + model_config=ModelConfig(mm_processor_cache_gb=mm_processor_cache_gb), + parallel_config=ParallelConfig( + data_parallel_size=1 if enable_ipc else 2), + ) + + +def _compare_caches( + config_0: VllmConfig, + config_1: VllmConfig, + *, + item_capacity: int = 8, + hit_rate: float = 0.5, + max_items_per_iter: int = 3, + is_cached_calls_per_iter: int, + n_iter: int = 100, + seed: int = 0, +): + mm_registry = MultiModalRegistry() + cache_0_p0 = processor_cache_from_config(config_0, mm_registry) + cache_0_p1 = receiver_cache_from_config(config_0, mm_registry) + cache_1_p0 = processor_cache_from_config(config_1, mm_registry) + cache_1_p1 = receiver_cache_from_config(config_1, mm_registry) + + cache_size_gb = max( + config_0.model_config.mm_processor_cache_gb, + config_1.model_config.mm_processor_cache_gb, + ) + item_size_gb = int(cache_size_gb / item_capacity) + + rng = np.random.RandomState(seed) + all_items = [ + _dummy_item("item", {"key": item_size_gb}, rng=rng) + for _ in range(int(item_capacity / hit_rate)) + ] + all_hashes = [ + MultiModalHasher.hash_kwargs(item=item.get_data()) + for item in all_items + ] + + # Should not be used since there is nothing to convert to text + prompt_update = PromptInsertion("dummy", "target", "insertion") + + for it in range(n_iter): + num_items_to_select = rng.randint(0, max_items_per_iter) + item_idxs_to_select = rng.choice(len(all_items), num_items_to_select) + + selected_items = [all_items[idx] for idx in item_idxs_to_select] + selected_hashes = [all_hashes[idx] for idx in item_idxs_to_select] + + if cache_0_p0 is None: + cache_0_p0_out = selected_items + else: + for _ in range(is_cached_calls_per_iter): + cache_0_p0.is_cached(selected_hashes) + cache_0_p0_out = [ + item for item, _ in cache_0_p0.get_and_update( + [(item, prompt_update.content) for item in selected_items], + selected_hashes, + ) + ] + + if cache_1_p0 is None: + cache_1_p0_out = selected_items + else: + for _ in range(is_cached_calls_per_iter): + cache_1_p0.is_cached(selected_hashes) + cache_1_p0_out = [ + item for item, _ in cache_1_p0.get_and_update( + [(item, prompt_update.content) for item in selected_items], + selected_hashes, + ) + ] + + if cache_0_p1 is None: + cache_0_p1_out = cache_0_p0_out + else: + cache_0_p1_out = cache_0_p1.get_and_update(cache_0_p0_out, + selected_hashes) + + if cache_1_p1 is None: + cache_1_p1_out = cache_1_p0_out + else: + cache_1_p1_out = cache_1_p1.get_and_update(cache_1_p0_out, + selected_hashes) + + assert cache_0_p1_out == cache_1_p1_out, f"Failed at {it=}" + + +@pytest.mark.parametrize("is_cached_calls_per_iter", [1, 2, 3]) +def test_ipc_enable_disable_consistency(is_cached_calls_per_iter): + cache_size_gb = 1 / (1 << 20) + + vllm_config_ipc_enabled = _create_vllm_config( + mm_processor_cache_gb=cache_size_gb, + enable_ipc=True, + ) + vllm_config_ipc_disabled = _create_vllm_config( + mm_processor_cache_gb=0, + enable_ipc=False, + ) + vllm_config_cache_disabled = _create_vllm_config( + mm_processor_cache_gb=cache_size_gb, + enable_ipc=True, + ) + + _compare_caches( + vllm_config_ipc_enabled, + vllm_config_ipc_disabled, + is_cached_calls_per_iter=is_cached_calls_per_iter, + ) + _compare_caches( + vllm_config_ipc_disabled, + vllm_config_cache_disabled, + is_cached_calls_per_iter=is_cached_calls_per_iter, + ) + _compare_caches( + vllm_config_cache_disabled, + vllm_config_ipc_enabled, + is_cached_calls_per_iter=is_cached_calls_per_iter, + ) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index cd0e17977e..ac6f51df95 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -437,7 +437,7 @@ class ModelConfig: from `AutoProcessor.from_pretrained`. The available overrides depend on the model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`. """ - mm_processor_cache_gb: int = 4 + mm_processor_cache_gb: float = 4 """The size (in GiB) of the multi-modal processor cache, which is used to avoid re-processing past multi-modal inputs. @@ -884,12 +884,6 @@ class ModelConfig: return None - def set_mm_processor_cache_gb(self, value: int) -> None: - mm_config = self.get_multimodal_config() - - self.mm_processor_cache_gb = value - mm_config.mm_processor_cache_gb = value - def _get_encoder_config(self): return get_sentence_transformer_tokenizer_config( self.model, self.revision) @@ -1697,22 +1691,6 @@ class ModelConfig: def is_multimodal_model(self) -> bool: return self.multimodal_config is not None - @property - def enable_mm_processor_cache(self) -> bool: - """Whether the multi-modal processor cache should be enabled.""" - mm_config = self.multimodal_config - if mm_config is None: - return False - - return mm_config.mm_processor_cache_gb > 0 - - def get_mm_input_cache_gb(self) -> int: - mm_config = self.multimodal_config - if mm_config is None: - return 0 - - return envs.VLLM_MM_INPUT_CACHE_GIB - @property def is_cross_encoder(self) -> bool: return (self._model_info.supports_cross_encoding @@ -2561,7 +2539,7 @@ class MultiModalConfig: `{"num_crops": 4}`. """ - mm_processor_cache_gb: int = 4 + mm_processor_cache_gb: float = 4 """ The size (in GiB) of the multi-modal processor cache, which is used to diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f24c50ad73..9e7c95ea52 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -351,7 +351,7 @@ class EngineArgs: mm_processor_kwargs: Optional[Dict[str, Any]] = \ MultiModalConfig.mm_processor_kwargs disable_mm_preprocessor_cache: bool = False # DEPRECATED - mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb + mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling # LoRA fields @@ -1293,18 +1293,6 @@ class EngineArgs: worker_extension_cls=self.worker_extension_cls, ) - if model_config.is_multimodal_model: - dp_supports_mm_processor_cache = (self.data_parallel_size == 1 - or data_parallel_external_lb) - if (not dp_supports_mm_processor_cache - and model_config.mm_processor_cache_gb > 0): - logger.warning( - "Multi-modal processor cache is disabled because " - "it is not compatible with data parallelism when " - "there does not exist a one-to-one correspondance " - "between API and engine core processes.") - model_config.set_mm_processor_cache_gb(0) - speculative_config = self.create_speculative_config( target_model_config=model_config, target_parallel_config=parallel_config, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index cbd714c159..03c2f0375d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -36,6 +36,7 @@ from vllm.logits_process import get_bad_words_logits_processors from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.outputs import (PoolingRequestOutput, RequestOutput, RequestOutputFactory) @@ -250,9 +251,13 @@ class LLMEngine: self.generation_config_fields = ( self.model_config.try_get_generation_config()) - self.input_preprocessor = InputPreprocessor(self.model_config, - self.tokenizer, - mm_registry) + self.input_preprocessor = InputPreprocessor( + self.model_config, + self.tokenizer, + mm_registry, + mm_processor_cache=processor_only_cache_from_config( + self.model_config, mm_registry), + ) self.model_executor = executor_class(vllm_config=vllm_config) @@ -840,8 +845,8 @@ class LLMEngine: def reset_mm_cache(self) -> bool: """Reset the multi-modal cache.""" - return self.input_preprocessor.mm_registry.reset_processor_cache( - self.model_config) + self.input_preprocessor.clear_cache() + return True def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: """Reset prefix cache for all devices.""" diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 3f521012e8..f0d0cab3df 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -11,6 +11,7 @@ from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, MultiModalInputs) from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -32,12 +33,14 @@ class InputPreprocessor: model_config: ModelConfig, tokenizer: Optional[TokenizerGroup], mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None, ) -> None: super().__init__() self.model_config = model_config self.tokenizer = tokenizer self.mm_registry = mm_registry + self.mm_processor_cache = mm_processor_cache def get_tokenizer_group(self) -> TokenizerGroup: if self.tokenizer is None: @@ -261,8 +264,11 @@ class InputPreprocessor: """ tokenizer = self._get_mm_tokenizer(lora_request) - mm_processor = self.mm_registry.create_processor(self.model_config, - tokenizer=tokenizer) + mm_processor = self.mm_registry.create_processor( + self.model_config, + tokenizer=tokenizer, + cache=self.mm_processor_cache, + ) if mm_processor_kwargs is None: mm_processor_kwargs = {} @@ -286,8 +292,12 @@ class InputPreprocessor: """ tokenizer = await self._get_mm_tokenizer_async(lora_request) - mm_processor = self.mm_registry.create_processor(self.model_config, - tokenizer=tokenizer) + mm_processor = self.mm_registry.create_processor( + self.model_config, + tokenizer=tokenizer, + cache=self.mm_processor_cache, + ) + if mm_processor_kwargs is None: mm_processor_kwargs = {} @@ -860,3 +870,7 @@ class InputPreprocessor: tokenization_kwargs=tokenization_kwargs, lora_request=lora_request, ) + + def clear_cache(self) -> None: + if self.mm_processor_cache is not None: + self.mm_processor_cache.clear_cache() diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index ef146fdfbf..f0b392e976 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -223,20 +223,26 @@ class InputRegistry: The model is identified by ``model_config``. """ # Avoid circular import + from vllm.multimodal.cache import processor_only_cache_from_config from vllm.sequence import SequenceData if not model_config.is_multimodal_model: seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) return DummyData(seq_data=seq_data) + cache = processor_only_cache_from_config(model_config, mm_registry) + # Encoder dummy data does not contain multi-modal data if is_encoder_data: - enc_data = mm_registry.get_encoder_dummy_data( - model_config, seq_len) + enc_data = mm_registry.get_encoder_dummy_data(model_config, + seq_len, + cache=cache) seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids) return DummyData(seq_data=seq_data) - dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len) + dec_data = mm_registry.get_decoder_dummy_data(model_config, + seq_len, + cache=cache) return DummyData( seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids), diff --git a/vllm/model_executor/models/hyperclovax_vision.py b/vllm/model_executor/models/hyperclovax_vision.py index eeb8291c77..53f0585541 100644 --- a/vllm/model_executor/models/hyperclovax_vision.py +++ b/vllm/model_executor/models/hyperclovax_vision.py @@ -33,12 +33,13 @@ from vllm.inputs import InputProcessingContext from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate) + BaseProcessingInfo, PromptReplacement, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors @@ -367,7 +368,7 @@ def _build_hcxvision_hf_processor( info: HCXVisionProcessingInfo, dummy_inputs: BaseDummyInputsBuilder[HCXVisionProcessingInfo], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: if isinstance(info, HCXVisionProcessingInfo): return HCXVisionMultiModalProcessor( diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index bc53982c93..0ee26b6834 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -22,14 +22,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargsItems) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.jsontree import json_map_leaves @@ -394,7 +394,7 @@ def _build_llava_or_pixtral_hf_processor( info: _I, dummy_inputs: BaseDummyInputsBuilder[_I], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: if isinstance(info, PixtralHFProcessingInfo): return PixtralHFMultiModalProcessor( diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index a2a71bdd12..c22d871ab2 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -58,7 +58,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, VideoItem, VideoProcessorItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptUpdate, PromptUpdateDetails) + PromptUpdate, PromptUpdateDetails, + ResolvedPromptUpdate, _seq2text) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors @@ -744,6 +745,43 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): for modality, pattern in placeholders ] + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + new_update = super()._recompute_cached_prompt_update( + cached_update, + new_item_idx, + ) + + if cached_update.modality == "image": + tokenizer = self.info.get_tokenizer() + image_processor = self.info.get_image_processor() + version = self.info.get_model_version() + + text = _seq2text(tokenizer, cached_update.content.full) + prev_item_idx = cached_update.item_idx + + if version == (2, 0) or version == (2, 5): + im_start = image_processor.im_start_token + im_end = image_processor.im_end_token + else: + im_start = image_processor.im_id_start + im_end = image_processor.im_id_end + + new_update = new_update.with_content( + PromptUpdateDetails.select_text( + text.replace( + f"{im_start}{prev_item_idx}{im_end}", + f"{im_start}{new_item_idx}{im_end}", + 1, + ), + "", + )) + + return new_update + def _get_mm_fields_config( self, hf_inputs: BatchFeature, diff --git a/vllm/model_executor/models/mistral3.py b/vllm/model_executor/models/mistral3.py index 438513433d..08948960b2 100644 --- a/vllm/model_executor/models/mistral3.py +++ b/vllm/model_executor/models/mistral3.py @@ -22,14 +22,14 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargsItems) from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate, - PromptUpdateDetails) + BaseProcessingInfo, PromptReplacement, + PromptUpdate, PromptUpdateDetails) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape @@ -322,7 +322,7 @@ def _build_mistral3_processor( info: _I, dummy_inputs: BaseDummyInputsBuilder[_I], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: assert isinstance(info, Mistral3ProcessingInfo) return Mistral3MultiModalProcessor( diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 61e09d5604..4522c7043d 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -41,7 +41,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, MultiModalPromptUpdates, PlaceholderFeaturesInfo, - PromptReplacement, PromptUpdate) + PromptReplacement, PromptUpdate, + ResolvedPromptUpdate) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors @@ -440,6 +441,23 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): ) ] + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + new_update = super()._recompute_cached_prompt_update( + cached_update, + new_item_idx, + ) + + if cached_update.modality == "image": + hf_processor = self.info.get_hf_processor() + image_tokens: list[str] = hf_processor.img_tokens # type: ignore + new_update = new_update.with_target(image_tokens[new_item_idx]) + + return new_update + def _apply_prompt_updates( self, token_ids: list[int], diff --git a/vllm/model_executor/models/phi4mm.py b/vllm/model_executor/models/phi4mm.py index 5129770e8d..211cbd9c81 100644 --- a/vllm/model_executor/models/phi4mm.py +++ b/vllm/model_executor/models/phi4mm.py @@ -27,7 +27,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems, MultiModalDataItems, MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, PromptReplacement, - PromptUpdate) + PromptUpdate, ResolvedPromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -850,6 +850,25 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): ), ] + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + new_update = super()._recompute_cached_prompt_update( + cached_update, + new_item_idx, + ) + + if cached_update.modality == "image": + image_tokens: list[str] = self.info.image_tokens # type: ignore + new_update = new_update.with_target(image_tokens[new_item_idx]) + elif cached_update.modality == "audio": + audio_tokens: list[str] = self.info.audio_tokens # type: ignore + new_update = new_update.with_target(audio_tokens[new_item_idx]) + + return new_update + @MULTIMODAL_REGISTRY.register_processor( Phi4MMMultiModalProcessor, diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index 9b9cca8c6b..c66867315e 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -25,12 +25,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.llava import LlavaDummyInputsBuilder from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import BaseMultiModalProcessorCache from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize, MultiModalDataItems) from vllm.multimodal.processing import (BaseMultiModalProcessor, - BaseProcessingInfo, ProcessingCache, - PromptReplacement, PromptUpdate) + BaseProcessingInfo, PromptReplacement, + PromptUpdate) from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors from vllm.utils.jsontree import json_map_leaves @@ -332,7 +333,7 @@ def _build_tarsier_hf_processor( info: _I_Tarsier, dummy_inputs: BaseDummyInputsBuilder[_I_Tarsier], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor: if isinstance(info, TarsierProcessingInfo): return TarsierMultiModalProcessor( diff --git a/vllm/multimodal/cache.py b/vllm/multimodal/cache.py index 5cec8e71fb..0e81cb6d4d 100644 --- a/vllm/multimodal/cache.py +++ b/vllm/multimodal/cache.py @@ -1,11 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import sys -from collections.abc import Mapping -from dataclasses import dataclass -from typing import TypeVar, Union +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union import torch +from typing_extensions import TypeAlias, override from vllm.logger import init_logger from vllm.utils import GiB_bytes, LRUCache @@ -15,24 +16,67 @@ from .inputs import (MultiModalFieldElem, MultiModalKwargs, MultiModalKwargsItem, MultiModalKwargsItems, NestedTensors) +if TYPE_CHECKING: + from vllm.config import ModelConfig, VllmConfig + + from .processing import ResolvedPromptUpdate + from .registry import MultiModalRegistry + logger = init_logger(__name__) -@dataclass -class MultiModalCacheItemMetadata: - size: int +class MultiModalProcessorCacheItem: + """ + The data to store inside `MultiModalProcessorOnlyCache`. - @classmethod - def wraps(cls, value: "MultiModalCacheValue"): - return cls(size=MultiModalCache.get_item_size(value)) + Args: + item: The processed tensor data corresponding to a multi-modal item. + prompt_updates: The prompt updates corresponding to `item`. + """ + + def __init__( + self, + item: MultiModalKwargsItem, + prompt_updates: Sequence["ResolvedPromptUpdate"], + ) -> None: + super().__init__() + + self.item = item + self.prompt_updates = prompt_updates + + +class MultiModalProcessorCacheItemMetadata: + """ + The metadata to store inside `MultiModalProcessorSenderCache`. + + Args: + item: The processed tensor data corresponding to a multi-modal item. + Since P1 already stores the tensor data, we only store its size + metadata in P0 to reduce memory usage. The size metadata is still + needed to keep the same cache eviction policy as P0. + prompt_updates: The prompt updates corresponding to `item`. + This needs to stay on P0 because for some models, they are + dependent on the processed tensor data (cached on P1). + """ + + def __init__( + self, + item: MultiModalKwargsItem, + prompt_updates: Sequence["ResolvedPromptUpdate"], + ) -> None: + super().__init__() + + self.item_size = MultiModalCache.get_item_size(item) + self.prompt_updates = prompt_updates MultiModalCacheValue = Union[ + MultiModalProcessorCacheItem, + MultiModalProcessorCacheItemMetadata, MultiModalKwargsItems, MultiModalKwargsItem, MultiModalKwargs, Mapping[str, NestedTensors], - MultiModalCacheItemMetadata, ] _V = TypeVar("_V", bound=MultiModalCacheValue) @@ -47,8 +91,10 @@ class MultiModalCache: *, debug: bool = False, ) -> int: - if isinstance(leaf, MultiModalFieldElem): - return cls.get_item_size(leaf.data) # type: ignore + if isinstance(leaf, MultiModalProcessorCacheItem): + return cls.get_leaf_size(leaf.item) + if isinstance(leaf, MultiModalProcessorCacheItemMetadata): + return leaf.item_size # These are not subclasses of dict if isinstance(leaf, MultiModalKwargsItems): @@ -58,13 +104,13 @@ class MultiModalCache: if isinstance(leaf, MultiModalKwargs): return cls.get_item_size(leaf.data) # type: ignore + if isinstance(leaf, MultiModalFieldElem): + return cls.get_item_size(leaf.data) # type: ignore + # sys.getsizeof doesn't work for tensors if isinstance(leaf, torch.Tensor): return leaf.nbytes - if isinstance(leaf, MultiModalCacheItemMetadata): - return leaf.size - return sys.getsizeof(leaf) @classmethod @@ -98,3 +144,332 @@ class MultiModalCache: GiB_bytes * capacity_gb, getsizeof=lambda x: cls.get_item_size(x, debug=debug), ) + + +_I = TypeVar("_I", contravariant=True) +_O = TypeVar("_O", covariant=True) + + +class BaseMultiModalCache(ABC, Generic[_I, _O]): + """ + Abstract base class to read/write multi-modal items from cache. + + The idea of multi-modal caching is based on having a client and server + where the client executes in the frontend process (=P0) and + the server in the core process (=P1). The data flow is as follows: + + ``` + is_cached() x N get_and_update() + P0: From API -----------------> -----------------> To P1 + + get_and_update() + P1: From P0 -----------------> To model + ``` + + `is_cached()` can be called any number of times in P0. However, + `get_and_update()` must be called in P0 and P1 one after another + so that their cache eviction order remains the same. + + This ensures that the keys in P0 and P1 caches are mirrored, + allowing us to determine whether a key is cached in P1 by looking + up the P0 cache, without having to communicate with P1. + """ + + @abstractmethod + def get_and_update_item( + self, + mm_item: _I, + mm_hash: str, + ) -> _O: + """ + Possibly update a multi-modal item based on whether it is + in the underlying cache. + + This update is done out-of-place and updates the cache eviction order. + + Args: + mm_item: The multi-modal item to update. + mm_hash: The hash of `mm_item`. + + Returns: + The update multi-modal item. + """ + raise NotImplementedError + + def get_and_update( + self, + mm_items: Sequence[_I], + mm_hashes: list[str], + ) -> list[_O]: + """ + Possibly update a sequence of multi-modal items based on whether they + are in the underlying cache. + + This update is done out-of-place and updates the cache eviction order. + + Args: + mm_items: The multi-modal items to update. + mm_hashes: The hash of each item in `mm_items`. + + Returns: + A new list of updated multi-modal items. + """ + assert len(mm_items) == len(mm_hashes) + + return [ + self.get_and_update_item(mm_item, mm_hash) + for mm_item, mm_hash in zip(mm_items, mm_hashes) + ] + + @abstractmethod + def clear_cache(self) -> None: + """Clear the underlying cache.""" + raise NotImplementedError + + +MultiModalProcessorCacheInItem: TypeAlias = \ + Optional[tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]]] + + +MultiModalProcessorCacheOutItem: TypeAlias = \ + tuple[Optional[MultiModalKwargsItem], Sequence["ResolvedPromptUpdate"]] + + +class BaseMultiModalProcessorCache( + BaseMultiModalCache[MultiModalProcessorCacheInItem, + MultiModalProcessorCacheOutItem]): + """The required interface for caches on P0.""" + + @abstractmethod + def is_cached_item(self, mm_hash: str) -> bool: + """ + Check whether a multi-modal item is + in the underlying cache. + + This **DOES NOT** update the cache eviction order. + + Args: + mm_hash: The hash of the item to check. + + Returns: + `True` if the item is cached, otherwise `False`. + """ + raise NotImplementedError + + def is_cached(self, mm_hashes: list[str]) -> list[bool]: + """ + Check whether a sequence of multi-modal items are + in the underlying cache. + + This **DOES NOT** update the cache eviction order. + + Args: + mm_hashes: The hash of each item to check. + + Returns: + For each item, `True` if the item is cached, otherwise `False`. + """ + return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes] + + +class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache): + """ + The cache which is used on P0 when IPC caching is disabled. + + How to update each item: + + - If the item is in the cache, replace the input with the cached item. + - If the item is not in the cache, store that item (which includes + tensor data and metadata) into the cache, and return the input. + """ + + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() + + mm_config = model_config.get_multimodal_config() + + self._cache = MultiModalCache.get_lru_cache( + mm_config.mm_processor_cache_gb, + MultiModalProcessorCacheItem, + ) + + @override + def is_cached_item(self, mm_hash: str) -> bool: + return mm_hash in self._cache + + @override + def get_and_update_item( + self, + mm_item: MultiModalProcessorCacheInItem, + mm_hash: str, + ) -> MultiModalProcessorCacheOutItem: + if (cached_item := self._cache.get(mm_hash)) is not None: + return cached_item.item, cached_item.prompt_updates + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + self._cache[mm_hash] = MultiModalProcessorCacheItem(*mm_item) + + return mm_item + + @override + def clear_cache(self) -> None: + self._cache.clear() + + +class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache): + """ + The cache which is used on P0 when IPC caching is enabled. + + How to update each item: + + - If the item is already in the cache, clear the input to avoid + unnecessary IPC. + + - If the item is not in the cache, store the metadata of that item so + that the eviction policy remains the same as the cache on P1, + and return the input. + By only storing the metadata, we avoid keeping the data itself in + memory inside P0. + """ + + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() + + mm_config = model_config.get_multimodal_config() + + self._cache = MultiModalCache.get_lru_cache( + mm_config.mm_processor_cache_gb, + MultiModalProcessorCacheItemMetadata, + ) + + @override + def is_cached_item(self, mm_hash: str) -> bool: + return mm_hash in self._cache + + @override + def get_and_update_item( + self, + mm_item: MultiModalProcessorCacheInItem, + mm_hash: str, + ) -> MultiModalProcessorCacheOutItem: + if (cached_item := self._cache.get(mm_hash)) is not None: + return None, cached_item.prompt_updates + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + self._cache[mm_hash] = MultiModalProcessorCacheItemMetadata(*mm_item) + + return mm_item + + @override + def clear_cache(self) -> None: + self._cache.clear() + + +def _enable_processor_cache( + model_config: "ModelConfig", + mm_registry: "MultiModalRegistry", +) -> bool: + if not mm_registry.supports_multimodal_inputs(model_config): + return False + + mm_config = model_config.get_multimodal_config() + return mm_config.mm_processor_cache_gb > 0 + + +def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool: + parallel_config = vllm_config.parallel_config + supports_ipc_cache = (parallel_config.data_parallel_size == 1 + or parallel_config.data_parallel_external_lb) + + return supports_ipc_cache + + +def processor_cache_from_config( + vllm_config: "VllmConfig", + mm_registry: "MultiModalRegistry", +) -> Optional[BaseMultiModalProcessorCache]: + """Return a `BaseMultiModalProcessorCache`, if enabled.""" + model_config = vllm_config.model_config + + if not _enable_processor_cache(model_config, mm_registry): + return None + + if not _enable_ipc_cache(vllm_config): + return MultiModalProcessorOnlyCache(model_config) + + return MultiModalProcessorSenderCache(model_config) + + +def processor_only_cache_from_config( + model_config: "ModelConfig", + mm_registry: "MultiModalRegistry", +): + """Return a `MultiModalProcessorOnlyCache`, if enabled.""" + if not _enable_processor_cache(model_config, mm_registry): + return None + + return MultiModalProcessorOnlyCache(model_config) + + +class BaseMultiModalReceiverCache( + BaseMultiModalCache[Optional[MultiModalKwargsItem], + MultiModalKwargsItem]): + """The required interface for caches on P1.""" + + +class MultiModalReceiverCache(BaseMultiModalReceiverCache): + """ + The cache which is used on P1 when IPC caching is enabled. + + How to update each item: + + - If the item is in the cache, replace the input with the cached item. + - If the item is not in the cache, store that item (which includes tensor + data) into the cache, and return the input. + """ + + def __init__(self, model_config: "ModelConfig") -> None: + super().__init__() + + mm_config = model_config.get_multimodal_config() + + self._cache = MultiModalCache.get_lru_cache( + mm_config.mm_processor_cache_gb, + MultiModalKwargsItem, + ) + + @override + def get_and_update_item( + self, + mm_item: Optional[MultiModalKwargsItem], + mm_hash: str, + ) -> MultiModalKwargsItem: + if (cached_item := self._cache.get(mm_hash)) is not None: + return cached_item + + assert mm_item is not None, f"Expected a cached item for {mm_hash=}" + + self._cache[mm_hash] = mm_item + return mm_item + + @override + def clear_cache(self) -> None: + self._cache.clear() + + +def receiver_cache_from_config( + vllm_config: "VllmConfig", + mm_registry: "MultiModalRegistry", +) -> Optional[BaseMultiModalReceiverCache]: + """Return a `BaseMultiModalReceiverCache`, if enabled.""" + model_config = vllm_config.model_config + + if not _enable_processor_cache(model_config, mm_registry): + return None + + if not _enable_ipc_cache(vllm_config): + return None + + return MultiModalReceiverCache(model_config) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 581f9a109c..2c0ebaced6 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -7,11 +7,11 @@ from collections.abc import Mapping, Sequence from dataclasses import dataclass from functools import partial from itertools import accumulate -from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar, - Union, cast, final) +from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, + cast, final) import numpy as np -from typing_extensions import NotRequired, TypeAlias, deprecated +from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated from vllm.utils import LazyLoader, full_groupby, is_list_of from vllm.utils.jsontree import JSONTree, json_map_leaves @@ -668,7 +668,15 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): return {key: elem.data for key, elem in self.items()} -class MultiModalKwargsItems(UserDict[str, Sequence[MultiModalKwargsItem]]): +_I = TypeVar( + "_I", + MultiModalKwargsItem, + Optional[MultiModalKwargsItem], + default=MultiModalKwargsItem, +) + + +class MultiModalKwargsItems(UserDict[str, Sequence[_I]]): """ A dictionary of [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s @@ -714,27 +722,37 @@ class MultiModalKwargsItems(UserDict[str, Sequence[MultiModalKwargsItem]]): items_by_modality = full_groupby(items, key=lambda x: x.modality) return MultiModalKwargsItems(items_by_modality) - def __getitem__(self, modality: str): + def __getitem__(self, modality: str) -> Sequence[_I]: if modality not in self: raise KeyError(f"Modality {modality!r} not found. " f"Available modalities: {set(self.keys())}") - return super().__getitem__(modality) + return super().__getitem__(modality) # type: ignore[return-value] def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs": elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) - for items in self.values(): - for item in items: + for modality, items in self.items(): + for i, item in enumerate(items): + if item is None: + raise RuntimeError("Cannot build data from empty " + f"mm_items[{modality}][{i}]") + for key, elem in item.items(): elems_by_key[key].append(elem) return MultiModalKwargs({ key: elems[0].field.reduce_data(elems, pin_memory=pin_memory) - for key, elems in elems_by_key.items() if len(elems) > 0 + for key, elems in elems_by_key.items() }) +MultiModalKwargsOptionalItems: TypeAlias = Union[ + MultiModalKwargsItems[MultiModalKwargsItem], + MultiModalKwargsItems[Optional[MultiModalKwargsItem]], +] + + class MultiModalKwargs(UserDict[str, NestedTensors]): """ A dictionary that represents the keyword arguments to @@ -898,7 +916,7 @@ class MultiModalInputs(TypedDict): token_type_ids: NotRequired[list[int]] """The token type IDs of the prompt.""" - mm_kwargs: MultiModalKwargsItems + mm_kwargs: MultiModalKwargsOptionalItems """Keyword arguments to be directly passed to the model after batching.""" mm_hashes: "MultiModalHashDict" diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 8c225e2a3c..6ecdf80d4a 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from collections import defaultdict from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, Sequence) -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from enum import Enum from functools import lru_cache from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, @@ -20,12 +20,11 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, encode_tokens) from vllm.utils import flatten_2d_lists, full_groupby -from .cache import MultiModalCache from .hasher import MultiModalHasher from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargsItem, MultiModalKwargsItems, - PlaceholderRange) + MultiModalKwargsOptionalItems, PlaceholderRange) from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems, MultiModalDataParser) @@ -34,6 +33,7 @@ if TYPE_CHECKING: from transformers.feature_extraction_utils import BatchFeature from transformers.processing_utils import ProcessorMixin + from .cache import BaseMultiModalProcessorCache from .profiling import BaseDummyInputsBuilder logger = init_logger(__name__) @@ -557,6 +557,15 @@ class ResolvedPromptUpdate: return self.iter_token_matches(prompt, tokenizer, start_idx=start_idx) + def with_target(self, target: UpdateTarget): + return replace(self, target=target) + + def with_content(self, content: PromptUpdateInfo): + if not isinstance(content, PromptUpdateDetails): + content = PromptUpdateDetails.from_seq(content) + + return replace(self, content=content) + class _TokenMatch(NamedTuple): start_idx: int @@ -865,21 +874,6 @@ def find_mm_placeholders( return dict(full_groupby_modality(it)) -class ProcessingCache(MultiModalCache): - - def __init__(self, capacity_gb: float) -> None: - super().__init__() - - self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem) - - self.get = self._cache.get - self.put = self._cache.put - self.reset = self._cache.clear - - -_CacheItemOrHash = Union[MultiModalKwargsItem, str] - - class BaseProcessingInfo: """Base class to provide the information necessary for data processing.""" @@ -982,7 +976,7 @@ For an item `MultiModalPromptUpdates[k][i]`, class MultiModalProcessingInfo(NamedTuple): - kwargs: MultiModalKwargsItems + kwargs: MultiModalKwargsOptionalItems hashes: MultiModalHashes prompt_updates: MultiModalPromptUpdates @@ -994,11 +988,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): Not to be confused with `transformers.ProcessorMixin`. """ - def __init__(self, - info: _I, - dummy_inputs: "BaseDummyInputsBuilder[_I]", - *, - cache: Optional[ProcessingCache] = None) -> None: + def __init__( + self, + info: _I, + dummy_inputs: "BaseDummyInputsBuilder[_I]", + *, + cache: Optional["BaseMultiModalProcessorCache"] = None, + ) -> None: super().__init__() self.info = info @@ -1355,32 +1351,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): return prompt_ids, mm_processed_data, False - def _get_cache_missing_items( - self, - cache: ProcessingCache, - mm_data_items: MultiModalDataItems, - mm_hashes: MultiModalHashes, - ) -> tuple[dict[str, list[_CacheItemOrHash]], MultiModalDataItems]: - mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]] = { - modality: [(h if (v := cache.get(h)) is None else v) - for h in hashes] - for modality, hashes in mm_hashes.items() - } - - mm_missing_idxs = { - modality: [ - idx for idx, item_or_hash in enumerate(items_or_hashes) - if isinstance(item_or_hash, str) - ] - for modality, items_or_hashes in mm_cache_items_or_hashes.items() - } - mm_missing_data = { - modality: [mm_data_items[modality][idx] for idx in idxs] - for modality, idxs in mm_missing_idxs.items() - } - - return mm_cache_items_or_hashes, self._to_mm_items(mm_missing_data) - def _hash_mm_items( self, mm_items: MultiModalDataItems, @@ -1401,28 +1371,92 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): for modality, items in mm_items.items() } + def _get_cache_missing_items( + self, + cache: "BaseMultiModalProcessorCache", + mm_data_items: MultiModalDataItems, + mm_hashes: MultiModalHashes, + ) -> MultiModalDataItems: + mm_is_cached = { + modality: cache.is_cached(hashes) + for modality, hashes in mm_hashes.items() + } + + mm_missing_idxs = { + modality: [ + idx for idx, item_is_cached in enumerate(items_is_cached) + if not item_is_cached + ] + for modality, items_is_cached in mm_is_cached.items() + } + mm_missing_data = { + modality: [mm_data_items[modality][idx] for idx in idxs] + for modality, idxs in mm_missing_idxs.items() + } + + return self._to_mm_items(mm_missing_data) + + def _recompute_cached_prompt_update( + self, + cached_update: ResolvedPromptUpdate, + new_item_idx: int, + ) -> ResolvedPromptUpdate: + """ + Override this if other attributes of `ResolvedPromptUpdate` + also need to be recomputed after retrieving from the cache. + """ + return replace(cached_update, item_idx=new_item_idx) + def _merge_mm_kwargs( self, - cache: ProcessingCache, - mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]], + cache: "BaseMultiModalProcessorCache", + mm_hashes: MultiModalHashes, mm_missing_kwargs: MultiModalKwargsItems, - ) -> MultiModalKwargsItems: + mm_missing_prompt_updates: MultiModalPromptUpdates, + ) -> tuple[MultiModalKwargsOptionalItems, MultiModalPromptUpdates]: + # Need to calculate this at the beginning to avoid skipping cache logic + # for subsequently repeated items in the same modality + mm_is_cached = { + modality: cache.is_cached(hashes) + for modality, hashes in mm_hashes.items() + } + mm_missing_next_idx = defaultdict[str, int](lambda: 0) - merged_items = defaultdict[str, list[MultiModalKwargsItem]](list) - for modality, items_or_hashes in mm_cache_items_or_hashes.items(): - for item_or_hash in items_or_hashes: - if isinstance(item_or_hash, str): - kw_item = mm_missing_kwargs[modality][ - mm_missing_next_idx[modality]] - cache.put(item_or_hash, kw_item) + merged_kwargs = defaultdict[str, + list[Optional[MultiModalKwargsItem]]](list) + merged_prompt_updates = defaultdict[ + str, list[Sequence[ResolvedPromptUpdate]]](list) + for modality, hashes in mm_hashes.items(): + missing_kwargs = mm_missing_kwargs.get(modality, []) + missing_prompt_updates = mm_missing_prompt_updates.get( + modality, []) + + for item_idx, item_hash in enumerate(hashes): + kwargs: Optional[MultiModalKwargsItem] + if not mm_is_cached[modality][item_idx]: + missing_next_idx = mm_missing_next_idx[modality] + kwargs = missing_kwargs[missing_next_idx] + updates = missing_prompt_updates[missing_next_idx] + mm_missing_next_idx[modality] += 1 + + item = kwargs, updates else: - kw_item = item_or_hash + item = None - merged_items[modality].append(kw_item) + kwargs, updates = cache.get_and_update_item(item, item_hash) - return MultiModalKwargsItems(merged_items) + merged_kwargs[modality].append(kwargs) + merged_prompt_updates[modality].append([ + self._recompute_cached_prompt_update(update, item_idx) + for update in updates + ]) + + mm_kwargs = MultiModalKwargsItems(merged_kwargs) + mm_prompt_updates = dict(merged_prompt_updates) + + return mm_kwargs, mm_prompt_updates def _apply_hf_processor( self, @@ -1490,10 +1524,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, tokenization_kwargs) - ( - mm_cache_items_or_hashes, - mm_missing_data_items, - ) = self._get_cache_missing_items( + + mm_missing_data_items = self._get_cache_missing_items( cache=cache, mm_data_items=mm_data_items, mm_hashes=mm_hashes, @@ -1520,16 +1552,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): hf_processor_mm_kwargs), ) - mm_kwargs = self._merge_mm_kwargs( - cache, - mm_cache_items_or_hashes=mm_cache_items_or_hashes, - mm_missing_kwargs=mm_missing_kwargs, + mm_missing_prompt_updates = self._get_mm_prompt_updates( + mm_missing_data_items, + hf_processor_mm_kwargs, + mm_missing_kwargs, ) - mm_prompt_updates = self._get_mm_prompt_updates( - mm_data_items, - hf_processor_mm_kwargs, - mm_kwargs, + mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs( + cache, + mm_hashes=mm_hashes, + mm_missing_kwargs=mm_missing_kwargs, + mm_missing_prompt_updates=mm_missing_prompt_updates, ) mm_info = MultiModalProcessingInfo( @@ -1614,7 +1647,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): def _validate_mm_kwargs( self, - mm_kwargs: MultiModalKwargsItems, + mm_kwargs: MultiModalKwargsOptionalItems, mm_item_counts: Mapping[str, int], ) -> None: for modality, item_count in mm_item_counts.items(): @@ -1655,7 +1688,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): self, mm_items: MultiModalDataItems, prompt_ids: list[int], - mm_kwargs: MultiModalKwargsItems, + mm_kwargs: MultiModalKwargsOptionalItems, mm_prompt_updates: MultiModalPromptUpdates, is_update_applied: bool, ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index ea2efbdd8b..ffc69a2db6 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -13,7 +13,7 @@ import vllm.envs as envs from vllm.logger import init_logger from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, - MultiModalInputs, MultiModalKwargsItems, + MultiModalInputs, MultiModalKwargsOptionalItems, MultiModalPlaceholderDict) from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, EncDecMultiModalProcessor) @@ -43,7 +43,7 @@ class DummyDecoderData(NamedTuple): """Dummy data used for profiling.""" prompt_token_ids: list[int] - multi_modal_data: MultiModalKwargsItems + multi_modal_data: MultiModalKwargsOptionalItems multi_modal_placeholders: MultiModalPlaceholderDict diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 8cd9e56048..38adbf8f35 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Mapping from dataclasses import dataclass -from functools import lru_cache from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar import torch.nn as nn @@ -13,8 +12,9 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, cached_tokenizer_from_config) from vllm.utils import ClassRegistry -from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, - ProcessingCache) +from .cache import (BaseMultiModalProcessorCache, + processor_only_cache_from_config) +from .processing import BaseMultiModalProcessor, BaseProcessingInfo from .profiling import (BaseDummyInputsBuilder, DummyDecoderData, DummyEncoderData, MultiModalProfiler) @@ -65,7 +65,7 @@ class MultiModalProcessorFactory(Protocol[_I]): info: _I, dummy_inputs: BaseDummyInputsBuilder[_I], *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor[_I]: ... @@ -80,20 +80,13 @@ class _ProcessorFactories(Generic[_I]): self, ctx: InputProcessingContext, *, - cache: Optional[ProcessingCache] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ): info = self.info(ctx) dummy_inputs_builder = self.dummy_inputs(info) return self.processor(info, dummy_inputs_builder, cache=cache) -# Make sure a different cache is used for each model config -# NOTE: ModelConfig is not hashable so it cannot be passed directly -@lru_cache(maxsize=1) -def _get_processor_cache(model_id: str, capacity_gb: int): - return ProcessingCache(capacity_gb) if capacity_gb > 0 else None - - class MultiModalRegistry: """ A registry that dispatches data processing according to the model. @@ -103,31 +96,6 @@ class MultiModalRegistry: self._processor_factories = ClassRegistry[nn.Module, _ProcessorFactories]() - def _get_processor_cache(self, model_config: "ModelConfig"): - model_id = model_config.model - capacity_gb = model_config.mm_processor_cache_gb - return _get_processor_cache(model_id, capacity_gb) - - def reset_processor_cache(self, model_config: "ModelConfig") -> bool: - """Reset the multi-modal processing cache.""" - if processor_cache := self._get_processor_cache(model_config): - processor_cache.reset() - - return True # Success - - def enable_mm_input_cache(self, model_config: "ModelConfig") -> bool: - """Whether the multi-modal input cache should be enabled. - NOTE: This is put under MultiModalRegistry on purpose to respect - text-only mode for multimodal models. - """ - - if not self.supports_multimodal_inputs(model_config): - return False - - mm_config = model_config.get_multimodal_config() - - return mm_config.mm_processor_cache_gb > 0 - def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool: """ Checks if the model supports multimodal inputs. @@ -157,6 +125,8 @@ class MultiModalRegistry: def get_max_tokens_per_item_by_modality( self, model_config: "ModelConfig", + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> Mapping[str, int]: """ Get the maximum number of tokens per data item from each modality based @@ -165,11 +135,11 @@ class MultiModalRegistry: if not model_config.is_multimodal_model: return {} - processor = self.create_processor(model_config, disable_cache=False) + processor = self.create_processor(model_config, cache=cache) profiler = MultiModalProfiler(processor) seq_len = model_config.max_model_len - mm_limits = self.get_mm_limits_per_prompt(model_config) + mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) return profiler.get_mm_max_contiguous_tokens( seq_len, @@ -182,6 +152,8 @@ class MultiModalRegistry: def get_max_tokens_per_item_by_nonzero_modality( self, model_config: "ModelConfig", + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> Mapping[str, int]: """ Get the maximum number of tokens per data item from each modality based @@ -192,15 +164,19 @@ class MultiModalRegistry: This is currently directly used only in V1 for profiling the memory usage of a model. """ - mm_limits = self.get_mm_limits_per_prompt(model_config) + mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) + max_tokens_per_item = self.get_max_tokens_per_item_by_modality( + model_config, + cache=cache, + ) return { key: max_tokens_per_mm_item - for key, max_tokens_per_mm_item in - self.get_max_tokens_per_item_by_modality(model_config).items() + for key, max_tokens_per_mm_item in max_tokens_per_item.items() if mm_limits[key] > 0 } + # TODO: Remove once V0 is gone def get_max_tokens_by_modality( self, model_config: "ModelConfig", @@ -209,14 +185,19 @@ class MultiModalRegistry: Get the maximum number of tokens from each modality for profiling the memory usage of a model. """ - mm_limits = self.get_mm_limits_per_prompt(model_config) + cache = processor_only_cache_from_config(model_config, self) + mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache) + max_tokens_per_item = self.get_max_tokens_per_item_by_modality( + model_config, + cache=cache, + ) return { key: mm_limits[key] * max_tokens_per_mm_item - for key, max_tokens_per_mm_item in - self.get_max_tokens_per_item_by_modality(model_config).items() + for key, max_tokens_per_mm_item in max_tokens_per_item.items() } + # TODO: Remove once V0 is gone def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int: """ Get the maximum number of multi-modal tokens @@ -227,6 +208,8 @@ class MultiModalRegistry: def get_mm_limits_per_prompt( self, model_config: "ModelConfig", + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> Mapping[str, int]: """ Get the maximum number of multi-modal input instances for each modality @@ -235,7 +218,7 @@ class MultiModalRegistry: if not model_config.is_multimodal_model: return {} - processor = self.create_processor(model_config, disable_cache=False) + processor = self.create_processor(model_config, cache=cache) profiler = MultiModalProfiler(processor) return profiler.get_mm_limits() @@ -303,7 +286,7 @@ class MultiModalRegistry: model_config: "ModelConfig", *, tokenizer: Optional[AnyTokenizer] = None, - disable_cache: Optional[bool] = None, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> BaseMultiModalProcessor[BaseProcessingInfo]: """ Create a multi-modal processor for a specific model and tokenizer. @@ -311,15 +294,10 @@ class MultiModalRegistry: if not model_config.is_multimodal_model: raise ValueError(f"{model_config.model} is not a multimodal model") - if disable_cache is None: - disable_cache = not model_config.enable_mm_processor_cache - model_cls = self._get_model_cls(model_config) factories = self._processor_factories[model_cls] ctx = self._create_processing_ctx(model_config, tokenizer) - cache = None if disable_cache else self._get_processor_cache( - model_config) return factories.build_processor(ctx, cache=cache) @@ -328,13 +306,15 @@ class MultiModalRegistry: model_config: "ModelConfig", seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> DummyDecoderData: """ Create dummy data for profiling the memory usage of a model. The model is identified by ``model_config``. """ - processor = self.create_processor(model_config, disable_cache=False) + processor = self.create_processor(model_config, cache=cache) profiler = MultiModalProfiler(processor) dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts) @@ -352,13 +332,15 @@ class MultiModalRegistry: model_config: "ModelConfig", seq_len: int, mm_counts: Optional[Mapping[str, int]] = None, + *, + cache: Optional[BaseMultiModalProcessorCache] = None, ) -> DummyEncoderData: """ Create dummy data for profiling the memory usage of a model. The model is identified by ``model_config``. """ - processor = self.create_processor(model_config, disable_cache=False) + processor = self.create_processor(model_config, cache=cache) profiler = MultiModalProfiler(processor) dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 342d7b24f8..dbea0b610b 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -597,8 +597,7 @@ class AsyncLLM(EngineClient): await asyncio.gather(*coros) async def reset_mm_cache(self) -> None: - self.processor.mm_registry.reset_processor_cache(self.model_config) - self.processor.mm_input_cache_client.reset() + self.processor.clear_cache() await self.engine_core.reset_mm_cache_async() async def reset_prefix_cache(self, diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 32765cda64..b614828061 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -22,6 +22,7 @@ from vllm.logger import init_logger from vllm.logging_utils.dump_input import dump_engine_exception from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.cache import receiver_cache_from_config from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) @@ -38,7 +39,6 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, ReconfigureDistributedRequest, ReconfigureRankType, UtilityOutput, UtilityResult) -from vllm.v1.engine.mm_input_cache import MultiModalInputCacheServer from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses from vllm.v1.executor.abstract import Executor from vllm.v1.kv_cache_interface import KVCacheConfig @@ -128,8 +128,9 @@ class EngineCore: ) self.use_spec_decode = vllm_config.speculative_config is not None - self.mm_input_cache_server = MultiModalInputCacheServer( - vllm_config.model_config, MULTIMODAL_REGISTRY) + self.mm_registry = mm_registry = MULTIMODAL_REGISTRY + self.mm_receiver_cache = receiver_cache_from_config( + vllm_config, mm_registry) # Setup batch queue for pipeline parallelism. # Batch queue for scheduled batches. This enables us to asynchronously @@ -370,7 +371,8 @@ class EngineCore: logger.warning("Resetting the multi-modal cache when requests are " "in progress may lead to desynced internal caches.") - self.mm_input_cache_server.reset() + if self.mm_receiver_cache is not None: + self.mm_receiver_cache.clear_cache() def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() @@ -435,10 +437,11 @@ class EngineCore: assert request.mm_kwargs is not None # Note on thread safety: no race condition. - # `mm_input_cache_server` is reset at the end of LLMEngine init, + # `mm_receiver_cache` is reset at the end of LLMEngine init, # and will only accessed in the input processing thread afterwards. - request.mm_kwargs = self.mm_input_cache_server.get_and_update( - request.mm_kwargs, request.mm_hashes) + if self.mm_receiver_cache is not None: + request.mm_kwargs = self.mm_receiver_cache.get_and_update( + request.mm_kwargs, request.mm_hashes) req = Request.from_engine_core_request(request, self.request_block_hasher) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 5a00a93095..7130f666ef 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -271,8 +271,7 @@ class LLMEngine: self.engine_core.profile(False) def reset_mm_cache(self): - self.processor.mm_registry.reset_processor_cache(self.model_config) - self.processor.mm_input_cache_client.reset() + self.processor.clear_cache() self.engine_core.reset_mm_cache() def reset_prefix_cache(self, device: Optional[Device] = None): diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py deleted file mode 100644 index aa7dc62fd4..0000000000 --- a/vllm/v1/engine/mm_input_cache.py +++ /dev/null @@ -1,121 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Sequence -from typing import TYPE_CHECKING, Optional - -from vllm.multimodal import MultiModalRegistry -from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata -from vllm.multimodal.inputs import MultiModalKwargsItem -from vllm.utils import is_list_of - -if TYPE_CHECKING: - from vllm.config import ModelConfig - -# The idea of multimodal input caching is based on having a client and -# a server, where the client executes in the frontend process (=P0) and the -# server in the core process (=P1). -# -# -- P0: -# - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of -# each input multi-modal item (e.g. image), -# - BaseMultiModalProcessor processes the input items into `mm_kwargs`, -# which are MultiModalKwargsItem instances that each correspond to an -# input multi-modal item. -# - MultiModalInputCacheClient accepts the `mm_kwargs` and corresponding -# `mm_hash` for each item. It stores the `mm_hash` as keys and the size -# of `mm_kwargs`, but not the `mm_kwargs` themselves, to avoid taking -# up additional memory in P0. -# - The `mm_hash` is always sent to P1. -# - The corresponding `mm_kwargs` are only sent to P1 if they are not cached -# in MultiModalInputCacheServer. -# -# -- P1: -# - If the `mm_hash` is cached (i.e. `mm_kwargs` are not sent from P0), -# MultiModalInputCacheServer retrieves the corresponding `mm_kwargs`. -# - If the `mm_hash` is not cached (i.e. `mm_kwargs` are sent from P0), -# MultiModalInputCacheServer stores `mm_kwargs` under the key `mm_hash`. -# - Either way, the `mm_hash` and corresponding `mm_kwargs` are sent to -# the engine for model execution. -# -# Both Client and Server must perform cache update and eviction based on the -# same item size. This ensures that the keys of MultiModalInputCacheClient -# and MultiModalInputCacheServer are mirrored, allowing us to determine in P0 -# whether a key is cached in MultiModalInputCacheServer by querying -# MultiModalInputCacheClient without having to communicate with P1. - - -class MultiModalInputCacheClient: - """Used by P0 to check whether multi-modal kwargs are cached in P1.""" - - def __init__(self, model_config: "ModelConfig", - mm_registry: MultiModalRegistry) -> None: - super().__init__() - - self.enabled = mm_registry.enable_mm_input_cache(model_config) - self.mm_cache = MultiModalCache.get_lru_cache( - model_config.get_mm_input_cache_gb(), - MultiModalCacheItemMetadata, - ) - - def get_and_update( - self, - mm_kwargs: Sequence[MultiModalKwargsItem], - mm_hashes: list[str], - ) -> list[Optional[MultiModalKwargsItem]]: - if not self.enabled: - return list(mm_kwargs) - - assert len(mm_kwargs) == len(mm_hashes) - - out_mm_items = list[Optional[MultiModalKwargsItem]]() - for mm_item, mm_hash in zip(mm_kwargs, mm_hashes): - if self.mm_cache.get(mm_hash) is not None: - out_mm_items.append(None) - else: - self.mm_cache[mm_hash] = \ - MultiModalCacheItemMetadata.wraps(mm_item) - out_mm_items.append(mm_item) - - return out_mm_items - - def reset(self) -> None: - self.mm_cache.clear() - - -class MultiModalInputCacheServer: - """Used by P1 to avoid requiring past multi-modal kwargs from P0.""" - - def __init__(self, model_config: "ModelConfig", - mm_registry: MultiModalRegistry) -> None: - super().__init__() - - self.enabled = mm_registry.enable_mm_input_cache(model_config) - self.mm_cache = MultiModalCache.get_lru_cache( - model_config.get_mm_input_cache_gb(), - MultiModalKwargsItem, - ) - - def get_and_update( - self, - mm_kwargs: Sequence[Optional[MultiModalKwargsItem]], - mm_hashes: list[str], - ) -> list[MultiModalKwargsItem]: - if not self.enabled: - mm_kwargs_lst = list(mm_kwargs) - assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem) - return mm_kwargs_lst - - assert len(mm_kwargs) == len(mm_hashes) - - out_mm_items = list[MultiModalKwargsItem]() - for mm_item, mm_hash in zip(mm_kwargs, mm_hashes): - if mm_item is None: - out_mm_items.append(self.mm_cache[mm_hash]) - else: - self.mm_cache[mm_hash] = mm_item - out_mm_items.append(mm_item) - - return out_mm_items - - def reset(self) -> None: - self.mm_cache.clear() diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 300b0713b2..7ed6015662 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -11,6 +11,7 @@ from vllm.inputs.parse import split_enc_dec_inputs from vllm.inputs.preprocess import InputPreprocessor from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal.cache import processor_cache_from_config from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.utils import argsort_mm_positions @@ -18,7 +19,6 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreRequest -from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) from vllm.v1.structured_output.backend_lm_format_enforcer import ( @@ -47,16 +47,17 @@ class Processor: self.generation_config_fields = ( self.model_config.try_get_generation_config()) - self.input_preprocessor = InputPreprocessor(self.model_config, - self.tokenizer, - mm_registry) - self.mm_input_cache_client = MultiModalInputCacheClient( - self.model_config, mm_registry) + self.mm_registry = mm_registry + self.mm_processor_cache = processor_cache_from_config( + vllm_config, mm_registry) - @property - def mm_registry(self): - return self.input_preprocessor.mm_registry + self.input_preprocessor = InputPreprocessor( + self.model_config, + self.tokenizer, + mm_registry, + mm_processor_cache=self.mm_processor_cache, + ) def _validate_logprobs( self, @@ -310,7 +311,7 @@ class Processor: # in the input sequence. sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions) - orig_sorted_mm_inputs = [ + sorted_mm_inputs = [ decoder_mm_inputs[modality][idx] for modality, idx in sorted_mm_idxs ] @@ -323,11 +324,6 @@ class Processor: for modality, idx in sorted_mm_idxs ] - sorted_mm_inputs = self.mm_input_cache_client.get_and_update( - orig_sorted_mm_inputs, - sorted_mm_hashes, - ) - return decoder_inputs.get("prompt"), EngineCoreRequest( request_id=request_id, prompt_token_ids=decoder_inputs["prompt_token_ids"], @@ -415,3 +411,6 @@ class Processor: # TODO: Find out how many placeholder tokens are there so we can # check that chunked prefill does not truncate them # max_batch_len = self.scheduler_config.max_num_batched_tokens + + def clear_cache(self) -> None: + self.input_preprocessor.clear_cache() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f1ceaaae62..053aaf4f96 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2186,10 +2186,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): max_items_per_batch: int, ) -> BatchedTensorInputs: """Dummy data for profiling and precompiling multimodal models.""" + assert self.mm_budget is not None + dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, seq_len=self.max_num_tokens, mm_counts={modality: 1}, + cache=self.mm_budget.cache, ) dummy_mm_data = dummy_decoder_data.multi_modal_data diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 4a485b7e07..d364236604 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1813,10 +1813,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): max_items_per_batch: int, ) -> BatchedTensorInputs: """Dummy data for profiling and precompiling multimodal models.""" + assert self.mm_budget is not None + dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( model_config=self.model_config, seq_len=self.max_num_tokens, mm_counts={modality: 1}, + cache=self.mm_budget.cache, ) dummy_mm_data = dummy_decoder_data.multi_modal_data diff --git a/vllm/v1/worker/utils.py b/vllm/v1/worker/utils.py index b96473e7b1..82ede5ad8e 100644 --- a/vllm/v1/worker/utils.py +++ b/vllm/v1/worker/utils.py @@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.config import ModelConfig, SchedulerConfig from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.utils import extract_layer_index +from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.registry import MultiModalRegistry from vllm.v1.attention.backends.utils import AttentionMetadataBuilder from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget @@ -33,14 +34,18 @@ class MultiModalBudget: self.model_config = model_config self.scheduler_config = scheduler_config self.mm_registry = mm_registry + self.cache = cache = processor_only_cache_from_config( + model_config, mm_registry) self.max_model_len = model_config.max_model_len self.max_num_reqs = scheduler_config.max_num_seqs - self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config) + self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config, + cache=cache) max_tokens_by_modality = mm_registry \ - .get_max_tokens_per_item_by_nonzero_modality(model_config) + .get_max_tokens_per_item_by_nonzero_modality(model_config, + cache=cache) encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget( scheduler_config,