mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Optim] Compute multimodal hash only once per item (#17314)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -22,8 +22,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
BaseProcessingInfo, MultiModalHashes,
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
|
||||
@ -279,24 +279,26 @@ class DeepseekVL2MultiModalProcessor(
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> tuple[list[int], MultiModalKwargs, bool]:
|
||||
*,
|
||||
return_mm_hashes: bool,
|
||||
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
|
||||
# The processor logic is different for len(images) <= 2 vs > 2
|
||||
# Since the processing cache assumes that the processor output is
|
||||
# invariant of how many images are passed per prompt, we only
|
||||
# perform caching for the most common case
|
||||
if mm_data_items.get_count("image", strict=False) > 2:
|
||||
# This code path corresponds to the cache being disabled
|
||||
return self._apply_hf_processor_main(
|
||||
return self._apply_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_items=mm_data_items,
|
||||
mm_data_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
enable_hf_prompt_update=True,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
return super()._cached_apply_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
|
||||
|
@ -19,8 +19,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
from vllm.multimodal.processing import (MultiModalHashes, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from .intern_vit import InternVisionModel
|
||||
@ -488,24 +488,26 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> tuple[list[int], MultiModalKwargs, bool]:
|
||||
*,
|
||||
return_mm_hashes: bool,
|
||||
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
|
||||
# The processor logic is different for len(images) <= 1 vs > 1
|
||||
# Since the processing cache assumes that the processor output is
|
||||
# invariant of how many images are passed per prompt, we only
|
||||
# perform caching for the most common case
|
||||
if mm_data_items.get_count("image", strict=False) > 1:
|
||||
# This code path corresponds to the cache being disabled
|
||||
return self._apply_hf_processor_main(
|
||||
return self._apply_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_items=mm_data_items,
|
||||
mm_data_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
enable_hf_prompt_update=True,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
return super()._cached_apply_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
|
||||
|
@ -396,14 +396,12 @@ def _build_llava_or_pixtral_hf_processor(
|
||||
dummy_inputs: BaseDummyInputsBuilder[_I],
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
enable_sanity_checks: bool = True,
|
||||
) -> BaseMultiModalProcessor:
|
||||
if isinstance(info, PixtralHFProcessingInfo):
|
||||
return PixtralHFMultiModalProcessor(
|
||||
info,
|
||||
dummy_inputs, # type: ignore
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks,
|
||||
)
|
||||
|
||||
if isinstance(info, LlavaProcessingInfo):
|
||||
@ -411,7 +409,6 @@ def _build_llava_or_pixtral_hf_processor(
|
||||
info,
|
||||
dummy_inputs, # type: ignore
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks,
|
||||
)
|
||||
|
||||
raise NotImplementedError(type(info))
|
||||
|
@ -312,14 +312,12 @@ def _build_mistral3_processor(
|
||||
dummy_inputs: BaseDummyInputsBuilder[_I],
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
enable_sanity_checks: bool = True,
|
||||
) -> BaseMultiModalProcessor:
|
||||
assert isinstance(info, Mistral3ProcessingInfo)
|
||||
return Mistral3MultiModalProcessor(
|
||||
info,
|
||||
dummy_inputs, # type: ignore
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks,
|
||||
)
|
||||
|
||||
|
||||
|
@ -36,8 +36,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
BaseProcessingInfo, MultiModalHashes,
|
||||
PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||
@ -271,15 +272,19 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> tuple[list[int], MultiModalKwargs, bool]:
|
||||
prompt_ids, mm_kwargs, _ = super()._cached_apply_hf_processor(
|
||||
*,
|
||||
return_mm_hashes: bool,
|
||||
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
|
||||
prompt_ids, mm_kwargs, mm_hashes, _ = super(
|
||||
)._cached_apply_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
# NOTE: The tokens are already inserted by the chat template
|
||||
return prompt_ids, mm_kwargs, True
|
||||
return prompt_ids, mm_kwargs, mm_hashes, True
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor,
|
||||
|
@ -876,6 +876,16 @@ def find_mm_placeholders(
|
||||
_V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]")
|
||||
|
||||
|
||||
class ProcessingCacheOptionalItem(NamedTuple):
|
||||
key: str
|
||||
value: Optional[MultiModalKwargsItem]
|
||||
|
||||
|
||||
class ProcessingCacheItem(NamedTuple):
|
||||
key: str
|
||||
value: MultiModalKwargsItem
|
||||
|
||||
|
||||
class ProcessingCache:
|
||||
|
||||
@staticmethod
|
||||
@ -980,6 +990,22 @@ class ProcessingCache:
|
||||
|
||||
return self._cache.get(cache_key)
|
||||
|
||||
def get_item(
|
||||
self,
|
||||
model_id: str,
|
||||
modality: str,
|
||||
input_item: object,
|
||||
input_kwargs: Mapping[str, object],
|
||||
) -> ProcessingCacheOptionalItem:
|
||||
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
|
||||
**{modality: input_item},
|
||||
**input_kwargs)
|
||||
|
||||
return ProcessingCacheOptionalItem(
|
||||
key=cache_key,
|
||||
value=self._cache.get(cache_key),
|
||||
)
|
||||
|
||||
def put(
|
||||
self,
|
||||
model_id: str,
|
||||
@ -997,6 +1023,9 @@ class ProcessingCache:
|
||||
**input_kwargs)
|
||||
self._cache[cache_key] = output_kwargs
|
||||
|
||||
def put_item(self, item: ProcessingCacheItem) -> None:
|
||||
self._cache[item.key] = item.value
|
||||
|
||||
|
||||
class BaseProcessingInfo:
|
||||
"""Base class to provide the information necessary for data processing."""
|
||||
@ -1052,6 +1081,11 @@ class BaseProcessingInfo:
|
||||
|
||||
_I = TypeVar("_I", bound=BaseProcessingInfo)
|
||||
|
||||
MultiModalHashes = dict[str, list[str]]
|
||||
"""
|
||||
A collection of hashes with a similar structure as :class:`MultiModalKwargs`.
|
||||
"""
|
||||
|
||||
|
||||
class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
"""
|
||||
@ -1064,14 +1098,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
info: _I,
|
||||
dummy_inputs: "BaseDummyInputsBuilder[_I]",
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
enable_sanity_checks: bool = True) -> None:
|
||||
cache: Optional[ProcessingCache] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.info = info
|
||||
self.dummy_inputs = dummy_inputs
|
||||
self.cache = cache
|
||||
self.enable_sanity_checks = enable_sanity_checks
|
||||
|
||||
self.data_parser = self._get_data_parser()
|
||||
|
||||
@ -1340,46 +1372,144 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
|
||||
return prompt_ids, mm_kwargs, False
|
||||
|
||||
def _cached_apply_hf_processor(
|
||||
def _get_cache_missing_items(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
cache: ProcessingCache,
|
||||
mm_data_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> tuple[list[int], MultiModalKwargs, bool]:
|
||||
"""
|
||||
Apply the HF processor on the full prompt text,
|
||||
caching the results and reusing cached results.
|
||||
"""
|
||||
cache = self.cache
|
||||
) -> tuple[dict[str, list[ProcessingCacheOptionalItem]], dict[
|
||||
str, list[object]]]:
|
||||
model_id = self.info.model_id
|
||||
|
||||
_, passthrough_data = self._get_hf_mm_data(mm_data_items)
|
||||
if cache is None or passthrough_data:
|
||||
return self._apply_hf_processor_main(
|
||||
prompt=prompt,
|
||||
mm_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
enable_hf_prompt_update=True,
|
||||
)
|
||||
|
||||
mm_maybe_cached_kw_items = {
|
||||
mm_cache_items = {
|
||||
modality: [
|
||||
cache.get(model_id, modality, item, hf_processor_mm_kwargs)
|
||||
for item in items
|
||||
cache.get_item(model_id, modality, item,
|
||||
hf_processor_mm_kwargs) for item in items
|
||||
]
|
||||
for modality, items in mm_data_items.items()
|
||||
}
|
||||
|
||||
mm_missing_idxs = {
|
||||
modality:
|
||||
[idx for idx, item in enumerate(kw_items) if item is None]
|
||||
for modality, kw_items in mm_maybe_cached_kw_items.items()
|
||||
modality: [
|
||||
idx for idx, item in enumerate(cache_items)
|
||||
if item.value is None
|
||||
]
|
||||
for modality, cache_items in mm_cache_items.items()
|
||||
}
|
||||
mm_missing_data = {
|
||||
modality: [mm_data_items[modality][idx] for idx in idxs]
|
||||
for modality, idxs in mm_missing_idxs.items()
|
||||
}
|
||||
mm_missing_data_items = self._to_mm_items(mm_missing_data)
|
||||
|
||||
return mm_cache_items, mm_missing_data
|
||||
|
||||
def _hash_mm_items(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalHashes:
|
||||
"""Create MM hashes to be returned (only used in V1)."""
|
||||
model_id = self.info.model_id
|
||||
|
||||
return {
|
||||
modality: [
|
||||
MultiModalHasher.hash_kwargs(model_id=model_id,
|
||||
**{modality: item},
|
||||
**hf_processor_mm_kwargs)
|
||||
for item in items
|
||||
]
|
||||
for modality, items in mm_items.items()
|
||||
}
|
||||
|
||||
def _merge_mm_kwargs(
|
||||
self,
|
||||
cache: ProcessingCache,
|
||||
mm_cache_items: dict[str, list[ProcessingCacheOptionalItem]],
|
||||
mm_missing_data: dict[str, list[object]],
|
||||
mm_missing_kwargs: MultiModalKwargs,
|
||||
) -> dict[str, list[ProcessingCacheItem]]:
|
||||
mm_missing_next_idx = {modality: 0 for modality in mm_missing_data}
|
||||
|
||||
merged_items = defaultdict[str, list[ProcessingCacheItem]](list)
|
||||
for modality, cache_items in mm_cache_items.items():
|
||||
for cache_item in cache_items:
|
||||
if cache_item.value is None:
|
||||
kw_item = mm_missing_kwargs.get_item(
|
||||
modality,
|
||||
mm_missing_next_idx[modality],
|
||||
)
|
||||
cache_item_new = ProcessingCacheItem(
|
||||
key=cache_item.key,
|
||||
value=kw_item,
|
||||
)
|
||||
|
||||
cache.put_item(cache_item_new)
|
||||
mm_missing_next_idx[modality] += 1
|
||||
else:
|
||||
cache_item_new = ProcessingCacheItem(
|
||||
key=cache_item.key,
|
||||
value=cache_item.value,
|
||||
)
|
||||
|
||||
merged_items[modality].append(cache_item_new)
|
||||
|
||||
return dict(merged_items)
|
||||
|
||||
def _apply_hf_processor(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
*,
|
||||
return_mm_hashes: bool,
|
||||
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
|
||||
(
|
||||
prompt_ids,
|
||||
mm_kwargs,
|
||||
is_update_applied,
|
||||
) = self._apply_hf_processor_main(
|
||||
prompt=prompt,
|
||||
mm_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
enable_hf_prompt_update=True,
|
||||
)
|
||||
|
||||
mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs)
|
||||
if return_mm_hashes else None)
|
||||
|
||||
return prompt_ids, mm_kwargs, mm_hashes, is_update_applied
|
||||
|
||||
def _cached_apply_hf_processor(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
*,
|
||||
return_mm_hashes: bool,
|
||||
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
|
||||
"""
|
||||
Apply the HF processor on the full prompt text,
|
||||
caching the results and reusing cached results.
|
||||
"""
|
||||
cache = self.cache
|
||||
|
||||
_, passthrough_data = self._get_hf_mm_data(mm_data_items)
|
||||
if cache is None or passthrough_data:
|
||||
return self._apply_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
(
|
||||
mm_cache_items,
|
||||
mm_missing_data,
|
||||
) = self._get_cache_missing_items(
|
||||
cache=cache,
|
||||
mm_data_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
|
||||
# so we can't apply prompt updates until the new multimodal
|
||||
@ -1390,48 +1520,29 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
is_update_applied,
|
||||
) = self._apply_hf_processor_main(
|
||||
prompt=prompt,
|
||||
mm_items=mm_missing_data_items,
|
||||
mm_items=self._to_mm_items(mm_missing_data),
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
enable_hf_prompt_update=False,
|
||||
)
|
||||
|
||||
mm_missing_next_idx = {
|
||||
modality: 0
|
||||
for modality in mm_missing_data_items
|
||||
}
|
||||
mm_cache_items_merged = self._merge_mm_kwargs(
|
||||
cache,
|
||||
mm_cache_items=mm_cache_items,
|
||||
mm_missing_data=mm_missing_data,
|
||||
mm_missing_kwargs=mm_missing_kwargs,
|
||||
)
|
||||
|
||||
merged_kw_items = list[MultiModalKwargsItem]()
|
||||
for modality, kw_items in mm_maybe_cached_kw_items.items():
|
||||
for idx, kw_item in enumerate(kw_items):
|
||||
if kw_item is None:
|
||||
kw_item = mm_missing_kwargs.get_item(
|
||||
modality,
|
||||
mm_missing_next_idx[modality],
|
||||
)
|
||||
mm_kwargs = MultiModalKwargs.from_items([
|
||||
item.value for cache_items in mm_cache_items_merged.values()
|
||||
for item in cache_items
|
||||
])
|
||||
|
||||
cache.put(
|
||||
model_id,
|
||||
modality,
|
||||
mm_data_items[modality][idx],
|
||||
hf_processor_mm_kwargs,
|
||||
kw_item,
|
||||
)
|
||||
mm_hashes = {
|
||||
modality: [item.key for item in cache_items]
|
||||
for modality, cache_items in mm_cache_items_merged.items()
|
||||
} if return_mm_hashes else None
|
||||
|
||||
mm_missing_next_idx[modality] += 1
|
||||
|
||||
merged_kw_items.append(kw_item)
|
||||
|
||||
if self.enable_sanity_checks:
|
||||
mm_missing_counts = mm_missing_data_items.get_all_counts()
|
||||
assert all(
|
||||
item_count == mm_missing_counts[modality]
|
||||
for modality, item_count in mm_missing_next_idx.items()), dict(
|
||||
mm_missing_next_idx=mm_missing_next_idx,
|
||||
mm_missing_counts=mm_missing_counts)
|
||||
|
||||
mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
|
||||
|
||||
return prompt_ids, mm_kwargs, is_update_applied
|
||||
return prompt_ids, mm_kwargs, mm_hashes, is_update_applied
|
||||
|
||||
def _bind_and_group_updates(
|
||||
self,
|
||||
@ -1569,27 +1680,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
"model (usually arising from an inconsistency between "
|
||||
"`_call_hf_processor` and `_get_prompt_updates`).")
|
||||
|
||||
def _hash_mm_items(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> dict[str, list[str]]:
|
||||
"""Create MM hashes to be returned (only used in V1)."""
|
||||
|
||||
# TODO: Use these hash keys for caching operations in apply_hf_processor
|
||||
# instead of rehashing.
|
||||
model_id = self.info.model_id
|
||||
|
||||
return {
|
||||
modality: [
|
||||
MultiModalHasher.hash_kwargs(model_id=model_id,
|
||||
**{modality: item},
|
||||
**hf_processor_mm_kwargs)
|
||||
for item in items
|
||||
]
|
||||
for modality, items in mm_items.items()
|
||||
}
|
||||
|
||||
def _maybe_apply_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
@ -1655,17 +1745,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
"""
|
||||
mm_items = self._to_mm_items(mm_data)
|
||||
|
||||
mm_hashes = (self._hash_mm_items(mm_items, hf_processor_mm_kwargs)
|
||||
if return_mm_hashes else None)
|
||||
|
||||
(
|
||||
prompt_ids,
|
||||
mm_kwargs,
|
||||
mm_hashes,
|
||||
is_update_applied,
|
||||
) = self._cached_apply_hf_processor(
|
||||
prompt,
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
|
||||
prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates(
|
||||
@ -1717,6 +1806,32 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
"""Create input prompt for the decoder."""
|
||||
return prompt
|
||||
|
||||
def _get_enc_dec_inputs(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
encoder_inputs: MultiModalInputs,
|
||||
):
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
|
||||
if isinstance(decoder_prompt, str):
|
||||
decoder_prompt_ids = encode_tokens(tokenizer,
|
||||
decoder_prompt,
|
||||
add_special_tokens=False)
|
||||
else:
|
||||
decoder_prompt_ids = decoder_prompt
|
||||
decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
|
||||
|
||||
mm_inputs = MultiModalEncDecInputs(
|
||||
encoder_prompt=encoder_inputs["prompt"],
|
||||
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
|
||||
**encoder_inputs)
|
||||
mm_inputs.update({
|
||||
"prompt": decoder_prompt,
|
||||
"prompt_token_ids": decoder_prompt_ids
|
||||
})
|
||||
return mm_inputs
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
@ -1739,22 +1854,8 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
return_mm_hashes,
|
||||
)
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
|
||||
if isinstance(decoder_prompt, str):
|
||||
decoder_prompt_ids = encode_tokens(tokenizer,
|
||||
decoder_prompt,
|
||||
add_special_tokens=False)
|
||||
else:
|
||||
decoder_prompt_ids = decoder_prompt
|
||||
decoder_prompt = decode_tokens(tokenizer, decoder_prompt)
|
||||
|
||||
mm_inputs = MultiModalEncDecInputs(
|
||||
encoder_prompt=encoder_inputs["prompt"],
|
||||
encoder_prompt_token_ids=encoder_inputs["prompt_token_ids"],
|
||||
**encoder_inputs)
|
||||
mm_inputs.update({
|
||||
"prompt": decoder_prompt,
|
||||
"prompt_token_ids": decoder_prompt_ids
|
||||
})
|
||||
return mm_inputs
|
||||
return self._get_enc_dec_inputs(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
encoder_inputs=encoder_inputs,
|
||||
)
|
||||
|
Reference in New Issue
Block a user