[Refactor] Pass tokenizer explicitly instead of binding to prompt update (#23542)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-08-25 21:31:57 +08:00
committed by GitHub
parent e269be2ba2
commit 6879cd80ae
4 changed files with 95 additions and 144 deletions

View File

@ -243,7 +243,7 @@ def test_find_token_matches(
mock_tokenizer = cast(AnyTokenizer, object())
prompt_updates = {
key: update_type(key, target, []).resolve(mock_tokenizer, 0)
key: update_type(key, target, []).resolve(0)
for key, target in target_by_key.items()
}
result = {
@ -392,7 +392,7 @@ def test_find_text_matches(
mock_tokenizer = cast(AnyTokenizer, object())
prompt_updates = {
key: update_type(key, target, []).resolve(mock_tokenizer, 0)
key: update_type(key, target, []).resolve(0)
for key, target in target_by_key.items()
}
result = {
@ -559,10 +559,8 @@ def test_find_update_text(
) in expected_by_update_type_mm_count.items():
for mm_count, expected in expected_by_mm_count.items():
mm_prompt_updates = {
key: [[
update_type(key, target,
repl_by_key[key]).resolve(mock_tokenizer, i)
] for i in range(mm_count)]
key: [[update_type(key, target, repl_by_key[key]).resolve(i)]
for i in range(mm_count)]
for key, target in target_by_key.items()
}
@ -731,10 +729,8 @@ def test_find_update_tokens(
) in expected_by_update_type_mm_count.items():
for mm_count, expected in expected_by_mm_count.items():
mm_prompt_updates = {
key: [[
update_type(key, target,
repl_by_key[key]).resolve(mock_tokenizer, i)
] for i in range(mm_count)]
key: [[update_type(key, target, repl_by_key[key]).resolve(i)]
for i in range(mm_count)]
for key, target in target_by_key.items()
}
@ -879,12 +875,11 @@ def test_find_mm_placeholders(
mock_tokenizer = cast(AnyTokenizer, object())
mm_prompt_updates = {
key: [[update_type(key, [], repl).resolve(mock_tokenizer, i)]
for i in range(3)]
key: [[update_type(key, [], repl).resolve(i)] for i in range(3)]
for key, repl in repl_by_key.items()
}
result = find_mm_placeholders(prompt, mm_prompt_updates)
result = find_mm_placeholders(prompt, mm_prompt_updates, mock_tokenizer)
# Only displayed on error
print("result:", result)

View File

@ -28,7 +28,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PlaceholderFeaturesInfo,
PromptReplacement, PromptUpdate,
PromptUpdateDetails,
find_mm_placeholders,
replace_token_matches)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder
@ -401,7 +400,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
repl_token_ids.extend(repl_toks)
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
repls = find_mm_placeholders(repl_token_ids, mm_prompt_updates)
repls = super()._find_mm_placeholders(repl_token_ids,
mm_prompt_updates)
return {
modality: [

View File

@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PlaceholderFeaturesInfo,
PromptReplacement, PromptUpdate,
PromptUpdateDetails,
find_mm_placeholders,
replace_token_matches)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder
@ -318,7 +317,8 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
repl_token_ids.extend(repl_toks)
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
repls = find_mm_placeholders(repl_token_ids, mm_prompt_updates)
repls = super()._find_mm_placeholders(repl_token_ids,
mm_prompt_updates)
return {
modality: [

View File

@ -44,6 +44,44 @@ PromptSeq = Union[str, list[int]]
"""A token sequence (list of token IDs) or text."""
@lru_cache(maxsize=2048)
def _cached_encode(
tokenizer: AnyTokenizer,
text: str,
*,
add_special_tokens: Optional[bool] = None,
) -> list[int]:
return encode_tokens(tokenizer,
text,
add_special_tokens=add_special_tokens)
@lru_cache(maxsize=2048)
def _cached_decode(
tokenizer: AnyTokenizer,
token_ids: tuple[int, ...],
*,
skip_special_tokens: Optional[bool] = None,
) -> str:
return decode_tokens(tokenizer,
list(token_ids),
skip_special_tokens=skip_special_tokens)
def _seq2text(tokenizer: AnyTokenizer, seq: PromptSeq) -> str:
if isinstance(seq, str):
return seq
return _cached_decode(tokenizer, tuple(seq))
def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]:
if isinstance(seq, str):
return _cached_encode(tokenizer, seq, add_special_tokens=False)
return seq
class _GetMatchIndex(Protocol):
def __call__(
@ -137,7 +175,8 @@ class PromptUpdateDetails(Generic[_S]):
full: _S
"""The full content."""
is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None
is_embed: Optional[Callable[[AnyTokenizer, PromptSeq],
torch.Tensor]] = None
"""
Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full],
return a boolean mask of shape `(len(full),)` indicating which positions
@ -159,11 +198,12 @@ class PromptUpdateDetails(Generic[_S]):
embed_text: str,
) -> "PromptUpdateDetails[_S]":
def is_embed(full: "_BoundPromptSequence") -> torch.Tensor:
embed_token_ids = encode_tokens(full.tokenizer, embed_text)
def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor:
embed_token_ids = encode_tokens(tokenizer, embed_text)
token_ids = _seq2tokens(tokenizer, full)
return torch.isin(
torch.tensor(full.token_ids),
torch.tensor(token_ids),
torch.tensor(embed_token_ids),
)
@ -174,10 +214,13 @@ class PromptUpdateDetails(Generic[_S]):
seq: _S,
embed_token_id: int,
) -> "PromptUpdateDetails[_S]":
return PromptUpdateDetails(
full=seq,
is_embed=lambda f: torch.tensor(f.token_ids) == embed_token_id,
)
def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor:
token_ids = _seq2tokens(tokenizer, full)
return torch.tensor(token_ids) == embed_token_id
return PromptUpdateDetails(full=seq, is_embed=is_embed)
PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
@ -230,25 +273,14 @@ class PromptUpdate(ABC):
"""Defines how to update the prompt."""
raise NotImplementedError
def _resolve_target(
self,
tokenizer: AnyTokenizer,
item_idx: int,
) -> Union["_BoundPromptSequence", PromptIndex]:
def _resolve_target(self, item_idx: int) -> UpdateTarget:
target = self.target
if callable(target):
target = target(item_idx)
if isinstance(target, PromptIndex):
return target
return target
return _BoundPromptSequence.from_seq(tokenizer, target)
def _resolve_content(
self,
tokenizer: AnyTokenizer,
item_idx: int,
) -> "_BoundPromptContent":
def _resolve_content(self, item_idx: int) -> PromptUpdateDetails:
content = self.content
if callable(content):
content = content(item_idx)
@ -256,17 +288,9 @@ class PromptUpdate(ABC):
if not isinstance(content, PromptUpdateDetails):
content = PromptUpdateDetails.from_seq(content)
bound_full = _BoundPromptSequence.from_seq(tokenizer, content.full)
bound_content = _BoundPromptContent(full=bound_full,
is_embed=content.is_embed)
return content
return bound_content
def resolve(
self,
tokenizer: AnyTokenizer,
item_idx: int,
) -> "ResolvedPromptUpdate":
def resolve(self, item_idx: int) -> "ResolvedPromptUpdate":
"""
Given the index of the processed item within
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
@ -276,8 +300,8 @@ class PromptUpdate(ABC):
modality=self.modality,
item_idx=item_idx,
mode=self.mode,
target=self._resolve_target(tokenizer, item_idx),
content=self._resolve_content(tokenizer, item_idx),
target=self._resolve_target(item_idx),
content=self._resolve_content(item_idx),
)
@ -424,30 +448,6 @@ class PromptReplacement(PromptUpdate):
return UpdateMode.REPLACE
@lru_cache(maxsize=2048)
def _cached_encode(
tokenizer: AnyTokenizer,
text: str,
*,
add_special_tokens: Optional[bool] = None,
) -> list[int]:
return encode_tokens(tokenizer,
text,
add_special_tokens=add_special_tokens)
@lru_cache(maxsize=2048)
def _cached_decode(
tokenizer: AnyTokenizer,
token_ids: tuple[int, ...],
*,
skip_special_tokens: Optional[bool] = None,
) -> str:
return decode_tokens(tokenizer,
list(token_ids),
skip_special_tokens=skip_special_tokens)
class _HasModalityAttr(Protocol):
modality: str
@ -468,59 +468,6 @@ def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
return full_groupby(values, key=lambda x: x.modality)
@dataclass
class _BoundPromptSequence:
"""
A [`_PromptSeq`][vllm.multimodal.processing.PromptSeq] bound
to a tokenizer to automatically
convert between token sequence and text representations.
"""
tokenizer: AnyTokenizer = field(repr=False)
_text: Optional[str]
_token_ids: Optional[list[int]]
@staticmethod
def from_seq(
tokenizer: AnyTokenizer,
seq: PromptSeq,
) -> "_BoundPromptSequence":
return _BoundPromptSequence(
tokenizer=tokenizer,
_text=seq if isinstance(seq, str) else None,
_token_ids=seq if isinstance(seq, list) else None,
)
def __post_init__(self) -> None:
if self._text is None and self._token_ids is None:
raise ValueError("At least one of 'text' and 'token_ids' must be "
"specified")
@property
def text(self) -> str:
if self._text is None:
assert self._token_ids is not None
self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))
return self._text
@property
def token_ids(self) -> list[int]:
if self._token_ids is None:
assert self._text is not None
self._token_ids = _cached_encode(self.tokenizer,
self._text,
add_special_tokens=False)
return self._token_ids
@dataclass
class _BoundPromptContent:
full: _BoundPromptSequence
is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]]
class PromptTargetMatch(NamedTuple):
start_idx: int
end_idx: int
@ -542,10 +489,10 @@ class ResolvedPromptUpdate:
mode: UpdateMode
"""Defines how to update the prompt."""
target: Union[_BoundPromptSequence, PromptIndex]
target: UpdateTarget
"""The token sequence (or text) to update."""
content: _BoundPromptContent = field(repr=False)
content: PromptUpdateDetails = field(repr=False)
"""The placeholder tokens that are part of the update."""
def iter_token_matches(
@ -565,8 +512,10 @@ class ResolvedPromptUpdate:
return
target_token_ids = _seq2tokens(tokenizer, target)
for match in iter_token_matches(prompt,
target.token_ids,
target_token_ids,
start_idx=start_idx):
yield PromptTargetMatch(match.start_idx, match.end_idx)
@ -587,7 +536,9 @@ class ResolvedPromptUpdate:
return
for match in re.finditer(re.escape(target.text), prompt,
target_text = _seq2text(tokenizer, target)
for match in re.finditer(re.escape(target_text), prompt,
pos=start_idx):
yield PromptTargetMatch(match.start(), match.end())
@ -779,7 +730,7 @@ def _apply_matches(
matched_update = mm_prompt_updates[modality][item_idx][
update_idx]
matched_content = matched_update.content
matched_content = matched_update.content.full
if mode == UpdateMode.INSERT:
end_idx_to_insert = match.end_idx
@ -789,8 +740,10 @@ def _apply_matches(
assert_never(mode)
out_seqs.append(prompt[prev_end_idx:end_idx_to_insert])
out_seqs.append(matched_content.full.text if isinstance(
prompt, str) else matched_content.full.token_ids)
out_seqs.append(
_seq2text(tokenizer, matched_content
) if isinstance(prompt, str) else _seq2tokens(
tokenizer, matched_content))
out_result[modality][item_idx] = update_idx
# Exclude overlapping matches
@ -842,6 +795,7 @@ def apply_text_matches(
def _iter_placeholders(
prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer,
) -> Iterable[PlaceholderFeaturesInfo]:
"""
Yield each set of placeholder tokens found in `prompt`.
@ -868,7 +822,7 @@ def _iter_placeholders(
for update in modality_updates[item_idx]:
content = update.content
content_tokens_full = content.full.token_ids
content_tokens_full = _seq2tokens(tokenizer, content.full)
content_len_full = len(content_tokens_full)
end_idx_full = start_idx + content_len_full
@ -878,7 +832,8 @@ def _iter_placeholders(
if prompt[start_idx:end_idx_full] == content_tokens_full:
content_is_embed = content.is_embed
if content_is_embed is not None:
content_is_embed = content_is_embed(content.full)
content_is_embed = content_is_embed(
tokenizer, content.full)
yield PlaceholderFeaturesInfo(
modality=modality,
@ -904,8 +859,9 @@ def _iter_placeholders(
def find_mm_placeholders(
prompt: list[int],
mm_prompt_updates: "MultiModalPromptUpdates",
tokenizer: AnyTokenizer,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
it = _iter_placeholders(prompt, mm_prompt_updates)
it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer)
return dict(full_groupby_modality(it))
@ -1160,12 +1116,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_updates: Sequence[PromptUpdate],
mm_item_counts: Mapping[str, int],
) -> MultiModalPromptUpdates:
tokenizer = self.info.get_tokenizer()
return {
modality:
[[update.resolve(tokenizer, item_idx) for update in updates]
for item_idx in range(mm_item_counts.get(modality, 0))]
modality: [[update.resolve(item_idx) for update in updates]
for item_idx in range(mm_item_counts.get(modality, 0))]
for modality, updates in full_groupby_modality(prompt_updates)
}
@ -1208,7 +1161,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
new_token_ids: list[int],
mm_prompt_updates: MultiModalPromptUpdates,
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
return find_mm_placeholders(new_token_ids, mm_prompt_updates)
tokenizer = self.info.get_tokenizer()
return find_mm_placeholders(new_token_ids, mm_prompt_updates,
tokenizer)
def _get_hf_mm_data(
self,