mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Refactor] Pass tokenizer
explicitly instead of binding to prompt update (#23542)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -243,7 +243,7 @@ def test_find_token_matches(
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
prompt_updates = {
|
||||
key: update_type(key, target, []).resolve(mock_tokenizer, 0)
|
||||
key: update_type(key, target, []).resolve(0)
|
||||
for key, target in target_by_key.items()
|
||||
}
|
||||
result = {
|
||||
@ -392,7 +392,7 @@ def test_find_text_matches(
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
prompt_updates = {
|
||||
key: update_type(key, target, []).resolve(mock_tokenizer, 0)
|
||||
key: update_type(key, target, []).resolve(0)
|
||||
for key, target in target_by_key.items()
|
||||
}
|
||||
result = {
|
||||
@ -559,10 +559,8 @@ def test_find_update_text(
|
||||
) in expected_by_update_type_mm_count.items():
|
||||
for mm_count, expected in expected_by_mm_count.items():
|
||||
mm_prompt_updates = {
|
||||
key: [[
|
||||
update_type(key, target,
|
||||
repl_by_key[key]).resolve(mock_tokenizer, i)
|
||||
] for i in range(mm_count)]
|
||||
key: [[update_type(key, target, repl_by_key[key]).resolve(i)]
|
||||
for i in range(mm_count)]
|
||||
for key, target in target_by_key.items()
|
||||
}
|
||||
|
||||
@ -731,10 +729,8 @@ def test_find_update_tokens(
|
||||
) in expected_by_update_type_mm_count.items():
|
||||
for mm_count, expected in expected_by_mm_count.items():
|
||||
mm_prompt_updates = {
|
||||
key: [[
|
||||
update_type(key, target,
|
||||
repl_by_key[key]).resolve(mock_tokenizer, i)
|
||||
] for i in range(mm_count)]
|
||||
key: [[update_type(key, target, repl_by_key[key]).resolve(i)]
|
||||
for i in range(mm_count)]
|
||||
for key, target in target_by_key.items()
|
||||
}
|
||||
|
||||
@ -879,12 +875,11 @@ def test_find_mm_placeholders(
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
mm_prompt_updates = {
|
||||
key: [[update_type(key, [], repl).resolve(mock_tokenizer, i)]
|
||||
for i in range(3)]
|
||||
key: [[update_type(key, [], repl).resolve(i)] for i in range(3)]
|
||||
for key, repl in repl_by_key.items()
|
||||
}
|
||||
|
||||
result = find_mm_placeholders(prompt, mm_prompt_updates)
|
||||
result = find_mm_placeholders(prompt, mm_prompt_updates, mock_tokenizer)
|
||||
|
||||
# Only displayed on error
|
||||
print("result:", result)
|
||||
|
@ -28,7 +28,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PlaceholderFeaturesInfo,
|
||||
PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
find_mm_placeholders,
|
||||
replace_token_matches)
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
@ -401,7 +400,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
repl_token_ids.extend(repl_toks)
|
||||
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
|
||||
|
||||
repls = find_mm_placeholders(repl_token_ids, mm_prompt_updates)
|
||||
repls = super()._find_mm_placeholders(repl_token_ids,
|
||||
mm_prompt_updates)
|
||||
|
||||
return {
|
||||
modality: [
|
||||
|
@ -35,7 +35,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PlaceholderFeaturesInfo,
|
||||
PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
find_mm_placeholders,
|
||||
replace_token_matches)
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
@ -318,7 +317,8 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
|
||||
repl_token_ids.extend(repl_toks)
|
||||
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
|
||||
|
||||
repls = find_mm_placeholders(repl_token_ids, mm_prompt_updates)
|
||||
repls = super()._find_mm_placeholders(repl_token_ids,
|
||||
mm_prompt_updates)
|
||||
|
||||
return {
|
||||
modality: [
|
||||
|
@ -44,6 +44,44 @@ PromptSeq = Union[str, list[int]]
|
||||
"""A token sequence (list of token IDs) or text."""
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def _cached_encode(
|
||||
tokenizer: AnyTokenizer,
|
||||
text: str,
|
||||
*,
|
||||
add_special_tokens: Optional[bool] = None,
|
||||
) -> list[int]:
|
||||
return encode_tokens(tokenizer,
|
||||
text,
|
||||
add_special_tokens=add_special_tokens)
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def _cached_decode(
|
||||
tokenizer: AnyTokenizer,
|
||||
token_ids: tuple[int, ...],
|
||||
*,
|
||||
skip_special_tokens: Optional[bool] = None,
|
||||
) -> str:
|
||||
return decode_tokens(tokenizer,
|
||||
list(token_ids),
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
|
||||
|
||||
def _seq2text(tokenizer: AnyTokenizer, seq: PromptSeq) -> str:
|
||||
if isinstance(seq, str):
|
||||
return seq
|
||||
|
||||
return _cached_decode(tokenizer, tuple(seq))
|
||||
|
||||
|
||||
def _seq2tokens(tokenizer: AnyTokenizer, seq: PromptSeq) -> list[int]:
|
||||
if isinstance(seq, str):
|
||||
return _cached_encode(tokenizer, seq, add_special_tokens=False)
|
||||
|
||||
return seq
|
||||
|
||||
|
||||
class _GetMatchIndex(Protocol):
|
||||
|
||||
def __call__(
|
||||
@ -137,7 +175,8 @@ class PromptUpdateDetails(Generic[_S]):
|
||||
full: _S
|
||||
"""The full content."""
|
||||
|
||||
is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]] = None
|
||||
is_embed: Optional[Callable[[AnyTokenizer, PromptSeq],
|
||||
torch.Tensor]] = None
|
||||
"""
|
||||
Given [`full`][vllm.multimodal.processing.PromptUpdateDetails.full],
|
||||
return a boolean mask of shape `(len(full),)` indicating which positions
|
||||
@ -159,11 +198,12 @@ class PromptUpdateDetails(Generic[_S]):
|
||||
embed_text: str,
|
||||
) -> "PromptUpdateDetails[_S]":
|
||||
|
||||
def is_embed(full: "_BoundPromptSequence") -> torch.Tensor:
|
||||
embed_token_ids = encode_tokens(full.tokenizer, embed_text)
|
||||
def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor:
|
||||
embed_token_ids = encode_tokens(tokenizer, embed_text)
|
||||
token_ids = _seq2tokens(tokenizer, full)
|
||||
|
||||
return torch.isin(
|
||||
torch.tensor(full.token_ids),
|
||||
torch.tensor(token_ids),
|
||||
torch.tensor(embed_token_ids),
|
||||
)
|
||||
|
||||
@ -174,10 +214,13 @@ class PromptUpdateDetails(Generic[_S]):
|
||||
seq: _S,
|
||||
embed_token_id: int,
|
||||
) -> "PromptUpdateDetails[_S]":
|
||||
return PromptUpdateDetails(
|
||||
full=seq,
|
||||
is_embed=lambda f: torch.tensor(f.token_ids) == embed_token_id,
|
||||
)
|
||||
|
||||
def is_embed(tokenizer: AnyTokenizer, full: PromptSeq) -> torch.Tensor:
|
||||
token_ids = _seq2tokens(tokenizer, full)
|
||||
|
||||
return torch.tensor(token_ids) == embed_token_id
|
||||
|
||||
return PromptUpdateDetails(full=seq, is_embed=is_embed)
|
||||
|
||||
|
||||
PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
|
||||
@ -230,25 +273,14 @@ class PromptUpdate(ABC):
|
||||
"""Defines how to update the prompt."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _resolve_target(
|
||||
self,
|
||||
tokenizer: AnyTokenizer,
|
||||
item_idx: int,
|
||||
) -> Union["_BoundPromptSequence", PromptIndex]:
|
||||
def _resolve_target(self, item_idx: int) -> UpdateTarget:
|
||||
target = self.target
|
||||
if callable(target):
|
||||
target = target(item_idx)
|
||||
|
||||
if isinstance(target, PromptIndex):
|
||||
return target
|
||||
|
||||
return _BoundPromptSequence.from_seq(tokenizer, target)
|
||||
|
||||
def _resolve_content(
|
||||
self,
|
||||
tokenizer: AnyTokenizer,
|
||||
item_idx: int,
|
||||
) -> "_BoundPromptContent":
|
||||
def _resolve_content(self, item_idx: int) -> PromptUpdateDetails:
|
||||
content = self.content
|
||||
if callable(content):
|
||||
content = content(item_idx)
|
||||
@ -256,17 +288,9 @@ class PromptUpdate(ABC):
|
||||
if not isinstance(content, PromptUpdateDetails):
|
||||
content = PromptUpdateDetails.from_seq(content)
|
||||
|
||||
bound_full = _BoundPromptSequence.from_seq(tokenizer, content.full)
|
||||
bound_content = _BoundPromptContent(full=bound_full,
|
||||
is_embed=content.is_embed)
|
||||
return content
|
||||
|
||||
return bound_content
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
tokenizer: AnyTokenizer,
|
||||
item_idx: int,
|
||||
) -> "ResolvedPromptUpdate":
|
||||
def resolve(self, item_idx: int) -> "ResolvedPromptUpdate":
|
||||
"""
|
||||
Given the index of the processed item within
|
||||
[`modality`][vllm.multimodal.processing.PromptUpdate.modality],
|
||||
@ -276,8 +300,8 @@ class PromptUpdate(ABC):
|
||||
modality=self.modality,
|
||||
item_idx=item_idx,
|
||||
mode=self.mode,
|
||||
target=self._resolve_target(tokenizer, item_idx),
|
||||
content=self._resolve_content(tokenizer, item_idx),
|
||||
target=self._resolve_target(item_idx),
|
||||
content=self._resolve_content(item_idx),
|
||||
)
|
||||
|
||||
|
||||
@ -424,30 +448,6 @@ class PromptReplacement(PromptUpdate):
|
||||
return UpdateMode.REPLACE
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def _cached_encode(
|
||||
tokenizer: AnyTokenizer,
|
||||
text: str,
|
||||
*,
|
||||
add_special_tokens: Optional[bool] = None,
|
||||
) -> list[int]:
|
||||
return encode_tokens(tokenizer,
|
||||
text,
|
||||
add_special_tokens=add_special_tokens)
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def _cached_decode(
|
||||
tokenizer: AnyTokenizer,
|
||||
token_ids: tuple[int, ...],
|
||||
*,
|
||||
skip_special_tokens: Optional[bool] = None,
|
||||
) -> str:
|
||||
return decode_tokens(tokenizer,
|
||||
list(token_ids),
|
||||
skip_special_tokens=skip_special_tokens)
|
||||
|
||||
|
||||
class _HasModalityAttr(Protocol):
|
||||
modality: str
|
||||
|
||||
@ -468,59 +468,6 @@ def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
|
||||
return full_groupby(values, key=lambda x: x.modality)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BoundPromptSequence:
|
||||
"""
|
||||
A [`_PromptSeq`][vllm.multimodal.processing.PromptSeq] bound
|
||||
to a tokenizer to automatically
|
||||
convert between token sequence and text representations.
|
||||
"""
|
||||
tokenizer: AnyTokenizer = field(repr=False)
|
||||
|
||||
_text: Optional[str]
|
||||
_token_ids: Optional[list[int]]
|
||||
|
||||
@staticmethod
|
||||
def from_seq(
|
||||
tokenizer: AnyTokenizer,
|
||||
seq: PromptSeq,
|
||||
) -> "_BoundPromptSequence":
|
||||
return _BoundPromptSequence(
|
||||
tokenizer=tokenizer,
|
||||
_text=seq if isinstance(seq, str) else None,
|
||||
_token_ids=seq if isinstance(seq, list) else None,
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self._text is None and self._token_ids is None:
|
||||
raise ValueError("At least one of 'text' and 'token_ids' must be "
|
||||
"specified")
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
if self._text is None:
|
||||
assert self._token_ids is not None
|
||||
self._text = _cached_decode(self.tokenizer, tuple(self._token_ids))
|
||||
|
||||
return self._text
|
||||
|
||||
@property
|
||||
def token_ids(self) -> list[int]:
|
||||
if self._token_ids is None:
|
||||
assert self._text is not None
|
||||
self._token_ids = _cached_encode(self.tokenizer,
|
||||
self._text,
|
||||
add_special_tokens=False)
|
||||
|
||||
return self._token_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BoundPromptContent:
|
||||
full: _BoundPromptSequence
|
||||
is_embed: Optional[Callable[["_BoundPromptSequence"], torch.Tensor]]
|
||||
|
||||
|
||||
class PromptTargetMatch(NamedTuple):
|
||||
start_idx: int
|
||||
end_idx: int
|
||||
@ -542,10 +489,10 @@ class ResolvedPromptUpdate:
|
||||
mode: UpdateMode
|
||||
"""Defines how to update the prompt."""
|
||||
|
||||
target: Union[_BoundPromptSequence, PromptIndex]
|
||||
target: UpdateTarget
|
||||
"""The token sequence (or text) to update."""
|
||||
|
||||
content: _BoundPromptContent = field(repr=False)
|
||||
content: PromptUpdateDetails = field(repr=False)
|
||||
"""The placeholder tokens that are part of the update."""
|
||||
|
||||
def iter_token_matches(
|
||||
@ -565,8 +512,10 @@ class ResolvedPromptUpdate:
|
||||
|
||||
return
|
||||
|
||||
target_token_ids = _seq2tokens(tokenizer, target)
|
||||
|
||||
for match in iter_token_matches(prompt,
|
||||
target.token_ids,
|
||||
target_token_ids,
|
||||
start_idx=start_idx):
|
||||
yield PromptTargetMatch(match.start_idx, match.end_idx)
|
||||
|
||||
@ -587,7 +536,9 @@ class ResolvedPromptUpdate:
|
||||
|
||||
return
|
||||
|
||||
for match in re.finditer(re.escape(target.text), prompt,
|
||||
target_text = _seq2text(tokenizer, target)
|
||||
|
||||
for match in re.finditer(re.escape(target_text), prompt,
|
||||
pos=start_idx):
|
||||
yield PromptTargetMatch(match.start(), match.end())
|
||||
|
||||
@ -779,7 +730,7 @@ def _apply_matches(
|
||||
|
||||
matched_update = mm_prompt_updates[modality][item_idx][
|
||||
update_idx]
|
||||
matched_content = matched_update.content
|
||||
matched_content = matched_update.content.full
|
||||
|
||||
if mode == UpdateMode.INSERT:
|
||||
end_idx_to_insert = match.end_idx
|
||||
@ -789,8 +740,10 @@ def _apply_matches(
|
||||
assert_never(mode)
|
||||
|
||||
out_seqs.append(prompt[prev_end_idx:end_idx_to_insert])
|
||||
out_seqs.append(matched_content.full.text if isinstance(
|
||||
prompt, str) else matched_content.full.token_ids)
|
||||
out_seqs.append(
|
||||
_seq2text(tokenizer, matched_content
|
||||
) if isinstance(prompt, str) else _seq2tokens(
|
||||
tokenizer, matched_content))
|
||||
out_result[modality][item_idx] = update_idx
|
||||
|
||||
# Exclude overlapping matches
|
||||
@ -842,6 +795,7 @@ def apply_text_matches(
|
||||
def _iter_placeholders(
|
||||
prompt: list[int],
|
||||
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> Iterable[PlaceholderFeaturesInfo]:
|
||||
"""
|
||||
Yield each set of placeholder tokens found in `prompt`.
|
||||
@ -868,7 +822,7 @@ def _iter_placeholders(
|
||||
|
||||
for update in modality_updates[item_idx]:
|
||||
content = update.content
|
||||
content_tokens_full = content.full.token_ids
|
||||
content_tokens_full = _seq2tokens(tokenizer, content.full)
|
||||
content_len_full = len(content_tokens_full)
|
||||
end_idx_full = start_idx + content_len_full
|
||||
|
||||
@ -878,7 +832,8 @@ def _iter_placeholders(
|
||||
if prompt[start_idx:end_idx_full] == content_tokens_full:
|
||||
content_is_embed = content.is_embed
|
||||
if content_is_embed is not None:
|
||||
content_is_embed = content_is_embed(content.full)
|
||||
content_is_embed = content_is_embed(
|
||||
tokenizer, content.full)
|
||||
|
||||
yield PlaceholderFeaturesInfo(
|
||||
modality=modality,
|
||||
@ -904,8 +859,9 @@ def _iter_placeholders(
|
||||
def find_mm_placeholders(
|
||||
prompt: list[int],
|
||||
mm_prompt_updates: "MultiModalPromptUpdates",
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
||||
it = _iter_placeholders(prompt, mm_prompt_updates)
|
||||
it = _iter_placeholders(prompt, mm_prompt_updates, tokenizer)
|
||||
return dict(full_groupby_modality(it))
|
||||
|
||||
|
||||
@ -1160,11 +1116,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
prompt_updates: Sequence[PromptUpdate],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> MultiModalPromptUpdates:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
return {
|
||||
modality:
|
||||
[[update.resolve(tokenizer, item_idx) for update in updates]
|
||||
modality: [[update.resolve(item_idx) for update in updates]
|
||||
for item_idx in range(mm_item_counts.get(modality, 0))]
|
||||
for modality, updates in full_groupby_modality(prompt_updates)
|
||||
}
|
||||
@ -1208,7 +1161,10 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
new_token_ids: list[int],
|
||||
mm_prompt_updates: MultiModalPromptUpdates,
|
||||
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
||||
return find_mm_placeholders(new_token_ids, mm_prompt_updates)
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
return find_mm_placeholders(new_token_ids, mm_prompt_updates,
|
||||
tokenizer)
|
||||
|
||||
def _get_hf_mm_data(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user