[Core] Use key-only cache for BaseMultiModalProcessor
(#23018)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
|
||||
If you run out of CPU RAM, try the following options:
|
||||
|
||||
- (Multi-modal models only) you can set the size of multi-modal processor cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB per API process + 4 GiB per engine core process)
|
||||
- (Multi-modal models only) you can set the size of multi-modal cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB).
|
||||
- (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB).
|
||||
|
||||
## Multi-modal input limits
|
||||
|
@ -204,20 +204,33 @@ vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2
|
||||
to avoid CPU resource exhaustion.
|
||||
|
||||
!!! note
|
||||
[Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled
|
||||
because it requires a one-to-one correspondence between API and engine core processes.
|
||||
API server scale-out disables [multi-modal IPC caching](#ipc-caching)
|
||||
because it requires a one-to-one correspondance between API and engine core processes.
|
||||
|
||||
This does not impact [multi-modal processor caching](#processor-caching).
|
||||
|
||||
## Multi-Modal Caching
|
||||
|
||||
### Processor Cache
|
||||
|
||||
By default, the multi-modal processor cache is enabled to avoid repeatedly processing
|
||||
the same multi-modal inputs via Hugging Face `AutoProcessor`,
|
||||
Multi-modal caching avoids repeated transfer or processing of the same multi-modal data,
|
||||
which commonly occurs in multi-turn conversations.
|
||||
|
||||
You can adjust the size of the cache by setting the value of `mm_processor_cache_gb`
|
||||
(default 4 GiB per API process + 4 GiB per engine core process).
|
||||
If you do not benefit much from the cache, you can disable it completely via `mm_processor_cache_gb=0`.
|
||||
### Processor Caching
|
||||
|
||||
Multi-modal processor caching is automatically enabled
|
||||
to avoid repeatedly processing the same multi-modal inputs in `BaseMultiModalProcessor`.
|
||||
|
||||
### IPC Caching
|
||||
|
||||
Multi-modal IPC caching is automatically enabled when
|
||||
there is a one-to-one correspondance between API (`P0`) and engine core (`P1`) processes,
|
||||
to avoid repeatedly transferring the same multi-modal inputs between them.
|
||||
|
||||
### Configuration
|
||||
|
||||
You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` (default 4 GiB).
|
||||
|
||||
If you do not benefit much from the cache, you can disable both IPC
|
||||
and processor caching completely via `mm_processor_cache_gb=0`.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -230,3 +243,16 @@ llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
mm_processor_cache_gb=0)
|
||||
```
|
||||
|
||||
### Cache Placement
|
||||
|
||||
Based on the configuration, the content of the multi-modal caches on `P0` and `P1` are as follows:
|
||||
|
||||
| Processor Caching | IPC Caching | `P0` Cache | `P1` Cache | Max. Memory |
|
||||
|-------------------|-------------|------------|------------|-------------|
|
||||
| ✅ | ✅ | K | K + V | `mm_processor_cache_gb * data_parallel_size` |
|
||||
| ✅ | ❌ | K + V | N/A | `mm_processor_cache_gb * api_server_count` |
|
||||
| ❌ | ❌ | N/A | N/A | `0` |
|
||||
|
||||
K: Stores the hashes of multi-modal items
|
||||
V: Stores the processed tensor data of multi-modal items
|
||||
|
@ -14,8 +14,9 @@ from PIL import Image
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
||||
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
|
||||
from vllm.multimodal.inputs import MultiModalInputs
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
cached_tokenizer_from_config,
|
||||
encode_tokens)
|
||||
@ -63,6 +64,8 @@ def _test_processing_correctness(
|
||||
revision=model_info.revision,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
# Ensure that the cache can fit all of the data
|
||||
mm_processor_cache_gb=2048,
|
||||
)
|
||||
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
@ -71,8 +74,7 @@ def _test_processing_correctness(
|
||||
model_config,
|
||||
tokenizer=cached_tokenizer_from_config(model_config),
|
||||
)
|
||||
# Ensure that it can fit all of the data
|
||||
cache = ProcessingCache(capacity_gb=2048)
|
||||
cache = MultiModalProcessorOnlyCache(model_config)
|
||||
|
||||
processing_info = factories.info(ctx)
|
||||
supported_mm_limits = processing_info.get_supported_mm_limits()
|
||||
|
@ -1,32 +1,64 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
|
||||
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
|
||||
from vllm.multimodal.cache import (MultiModalCache,
|
||||
MultiModalProcessorCacheItem,
|
||||
MultiModalProcessorCacheItemMetadata,
|
||||
processor_cache_from_config,
|
||||
receiver_cache_from_config)
|
||||
from vllm.multimodal.hasher import MultiModalHasher
|
||||
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalSharedField)
|
||||
from vllm.multimodal.processing import PromptInsertion
|
||||
from vllm.multimodal.registry import MultiModalRegistry
|
||||
|
||||
|
||||
def _dummy_elem(modality: str, key: str, size: int):
|
||||
def _dummy_elem(
|
||||
modality: str,
|
||||
key: str,
|
||||
size: int,
|
||||
*,
|
||||
rng: Optional[np.random.RandomState] = None,
|
||||
):
|
||||
if rng is None:
|
||||
data = torch.empty((size, ), dtype=torch.int8)
|
||||
else:
|
||||
data = torch.from_numpy(rng.randint(4, size=(size, ), dtype=np.int8))
|
||||
|
||||
return MultiModalFieldElem(
|
||||
modality=modality,
|
||||
key=key,
|
||||
data=torch.empty((size, ), dtype=torch.int8),
|
||||
data=data,
|
||||
field=MultiModalSharedField(1),
|
||||
)
|
||||
|
||||
|
||||
def _dummy_item(modality: str, size_by_key: dict[str, int]):
|
||||
def _dummy_item(
|
||||
modality: str,
|
||||
size_by_key: dict[str, int],
|
||||
*,
|
||||
rng: Optional[np.random.RandomState] = None,
|
||||
):
|
||||
return MultiModalKwargsItem.from_elems([
|
||||
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
|
||||
_dummy_elem(modality, key, size, rng=rng)
|
||||
for key, size in size_by_key.items()
|
||||
])
|
||||
|
||||
|
||||
def _dummy_items(size_by_key_modality: dict[str, dict[str, int]]):
|
||||
def _dummy_items(
|
||||
size_by_key_modality: dict[str, dict[str, int]],
|
||||
*,
|
||||
rng: Optional[np.random.RandomState] = None,
|
||||
):
|
||||
return MultiModalKwargsItems.from_seq([
|
||||
_dummy_item(modality, size_by_key)
|
||||
_dummy_item(modality, size_by_key, rng=rng)
|
||||
for modality, size_by_key in size_by_key_modality.items()
|
||||
])
|
||||
|
||||
@ -48,5 +80,139 @@ def test_cache_item_size(item, expected_size):
|
||||
cache[""] = item
|
||||
assert cache.currsize == expected_size
|
||||
|
||||
cache[""] = MultiModalCacheItemMetadata.wraps(item)
|
||||
prompt_update = PromptInsertion("dummy", "target", "insertion") \
|
||||
.resolve(0)
|
||||
|
||||
cache[""] = MultiModalProcessorCacheItem(item, [prompt_update])
|
||||
assert cache.currsize == expected_size
|
||||
|
||||
cache[""] = MultiModalProcessorCacheItemMetadata(item, [prompt_update])
|
||||
assert cache.currsize == expected_size
|
||||
|
||||
|
||||
def _create_vllm_config(
|
||||
*,
|
||||
mm_processor_cache_gb: float,
|
||||
enable_ipc: bool,
|
||||
):
|
||||
return VllmConfig(
|
||||
model_config=ModelConfig(mm_processor_cache_gb=mm_processor_cache_gb),
|
||||
parallel_config=ParallelConfig(
|
||||
data_parallel_size=1 if enable_ipc else 2),
|
||||
)
|
||||
|
||||
|
||||
def _compare_caches(
|
||||
config_0: VllmConfig,
|
||||
config_1: VllmConfig,
|
||||
*,
|
||||
item_capacity: int = 8,
|
||||
hit_rate: float = 0.5,
|
||||
max_items_per_iter: int = 3,
|
||||
is_cached_calls_per_iter: int,
|
||||
n_iter: int = 100,
|
||||
seed: int = 0,
|
||||
):
|
||||
mm_registry = MultiModalRegistry()
|
||||
cache_0_p0 = processor_cache_from_config(config_0, mm_registry)
|
||||
cache_0_p1 = receiver_cache_from_config(config_0, mm_registry)
|
||||
cache_1_p0 = processor_cache_from_config(config_1, mm_registry)
|
||||
cache_1_p1 = receiver_cache_from_config(config_1, mm_registry)
|
||||
|
||||
cache_size_gb = max(
|
||||
config_0.model_config.mm_processor_cache_gb,
|
||||
config_1.model_config.mm_processor_cache_gb,
|
||||
)
|
||||
item_size_gb = int(cache_size_gb / item_capacity)
|
||||
|
||||
rng = np.random.RandomState(seed)
|
||||
all_items = [
|
||||
_dummy_item("item", {"key": item_size_gb}, rng=rng)
|
||||
for _ in range(int(item_capacity / hit_rate))
|
||||
]
|
||||
all_hashes = [
|
||||
MultiModalHasher.hash_kwargs(item=item.get_data())
|
||||
for item in all_items
|
||||
]
|
||||
|
||||
# Should not be used since there is nothing to convert to text
|
||||
prompt_update = PromptInsertion("dummy", "target", "insertion")
|
||||
|
||||
for it in range(n_iter):
|
||||
num_items_to_select = rng.randint(0, max_items_per_iter)
|
||||
item_idxs_to_select = rng.choice(len(all_items), num_items_to_select)
|
||||
|
||||
selected_items = [all_items[idx] for idx in item_idxs_to_select]
|
||||
selected_hashes = [all_hashes[idx] for idx in item_idxs_to_select]
|
||||
|
||||
if cache_0_p0 is None:
|
||||
cache_0_p0_out = selected_items
|
||||
else:
|
||||
for _ in range(is_cached_calls_per_iter):
|
||||
cache_0_p0.is_cached(selected_hashes)
|
||||
cache_0_p0_out = [
|
||||
item for item, _ in cache_0_p0.get_and_update(
|
||||
[(item, prompt_update.content) for item in selected_items],
|
||||
selected_hashes,
|
||||
)
|
||||
]
|
||||
|
||||
if cache_1_p0 is None:
|
||||
cache_1_p0_out = selected_items
|
||||
else:
|
||||
for _ in range(is_cached_calls_per_iter):
|
||||
cache_1_p0.is_cached(selected_hashes)
|
||||
cache_1_p0_out = [
|
||||
item for item, _ in cache_1_p0.get_and_update(
|
||||
[(item, prompt_update.content) for item in selected_items],
|
||||
selected_hashes,
|
||||
)
|
||||
]
|
||||
|
||||
if cache_0_p1 is None:
|
||||
cache_0_p1_out = cache_0_p0_out
|
||||
else:
|
||||
cache_0_p1_out = cache_0_p1.get_and_update(cache_0_p0_out,
|
||||
selected_hashes)
|
||||
|
||||
if cache_1_p1 is None:
|
||||
cache_1_p1_out = cache_1_p0_out
|
||||
else:
|
||||
cache_1_p1_out = cache_1_p1.get_and_update(cache_1_p0_out,
|
||||
selected_hashes)
|
||||
|
||||
assert cache_0_p1_out == cache_1_p1_out, f"Failed at {it=}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_cached_calls_per_iter", [1, 2, 3])
|
||||
def test_ipc_enable_disable_consistency(is_cached_calls_per_iter):
|
||||
cache_size_gb = 1 / (1 << 20)
|
||||
|
||||
vllm_config_ipc_enabled = _create_vllm_config(
|
||||
mm_processor_cache_gb=cache_size_gb,
|
||||
enable_ipc=True,
|
||||
)
|
||||
vllm_config_ipc_disabled = _create_vllm_config(
|
||||
mm_processor_cache_gb=0,
|
||||
enable_ipc=False,
|
||||
)
|
||||
vllm_config_cache_disabled = _create_vllm_config(
|
||||
mm_processor_cache_gb=cache_size_gb,
|
||||
enable_ipc=True,
|
||||
)
|
||||
|
||||
_compare_caches(
|
||||
vllm_config_ipc_enabled,
|
||||
vllm_config_ipc_disabled,
|
||||
is_cached_calls_per_iter=is_cached_calls_per_iter,
|
||||
)
|
||||
_compare_caches(
|
||||
vllm_config_ipc_disabled,
|
||||
vllm_config_cache_disabled,
|
||||
is_cached_calls_per_iter=is_cached_calls_per_iter,
|
||||
)
|
||||
_compare_caches(
|
||||
vllm_config_cache_disabled,
|
||||
vllm_config_ipc_enabled,
|
||||
is_cached_calls_per_iter=is_cached_calls_per_iter,
|
||||
)
|
||||
|
@ -437,7 +437,7 @@ class ModelConfig:
|
||||
from `AutoProcessor.from_pretrained`. The available overrides depend on the
|
||||
model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`.
|
||||
"""
|
||||
mm_processor_cache_gb: int = 4
|
||||
mm_processor_cache_gb: float = 4
|
||||
"""The size (in GiB) of the multi-modal processor cache, which is used to
|
||||
avoid re-processing past multi-modal inputs.
|
||||
|
||||
@ -884,12 +884,6 @@ class ModelConfig:
|
||||
|
||||
return None
|
||||
|
||||
def set_mm_processor_cache_gb(self, value: int) -> None:
|
||||
mm_config = self.get_multimodal_config()
|
||||
|
||||
self.mm_processor_cache_gb = value
|
||||
mm_config.mm_processor_cache_gb = value
|
||||
|
||||
def _get_encoder_config(self):
|
||||
return get_sentence_transformer_tokenizer_config(
|
||||
self.model, self.revision)
|
||||
@ -1697,22 +1691,6 @@ class ModelConfig:
|
||||
def is_multimodal_model(self) -> bool:
|
||||
return self.multimodal_config is not None
|
||||
|
||||
@property
|
||||
def enable_mm_processor_cache(self) -> bool:
|
||||
"""Whether the multi-modal processor cache should be enabled."""
|
||||
mm_config = self.multimodal_config
|
||||
if mm_config is None:
|
||||
return False
|
||||
|
||||
return mm_config.mm_processor_cache_gb > 0
|
||||
|
||||
def get_mm_input_cache_gb(self) -> int:
|
||||
mm_config = self.multimodal_config
|
||||
if mm_config is None:
|
||||
return 0
|
||||
|
||||
return envs.VLLM_MM_INPUT_CACHE_GIB
|
||||
|
||||
@property
|
||||
def is_cross_encoder(self) -> bool:
|
||||
return (self._model_info.supports_cross_encoding
|
||||
@ -2561,7 +2539,7 @@ class MultiModalConfig:
|
||||
`{"num_crops": 4}`.
|
||||
"""
|
||||
|
||||
mm_processor_cache_gb: int = 4
|
||||
mm_processor_cache_gb: float = 4
|
||||
"""
|
||||
The size (in GiB) of the multi-modal processor cache, which is used to
|
||||
|
||||
|
@ -351,7 +351,7 @@ class EngineArgs:
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = \
|
||||
MultiModalConfig.mm_processor_kwargs
|
||||
disable_mm_preprocessor_cache: bool = False # DEPRECATED
|
||||
mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb
|
||||
mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
|
||||
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
|
||||
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
|
||||
# LoRA fields
|
||||
@ -1293,18 +1293,6 @@ class EngineArgs:
|
||||
worker_extension_cls=self.worker_extension_cls,
|
||||
)
|
||||
|
||||
if model_config.is_multimodal_model:
|
||||
dp_supports_mm_processor_cache = (self.data_parallel_size == 1
|
||||
or data_parallel_external_lb)
|
||||
if (not dp_supports_mm_processor_cache
|
||||
and model_config.mm_processor_cache_gb > 0):
|
||||
logger.warning(
|
||||
"Multi-modal processor cache is disabled because "
|
||||
"it is not compatible with data parallelism when "
|
||||
"there does not exist a one-to-one correspondance "
|
||||
"between API and engine core processes.")
|
||||
model_config.set_mm_processor_cache_gb(0)
|
||||
|
||||
speculative_config = self.create_speculative_config(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=parallel_config,
|
||||
|
@ -36,6 +36,7 @@ from vllm.logits_process import get_bad_words_logits_processors
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.cache import processor_only_cache_from_config
|
||||
from vllm.multimodal.processing import EncDecMultiModalProcessor
|
||||
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
|
||||
RequestOutputFactory)
|
||||
@ -250,9 +251,13 @@ class LLMEngine:
|
||||
self.generation_config_fields = (
|
||||
self.model_config.try_get_generation_config())
|
||||
|
||||
self.input_preprocessor = InputPreprocessor(self.model_config,
|
||||
self.input_preprocessor = InputPreprocessor(
|
||||
self.model_config,
|
||||
self.tokenizer,
|
||||
mm_registry)
|
||||
mm_registry,
|
||||
mm_processor_cache=processor_only_cache_from_config(
|
||||
self.model_config, mm_registry),
|
||||
)
|
||||
|
||||
self.model_executor = executor_class(vllm_config=vllm_config)
|
||||
|
||||
@ -840,8 +845,8 @@ class LLMEngine:
|
||||
|
||||
def reset_mm_cache(self) -> bool:
|
||||
"""Reset the multi-modal cache."""
|
||||
return self.input_preprocessor.mm_registry.reset_processor_cache(
|
||||
self.model_config)
|
||||
self.input_preprocessor.clear_cache()
|
||||
return True
|
||||
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
|
||||
"""Reset prefix cache for all devices."""
|
||||
|
@ -11,6 +11,7 @@ from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
MultiModalInputs)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
@ -32,12 +33,14 @@ class InputPreprocessor:
|
||||
model_config: ModelConfig,
|
||||
tokenizer: Optional[TokenizerGroup],
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.model_config = model_config
|
||||
self.tokenizer = tokenizer
|
||||
self.mm_registry = mm_registry
|
||||
self.mm_processor_cache = mm_processor_cache
|
||||
|
||||
def get_tokenizer_group(self) -> TokenizerGroup:
|
||||
if self.tokenizer is None:
|
||||
@ -261,8 +264,11 @@ class InputPreprocessor:
|
||||
"""
|
||||
tokenizer = self._get_mm_tokenizer(lora_request)
|
||||
|
||||
mm_processor = self.mm_registry.create_processor(self.model_config,
|
||||
tokenizer=tokenizer)
|
||||
mm_processor = self.mm_registry.create_processor(
|
||||
self.model_config,
|
||||
tokenizer=tokenizer,
|
||||
cache=self.mm_processor_cache,
|
||||
)
|
||||
|
||||
if mm_processor_kwargs is None:
|
||||
mm_processor_kwargs = {}
|
||||
@ -286,8 +292,12 @@ class InputPreprocessor:
|
||||
"""
|
||||
tokenizer = await self._get_mm_tokenizer_async(lora_request)
|
||||
|
||||
mm_processor = self.mm_registry.create_processor(self.model_config,
|
||||
tokenizer=tokenizer)
|
||||
mm_processor = self.mm_registry.create_processor(
|
||||
self.model_config,
|
||||
tokenizer=tokenizer,
|
||||
cache=self.mm_processor_cache,
|
||||
)
|
||||
|
||||
if mm_processor_kwargs is None:
|
||||
mm_processor_kwargs = {}
|
||||
|
||||
@ -860,3 +870,7 @@ class InputPreprocessor:
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
if self.mm_processor_cache is not None:
|
||||
self.mm_processor_cache.clear_cache()
|
||||
|
@ -223,20 +223,26 @@ class InputRegistry:
|
||||
The model is identified by ``model_config``.
|
||||
"""
|
||||
# Avoid circular import
|
||||
from vllm.multimodal.cache import processor_only_cache_from_config
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
if not model_config.is_multimodal_model:
|
||||
seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
|
||||
return DummyData(seq_data=seq_data)
|
||||
|
||||
cache = processor_only_cache_from_config(model_config, mm_registry)
|
||||
|
||||
# Encoder dummy data does not contain multi-modal data
|
||||
if is_encoder_data:
|
||||
enc_data = mm_registry.get_encoder_dummy_data(
|
||||
model_config, seq_len)
|
||||
enc_data = mm_registry.get_encoder_dummy_data(model_config,
|
||||
seq_len,
|
||||
cache=cache)
|
||||
seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids)
|
||||
return DummyData(seq_data=seq_data)
|
||||
|
||||
dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len)
|
||||
dec_data = mm_registry.get_decoder_dummy_data(model_config,
|
||||
seq_len,
|
||||
cache=cache)
|
||||
|
||||
return DummyData(
|
||||
seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
|
||||
|
@ -33,12 +33,13 @@ from vllm.inputs import InputProcessingContext
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, ProcessingCache,
|
||||
PromptReplacement, PromptUpdate)
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -367,7 +368,7 @@ def _build_hcxvision_hf_processor(
|
||||
info: HCXVisionProcessingInfo,
|
||||
dummy_inputs: BaseDummyInputsBuilder[HCXVisionProcessingInfo],
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
) -> BaseMultiModalProcessor:
|
||||
if isinstance(info, HCXVisionProcessingInfo):
|
||||
return HCXVisionMultiModalProcessor(
|
||||
|
@ -22,14 +22,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, ProcessingCache,
|
||||
PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
@ -394,7 +394,7 @@ def _build_llava_or_pixtral_hf_processor(
|
||||
info: _I,
|
||||
dummy_inputs: BaseDummyInputsBuilder[_I],
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
) -> BaseMultiModalProcessor:
|
||||
if isinstance(info, PixtralHFProcessingInfo):
|
||||
return PixtralHFMultiModalProcessor(
|
||||
|
@ -58,7 +58,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
|
||||
VideoItem, VideoProcessorItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
PromptUpdate, PromptUpdateDetails,
|
||||
ResolvedPromptUpdate, _seq2text)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -744,6 +745,43 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
for modality, pattern in placeholders
|
||||
]
|
||||
|
||||
def _recompute_cached_prompt_update(
|
||||
self,
|
||||
cached_update: ResolvedPromptUpdate,
|
||||
new_item_idx: int,
|
||||
) -> ResolvedPromptUpdate:
|
||||
new_update = super()._recompute_cached_prompt_update(
|
||||
cached_update,
|
||||
new_item_idx,
|
||||
)
|
||||
|
||||
if cached_update.modality == "image":
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
image_processor = self.info.get_image_processor()
|
||||
version = self.info.get_model_version()
|
||||
|
||||
text = _seq2text(tokenizer, cached_update.content.full)
|
||||
prev_item_idx = cached_update.item_idx
|
||||
|
||||
if version == (2, 0) or version == (2, 5):
|
||||
im_start = image_processor.im_start_token
|
||||
im_end = image_processor.im_end_token
|
||||
else:
|
||||
im_start = image_processor.im_id_start
|
||||
im_end = image_processor.im_id_end
|
||||
|
||||
new_update = new_update.with_content(
|
||||
PromptUpdateDetails.select_text(
|
||||
text.replace(
|
||||
f"{im_start}{prev_item_idx}{im_end}",
|
||||
f"{im_start}{new_item_idx}{im_end}",
|
||||
1,
|
||||
),
|
||||
"<unk>",
|
||||
))
|
||||
|
||||
return new_update
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
|
@ -22,14 +22,14 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, ProcessingCache,
|
||||
PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
@ -322,7 +322,7 @@ def _build_mistral3_processor(
|
||||
info: _I,
|
||||
dummy_inputs: BaseDummyInputsBuilder[_I],
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
) -> BaseMultiModalProcessor:
|
||||
assert isinstance(info, Mistral3ProcessingInfo)
|
||||
return Mistral3MultiModalProcessor(
|
||||
|
@ -41,7 +41,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
MultiModalPromptUpdates,
|
||||
PlaceholderFeaturesInfo,
|
||||
PromptReplacement, PromptUpdate)
|
||||
PromptReplacement, PromptUpdate,
|
||||
ResolvedPromptUpdate)
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -440,6 +441,23 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
|
||||
)
|
||||
]
|
||||
|
||||
def _recompute_cached_prompt_update(
|
||||
self,
|
||||
cached_update: ResolvedPromptUpdate,
|
||||
new_item_idx: int,
|
||||
) -> ResolvedPromptUpdate:
|
||||
new_update = super()._recompute_cached_prompt_update(
|
||||
cached_update,
|
||||
new_item_idx,
|
||||
)
|
||||
|
||||
if cached_update.modality == "image":
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
|
||||
new_update = new_update.with_target(image_tokens[new_item_idx])
|
||||
|
||||
return new_update
|
||||
|
||||
def _apply_prompt_updates(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
|
@ -27,7 +27,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems,
|
||||
MultiModalDataItems, MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
PromptUpdate, ResolvedPromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
@ -850,6 +850,25 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
|
||||
),
|
||||
]
|
||||
|
||||
def _recompute_cached_prompt_update(
|
||||
self,
|
||||
cached_update: ResolvedPromptUpdate,
|
||||
new_item_idx: int,
|
||||
) -> ResolvedPromptUpdate:
|
||||
new_update = super()._recompute_cached_prompt_update(
|
||||
cached_update,
|
||||
new_item_idx,
|
||||
)
|
||||
|
||||
if cached_update.modality == "image":
|
||||
image_tokens: list[str] = self.info.image_tokens # type: ignore
|
||||
new_update = new_update.with_target(image_tokens[new_item_idx])
|
||||
elif cached_update.modality == "audio":
|
||||
audio_tokens: list[str] = self.info.audio_tokens # type: ignore
|
||||
new_update = new_update.with_target(audio_tokens[new_item_idx])
|
||||
|
||||
return new_update
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Phi4MMMultiModalProcessor,
|
||||
|
@ -25,12 +25,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.llava import LlavaDummyInputsBuilder
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import BaseMultiModalProcessorCache
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, ProcessingCache,
|
||||
PromptReplacement, PromptUpdate)
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.jsontree import json_map_leaves
|
||||
@ -332,7 +333,7 @@ def _build_tarsier_hf_processor(
|
||||
info: _I_Tarsier,
|
||||
dummy_inputs: BaseDummyInputsBuilder[_I_Tarsier],
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
) -> BaseMultiModalProcessor:
|
||||
if isinstance(info, TarsierProcessingInfo):
|
||||
return TarsierMultiModalProcessor(
|
||||
|
@ -1,11 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import sys
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeVar, Union
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypeAlias, override
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import GiB_bytes, LRUCache
|
||||
@ -15,24 +16,67 @@ from .inputs import (MultiModalFieldElem, MultiModalKwargs,
|
||||
MultiModalKwargsItem, MultiModalKwargsItems,
|
||||
NestedTensors)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
|
||||
from .processing import ResolvedPromptUpdate
|
||||
from .registry import MultiModalRegistry
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiModalCacheItemMetadata:
|
||||
size: int
|
||||
class MultiModalProcessorCacheItem:
|
||||
"""
|
||||
The data to store inside `MultiModalProcessorOnlyCache`.
|
||||
|
||||
@classmethod
|
||||
def wraps(cls, value: "MultiModalCacheValue"):
|
||||
return cls(size=MultiModalCache.get_item_size(value))
|
||||
Args:
|
||||
item: The processed tensor data corresponding to a multi-modal item.
|
||||
prompt_updates: The prompt updates corresponding to `item`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
item: MultiModalKwargsItem,
|
||||
prompt_updates: Sequence["ResolvedPromptUpdate"],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.item = item
|
||||
self.prompt_updates = prompt_updates
|
||||
|
||||
|
||||
class MultiModalProcessorCacheItemMetadata:
|
||||
"""
|
||||
The metadata to store inside `MultiModalProcessorSenderCache`.
|
||||
|
||||
Args:
|
||||
item: The processed tensor data corresponding to a multi-modal item.
|
||||
Since P1 already stores the tensor data, we only store its size
|
||||
metadata in P0 to reduce memory usage. The size metadata is still
|
||||
needed to keep the same cache eviction policy as P0.
|
||||
prompt_updates: The prompt updates corresponding to `item`.
|
||||
This needs to stay on P0 because for some models, they are
|
||||
dependent on the processed tensor data (cached on P1).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
item: MultiModalKwargsItem,
|
||||
prompt_updates: Sequence["ResolvedPromptUpdate"],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.item_size = MultiModalCache.get_item_size(item)
|
||||
self.prompt_updates = prompt_updates
|
||||
|
||||
|
||||
MultiModalCacheValue = Union[
|
||||
MultiModalProcessorCacheItem,
|
||||
MultiModalProcessorCacheItemMetadata,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalKwargs,
|
||||
Mapping[str, NestedTensors],
|
||||
MultiModalCacheItemMetadata,
|
||||
]
|
||||
|
||||
_V = TypeVar("_V", bound=MultiModalCacheValue)
|
||||
@ -47,8 +91,10 @@ class MultiModalCache:
|
||||
*,
|
||||
debug: bool = False,
|
||||
) -> int:
|
||||
if isinstance(leaf, MultiModalFieldElem):
|
||||
return cls.get_item_size(leaf.data) # type: ignore
|
||||
if isinstance(leaf, MultiModalProcessorCacheItem):
|
||||
return cls.get_leaf_size(leaf.item)
|
||||
if isinstance(leaf, MultiModalProcessorCacheItemMetadata):
|
||||
return leaf.item_size
|
||||
|
||||
# These are not subclasses of dict
|
||||
if isinstance(leaf, MultiModalKwargsItems):
|
||||
@ -58,13 +104,13 @@ class MultiModalCache:
|
||||
if isinstance(leaf, MultiModalKwargs):
|
||||
return cls.get_item_size(leaf.data) # type: ignore
|
||||
|
||||
if isinstance(leaf, MultiModalFieldElem):
|
||||
return cls.get_item_size(leaf.data) # type: ignore
|
||||
|
||||
# sys.getsizeof doesn't work for tensors
|
||||
if isinstance(leaf, torch.Tensor):
|
||||
return leaf.nbytes
|
||||
|
||||
if isinstance(leaf, MultiModalCacheItemMetadata):
|
||||
return leaf.size
|
||||
|
||||
return sys.getsizeof(leaf)
|
||||
|
||||
@classmethod
|
||||
@ -98,3 +144,332 @@ class MultiModalCache:
|
||||
GiB_bytes * capacity_gb,
|
||||
getsizeof=lambda x: cls.get_item_size(x, debug=debug),
|
||||
)
|
||||
|
||||
|
||||
_I = TypeVar("_I", contravariant=True)
|
||||
_O = TypeVar("_O", covariant=True)
|
||||
|
||||
|
||||
class BaseMultiModalCache(ABC, Generic[_I, _O]):
|
||||
"""
|
||||
Abstract base class to read/write multi-modal items from cache.
|
||||
|
||||
The idea of multi-modal caching is based on having a client and server
|
||||
where the client executes in the frontend process (=P0) and
|
||||
the server in the core process (=P1). The data flow is as follows:
|
||||
|
||||
```
|
||||
is_cached() x N get_and_update()
|
||||
P0: From API -----------------> -----------------> To P1
|
||||
|
||||
get_and_update()
|
||||
P1: From P0 -----------------> To model
|
||||
```
|
||||
|
||||
`is_cached()` can be called any number of times in P0. However,
|
||||
`get_and_update()` must be called in P0 and P1 one after another
|
||||
so that their cache eviction order remains the same.
|
||||
|
||||
This ensures that the keys in P0 and P1 caches are mirrored,
|
||||
allowing us to determine whether a key is cached in P1 by looking
|
||||
up the P0 cache, without having to communicate with P1.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: _I,
|
||||
mm_hash: str,
|
||||
) -> _O:
|
||||
"""
|
||||
Possibly update a multi-modal item based on whether it is
|
||||
in the underlying cache.
|
||||
|
||||
This update is done out-of-place and updates the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_item: The multi-modal item to update.
|
||||
mm_hash: The hash of `mm_item`.
|
||||
|
||||
Returns:
|
||||
The update multi-modal item.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_and_update(
|
||||
self,
|
||||
mm_items: Sequence[_I],
|
||||
mm_hashes: list[str],
|
||||
) -> list[_O]:
|
||||
"""
|
||||
Possibly update a sequence of multi-modal items based on whether they
|
||||
are in the underlying cache.
|
||||
|
||||
This update is done out-of-place and updates the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_items: The multi-modal items to update.
|
||||
mm_hashes: The hash of each item in `mm_items`.
|
||||
|
||||
Returns:
|
||||
A new list of updated multi-modal items.
|
||||
"""
|
||||
assert len(mm_items) == len(mm_hashes)
|
||||
|
||||
return [
|
||||
self.get_and_update_item(mm_item, mm_hash)
|
||||
for mm_item, mm_hash in zip(mm_items, mm_hashes)
|
||||
]
|
||||
|
||||
@abstractmethod
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the underlying cache."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
MultiModalProcessorCacheInItem: TypeAlias = \
|
||||
Optional[tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]]]
|
||||
|
||||
|
||||
MultiModalProcessorCacheOutItem: TypeAlias = \
|
||||
tuple[Optional[MultiModalKwargsItem], Sequence["ResolvedPromptUpdate"]]
|
||||
|
||||
|
||||
class BaseMultiModalProcessorCache(
|
||||
BaseMultiModalCache[MultiModalProcessorCacheInItem,
|
||||
MultiModalProcessorCacheOutItem]):
|
||||
"""The required interface for caches on P0."""
|
||||
|
||||
@abstractmethod
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
"""
|
||||
Check whether a multi-modal item is
|
||||
in the underlying cache.
|
||||
|
||||
This **DOES NOT** update the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_hash: The hash of the item to check.
|
||||
|
||||
Returns:
|
||||
`True` if the item is cached, otherwise `False`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_cached(self, mm_hashes: list[str]) -> list[bool]:
|
||||
"""
|
||||
Check whether a sequence of multi-modal items are
|
||||
in the underlying cache.
|
||||
|
||||
This **DOES NOT** update the cache eviction order.
|
||||
|
||||
Args:
|
||||
mm_hashes: The hash of each item to check.
|
||||
|
||||
Returns:
|
||||
For each item, `True` if the item is cached, otherwise `False`.
|
||||
"""
|
||||
return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]
|
||||
|
||||
|
||||
class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
|
||||
"""
|
||||
The cache which is used on P0 when IPC caching is disabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is in the cache, replace the input with the cached item.
|
||||
- If the item is not in the cache, store that item (which includes
|
||||
tensor data and metadata) into the cache, and return the input.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: "ModelConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
self._cache = MultiModalCache.get_lru_cache(
|
||||
mm_config.mm_processor_cache_gb,
|
||||
MultiModalProcessorCacheItem,
|
||||
)
|
||||
|
||||
@override
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
return mm_hash in self._cache
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalProcessorCacheInItem,
|
||||
mm_hash: str,
|
||||
) -> MultiModalProcessorCacheOutItem:
|
||||
if (cached_item := self._cache.get(mm_hash)) is not None:
|
||||
return cached_item.item, cached_item.prompt_updates
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._cache[mm_hash] = MultiModalProcessorCacheItem(*mm_item)
|
||||
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
|
||||
|
||||
class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
|
||||
"""
|
||||
The cache which is used on P0 when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is already in the cache, clear the input to avoid
|
||||
unnecessary IPC.
|
||||
|
||||
- If the item is not in the cache, store the metadata of that item so
|
||||
that the eviction policy remains the same as the cache on P1,
|
||||
and return the input.
|
||||
By only storing the metadata, we avoid keeping the data itself in
|
||||
memory inside P0.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: "ModelConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
self._cache = MultiModalCache.get_lru_cache(
|
||||
mm_config.mm_processor_cache_gb,
|
||||
MultiModalProcessorCacheItemMetadata,
|
||||
)
|
||||
|
||||
@override
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
return mm_hash in self._cache
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalProcessorCacheInItem,
|
||||
mm_hash: str,
|
||||
) -> MultiModalProcessorCacheOutItem:
|
||||
if (cached_item := self._cache.get(mm_hash)) is not None:
|
||||
return None, cached_item.prompt_updates
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._cache[mm_hash] = MultiModalProcessorCacheItemMetadata(*mm_item)
|
||||
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
|
||||
|
||||
def _enable_processor_cache(
|
||||
model_config: "ModelConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
) -> bool:
|
||||
if not mm_registry.supports_multimodal_inputs(model_config):
|
||||
return False
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
return mm_config.mm_processor_cache_gb > 0
|
||||
|
||||
|
||||
def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
supports_ipc_cache = (parallel_config.data_parallel_size == 1
|
||||
or parallel_config.data_parallel_external_lb)
|
||||
|
||||
return supports_ipc_cache
|
||||
|
||||
|
||||
def processor_cache_from_config(
|
||||
vllm_config: "VllmConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
) -> Optional[BaseMultiModalProcessorCache]:
|
||||
"""Return a `BaseMultiModalProcessorCache`, if enabled."""
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return MultiModalProcessorOnlyCache(model_config)
|
||||
|
||||
return MultiModalProcessorSenderCache(model_config)
|
||||
|
||||
|
||||
def processor_only_cache_from_config(
|
||||
model_config: "ModelConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
):
|
||||
"""Return a `MultiModalProcessorOnlyCache`, if enabled."""
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
return MultiModalProcessorOnlyCache(model_config)
|
||||
|
||||
|
||||
class BaseMultiModalReceiverCache(
|
||||
BaseMultiModalCache[Optional[MultiModalKwargsItem],
|
||||
MultiModalKwargsItem]):
|
||||
"""The required interface for caches on P1."""
|
||||
|
||||
|
||||
class MultiModalReceiverCache(BaseMultiModalReceiverCache):
|
||||
"""
|
||||
The cache which is used on P1 when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is in the cache, replace the input with the cached item.
|
||||
- If the item is not in the cache, store that item (which includes tensor
|
||||
data) into the cache, and return the input.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: "ModelConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
self._cache = MultiModalCache.get_lru_cache(
|
||||
mm_config.mm_processor_cache_gb,
|
||||
MultiModalKwargsItem,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: Optional[MultiModalKwargsItem],
|
||||
mm_hash: str,
|
||||
) -> MultiModalKwargsItem:
|
||||
if (cached_item := self._cache.get(mm_hash)) is not None:
|
||||
return cached_item
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
self._cache[mm_hash] = mm_item
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
|
||||
|
||||
def receiver_cache_from_config(
|
||||
vllm_config: "VllmConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
) -> Optional[BaseMultiModalReceiverCache]:
|
||||
"""Return a `BaseMultiModalReceiverCache`, if enabled."""
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return None
|
||||
|
||||
return MultiModalReceiverCache(model_config)
|
||||
|
@ -7,11 +7,11 @@ from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from itertools import accumulate
|
||||
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
|
||||
Union, cast, final)
|
||||
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union,
|
||||
cast, final)
|
||||
|
||||
import numpy as np
|
||||
from typing_extensions import NotRequired, TypeAlias, deprecated
|
||||
from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated
|
||||
|
||||
from vllm.utils import LazyLoader, full_groupby, is_list_of
|
||||
from vllm.utils.jsontree import JSONTree, json_map_leaves
|
||||
@ -668,7 +668,15 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
||||
return {key: elem.data for key, elem in self.items()}
|
||||
|
||||
|
||||
class MultiModalKwargsItems(UserDict[str, Sequence[MultiModalKwargsItem]]):
|
||||
_I = TypeVar(
|
||||
"_I",
|
||||
MultiModalKwargsItem,
|
||||
Optional[MultiModalKwargsItem],
|
||||
default=MultiModalKwargsItem,
|
||||
)
|
||||
|
||||
|
||||
class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
|
||||
"""
|
||||
A dictionary of
|
||||
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
|
||||
@ -714,27 +722,37 @@ class MultiModalKwargsItems(UserDict[str, Sequence[MultiModalKwargsItem]]):
|
||||
items_by_modality = full_groupby(items, key=lambda x: x.modality)
|
||||
return MultiModalKwargsItems(items_by_modality)
|
||||
|
||||
def __getitem__(self, modality: str):
|
||||
def __getitem__(self, modality: str) -> Sequence[_I]:
|
||||
if modality not in self:
|
||||
raise KeyError(f"Modality {modality!r} not found. "
|
||||
f"Available modalities: {set(self.keys())}")
|
||||
|
||||
return super().__getitem__(modality)
|
||||
return super().__getitem__(modality) # type: ignore[return-value]
|
||||
|
||||
def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
|
||||
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
|
||||
for items in self.values():
|
||||
for item in items:
|
||||
for modality, items in self.items():
|
||||
for i, item in enumerate(items):
|
||||
if item is None:
|
||||
raise RuntimeError("Cannot build data from empty "
|
||||
f"mm_items[{modality}][{i}]")
|
||||
|
||||
for key, elem in item.items():
|
||||
elems_by_key[key].append(elem)
|
||||
|
||||
return MultiModalKwargs({
|
||||
key:
|
||||
elems[0].field.reduce_data(elems, pin_memory=pin_memory)
|
||||
for key, elems in elems_by_key.items() if len(elems) > 0
|
||||
for key, elems in elems_by_key.items()
|
||||
})
|
||||
|
||||
|
||||
MultiModalKwargsOptionalItems: TypeAlias = Union[
|
||||
MultiModalKwargsItems[MultiModalKwargsItem],
|
||||
MultiModalKwargsItems[Optional[MultiModalKwargsItem]],
|
||||
]
|
||||
|
||||
|
||||
class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
"""
|
||||
A dictionary that represents the keyword arguments to
|
||||
@ -898,7 +916,7 @@ class MultiModalInputs(TypedDict):
|
||||
token_type_ids: NotRequired[list[int]]
|
||||
"""The token type IDs of the prompt."""
|
||||
|
||||
mm_kwargs: MultiModalKwargsItems
|
||||
mm_kwargs: MultiModalKwargsOptionalItems
|
||||
"""Keyword arguments to be directly passed to the model after batching."""
|
||||
|
||||
mm_hashes: "MultiModalHashDict"
|
||||
|
@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
|
||||
Sequence)
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass, field, replace
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
|
||||
@ -20,12 +20,11 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
|
||||
encode_tokens)
|
||||
from vllm.utils import flatten_2d_lists, full_groupby
|
||||
|
||||
from .cache import MultiModalCache
|
||||
from .hasher import MultiModalHasher
|
||||
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
MultiModalFieldConfig, MultiModalInputs,
|
||||
MultiModalKwargsItem, MultiModalKwargsItems,
|
||||
PlaceholderRange)
|
||||
MultiModalKwargsOptionalItems, PlaceholderRange)
|
||||
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
|
||||
@ -34,6 +33,7 @@ if TYPE_CHECKING:
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
from .cache import BaseMultiModalProcessorCache
|
||||
from .profiling import BaseDummyInputsBuilder
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -557,6 +557,15 @@ class ResolvedPromptUpdate:
|
||||
|
||||
return self.iter_token_matches(prompt, tokenizer, start_idx=start_idx)
|
||||
|
||||
def with_target(self, target: UpdateTarget):
|
||||
return replace(self, target=target)
|
||||
|
||||
def with_content(self, content: PromptUpdateInfo):
|
||||
if not isinstance(content, PromptUpdateDetails):
|
||||
content = PromptUpdateDetails.from_seq(content)
|
||||
|
||||
return replace(self, content=content)
|
||||
|
||||
|
||||
class _TokenMatch(NamedTuple):
|
||||
start_idx: int
|
||||
@ -865,21 +874,6 @@ def find_mm_placeholders(
|
||||
return dict(full_groupby_modality(it))
|
||||
|
||||
|
||||
class ProcessingCache(MultiModalCache):
|
||||
|
||||
def __init__(self, capacity_gb: float) -> None:
|
||||
super().__init__()
|
||||
|
||||
self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem)
|
||||
|
||||
self.get = self._cache.get
|
||||
self.put = self._cache.put
|
||||
self.reset = self._cache.clear
|
||||
|
||||
|
||||
_CacheItemOrHash = Union[MultiModalKwargsItem, str]
|
||||
|
||||
|
||||
class BaseProcessingInfo:
|
||||
"""Base class to provide the information necessary for data processing."""
|
||||
|
||||
@ -982,7 +976,7 @@ For an item `MultiModalPromptUpdates[k][i]`,
|
||||
|
||||
|
||||
class MultiModalProcessingInfo(NamedTuple):
|
||||
kwargs: MultiModalKwargsItems
|
||||
kwargs: MultiModalKwargsOptionalItems
|
||||
hashes: MultiModalHashes
|
||||
prompt_updates: MultiModalPromptUpdates
|
||||
|
||||
@ -994,11 +988,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
Not to be confused with `transformers.ProcessorMixin`.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(
|
||||
self,
|
||||
info: _I,
|
||||
dummy_inputs: "BaseDummyInputsBuilder[_I]",
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None) -> None:
|
||||
cache: Optional["BaseMultiModalProcessorCache"] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.info = info
|
||||
@ -1355,32 +1351,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
|
||||
return prompt_ids, mm_processed_data, False
|
||||
|
||||
def _get_cache_missing_items(
|
||||
self,
|
||||
cache: ProcessingCache,
|
||||
mm_data_items: MultiModalDataItems,
|
||||
mm_hashes: MultiModalHashes,
|
||||
) -> tuple[dict[str, list[_CacheItemOrHash]], MultiModalDataItems]:
|
||||
mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]] = {
|
||||
modality: [(h if (v := cache.get(h)) is None else v)
|
||||
for h in hashes]
|
||||
for modality, hashes in mm_hashes.items()
|
||||
}
|
||||
|
||||
mm_missing_idxs = {
|
||||
modality: [
|
||||
idx for idx, item_or_hash in enumerate(items_or_hashes)
|
||||
if isinstance(item_or_hash, str)
|
||||
]
|
||||
for modality, items_or_hashes in mm_cache_items_or_hashes.items()
|
||||
}
|
||||
mm_missing_data = {
|
||||
modality: [mm_data_items[modality][idx] for idx in idxs]
|
||||
for modality, idxs in mm_missing_idxs.items()
|
||||
}
|
||||
|
||||
return mm_cache_items_or_hashes, self._to_mm_items(mm_missing_data)
|
||||
|
||||
def _hash_mm_items(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
@ -1401,28 +1371,92 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
for modality, items in mm_items.items()
|
||||
}
|
||||
|
||||
def _get_cache_missing_items(
|
||||
self,
|
||||
cache: "BaseMultiModalProcessorCache",
|
||||
mm_data_items: MultiModalDataItems,
|
||||
mm_hashes: MultiModalHashes,
|
||||
) -> MultiModalDataItems:
|
||||
mm_is_cached = {
|
||||
modality: cache.is_cached(hashes)
|
||||
for modality, hashes in mm_hashes.items()
|
||||
}
|
||||
|
||||
mm_missing_idxs = {
|
||||
modality: [
|
||||
idx for idx, item_is_cached in enumerate(items_is_cached)
|
||||
if not item_is_cached
|
||||
]
|
||||
for modality, items_is_cached in mm_is_cached.items()
|
||||
}
|
||||
mm_missing_data = {
|
||||
modality: [mm_data_items[modality][idx] for idx in idxs]
|
||||
for modality, idxs in mm_missing_idxs.items()
|
||||
}
|
||||
|
||||
return self._to_mm_items(mm_missing_data)
|
||||
|
||||
def _recompute_cached_prompt_update(
|
||||
self,
|
||||
cached_update: ResolvedPromptUpdate,
|
||||
new_item_idx: int,
|
||||
) -> ResolvedPromptUpdate:
|
||||
"""
|
||||
Override this if other attributes of `ResolvedPromptUpdate`
|
||||
also need to be recomputed after retrieving from the cache.
|
||||
"""
|
||||
return replace(cached_update, item_idx=new_item_idx)
|
||||
|
||||
def _merge_mm_kwargs(
|
||||
self,
|
||||
cache: ProcessingCache,
|
||||
mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]],
|
||||
cache: "BaseMultiModalProcessorCache",
|
||||
mm_hashes: MultiModalHashes,
|
||||
mm_missing_kwargs: MultiModalKwargsItems,
|
||||
) -> MultiModalKwargsItems:
|
||||
mm_missing_prompt_updates: MultiModalPromptUpdates,
|
||||
) -> tuple[MultiModalKwargsOptionalItems, MultiModalPromptUpdates]:
|
||||
# Need to calculate this at the beginning to avoid skipping cache logic
|
||||
# for subsequently repeated items in the same modality
|
||||
mm_is_cached = {
|
||||
modality: cache.is_cached(hashes)
|
||||
for modality, hashes in mm_hashes.items()
|
||||
}
|
||||
|
||||
mm_missing_next_idx = defaultdict[str, int](lambda: 0)
|
||||
|
||||
merged_items = defaultdict[str, list[MultiModalKwargsItem]](list)
|
||||
for modality, items_or_hashes in mm_cache_items_or_hashes.items():
|
||||
for item_or_hash in items_or_hashes:
|
||||
if isinstance(item_or_hash, str):
|
||||
kw_item = mm_missing_kwargs[modality][
|
||||
mm_missing_next_idx[modality]]
|
||||
cache.put(item_or_hash, kw_item)
|
||||
merged_kwargs = defaultdict[str,
|
||||
list[Optional[MultiModalKwargsItem]]](list)
|
||||
merged_prompt_updates = defaultdict[
|
||||
str, list[Sequence[ResolvedPromptUpdate]]](list)
|
||||
for modality, hashes in mm_hashes.items():
|
||||
missing_kwargs = mm_missing_kwargs.get(modality, [])
|
||||
missing_prompt_updates = mm_missing_prompt_updates.get(
|
||||
modality, [])
|
||||
|
||||
for item_idx, item_hash in enumerate(hashes):
|
||||
kwargs: Optional[MultiModalKwargsItem]
|
||||
if not mm_is_cached[modality][item_idx]:
|
||||
missing_next_idx = mm_missing_next_idx[modality]
|
||||
kwargs = missing_kwargs[missing_next_idx]
|
||||
updates = missing_prompt_updates[missing_next_idx]
|
||||
|
||||
mm_missing_next_idx[modality] += 1
|
||||
|
||||
item = kwargs, updates
|
||||
else:
|
||||
kw_item = item_or_hash
|
||||
item = None
|
||||
|
||||
merged_items[modality].append(kw_item)
|
||||
kwargs, updates = cache.get_and_update_item(item, item_hash)
|
||||
|
||||
return MultiModalKwargsItems(merged_items)
|
||||
merged_kwargs[modality].append(kwargs)
|
||||
merged_prompt_updates[modality].append([
|
||||
self._recompute_cached_prompt_update(update, item_idx)
|
||||
for update in updates
|
||||
])
|
||||
|
||||
mm_kwargs = MultiModalKwargsItems(merged_kwargs)
|
||||
mm_prompt_updates = dict(merged_prompt_updates)
|
||||
|
||||
return mm_kwargs, mm_prompt_updates
|
||||
|
||||
def _apply_hf_processor(
|
||||
self,
|
||||
@ -1490,10 +1524,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
|
||||
mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
|
||||
tokenization_kwargs)
|
||||
(
|
||||
mm_cache_items_or_hashes,
|
||||
mm_missing_data_items,
|
||||
) = self._get_cache_missing_items(
|
||||
|
||||
mm_missing_data_items = self._get_cache_missing_items(
|
||||
cache=cache,
|
||||
mm_data_items=mm_data_items,
|
||||
mm_hashes=mm_hashes,
|
||||
@ -1520,16 +1552,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
hf_processor_mm_kwargs),
|
||||
)
|
||||
|
||||
mm_kwargs = self._merge_mm_kwargs(
|
||||
cache,
|
||||
mm_cache_items_or_hashes=mm_cache_items_or_hashes,
|
||||
mm_missing_kwargs=mm_missing_kwargs,
|
||||
mm_missing_prompt_updates = self._get_mm_prompt_updates(
|
||||
mm_missing_data_items,
|
||||
hf_processor_mm_kwargs,
|
||||
mm_missing_kwargs,
|
||||
)
|
||||
|
||||
mm_prompt_updates = self._get_mm_prompt_updates(
|
||||
mm_data_items,
|
||||
hf_processor_mm_kwargs,
|
||||
mm_kwargs,
|
||||
mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
|
||||
cache,
|
||||
mm_hashes=mm_hashes,
|
||||
mm_missing_kwargs=mm_missing_kwargs,
|
||||
mm_missing_prompt_updates=mm_missing_prompt_updates,
|
||||
)
|
||||
|
||||
mm_info = MultiModalProcessingInfo(
|
||||
@ -1614,7 +1647,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
|
||||
def _validate_mm_kwargs(
|
||||
self,
|
||||
mm_kwargs: MultiModalKwargsItems,
|
||||
mm_kwargs: MultiModalKwargsOptionalItems,
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> None:
|
||||
for modality, item_count in mm_item_counts.items():
|
||||
@ -1655,7 +1688,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
prompt_ids: list[int],
|
||||
mm_kwargs: MultiModalKwargsItems,
|
||||
mm_kwargs: MultiModalKwargsOptionalItems,
|
||||
mm_prompt_updates: MultiModalPromptUpdates,
|
||||
is_update_applied: bool,
|
||||
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
|
||||
|
@ -13,7 +13,7 @@ import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
MultiModalInputs, MultiModalKwargsItems,
|
||||
MultiModalInputs, MultiModalKwargsOptionalItems,
|
||||
MultiModalPlaceholderDict)
|
||||
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
|
||||
EncDecMultiModalProcessor)
|
||||
@ -43,7 +43,7 @@ class DummyDecoderData(NamedTuple):
|
||||
"""Dummy data used for profiling."""
|
||||
|
||||
prompt_token_ids: list[int]
|
||||
multi_modal_data: MultiModalKwargsItems
|
||||
multi_modal_data: MultiModalKwargsOptionalItems
|
||||
multi_modal_placeholders: MultiModalPlaceholderDict
|
||||
|
||||
|
||||
|
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
|
||||
|
||||
import torch.nn as nn
|
||||
@ -13,8 +12,9 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||
cached_tokenizer_from_config)
|
||||
from vllm.utils import ClassRegistry
|
||||
|
||||
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
|
||||
ProcessingCache)
|
||||
from .cache import (BaseMultiModalProcessorCache,
|
||||
processor_only_cache_from_config)
|
||||
from .processing import BaseMultiModalProcessor, BaseProcessingInfo
|
||||
from .profiling import (BaseDummyInputsBuilder, DummyDecoderData,
|
||||
DummyEncoderData, MultiModalProfiler)
|
||||
|
||||
@ -65,7 +65,7 @@ class MultiModalProcessorFactory(Protocol[_I]):
|
||||
info: _I,
|
||||
dummy_inputs: BaseDummyInputsBuilder[_I],
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
) -> BaseMultiModalProcessor[_I]:
|
||||
...
|
||||
|
||||
@ -80,20 +80,13 @@ class _ProcessorFactories(Generic[_I]):
|
||||
self,
|
||||
ctx: InputProcessingContext,
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
):
|
||||
info = self.info(ctx)
|
||||
dummy_inputs_builder = self.dummy_inputs(info)
|
||||
return self.processor(info, dummy_inputs_builder, cache=cache)
|
||||
|
||||
|
||||
# Make sure a different cache is used for each model config
|
||||
# NOTE: ModelConfig is not hashable so it cannot be passed directly
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_processor_cache(model_id: str, capacity_gb: int):
|
||||
return ProcessingCache(capacity_gb) if capacity_gb > 0 else None
|
||||
|
||||
|
||||
class MultiModalRegistry:
|
||||
"""
|
||||
A registry that dispatches data processing according to the model.
|
||||
@ -103,31 +96,6 @@ class MultiModalRegistry:
|
||||
self._processor_factories = ClassRegistry[nn.Module,
|
||||
_ProcessorFactories]()
|
||||
|
||||
def _get_processor_cache(self, model_config: "ModelConfig"):
|
||||
model_id = model_config.model
|
||||
capacity_gb = model_config.mm_processor_cache_gb
|
||||
return _get_processor_cache(model_id, capacity_gb)
|
||||
|
||||
def reset_processor_cache(self, model_config: "ModelConfig") -> bool:
|
||||
"""Reset the multi-modal processing cache."""
|
||||
if processor_cache := self._get_processor_cache(model_config):
|
||||
processor_cache.reset()
|
||||
|
||||
return True # Success
|
||||
|
||||
def enable_mm_input_cache(self, model_config: "ModelConfig") -> bool:
|
||||
"""Whether the multi-modal input cache should be enabled.
|
||||
NOTE: This is put under MultiModalRegistry on purpose to respect
|
||||
text-only mode for multimodal models.
|
||||
"""
|
||||
|
||||
if not self.supports_multimodal_inputs(model_config):
|
||||
return False
|
||||
|
||||
mm_config = model_config.get_multimodal_config()
|
||||
|
||||
return mm_config.mm_processor_cache_gb > 0
|
||||
|
||||
def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
|
||||
"""
|
||||
Checks if the model supports multimodal inputs.
|
||||
@ -157,6 +125,8 @@ class MultiModalRegistry:
|
||||
def get_max_tokens_per_item_by_modality(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum number of tokens per data item from each modality based
|
||||
@ -165,11 +135,11 @@ class MultiModalRegistry:
|
||||
if not model_config.is_multimodal_model:
|
||||
return {}
|
||||
|
||||
processor = self.create_processor(model_config, disable_cache=False)
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
|
||||
seq_len = model_config.max_model_len
|
||||
mm_limits = self.get_mm_limits_per_prompt(model_config)
|
||||
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
|
||||
|
||||
return profiler.get_mm_max_contiguous_tokens(
|
||||
seq_len,
|
||||
@ -182,6 +152,8 @@ class MultiModalRegistry:
|
||||
def get_max_tokens_per_item_by_nonzero_modality(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum number of tokens per data item from each modality based
|
||||
@ -192,15 +164,19 @@ class MultiModalRegistry:
|
||||
This is currently directly used only in V1 for profiling the memory
|
||||
usage of a model.
|
||||
"""
|
||||
mm_limits = self.get_mm_limits_per_prompt(model_config)
|
||||
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
|
||||
max_tokens_per_item = self.get_max_tokens_per_item_by_modality(
|
||||
model_config,
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
return {
|
||||
key: max_tokens_per_mm_item
|
||||
for key, max_tokens_per_mm_item in
|
||||
self.get_max_tokens_per_item_by_modality(model_config).items()
|
||||
for key, max_tokens_per_mm_item in max_tokens_per_item.items()
|
||||
if mm_limits[key] > 0
|
||||
}
|
||||
|
||||
# TODO: Remove once V0 is gone
|
||||
def get_max_tokens_by_modality(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
@ -209,14 +185,19 @@ class MultiModalRegistry:
|
||||
Get the maximum number of tokens from each modality
|
||||
for profiling the memory usage of a model.
|
||||
"""
|
||||
mm_limits = self.get_mm_limits_per_prompt(model_config)
|
||||
cache = processor_only_cache_from_config(model_config, self)
|
||||
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
|
||||
max_tokens_per_item = self.get_max_tokens_per_item_by_modality(
|
||||
model_config,
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
return {
|
||||
key: mm_limits[key] * max_tokens_per_mm_item
|
||||
for key, max_tokens_per_mm_item in
|
||||
self.get_max_tokens_per_item_by_modality(model_config).items()
|
||||
for key, max_tokens_per_mm_item in max_tokens_per_item.items()
|
||||
}
|
||||
|
||||
# TODO: Remove once V0 is gone
|
||||
def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
|
||||
"""
|
||||
Get the maximum number of multi-modal tokens
|
||||
@ -227,6 +208,8 @@ class MultiModalRegistry:
|
||||
def get_mm_limits_per_prompt(
|
||||
self,
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum number of multi-modal input instances for each modality
|
||||
@ -235,7 +218,7 @@ class MultiModalRegistry:
|
||||
if not model_config.is_multimodal_model:
|
||||
return {}
|
||||
|
||||
processor = self.create_processor(model_config, disable_cache=False)
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
return profiler.get_mm_limits()
|
||||
|
||||
@ -303,7 +286,7 @@ class MultiModalRegistry:
|
||||
model_config: "ModelConfig",
|
||||
*,
|
||||
tokenizer: Optional[AnyTokenizer] = None,
|
||||
disable_cache: Optional[bool] = None,
|
||||
cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
|
||||
"""
|
||||
Create a multi-modal processor for a specific model and tokenizer.
|
||||
@ -311,15 +294,10 @@ class MultiModalRegistry:
|
||||
if not model_config.is_multimodal_model:
|
||||
raise ValueError(f"{model_config.model} is not a multimodal model")
|
||||
|
||||
if disable_cache is None:
|
||||
disable_cache = not model_config.enable_mm_processor_cache
|
||||
|
||||
model_cls = self._get_model_cls(model_config)
|
||||
factories = self._processor_factories[model_cls]
|
||||
|
||||
ctx = self._create_processing_ctx(model_config, tokenizer)
|
||||
cache = None if disable_cache else self._get_processor_cache(
|
||||
model_config)
|
||||
|
||||
return factories.build_processor(ctx, cache=cache)
|
||||
|
||||
@ -328,13 +306,15 @@ class MultiModalRegistry:
|
||||
model_config: "ModelConfig",
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
*,
|
||||
cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
) -> DummyDecoderData:
|
||||
"""
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
|
||||
The model is identified by ``model_config``.
|
||||
"""
|
||||
processor = self.create_processor(model_config, disable_cache=False)
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)
|
||||
|
||||
@ -352,13 +332,15 @@ class MultiModalRegistry:
|
||||
model_config: "ModelConfig",
|
||||
seq_len: int,
|
||||
mm_counts: Optional[Mapping[str, int]] = None,
|
||||
*,
|
||||
cache: Optional[BaseMultiModalProcessorCache] = None,
|
||||
) -> DummyEncoderData:
|
||||
"""
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
|
||||
The model is identified by ``model_config``.
|
||||
"""
|
||||
processor = self.create_processor(model_config, disable_cache=False)
|
||||
processor = self.create_processor(model_config, cache=cache)
|
||||
profiler = MultiModalProfiler(processor)
|
||||
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)
|
||||
|
||||
|
@ -597,8 +597,7 @@ class AsyncLLM(EngineClient):
|
||||
await asyncio.gather(*coros)
|
||||
|
||||
async def reset_mm_cache(self) -> None:
|
||||
self.processor.mm_registry.reset_processor_cache(self.model_config)
|
||||
self.processor.mm_input_cache_client.reset()
|
||||
self.processor.clear_cache()
|
||||
await self.engine_core.reset_mm_cache_async()
|
||||
|
||||
async def reset_prefix_cache(self,
|
||||
|
@ -22,6 +22,7 @@ from vllm.logger import init_logger
|
||||
from vllm.logging_utils.dump_input import dump_engine_exception
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import receiver_cache_from_config
|
||||
from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
@ -38,7 +39,6 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType,
|
||||
ReconfigureDistributedRequest, ReconfigureRankType,
|
||||
UtilityOutput, UtilityResult)
|
||||
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheServer
|
||||
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
@ -128,8 +128,9 @@ class EngineCore:
|
||||
)
|
||||
self.use_spec_decode = vllm_config.speculative_config is not None
|
||||
|
||||
self.mm_input_cache_server = MultiModalInputCacheServer(
|
||||
vllm_config.model_config, MULTIMODAL_REGISTRY)
|
||||
self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
|
||||
self.mm_receiver_cache = receiver_cache_from_config(
|
||||
vllm_config, mm_registry)
|
||||
|
||||
# Setup batch queue for pipeline parallelism.
|
||||
# Batch queue for scheduled batches. This enables us to asynchronously
|
||||
@ -370,7 +371,8 @@ class EngineCore:
|
||||
logger.warning("Resetting the multi-modal cache when requests are "
|
||||
"in progress may lead to desynced internal caches.")
|
||||
|
||||
self.mm_input_cache_server.reset()
|
||||
if self.mm_receiver_cache is not None:
|
||||
self.mm_receiver_cache.clear_cache()
|
||||
|
||||
def reset_prefix_cache(self):
|
||||
self.scheduler.reset_prefix_cache()
|
||||
@ -435,9 +437,10 @@ class EngineCore:
|
||||
assert request.mm_kwargs is not None
|
||||
|
||||
# Note on thread safety: no race condition.
|
||||
# `mm_input_cache_server` is reset at the end of LLMEngine init,
|
||||
# `mm_receiver_cache` is reset at the end of LLMEngine init,
|
||||
# and will only accessed in the input processing thread afterwards.
|
||||
request.mm_kwargs = self.mm_input_cache_server.get_and_update(
|
||||
if self.mm_receiver_cache is not None:
|
||||
request.mm_kwargs = self.mm_receiver_cache.get_and_update(
|
||||
request.mm_kwargs, request.mm_hashes)
|
||||
|
||||
req = Request.from_engine_core_request(request,
|
||||
|
@ -271,8 +271,7 @@ class LLMEngine:
|
||||
self.engine_core.profile(False)
|
||||
|
||||
def reset_mm_cache(self):
|
||||
self.processor.mm_registry.reset_processor_cache(self.model_config)
|
||||
self.processor.mm_input_cache_client.reset()
|
||||
self.processor.clear_cache()
|
||||
self.engine_core.reset_mm_cache()
|
||||
|
||||
def reset_prefix_cache(self, device: Optional[Device] = None):
|
||||
|
@ -1,121 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
# The idea of multimodal input caching is based on having a client and
|
||||
# a server, where the client executes in the frontend process (=P0) and the
|
||||
# server in the core process (=P1).
|
||||
#
|
||||
# -- P0:
|
||||
# - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of
|
||||
# each input multi-modal item (e.g. image),
|
||||
# - BaseMultiModalProcessor processes the input items into `mm_kwargs`,
|
||||
# which are MultiModalKwargsItem instances that each correspond to an
|
||||
# input multi-modal item.
|
||||
# - MultiModalInputCacheClient accepts the `mm_kwargs` and corresponding
|
||||
# `mm_hash` for each item. It stores the `mm_hash` as keys and the size
|
||||
# of `mm_kwargs`, but not the `mm_kwargs` themselves, to avoid taking
|
||||
# up additional memory in P0.
|
||||
# - The `mm_hash` is always sent to P1.
|
||||
# - The corresponding `mm_kwargs` are only sent to P1 if they are not cached
|
||||
# in MultiModalInputCacheServer.
|
||||
#
|
||||
# -- P1:
|
||||
# - If the `mm_hash` is cached (i.e. `mm_kwargs` are not sent from P0),
|
||||
# MultiModalInputCacheServer retrieves the corresponding `mm_kwargs`.
|
||||
# - If the `mm_hash` is not cached (i.e. `mm_kwargs` are sent from P0),
|
||||
# MultiModalInputCacheServer stores `mm_kwargs` under the key `mm_hash`.
|
||||
# - Either way, the `mm_hash` and corresponding `mm_kwargs` are sent to
|
||||
# the engine for model execution.
|
||||
#
|
||||
# Both Client and Server must perform cache update and eviction based on the
|
||||
# same item size. This ensures that the keys of MultiModalInputCacheClient
|
||||
# and MultiModalInputCacheServer are mirrored, allowing us to determine in P0
|
||||
# whether a key is cached in MultiModalInputCacheServer by querying
|
||||
# MultiModalInputCacheClient without having to communicate with P1.
|
||||
|
||||
|
||||
class MultiModalInputCacheClient:
|
||||
"""Used by P0 to check whether multi-modal kwargs are cached in P1."""
|
||||
|
||||
def __init__(self, model_config: "ModelConfig",
|
||||
mm_registry: MultiModalRegistry) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.enabled = mm_registry.enable_mm_input_cache(model_config)
|
||||
self.mm_cache = MultiModalCache.get_lru_cache(
|
||||
model_config.get_mm_input_cache_gb(),
|
||||
MultiModalCacheItemMetadata,
|
||||
)
|
||||
|
||||
def get_and_update(
|
||||
self,
|
||||
mm_kwargs: Sequence[MultiModalKwargsItem],
|
||||
mm_hashes: list[str],
|
||||
) -> list[Optional[MultiModalKwargsItem]]:
|
||||
if not self.enabled:
|
||||
return list(mm_kwargs)
|
||||
|
||||
assert len(mm_kwargs) == len(mm_hashes)
|
||||
|
||||
out_mm_items = list[Optional[MultiModalKwargsItem]]()
|
||||
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
|
||||
if self.mm_cache.get(mm_hash) is not None:
|
||||
out_mm_items.append(None)
|
||||
else:
|
||||
self.mm_cache[mm_hash] = \
|
||||
MultiModalCacheItemMetadata.wraps(mm_item)
|
||||
out_mm_items.append(mm_item)
|
||||
|
||||
return out_mm_items
|
||||
|
||||
def reset(self) -> None:
|
||||
self.mm_cache.clear()
|
||||
|
||||
|
||||
class MultiModalInputCacheServer:
|
||||
"""Used by P1 to avoid requiring past multi-modal kwargs from P0."""
|
||||
|
||||
def __init__(self, model_config: "ModelConfig",
|
||||
mm_registry: MultiModalRegistry) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.enabled = mm_registry.enable_mm_input_cache(model_config)
|
||||
self.mm_cache = MultiModalCache.get_lru_cache(
|
||||
model_config.get_mm_input_cache_gb(),
|
||||
MultiModalKwargsItem,
|
||||
)
|
||||
|
||||
def get_and_update(
|
||||
self,
|
||||
mm_kwargs: Sequence[Optional[MultiModalKwargsItem]],
|
||||
mm_hashes: list[str],
|
||||
) -> list[MultiModalKwargsItem]:
|
||||
if not self.enabled:
|
||||
mm_kwargs_lst = list(mm_kwargs)
|
||||
assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem)
|
||||
return mm_kwargs_lst
|
||||
|
||||
assert len(mm_kwargs) == len(mm_hashes)
|
||||
|
||||
out_mm_items = list[MultiModalKwargsItem]()
|
||||
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
|
||||
if mm_item is None:
|
||||
out_mm_items.append(self.mm_cache[mm_hash])
|
||||
else:
|
||||
self.mm_cache[mm_hash] = mm_item
|
||||
out_mm_items.append(mm_item)
|
||||
|
||||
return out_mm_items
|
||||
|
||||
def reset(self) -> None:
|
||||
self.mm_cache.clear()
|
@ -11,6 +11,7 @@ from vllm.inputs.parse import split_enc_dec_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.cache import processor_cache_from_config
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.multimodal.processing import EncDecMultiModalProcessor
|
||||
from vllm.multimodal.utils import argsort_mm_positions
|
||||
@ -18,7 +19,6 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
validate_guidance_grammar)
|
||||
from vllm.v1.structured_output.backend_lm_format_enforcer import (
|
||||
@ -47,16 +47,17 @@ class Processor:
|
||||
|
||||
self.generation_config_fields = (
|
||||
self.model_config.try_get_generation_config())
|
||||
self.input_preprocessor = InputPreprocessor(self.model_config,
|
||||
|
||||
self.mm_registry = mm_registry
|
||||
self.mm_processor_cache = processor_cache_from_config(
|
||||
vllm_config, mm_registry)
|
||||
|
||||
self.input_preprocessor = InputPreprocessor(
|
||||
self.model_config,
|
||||
self.tokenizer,
|
||||
mm_registry)
|
||||
|
||||
self.mm_input_cache_client = MultiModalInputCacheClient(
|
||||
self.model_config, mm_registry)
|
||||
|
||||
@property
|
||||
def mm_registry(self):
|
||||
return self.input_preprocessor.mm_registry
|
||||
mm_registry,
|
||||
mm_processor_cache=self.mm_processor_cache,
|
||||
)
|
||||
|
||||
def _validate_logprobs(
|
||||
self,
|
||||
@ -310,7 +311,7 @@ class Processor:
|
||||
# in the input sequence.
|
||||
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
|
||||
|
||||
orig_sorted_mm_inputs = [
|
||||
sorted_mm_inputs = [
|
||||
decoder_mm_inputs[modality][idx]
|
||||
for modality, idx in sorted_mm_idxs
|
||||
]
|
||||
@ -323,11 +324,6 @@ class Processor:
|
||||
for modality, idx in sorted_mm_idxs
|
||||
]
|
||||
|
||||
sorted_mm_inputs = self.mm_input_cache_client.get_and_update(
|
||||
orig_sorted_mm_inputs,
|
||||
sorted_mm_hashes,
|
||||
)
|
||||
|
||||
return decoder_inputs.get("prompt"), EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=decoder_inputs["prompt_token_ids"],
|
||||
@ -415,3 +411,6 @@ class Processor:
|
||||
# TODO: Find out how many placeholder tokens are there so we can
|
||||
# check that chunked prefill does not truncate them
|
||||
# max_batch_len = self.scheduler_config.max_num_batched_tokens
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
self.input_preprocessor.clear_cache()
|
||||
|
@ -2186,10 +2186,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
max_items_per_batch: int,
|
||||
) -> BatchedTensorInputs:
|
||||
"""Dummy data for profiling and precompiling multimodal models."""
|
||||
assert self.mm_budget is not None
|
||||
|
||||
dummy_decoder_data = self.mm_registry.get_decoder_dummy_data(
|
||||
model_config=self.model_config,
|
||||
seq_len=self.max_num_tokens,
|
||||
mm_counts={modality: 1},
|
||||
cache=self.mm_budget.cache,
|
||||
)
|
||||
dummy_mm_data = dummy_decoder_data.multi_modal_data
|
||||
|
||||
|
@ -1813,10 +1813,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
max_items_per_batch: int,
|
||||
) -> BatchedTensorInputs:
|
||||
"""Dummy data for profiling and precompiling multimodal models."""
|
||||
assert self.mm_budget is not None
|
||||
|
||||
dummy_decoder_data = self.mm_registry.get_decoder_dummy_data(
|
||||
model_config=self.model_config,
|
||||
seq_len=self.max_num_tokens,
|
||||
mm_counts={modality: 1},
|
||||
cache=self.mm_budget.cache,
|
||||
)
|
||||
dummy_mm_data = dummy_decoder_data.multi_modal_data
|
||||
|
||||
|
@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import ModelConfig, SchedulerConfig
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.multimodal.cache import processor_only_cache_from_config
|
||||
from vllm.multimodal.registry import MultiModalRegistry
|
||||
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
||||
@ -33,14 +34,18 @@ class MultiModalBudget:
|
||||
self.model_config = model_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.mm_registry = mm_registry
|
||||
self.cache = cache = processor_only_cache_from_config(
|
||||
model_config, mm_registry)
|
||||
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
|
||||
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config)
|
||||
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config,
|
||||
cache=cache)
|
||||
|
||||
max_tokens_by_modality = mm_registry \
|
||||
.get_max_tokens_per_item_by_nonzero_modality(model_config)
|
||||
.get_max_tokens_per_item_by_nonzero_modality(model_config,
|
||||
cache=cache)
|
||||
|
||||
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
|
||||
scheduler_config,
|
||||
|
Reference in New Issue
Block a user