[Optim] Compute multimodal hash only once per item (#17314)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-04-29 09:40:35 +08:00
committed by GitHub
parent cfe4532093
commit 506475de5f
6 changed files with 234 additions and 129 deletions

View File

@ -22,8 +22,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
BaseProcessingInfo, MultiModalHashes,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
@ -279,24 +279,26 @@ class DeepseekVL2MultiModalProcessor(
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]:
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
# The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
if mm_data_items.get_count("image", strict=False) > 2:
# This code path corresponds to the cache being disabled
return self._apply_hf_processor_main(
return self._apply_hf_processor(
prompt=prompt,
mm_items=mm_data_items,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_update=True,
return_mm_hashes=return_mm_hashes,
)
return super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
return_mm_hashes=return_mm_hashes,
)

View File

@ -19,8 +19,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems)
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.processing import (MultiModalHashes, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .intern_vit import InternVisionModel
@ -488,24 +488,26 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]:
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
# The processor logic is different for len(images) <= 1 vs > 1
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
if mm_data_items.get_count("image", strict=False) > 1:
# This code path corresponds to the cache being disabled
return self._apply_hf_processor_main(
return self._apply_hf_processor(
prompt=prompt,
mm_items=mm_data_items,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_update=True,
return_mm_hashes=return_mm_hashes,
)
return super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
return_mm_hashes=return_mm_hashes,
)

View File

