mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] Clean up Qwen2.5-Omni code (#17301)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -51,11 +51,9 @@ from vllm.model_executor.models.qwen2_audio import (
|
||||
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalDataParser
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.hasher import MultiModalHasher
|
||||
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
||||
MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
MultiModalKwargs, NestedTensors)
|
||||
from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems,
|
||||
ModalityDataItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
@ -279,46 +277,17 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return _qwen2_5_omni_thinker_field_config(hf_inputs)
|
||||
|
||||
def apply(
|
||||
def _maybe_apply_prompt_updates(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
return_mm_hashes: bool = False,
|
||||
) -> MultiModalInputs:
|
||||
prompt_ids: list[int],
|
||||
mm_kwargs: MultiModalKwargs,
|
||||
is_update_applied: bool,
|
||||
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
|
||||
"""
|
||||
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
|
||||
"""
|
||||
mm_items = self._to_mm_items(mm_data)
|
||||
|
||||
# Create MM hashes to be returned (only used in V1)
|
||||
# TODO: Use these hash keys for caching operations in apply_hf_processor
|
||||
# instead of rehashing.
|
||||
|
||||
if return_mm_hashes:
|
||||
model_id = self.info.model_id
|
||||
mm_hashes = {
|
||||
modality: [
|
||||
MultiModalHasher.hash_kwargs(model_id=model_id,
|
||||
**{modality: item},
|
||||
**hf_processor_mm_kwargs)
|
||||
for item in items
|
||||
]
|
||||
for modality, items in mm_items.items()
|
||||
}
|
||||
else:
|
||||
mm_hashes = None
|
||||
|
||||
(
|
||||
prompt_ids,
|
||||
mm_kwargs,
|
||||
is_update_applied,
|
||||
) = self._cached_apply_hf_processor(
|
||||
prompt,
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
unbound_prompt_updates = self._get_prompt_updates(
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
@ -364,22 +333,10 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
prompt = decode_tokens(tokenizer, prompt_ids)
|
||||
|
||||
mm_placeholder_ranges = {
|
||||
modality: [item.to_range() for item in placeholders]
|
||||
for modality, placeholders in mm_placeholders.items()
|
||||
}
|
||||
|
||||
if use_audio_in_video:
|
||||
mm_kwargs["use_audio_in_video"] = True
|
||||
|
||||
return MultiModalInputs(
|
||||
type="multimodal",
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_ids,
|
||||
mm_kwargs=mm_kwargs,
|
||||
mm_hashes=mm_hashes,
|
||||
mm_placeholders=mm_placeholder_ranges,
|
||||
)
|
||||
return prompt_ids, prompt, mm_placeholders
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
|
@ -1569,56 +1569,35 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
"model (usually arising from an inconsistency between "
|
||||
"`_call_hf_processor` and `_get_prompt_updates`).")
|
||||
|
||||
def apply(
|
||||
def _hash_mm_items(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
return_mm_hashes: bool = False,
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
Process multi-modal inputs to be used in vLLM.
|
||||
) -> dict[str, list[str]]:
|
||||
"""Create MM hashes to be returned (only used in V1)."""
|
||||
|
||||
The main steps are:
|
||||
|
||||
1. Apply HF Processor on prompt text and multi-modal data together,
|
||||
outputting token IDs and processed tensors.
|
||||
2. Find and update sequences in the token IDs with placeholder tokens.
|
||||
The number of placeholder tokens equals the feature size of the
|
||||
multi-modal data outputted by the multi-modal encoder.
|
||||
3. Extract information about the placeholder tokens from the
|
||||
processed token IDs.
|
||||
"""
|
||||
mm_items = self._to_mm_items(mm_data)
|
||||
|
||||
# Create MM hashes to be returned (only used in V1)
|
||||
# TODO: Use these hash keys for caching operations in apply_hf_processor
|
||||
# instead of rehashing.
|
||||
model_id = self.info.model_id
|
||||
|
||||
if return_mm_hashes:
|
||||
model_id = self.info.model_id
|
||||
mm_hashes = {
|
||||
modality: [
|
||||
MultiModalHasher.hash_kwargs(model_id=model_id,
|
||||
**{modality: item},
|
||||
**hf_processor_mm_kwargs)
|
||||
for item in items
|
||||
]
|
||||
for modality, items in mm_items.items()
|
||||
}
|
||||
else:
|
||||
mm_hashes = None
|
||||
|
||||
(
|
||||
prompt_ids,
|
||||
mm_kwargs,
|
||||
is_update_applied,
|
||||
) = self._cached_apply_hf_processor(
|
||||
prompt,
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
)
|
||||
return {
|
||||
modality: [
|
||||
MultiModalHasher.hash_kwargs(model_id=model_id,
|
||||
**{modality: item},
|
||||
**hf_processor_mm_kwargs)
|
||||
for item in items
|
||||
]
|
||||
for modality, items in mm_items.items()
|
||||
}
|
||||
|
||||
def _maybe_apply_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
prompt_ids: list[int],
|
||||
mm_kwargs: MultiModalKwargs,
|
||||
is_update_applied: bool,
|
||||
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
|
||||
unbound_prompt_updates = self._get_prompt_updates(
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
@ -1652,6 +1631,51 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
)
|
||||
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
||||
|
||||
return prompt_ids, prompt, mm_placeholders
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
return_mm_hashes: bool = False,
|
||||
) -> MultiModalInputs:
|
||||
"""
|
||||
Process multi-modal inputs to be used in vLLM.
|
||||
|
||||
The main steps are:
|
||||
|
||||
1. Apply HF Processor on prompt text and multi-modal data together,
|
||||
outputting token IDs and processed tensors.
|
||||
2. Find and update sequences in the token IDs with placeholder tokens.
|
||||
The number of placeholder tokens equals the feature size of the
|
||||
multi-modal data outputted by the multi-modal encoder.
|
||||
3. Extract information about the placeholder tokens from the
|
||||
processed token IDs.
|
||||
"""
|
||||
mm_items = self._to_mm_items(mm_data)
|
||||
|
||||
mm_hashes = (self._hash_mm_items(mm_items, hf_processor_mm_kwargs)
|
||||
if return_mm_hashes else None)
|
||||
|
||||
(
|
||||
prompt_ids,
|
||||
mm_kwargs,
|
||||
is_update_applied,
|
||||
) = self._cached_apply_hf_processor(
|
||||
prompt,
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates(
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
prompt_ids=prompt_ids,
|
||||
mm_kwargs=mm_kwargs,
|
||||
is_update_applied=is_update_applied,
|
||||
)
|
||||
|
||||
mm_placeholder_ranges = {
|
||||
modality: [item.to_range() for item in placeholders]
|
||||
for modality, placeholders in mm_placeholders.items()
|
||||
|
Reference in New Issue
Block a user