@ -396,14 +396,12 @@ def _build_llava_or_pixtral_hf_processor(
dummy_inputs: BaseDummyInputsBuilder[_I],
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True,
) -> BaseMultiModalProcessor:
if isinstance(info, PixtralHFProcessingInfo):
return PixtralHFMultiModalProcessor(
info,
dummy_inputs, # type: ignore
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
if isinstance(info, LlavaProcessingInfo):
@ -411,7 +409,6 @@ def _build_llava_or_pixtral_hf_processor(
info,
dummy_inputs, # type: ignore
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
raise NotImplementedError(type(info))

View File

@ -312,14 +312,12 @@ def _build_mistral3_processor(
dummy_inputs: BaseDummyInputsBuilder[_I],
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True,
) -> BaseMultiModalProcessor:
assert isinstance(info, Mistral3ProcessingInfo)
return Mistral3MultiModalProcessor(
info,
dummy_inputs, # type: ignore
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)

View File

@ -36,8 +36,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
BaseProcessingInfo, MultiModalHashes,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
@ -271,15 +272,19 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]:
prompt_ids, mm_kwargs, _ = super()._cached_apply_hf_processor(
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
prompt_ids, mm_kwargs, mm_hashes, _ = super(
)._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
return_mm_hashes=return_mm_hashes,
)
# NOTE: The tokens are already inserted by the chat template
return prompt_ids, mm_kwargs, True
return prompt_ids, mm_kwargs, mm_hashes, True
@MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor,

View File

@ -876,6 +876,16 @@ def find_mm_placeholders(
_V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]")
class ProcessingCacheOptionalItem(NamedTuple):
key: str
value: Optional[MultiModalKwargsItem]
class ProcessingCacheItem(NamedTuple):
key: str
value: MultiModalKwargsItem
class ProcessingCache:
@staticmethod
@ -980,6 +990,22 @@ class ProcessingCache:
return self._cache.get(cache_key)
def get_item(
self,
model_id: str,
modality: str,
input_item: object,
input_kwargs: Mapping[str, object],
) -> ProcessingCacheOptionalItem:
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
return ProcessingCacheOptionalItem(
key=cache_key,
value=self._cache.get(cache_key),
)
def put(
self,
model_id: str,
@ -997,6 +1023,9 @@ class ProcessingCache:
**input_kwargs)
self._cache[cache_key] = output_kwargs
def put_item(self, item: ProcessingCacheItem) -> None:
self._cache[item.key] = item.value
class BaseProcessingInfo:
"""Base class to provide the information necessary for data processing."""
@ -1052,6 +1081,11 @@ class BaseProcessingInfo:
_I = TypeVar("_I", bound=BaseProcessingInfo)
MultiModalHashes = dict[str, list[str]]
"""
A collection of hashes with a similar structure as :class:`MultiModalKwargs`.
"""
class BaseMultiModalProcessor(ABC, Generic[_I]):
"""
@ -1064,14 +1098,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
info: _I,
dummy_inputs: "BaseDummyInputsBuilder[_I]",
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
cache: Optional[ProcessingCache] = None) -> None:
super().__init__()
self.info = info
self.dummy_inputs = dummy_inputs
self.cache = cache
self.enable_sanity_checks = enable_sanity_checks
self.data_parser = self._get_data_parser()
@ -1340,46 +1372,144 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return prompt_ids, mm_kwargs, False
def _cached_apply_hf_processor(
def _get_cache_missing_items(
self,
prompt: Union[str, list[int]],
cache: ProcessingCache,
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]:
"""
Apply the HF processor on the full prompt text,
caching the results and reusing cached results.
"""
cache = self.cache
) -> tuple[dict[str, list[ProcessingCacheOptionalItem]], dict[
str, list[object]]]:
model_id = self.info.model_id
_, passthrough_data = self._get_hf_mm_data(mm_data_items)
if cache is None or passthrough_data:
return self._apply_hf_processor_main(
prompt=prompt,
mm_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_update=True,
)
mm_maybe_cached_kw_items = {
mm_cache_items = {
modality: [
cache.get(model_id, modality, item, hf_processor_mm_kwargs)
for item in items
cache.get_item(model_id, modality, item,
hf_processor_mm_kwargs) for item in items
]
for modality, items in mm_data_items.items()
}
mm_missing_idxs = {
modality:
[idx for idx, item in enumerate(kw_items) if item is None]
for modality, kw_items in mm_maybe_cached_kw_items.items()
modality: [
idx for idx, item in enumerate(cache_items)
if item.value is None
]
for modality, cache_items in mm_cache_items.items()
}
mm_missing_data = {
modality: [mm_data_items[modality][idx] for idx in idxs]
for modality, idxs in mm_missing_idxs.items()
}
mm_missing_data_items = self._to_mm_items(mm_missing_data)
return mm_cache_items, mm_missing_data
def _hash_mm_items(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalHashes:
"""Create MM hashes to be returned (only used in V1)."""
model_id = self.info.model_id
return {
modality: [
MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs)
for item in items
]
for modality, items in mm_items.items()
}
def _merge_mm_kwargs(
self,
cache: ProcessingCache,
mm_cache_items: dict[str, list[ProcessingCacheOptionalItem]],
mm_missing_data: dict[str, list[object]],
mm_missing_kwargs: MultiModalKwargs,
) -> dict[str, list[ProcessingCacheItem]]:
mm_missing_next_idx = {modality: 0 for modality in mm_missing_data}
merged_items = defaultdict[str, list[ProcessingCacheItem]](list)
for modality, cache_items in mm_cache_items.items():
for cache_item in cache_items:
if cache_item.value is None:
kw_item = mm_missing_kwargs.get_item(
modality,
mm_missing_next_idx[modality],
)
cache_item_new = ProcessingCacheItem(
key=cache_item.key,
value=kw_item,
)
cache.put_item(cache_item_new)
mm_missing_next_idx[modality] += 1
else:
cache_item_new = ProcessingCacheItem(
key=cache_item.key,
value=cache_item.value,
)
merged_items[modality].append(cache_item_new)
return dict(merged_items)
def _apply_hf_processor(
self,
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
(
prompt_ids,
mm_kwargs,
is_update_applied,
) = self._apply_hf_processor_main(
prompt=prompt,
mm_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_update=True,
)
mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs)
if return_mm_hashes else None)
return prompt_ids, mm_kwargs, mm_hashes, is_update_applied
def _cached_apply_hf_processor(
self,
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
"""
Apply the HF processor on the full prompt text,
caching the results and reusing cached results.
"""
cache = self.cache
_, passthrough_data = self._get_hf_mm_data(mm_data_items)
if cache is None or passthrough_data:
return self._apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
return_mm_hashes=return_mm_hashes,
)
(
mm_cache_items,
mm_missing_data,
) = self._get_cache_missing_items(
cache=cache,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we can't apply prompt updates until the new multimodal
@ -1390,48 +1520,29 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
is_update_applied,
) = self._apply_hf_processor_main(
prompt=prompt,
mm_items=mm_missing_data_items,
mm_items=self._to_mm_items(mm_missing_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_update=False,
)
mm_missing_next_idx = {
modality: 0
for modality in mm_missing_data_items
}
mm_cache_items_merged = self._merge_mm_kwargs(
cache,
mm_cache_items=mm_cache_items,
mm_missing_data=mm_missing_data,
mm_missing_kwargs=mm_missing_kwargs,
)
merged_kw_items = list[MultiModalKwargsItem]()
for modality, kw_items in mm_maybe_cached_kw_items.items():
for idx, kw_item in enumerate(kw_items):
if kw_item is None:
kw_item = mm_missing_kwargs.get_item(
modality,
mm_missing_next_idx[modality],
)
mm_kwargs = MultiModalKwargs.from_items([
item.value for cache_items in mm_cache_items_merged.values()
for item in cache_items
])
cache.put(
model_id,
modality,
mm_data_items[modality][idx],
hf_processor_mm_kwargs,
kw_item,
)
mm_hashes = {
modality: [item.key for item in cache_items]
for modality, cache_items in mm_cache_items_merged.items()
} if return_mm_hashes else None
mm_missing_next_idx[modality] += 1
merged_kw_items.append(kw_item)
if self.enable_sanity_checks:
mm_missing_counts = mm_missing_data_items.get_all_counts()
assert all(
item_count == mm_missing_counts[modality]
for modality, item_count in mm_missing_next_idx.items()), dict(
mm_missing_next_idx=mm_missing_next_idx,
mm_missing_counts=mm_missing_counts)
mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
return prompt_ids, mm_kwargs, is_update_applied
return prompt_ids, mm_kwargs, mm_hashes, is_update_applied
def _bind_and_group_updates(
self,
@ -1569,27 +1680,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
"model (usually arising from an inconsistency between "
"`_call_hf_processor` and `_get_prompt_updates`).")
def _hash_mm_items(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> dict[str, list[str]]:
"""Create MM hashes to be returned (only used in V1)."""
# TODO: Use these hash keys for caching operations in apply_hf_processor
# instead of rehashing.
model_id = self.info.model_id
return {
modality: [
MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs)
for item in items
]
for modality, items in mm_items.items()
}
def _maybe_apply_prompt_updates(
self,
mm_items: MultiModalDataItems,
@ -1655,17 +1745,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
"""
mm_items = self._to_mm_items(mm_data)
mm_hashes = (self._hash_mm_items(mm_items, hf_processor_mm_kwargs)
if return_mm_hashes else None)
(
prompt_ids,
mm_kwargs,
mm_hashes,
is_update_applied,
) = self._cached_apply_hf_processor(
prompt,
mm_items,
hf_processor_mm_kwargs,
return_mm_hashes=return_mm_hashes,
)
prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates(
@ -1717,6 +1806,32 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
"""Create input prompt for the decoder."""
return prompt
def _get_enc_dec_inputs(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
encoder_inputs: MultiModalInputs,
):
tokenizer = self.info.get_tokenizer()
decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
if isinstance(decoder_prompt, str):
decoder_prompt_ids = encode_tokens(tokenizer,
decoder_prompt,
add_special_tokens=False)
else:
decoder_prompt_ids = decoder_prompt
decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
mm_inputs = MultiModalEncDecInputs(
encoder_prompt=encoder_inputs["prompt"],
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
**encoder_inputs)
mm_inputs.update({
"prompt": decoder_prompt,
"prompt_token_ids": decoder_prompt_ids
})
return mm_inputs
def apply(
self,
prompt: Union[str, list[int]],
@ -1739,22 +1854,8 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
return_mm_hashes,
)
tokenizer = self.info.get_tokenizer()
decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
if isinstance(decoder_prompt, str):
decoder_prompt_ids = encode_tokens(tokenizer,
decoder_prompt,
add_special_tokens=False)
else:
decoder_prompt_ids = decoder_prompt
decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
mm_inputs = MultiModalEncDecInputs(
encoder_prompt=encoder_inputs["prompt"],
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
**encoder_inputs)
mm_inputs.update({
"prompt": decoder_prompt,
"prompt_token_ids": decoder_prompt_ids
})
return mm_inputs
return self._get_enc_dec_inputs(
prompt=prompt,
mm_data=mm_data,
encoder_inputs=encoder_inputs,
)