[VLM] Generalized prompt updates for multi-modal processor (#13964)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-02-28 01:44:25 +08:00
committed by GitHub
parent 7864875879
commit f1579b229d
29 changed files with 629 additions and 486 deletions

View File

@ -720,13 +720,13 @@ def _get_mm_fields_config(
:::::
### Prompt replacements
### Prompt updates
Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` to
return a list of {class}`~vllm.multimodal.processing.PromptReplacement` instances.
Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates` to
return a list of {class}`~vllm.multimodal.processing.PromptUpdate` instances.
Each {class}`~vllm.multimodal.processing.PromptReplacement` instance specifies a find-and-replace
operation performed by the HF processor.
Each {class}`~vllm.multimodal.processing.PromptUpdate` instance specifies an update operation
(e.g.: insertion, replacement) performed by the HF processor.
::::{tab-set}
:::{tab-item} Basic example: LLaVA
@ -743,15 +743,15 @@ for sample in text:
```
It simply repeats each input `image_token` a number of times equal to the number of placeholder feature tokens (`num_image_tokens`).
Based on this, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` as follows:
Based on this, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates` as follows:
```python
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
@ -859,7 +859,7 @@ prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
)
```
To accommodate this, instead of a string you can return an instance of `PromptReplacementDetails`
To accommodate this, instead of a string you can return an instance of `PromptUpdateDetails`
with different `full` and `feature` attributes:
```python
@ -878,7 +878,7 @@ def get_replacement_fuyu(item_idx: int):
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
[_NEWLINE_TOKEN_ID]) * nrows
return PromptReplacementDetails(
return PromptUpdateDetails(
full=image_tokens + [bos_token_id],
features=image_tokens,
)
@ -888,12 +888,12 @@ Finally, noticing that the HF processor removes the `|ENDOFTEXT|` token from the
we can search for it to conduct the replacement at the start of the string:
```python
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
bos_token_id = hf_config.bos_token_id
assert isinstance(bos_token_id, int)
@ -913,7 +913,7 @@ def _get_prompt_replacements(
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
[_NEWLINE_TOKEN_ID]) * nrows
return PromptReplacementDetails(
return PromptUpdateDetails(
full=image_tokens + [bos_token_id],
features=image_tokens,
)

View File

@ -6,11 +6,16 @@ To enable various optimizations in vLLM such as [chunked prefill](#chunked-prefi
Here are the main features of {class}`~vllm.multimodal.processing.BaseMultiModalProcessor`:
## Prompt Replacement Detection
## Prompt Update Detection
One of the main responsibilies of HF processor is to replace input placeholder tokens (e.g. `<image>` for a single image) with feature placeholder tokens (e.g. `<image><image>...<image>`, the number of which equals to the feature size). The information about which tokens have been replaced is key to finding the correspondence between placeholder feature tokens and multi-modal inputs.
One of the main responsibilies of HF processor is to update the prompt with placeholder tokens. For example:
In vLLM, this information is specified using {class}`~vllm.multimodal.processing.PromptReplacement` in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements`. Given this specification, we can automatically detect whether HF has replaced the input placeholder tokens by checking whether the feature placeholder tokens exist in the prompt.
- Insert feature placeholder tokens (e.g. `<image><image>...<image>`, the number of which equals to the feature size) at the start of the string.
- Replace existing input placeholder tokens (e.g. `<image>` for a single image) with feature placeholder tokens (e.g. `<image><image>...<image>`, the number of which equals to the feature size).
The information about which tokens have been updated is key to finding the correspondence between placeholder feature tokens and multi-modal inputs.
In vLLM, this information is specified using {class}`~vllm.multimodal.processing.PromptUpdate` in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates`. We can automatically detect whether HF has updated the prompt by checking the existence of the updated tokens.
## Tokenized Prompt Inputs
@ -22,7 +27,7 @@ Consider that HF processors follow these main steps:
1. Tokenize the text
2. Process multi-modal inputs
3. Perform prompt replacement
3. Perform prompt updates
And we require that:
@ -44,16 +49,16 @@ Moreover, since the tokenized text has not passed through the HF processor, we h
We work around the first issue by requiring each model to define how to generate dummy text based on the number of multi-modal inputs, via {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs`. This lets us generate dummy text corresponding to the multi-modal inputs and input them together to obtain the processed multi-modal data.
(mm-automatic-prompt-replacement)=
(mm-automatic-prompt-updating)=
### Automatic prompt replacement
### Automatic prompt updating
We address the second issue by implementing model-agnostic code in
{meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_prompt_replacements` to automatically replace input placeholder tokens with feature placeholder tokens based on the specification outputted by {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements`.
{meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_prompt_updates` to automatically update the prompt with feature placeholder tokens based on the specification outputted by {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates`.
### Summary
With the help of dummy text and automatic prompt replacement, our multi-modal processor can finally accept both text and token prompts with multi-modal data. The detailed logic is shown in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_main`.
With the help of dummy text and automatic prompt updating, our multi-modal processor can finally accept both text and token prompts with multi-modal data. The detailed logic is shown in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_main`.
## Processor Output Caching
@ -61,4 +66,4 @@ Some HF processors, such as the one for Qwen2-VL, are [very slow](gh-issue:9238)
When new data is passed in, we first check which items are in the cache, and which ones are missing. The missing items are passed into the HF processor in a single batch and cached, before being merged with the existing items in the cache.
Since we only process the missing multi-modal data items, the number of input placeholder tokens no longer corresponds to the number of the multi-modal inputs, so they can't be passed alongside the text prompt to HF processor. Therefore, we process the text and multi-modal inputs separately, using [dummy text](#mm-dummy-text) to avoid HF errors. Since this skips HF's prompt replacement code, we apply [automatic prompt replacement](#mm-automatic-prompt-replacement) afterwards to keep the output tokens and multi-modal data consistent with each other.
Since we only process the missing multi-modal data items, the number of input placeholder tokens no longer corresponds to the number of the multi-modal inputs, so they can't be passed alongside the text prompt to HF processor. Therefore, we process the text and multi-modal inputs separately, using [dummy text](#mm-dummy-text) to avoid HF errors. Since this skips HF's prompt updating code, we apply [automatic prompt updating](#mm-automatic-prompt-updating) afterwards to keep the output tokens and multi-modal data consistent with each other.

View File

@ -14,12 +14,12 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
PromptReplacement,
PromptInsertion, PromptReplacement,
apply_text_matches,
apply_token_matches,
find_mm_placeholders,
find_text_matches, find_token_matches,
iter_token_matches,
replace_text_matches,
replace_token_matches)
iter_token_matches)
# yapf: enable
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
@ -102,7 +102,7 @@ def test_iter_token_matches(token_ids, match_ids, expected):
{
"pattern_1": [],
"pattern_2": [],
}
},
),
(
[32000, 32000, 32000, 32000],
@ -147,16 +147,22 @@ def test_iter_token_matches(token_ids, match_ids, expected):
),
],
)
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
# yapf: enable
def test_find_token_matches(prompt, target_by_key, expected_by_key):
def test_find_token_matches(
prompt,
target_by_key,
expected_by_key,
update_type,
):
# Should not be used since there is nothing to convert to token IDs
mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [
PromptReplacement(key, target, []).bind(mock_tokenizer)
prompt_updates = [
update_type(key, target, []).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
result = find_token_matches(prompt, prompt_repls)
result = find_token_matches(prompt, prompt_updates)
# Only displayed on error
print("result:", result)
@ -254,16 +260,22 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key):
),
],
)
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
# yapf: enable
def test_find_text_matches(prompt, target_by_key, expected_by_key):
def test_find_text_matches(
prompt,
target_by_key,
expected_by_key,
update_type,
):
# Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [
PromptReplacement(key, target, []).bind(mock_tokenizer)
prompt_updates = [
update_type(key, target, []).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
result = find_text_matches(prompt, prompt_repls)
result = find_text_matches(prompt, prompt_updates)
# Only displayed on error
print("result:", result)
@ -281,7 +293,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
# yapf: disable
@pytest.mark.parametrize(
("prompt", "target_by_key", "repl_by_key"),
("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501
[
(
"Image:<image>Image:<image><image>!",
@ -300,58 +312,66 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
# Test dynamic replacement (beyond the form of `unit * count`)
"pattern_3": "?!?",
},
{
PromptInsertion: {
0: "Image:<image>Image:<image><image>!",
1: "Image:<image><image><image>Image:<image><image>!?!?",
2: "Image:<image><image><image><image><image>Image:<image><image>!?!??!?", # noqa: E501
},
PromptReplacement: {
0: "Image:<image>Image:<image><image>!",
1: "<image><image>Image:<image><image>?!?",
2: "<image><image><image><image><image>?!?",
},
},
),
]
)
@pytest.mark.parametrize(
("mm_count", "expected"),
[
(0, "Image:<image>Image:<image><image>!"),
(1, "<image><image>Image:<image><image>?!?"),
(2, "<image><image><image><image><image>?!?"),
]
)
# yapf: enable
def test_find_replace_text(
def test_find_update_text(
prompt,
target_by_key,
repl_by_key,
mm_count,
expected,
expected_by_update_type_mm_count,
):
# Should not be used since there is nothing to convert to text
mock_tokenizer = cast(AnyTokenizer, object())
mm_prompt_repls = {
key: [
PromptReplacement(key, target,
repl_by_key[key]).bind(mock_tokenizer)
]
for key, target in target_by_key.items()
}
mm_matches = {
key: find_text_matches(prompt, prompt_repls)
for key, prompt_repls in mm_prompt_repls.items()
}
for (
update_type,
expected_by_mm_count,
) in expected_by_update_type_mm_count.items():
mm_prompt_updates = {
key:
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
for key, target in target_by_key.items()
}
mm_matches = {
key: find_text_matches(prompt, updates)
for key, updates in mm_prompt_updates.items()
}
result = replace_text_matches(
prompt,
mm_matches,
{key: mm_count
for key in repl_by_key},
)
for mm_count, expected in expected_by_mm_count.items():
result = apply_text_matches(
prompt,
mm_matches,
{key: mm_count
for key in repl_by_key},
)
# Only displayed on error
print("mm_matches:", mm_matches)
print("result:", result)
# Only displayed on error
print("update_type:", update_type)
print("mm_count:", mm_count)
print("mm_matches:", mm_matches)
print("result:", result)
# Manually constructed results
assert result == expected
# Manually constructed results
assert result == expected
# yapf: disable
@pytest.mark.parametrize(
("prompt", "target_by_key", "repl_by_key"),
("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501
[
# Tokenized test cases of `test_find_replace_text`
# using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
@ -372,53 +392,61 @@ def test_find_replace_text(
# Test dynamic replacement (beyond the form of `unit * count`)
"pattern_3": [1550, 918, 1550],
},
{
PromptInsertion: {
0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
1: [1, 9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550], # noqa: E501
2: [1, 9833, 28747, 32000, 32000, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550, 1550, 918, 1550], # noqa: E501
},
PromptReplacement: {
0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
1: [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550], # noqa: E501
2: [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
},
},
),
]
)
@pytest.mark.parametrize(
("mm_count", "expected"),
[
(0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]),
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550]),
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550]),
]
)
# yapf: enable
def test_find_replace_tokens(
def test_find_update_tokens(
prompt,
target_by_key,
repl_by_key,
mm_count,
expected,
expected_by_update_type_mm_count,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object())
mm_prompt_repls = {
key: [
PromptReplacement(key, target,
repl_by_key[key]).bind(mock_tokenizer)
]
for key, target in target_by_key.items()
}
mm_matches = {
key: find_token_matches(prompt, prompt_repls)
for key, prompt_repls in mm_prompt_repls.items()
}
for (
update_type,
expected_by_mm_count,
) in expected_by_update_type_mm_count.items():
mm_prompt_updates = {
key:
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
for key, target in target_by_key.items()
}
mm_matches = {
key: find_token_matches(prompt, updates)
for key, updates in mm_prompt_updates.items()
}
result = replace_token_matches(
prompt,
mm_matches,
{key: mm_count
for key in repl_by_key},
)
for mm_count, expected in expected_by_mm_count.items():
result = apply_token_matches(
prompt,
mm_matches,
{key: mm_count
for key in repl_by_key},
)
# Only displayed on error
print("mm_matches:", mm_matches)
print("result:", result)
# Only displayed on error
print("update_type:", update_type)
print("mm_count:", mm_count)
print("mm_matches:", mm_matches)
print("result:", result)
# Manually constructed results
assert result == expected
# Manually constructed results
assert result == expected
# yapf: disable
@ -524,22 +552,24 @@ def test_find_replace_tokens(
),
]
)
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
# yapf: enable
def test_find_mm_placeholders(
repl_by_key,
prompt,
expected,
update_type,
):
# Should not be used since there is nothing to convert to tokens
mock_tokenizer = cast(AnyTokenizer, object())
mm_prompt_repls = {
key: [PromptReplacement(key, [], repl).bind(mock_tokenizer)]
mm_prompt_updates = {
key: [update_type(key, [], repl).bind(mock_tokenizer)]
for key, repl in repl_by_key.items()
}
result = find_mm_placeholders(
mm_prompt_repls,
mm_prompt_updates,
prompt,
# Effectively match all occurrences in the prompt
{key: 3

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
@ -26,7 +25,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@ -457,12 +457,12 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]):
pixel_mask=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index

View File

@ -1,8 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from typing import Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
@ -19,8 +19,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
BaseProcessingInfo, PromptInsertion,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@ -474,30 +474,24 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
bos_token_id = tokenizer.bos_token_id
assert isinstance(bos_token_id, int)
image_token_id = vocab["<image>"]
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [image_token_id] * num_image_tokens
return [
PromptReplacement(
PromptInsertion(
modality="image",
target=[bos_token_id],
replacement=PromptReplacementDetails(
full=image_tokens + [bos_token_id],
features=image_tokens,
),
target="",
insertion=image_tokens,
)
]

View File

@ -1,8 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Any, Dict, Iterable, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
from typing import Any, Dict, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
@ -35,7 +35,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@ -141,12 +141,12 @@ class ChameleonMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
@ -162,7 +162,7 @@ class ChameleonMultiModalProcessor(
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=PromptReplacementDetails(
replacement=PromptUpdateDetails(
full=([image_start_id] + image_tokens + [image_end_id]),
features=image_tokens,
),
@ -371,7 +371,7 @@ class ChameleonDecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if residual is None:
residual = hidden_states

View File

@ -3,9 +3,9 @@
# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py
"""Inference-only Deepseek-VL2 model compatible with HuggingFace weights."""
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
@ -26,7 +26,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache,
PromptReplacement)
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
@ -281,12 +281,12 @@ class DeepseekVL2MultiModalProcessor(
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token_id = hf_processor.image_token_id

View File

@ -1,9 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
import math
from collections import OrderedDict
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, OrderedDict,
Set, Tuple, TypedDict, Union)
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
@ -24,8 +25,7 @@ from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems
from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement,
PromptReplacementDetails)
PromptInsertion, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@ -803,7 +803,7 @@ class Florence2DummyInputsBuilder(
class Florence2MultiModalProcessor(
EncDecMultiModalProcessor[Florence2ProcessingInfo]):
def _hf_processor_applies_repl(
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
@ -850,26 +850,22 @@ class Florence2MultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
pad_token_id = hf_config.pad_token_id
bos_token_id = hf_config.bos_token_id
num_image_tokens = self.info.get_max_image_tokens()
image_tokens = [pad_token_id] * num_image_tokens
return [
PromptReplacement(
PromptInsertion(
modality="image",
target=[bos_token_id],
replacement=PromptReplacementDetails(
full=image_tokens + [bos_token_id],
features=image_tokens,
),
target="",
insertion=image_tokens,
)
]

View File

@ -17,8 +17,8 @@
# limitations under the License.
""" PyTorch Fuyu model."""
import math
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict)
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict
import torch
import torch.nn as nn
@ -37,7 +37,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@ -203,12 +203,12 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
) -> Mapping[str, MultiModalFieldConfig]:
return dict(image_patches=MultiModalFieldConfig.batched("image"))
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
bos_token_id = hf_config.bos_token_id
assert isinstance(bos_token_id, int)
@ -228,7 +228,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
[_NEWLINE_TOKEN_ID]) * nrows
return PromptReplacementDetails(
return PromptUpdateDetails(
full=image_tokens + [bos_token_id],
features=image_tokens,
)

View File

@ -4,7 +4,8 @@
# https://github.com/THUDM/CogAgent
"""Inference-only CogAgent model compatible with THUDM weights."""
from argparse import Namespace
from typing import Literal, Mapping, Optional, TypedDict, Union
from collections.abc import Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union
import torch
from torch import nn
@ -32,7 +33,7 @@ from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BatchFeature,
MultiModalFieldConfig,
PromptReplacement)
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig
@ -480,7 +481,7 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
def _hf_processor_applies_repl(
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
@ -495,12 +496,12 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
boi_token_id = hf_config.boi_token_id

View File

@ -7,7 +7,8 @@
# Copyright (c) 2024 H2O.AI
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
from typing import Mapping, Optional
from collections.abc import Mapping, Sequence
from typing import Optional
import torch
from PIL import Image
@ -20,7 +21,7 @@ from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems)
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
PromptReplacementDetails)
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.transformers_utils.tokenizer import AnyTokenizer
@ -487,12 +488,12 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
f"{type(self).__name__} does not support processing cache with "
"multi-image support enabled.")
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if "image_num_patches" in out_mm_kwargs:
@ -527,7 +528,7 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
if num_patches is not None:
assert isinstance(num_patches, int)
return PromptReplacementDetails(
return PromptUpdateDetails(
full=hf_processor.get_image_repl_full(feature_size,
num_patches),
features=hf_processor.get_image_repl_features(

View File

@ -16,8 +16,8 @@
"""Inference-only Idefics3 model compatible with HuggingFace weights."""
import math
from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
from collections.abc import Iterable, Mapping, Sequence
from typing import Dict, List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.utils.checkpoint
@ -41,7 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalDataItems,
MultiModalFieldConfig,
PromptReplacement)
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@ -274,12 +274,12 @@ class Idefics3MultimodalProcessor(
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token.content

View File

@ -7,9 +7,10 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, TypeVar, Union)
from typing import (List, Literal, Optional, Set, Tuple, TypedDict, TypeVar,
Union)
import torch
import torch.nn as nn
@ -31,7 +32,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
@ -599,12 +600,12 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if "image_num_patches" in out_mm_kwargs:
@ -636,7 +637,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
if num_patches is not None:
assert isinstance(num_patches, int)
return PromptReplacementDetails(
return PromptUpdateDetails(
full=hf_processor.get_image_repl_full(feature_size,
num_patches),
features=hf_processor.get_image_repl_features(

View File

@ -1,9 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, TypeVar, Union)
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
TypedDict, TypeVar, Union)
import torch
import torch.nn as nn
@ -31,7 +32,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache,
PromptReplacement)
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@ -222,12 +223,12 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
) -> Mapping[str, MultiModalFieldConfig]:
raise NotImplementedError
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
image_token_id = hf_config.image_token_index
@ -328,12 +329,12 @@ class PixtralHFMultiModalProcessor(
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
hf_config = self.info.get_hf_config()
tokenizer = self.info.get_tokenizer()
@ -789,7 +790,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
"</Image>)", # 3 tokens
])
mantis_mm_repls = self._bind_and_group_repls([
mantis_mm_repls = self._bind_and_group_updates([
PromptReplacement(
modality="image",
target=[image_token_id] * num_image_tokens,
@ -797,18 +798,18 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
)
])
prompt_ids, prompt, _ = self._apply_prompt_replacements(
prompt_ids, prompt, _ = self._apply_prompt_updates(
result["prompt_token_ids"],
mantis_mm_repls,
mm_item_counts,
)
unbound_orig_repls = self._get_prompt_replacements(
unbound_orig_repls = self._get_prompt_updates(
mm_items,
hf_processor_mm_kwargs,
mm_kwargs,
)
orig_repls = self._bind_and_group_repls(unbound_orig_repls)
orig_repls = self._bind_and_group_updates(unbound_orig_repls)
mm_placeholders = self._find_mm_placeholders(
orig_repls,

View File

@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
@ -21,7 +21,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
@ -183,12 +184,12 @@ class LlavaNextVideoMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values_videos=MultiModalFieldConfig.batched("video"))
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
video_token_id = hf_config.video_token_index

View File

@ -1,9 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, Union)
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
TypedDict, Union)
import torch
import torch.nn as nn
@ -22,7 +23,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import PromptReplacement
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
@ -347,13 +348,13 @@ class LlavaOnevisionMultiModalProcessor(
)
return BatchFeature(combined_outputs)
def _hf_processor_applies_repl(
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> bool:
base_result = super()._hf_processor_applies_repl(
base_result = super()._hf_processor_applies_updates(
prompt_text=prompt_text,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
@ -361,13 +362,13 @@ class LlavaOnevisionMultiModalProcessor(
return base_result and mm_items.get_count("video", strict=False) == 0
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
image_repls = super()._get_prompt_replacements(
) -> Sequence[PromptUpdate]:
image_repls = super()._get_prompt_updates(
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
out_mm_kwargs=out_mm_kwargs,
@ -392,7 +393,8 @@ class LlavaOnevisionMultiModalProcessor(
return [video_token_id] * num_video_tokens
return image_repls + [
return [
*image_repls,
PromptReplacement(
modality="video",
target=[video_token_id],

View File

@ -22,9 +22,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Set, Tuple, TypedDict, Union)
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
TypedDict, Union)
import torch
from torch import nn
@ -356,10 +357,10 @@ class MiniCPMOMultiModalProcessor(
inputs["audio"]["audio_lens"][index])
return super().get_prompt_texts_by_modality(inputs, modality, index)
def _get_prompt_replacements(
def _get_prompt_updates(
self, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs) -> List[PromptReplacement]:
out_mm_kwargs: MultiModalKwargs) -> Sequence[PromptReplacement]:
placeholder = {
"image": self.info.image_pattern,
"video": self.info.video_pattern,

View File

@ -25,9 +25,10 @@
import math
import re
from collections import Counter
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Set, Tuple, TypedDict, Union)
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
TypedDict, Union)
import numpy as np
import torch
@ -732,7 +733,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
}
}
def _hf_processor_applies_repl(
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
@ -740,10 +741,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
) -> bool:
return False
def _get_prompt_replacements(
def _get_prompt_updates(
self, mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs) -> List[PromptReplacement]:
out_mm_kwargs: MultiModalKwargs) -> Sequence[PromptReplacement]:
placeholder = {
"image": self.info.image_pattern,
"video": self.info.video_pattern,

View File

@ -15,8 +15,8 @@
# limitations under the License.
"""PyTorch Mllama model."""
import math
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
import numpy as np
import torch
@ -59,7 +59,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataDict, MultiModalDataItems)
from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement)
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from .clip import CLIPMLP
@ -243,12 +243,12 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
image_token_id = self.info.get_hf_config().image_token_index
return [image_token_id] * num_images
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
token_per_chunk = self.info.get_token_per_chunk_from_config()
image_token_id = self.info.get_hf_config().image_token_index

View File

@ -1,10 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from functools import cached_property, partial
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union, cast)
from typing import List, Optional, Set, Tuple, TypedDict, Union, cast
import numpy as np
import torch
@ -46,8 +46,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
BaseProcessingInfo, PromptInsertion,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import JSONTree, json_map_leaves
@ -1190,6 +1190,8 @@ class MolmoProcessingInfo(BaseProcessingInfo):
return MolmoProcessorWrapper(processor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
# TODO: Investigate different `embed_is_patch` between cache/no-cache
# in multi-image case
return {"image": 1}
def get_mm_max_tokens_per_item(
@ -1328,25 +1330,18 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
image_token_length_w = processor.image_token_length_w
image_token_length_h = processor.image_token_length_h
pooling_size = processor.pooling_size
user_str = "User:"
if processor.always_start_with_space:
user_str = " " + user_str
user_tokens = tokenizer.encode(user_str, add_special_tokens=False)
img_patch_id = processor.image_patch_id
img_col_id = processor.im_col_id
img_start_id = processor.im_start_id
@ -1356,7 +1351,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
extra_joint = ([img_start_id] + extra_row * image_token_length_h +
[img_end_id])
def get_replacement_molmo(item_idx: int):
def get_insertion_molmo(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
@ -1371,17 +1366,13 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
((nrows + 1) // pooling_size) + [img_end_id])
image_tokens = extra_joint + joint
return PromptReplacementDetails(
full=image_tokens + user_tokens,
features=image_tokens,
)
return image_tokens
return [
PromptReplacement(
PromptInsertion(
modality="image",
target=user_str,
replacement=get_replacement_molmo,
target="<|endoftext|>",
insertion=get_insertion_molmo,
)
]

View File

@ -6,7 +6,8 @@
# Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
from typing import Mapping, Optional
from collections.abc import Mapping, Sequence
from typing import Optional
import torch
import torch.nn as nn
@ -17,8 +18,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems)
from vllm.multimodal.processing import (PromptReplacement,
PromptReplacementDetails)
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import ProcessorInputs
from .intern_vit import InternVisionModel
@ -142,12 +143,12 @@ class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if "image_num_patches" in out_mm_kwargs:
@ -179,7 +180,7 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
if num_patches is not None:
assert isinstance(num_patches, int)
return PromptReplacementDetails(
return PromptUpdateDetails(
full=hf_processor.get_image_repl_full(feature_size,
num_patches) + "\n",
features=hf_processor.get_image_repl_features(

View File

@ -38,11 +38,10 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
BoundPromptReplacement,
BaseProcessingInfo, BoundPromptUpdate,
PlaceholderFeaturesInfo,
PromptReplacement,
PromptReplacementDetails)
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@ -420,12 +419,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
@ -449,7 +448,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens
return PromptReplacementDetails(
return PromptUpdateDetails(
full=image_tokens + [bos_token_id],
features=image_tokens,
)
@ -464,15 +463,15 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
) for image_token in image_tokens[:num_images]
]
def _apply_prompt_replacements(
def _apply_prompt_updates(
self,
token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
token_ids, text, placeholders = super()._apply_prompt_replacements(
token_ids, text, placeholders = super()._apply_prompt_updates(
token_ids=token_ids,
mm_prompt_repls=mm_prompt_repls,
mm_prompt_updates=mm_prompt_updates,
mm_item_counts=mm_item_counts,
)

View File

@ -15,7 +15,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only IBM/NASA Prithvi Geospatial model."""
from typing import Iterable, Mapping, Optional, Set, Tuple, Union
from collections.abc import Iterable, Mapping, Sequence
from typing import Optional, Set, Tuple, Union
import torch
import torch.nn as nn
@ -32,7 +33,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargs)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
BaseProcessingInfo, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import (IntermediateTensors, PoolerOutput,
PoolingSequenceGroupOutput)
@ -44,7 +45,7 @@ class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
pass
return {"image": 0}
class PrithviGeoSpatialMAEInputBuilder(
@ -78,20 +79,13 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
location_coords=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
pass
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
pass
) -> Sequence[PromptUpdate]:
return []
def apply(
self,
@ -120,7 +114,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
""" Prithvi Masked Autoencoder"""
def _instantiate_model(self, config: dict) -> nn.Module | None:
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
# We might be able/need to support different tasks with this same model
if config["task_args"]["task"] == "SemanticSegmentationTask":
@ -158,7 +152,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
"by PrithviGeospatialMAE.")
def _parse_and_validate_multimodal_data(
self, **kwargs) -> Tuple[torch.Tensor, torch.Tensor | None]:
self, **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
pixel_values = kwargs.pop("pixel_values", None)
if not isinstance(pixel_values, torch.Tensor):

View File

@ -21,9 +21,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Any, Iterable, Mapping, Optional, Set, Tuple, TypedDict,
Union)
from typing import Any, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
@ -43,7 +43,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@ -188,12 +188,12 @@ class Qwen2AudioMultiModalProcessor(
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
@ -230,7 +230,7 @@ class Qwen2AudioMultiModalProcessor(
audio_tokens = [audio_token_id] * num_features
return PromptReplacementDetails(
return PromptUpdateDetails(
full=[audio_bos_id] + audio_tokens + [audio_eos_id],
features=audio_tokens,
)

View File

@ -23,9 +23,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial
from typing import (Any, Callable, Iterable, Literal, Mapping, Optional, Set,
Tuple, Type, TypedDict, Union)
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
Union)
import torch
import torch.nn as nn
@ -61,7 +62,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize,
ModalityDataItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors
@ -169,7 +171,7 @@ class Qwen2VisionMLP(nn.Module):
self,
in_features: int,
hidden_features: int,
act_layer: Type[nn.Module] = QuickGELU,
act_layer: type[nn.Module] = QuickGELU,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
@ -383,7 +385,7 @@ class Qwen2VisionBlock(nn.Module):
dim: int,
num_heads: int,
mlp_ratio: float,
act_layer: Type[nn.Module] = QuickGELU,
act_layer: type[nn.Module] = QuickGELU,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
@ -987,12 +989,12 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
self.info._get_image_processor_kwargs(**mm_kwargs),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_processor = self.info.get_image_processor(
**hf_processor_mm_kwargs)

View File

@ -9,9 +9,10 @@ import copy
import math
import re
import unicodedata
from collections.abc import Collection, Mapping, Sequence
from collections.abc import Set as AbstractSet
from functools import lru_cache, partial
from typing import (AbstractSet, Callable, Collection, List, Literal, Mapping,
Optional, TypedDict, Union)
from typing import Callable, List, Literal, Optional, TypedDict, Union
import torch
from torch import nn
@ -36,7 +37,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@ -606,7 +607,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
mm_kwargs=mm_kwargs,
)
def _hf_processor_applies_repl(
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
@ -624,12 +625,12 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
tokenizer = self.info.get_tokenizer()
special_tokens: dict[str,
int] = tokenizer.special_tokens # type: ignore
@ -646,7 +647,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
PromptReplacement(
modality="image",
target=[img_start_id, img_end_id],
replacement=PromptReplacementDetails(
replacement=PromptUpdateDetails(
full=[img_start_id] + image_tokens + [img_end_id],
features=image_tokens,
),

View File

@ -3,9 +3,9 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Any, Iterable, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.utils.checkpoint
@ -29,7 +29,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
@ -197,12 +198,12 @@ class UltravoxMultiModalProcessor(
audio_embeds=MultiModalFieldConfig.batched("audio"),
)
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()

View File

@ -1,8 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
import math
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)
from collections.abc import Iterable, Mapping, Sequence
from typing import List, Optional, Set, Tuple, TypedDict, Union
import torch
from torch import nn
@ -31,7 +31,7 @@ from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement)
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from .interfaces import SupportsMultiModal, SupportsTranscription
@ -623,12 +623,12 @@ class WhisperMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
return dict(input_features=MultiModalFieldConfig.batched("audio"))
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
num_tokens = self.info.get_max_audio_tokens()
return [
PromptReplacement(

View File

@ -6,11 +6,14 @@ from collections import defaultdict
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
Sequence)
from dataclasses import dataclass, field
from enum import Enum
from functools import lru_cache
from itertools import groupby
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
TypeVar, Union)
TypeVar, Union, cast)
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import assert_never
import vllm.envs as envs
from vllm.inputs import InputProcessingContext
@ -38,35 +41,129 @@ PromptSeq = Union[str, list[int]]
@dataclass
class PromptReplacementDetails:
"""Details about the replacement token sequence or text."""
class PromptUpdateDetails:
"""Details about the token sequence or text that are part of the update."""
full: PromptSeq
"""The full replacement."""
"""The full content."""
features: PromptSeq
"""
The part of the replacement that corresponds to feature placeholders;
The part of the content that corresponds to feature placeholders;
this will be replaced by the output of the vision encoder during model
inference.
"""
@staticmethod
def from_seq(seq: PromptSeq) -> "PromptReplacementDetails":
return PromptReplacementDetails(full=seq, features=seq)
def from_seq(seq: PromptSeq) -> "PromptUpdateDetails":
return PromptUpdateDetails(full=seq, features=seq)
PromptRepl = Union[PromptSeq, PromptReplacementDetails]
PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
"""
The replacement token sequence or text.
The token sequence or text that are part of the update.
If only part of the replacement corresponds to feature placeholders, you can
use :class:`PromptReplacementDetails` to specify which part.
If only part of the content corresponds to feature placeholders, you can
use :class:`PromptUpdateDetails` to specify which part.
"""
PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo],
PromptUpdateInfo]
"""
Given the index of the processed item within :attr:`modality`,
output the corresponding token sequence (or text).
For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""
class UpdateMode(str, Enum):
INSERT = "insert"
REPLACE = "replace"
@dataclass
class PromptReplacement:
class PromptUpdate:
"""
Defines how to update a prompt with placeholder tokens.
"""
modality: str
"""The modality for which the update is made."""
target: PromptSeq
"""The token sequence (or text) to update."""
@property
@abstractmethod
def content(self) -> PromptUpdateContent:
"""The placeholder tokens that are part of the update."""
raise NotImplementedError
@property
@abstractmethod
def mode(self) -> UpdateMode:
"""Defines how to update the prompt."""
raise NotImplementedError
def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptUpdate":
return BoundPromptUpdate(
_origin=self,
tokenizer=tokenizer,
)
@dataclass
class PromptInsertion(PromptUpdate):
"""
Defines how to insert placeholder tokens into a prompt.
Example:
For each image, insert a number of ``<image>`` feature placeholders
equal to the feature size of the vision encoder at the start of the
prompt:
.. code-block:: python
PromptInsertion(
modality="image",
target="",
insertion="<image>" * image_feature_size,
)
As above, but insert after the ``<s>`` token:
.. code-block:: python
PromptInsertion(
modality="image",
target="<s>",
insertion="<image>" * image_feature_size,
)
"""
insertion: PromptUpdateContent = field(repr=False)
"""
Given the index of the processed item within :attr:`modality`,
output the token sequence (or text) to insert right after :attr:`target`.
For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""
@property
def content(self) -> PromptUpdateContent:
return self.insertion
@property
def mode(self) -> UpdateMode:
return UpdateMode.INSERT
@dataclass
class PromptReplacement(PromptUpdate):
"""
Defines how to replace portions of an input prompt with placeholder tokens.
@ -93,7 +190,7 @@ class PromptReplacement:
PromptReplacement(
modality="image",
target="<image>",
replacement=PromptReplacementDetails(
replacement=PromptUpdateDetails(
full="".join([
"<image_bos>",
"<image>" * image_feature_size,
@ -111,7 +208,7 @@ class PromptReplacement:
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=PromptReplacementDetails(
replacement=PromptUpdateDetails(
full=([image_bos_id] + [image_token_id] * image_feature_size
+ [image_eos_id]),
features=[image_token_id] * image_feature_size,
@ -119,29 +216,22 @@ class PromptReplacement:
)
"""
modality: str
"""The modality for which the replacement is made."""
target: PromptSeq
"""The token sequence (or text) to find and replace."""
replacement: Union[Callable[[int], PromptRepl],
PromptRepl] = field(repr=False)
replacement: PromptUpdateContent = field(repr=False)
"""
Given the index of the processed item within :attr:`modality`,
output the replacement token sequence (or text).
output the token sequence (or text) to replace :attr:`target`.
For convenience, you can directly pass in the replacement token sequence
(or text) instead of a function if it does not depend on the input.
For convenience, you can directly pass in the token sequence (or text)
instead of a function if it does not depend on the input.
"""
def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptReplacement":
return BoundPromptReplacement(
tokenizer=tokenizer,
modality=self.modality,
_target=self.target,
_replacement=self.replacement,
)
@property
def content(self) -> PromptUpdateContent:
return self.replacement
@property
def mode(self) -> UpdateMode:
return UpdateMode.REPLACE
@lru_cache(maxsize=2048)
@ -232,64 +322,73 @@ class _BoundPromptSequence:
@dataclass
class _BoundPromptReplacementGroup:
class _BoundPromptContent:
full: _BoundPromptSequence
features: _BoundPromptSequence
@dataclass
class BoundPromptReplacement:
class BoundPromptUpdate:
"""
A :class:`PromptReplacement` bound to a tokenizer to automatically
convert :attr:`target` and the result of :meth:`get_replacement` between
A :class:`PromptUpdate` bound to a tokenizer to automatically convert
:attr:`target` and the result of :meth:`get_content` between
token sequence and text representations.
"""
_origin: PromptUpdate
tokenizer: AnyTokenizer = field(repr=False)
modality: str
_target: PromptSeq
_replacement: Union[Callable[[int], PromptRepl],
PromptRepl] = field(repr=False)
def __post_init__(self) -> None:
self._replacement_cache = dict[int, _BoundPromptReplacementGroup]()
self._content_cache = dict[int, _BoundPromptContent]()
@property
def modality(self) -> str:
return self._origin.modality
@property
def target(self) -> _BoundPromptSequence:
"""The token sequence (or text) to find and replace."""
return _BoundPromptSequence.from_seq(self.tokenizer, self._target)
"""The token sequence (or text) to update."""
return _BoundPromptSequence.from_seq(self.tokenizer,
self._origin.target)
def get_replacement(self, item_idx: int) -> _BoundPromptReplacementGroup:
@property
def content(self) -> PromptUpdateContent:
"""The placeholder tokens that are part of the update."""
return self._origin.content
@property
def mode(self) -> UpdateMode:
"""Defines how to update the prompt."""
return self._origin.mode
def get_content(self, item_idx: int) -> _BoundPromptContent:
"""
Given the index of the processed item within :attr:`modality`,
output the replacement token sequence (or text).
output the token sequence (or text) to update.
"""
replacement = self._replacement
if callable(replacement):
content = self.content
if callable(content):
cache_key = item_idx
if cache_key in self._replacement_cache:
return self._replacement_cache[cache_key]
if cache_key in self._content_cache:
return self._content_cache[cache_key]
replacement = replacement(item_idx)
content = content(item_idx)
else:
cache_key = None
if not isinstance(replacement, PromptReplacementDetails):
replacement = PromptReplacementDetails.from_seq(replacement)
if not isinstance(content, PromptUpdateDetails):
content = PromptUpdateDetails.from_seq(content)
bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
replacement.full)
content.full)
bound_features = _BoundPromptSequence.from_seq(self.tokenizer,
replacement.features)
bound_replacement = _BoundPromptReplacementGroup(
full=bound_full,
features=bound_features,
)
content.features)
bound_content = _BoundPromptContent(full=bound_full,
features=bound_features)
if cache_key is not None:
self._replacement_cache[cache_key] = bound_replacement
self._content_cache[cache_key] = bound_content
return bound_replacement
return bound_content
class _TokenMatch(NamedTuple):
@ -326,12 +425,12 @@ def iter_token_matches(
@dataclass(repr=False)
class _PromptReplacementMatch(ABC):
prompt_repl: BoundPromptReplacement
class _PromptTargetMatch(ABC):
_origin: BoundPromptUpdate
@property
def modality(self) -> str:
return self.prompt_repl.modality
return self._origin.modality
@property
@abstractmethod
@ -349,7 +448,7 @@ class _PromptReplacementMatch(ABC):
@dataclass(repr=False)
class _PromptReplacementTokenMatch(_PromptReplacementMatch):
class _PromptTargetTokenMatch(_PromptTargetMatch):
match: _TokenMatch
@property
@ -362,7 +461,7 @@ class _PromptReplacementTokenMatch(_PromptReplacementMatch):
@dataclass(repr=False)
class _PromptReplacementTextMatch(_PromptReplacementMatch):
class _PromptTargetTextMatch(_PromptTargetMatch):
match: re.Match[str]
@property
@ -394,40 +493,37 @@ class PlaceholderFeaturesInfo:
def find_token_matches(
prompt: list[int],
prompt_repls: Sequence[BoundPromptReplacement],
) -> list[_PromptReplacementTokenMatch]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
prompt_updates: Sequence[BoundPromptUpdate],
) -> Sequence[_PromptTargetMatch]:
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
return [
_PromptReplacementTokenMatch(prompt_repl, match)
for prompt_repl in prompt_repls
for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
_PromptTargetTokenMatch(update, match) for update in prompt_updates
for match in iter_token_matches(prompt, update.target.token_ids)
]
def find_text_matches(
prompt: str,
prompt_repls: Sequence[BoundPromptReplacement],
) -> list[_PromptReplacementTextMatch]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
prompt_updates: Sequence[BoundPromptUpdate],
) -> Sequence[_PromptTargetMatch]:
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
return [
_PromptReplacementTextMatch(prompt_repl, match)
for prompt_repl in prompt_repls
for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
_PromptTargetTextMatch(update, match) for update in prompt_updates
for match in re.finditer(re.escape(update.target.text), prompt)
]
def _resolve_matches(
prompt: PromptSeq,
mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
) -> list[_PromptReplacementMatch]:
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
) -> list[_PromptTargetMatch]:
"""
Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
and sort them such that earlier matches take priority over later ones.
"""
matches = [m for matches in mm_matches.values() for m in matches]
seen_matches: list[Optional[_PromptReplacementMatch]] = [None
] * len(prompt)
seen_matches: list[Optional[_PromptTargetMatch]] = [None] * len(prompt)
for match in matches:
for idx in range(match.start_idx, match.end_idx):
@ -441,74 +537,91 @@ def _resolve_matches(
return sorted(matches, key=lambda x: x.start_idx)
def _replace_matches(
def _apply_matches(
prompt: _S,
mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
mm_item_counts: Mapping[str, int],
) -> list[_S]:
"""Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
out_seqs = list[_S]()
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
out_seqs = list[Union[str, list[int]]]()
prev_end_idx = 0
next_idx_by_modality = defaultdict[str, int](lambda: 0)
for match in _resolve_matches(prompt, mm_matches):
modality = match.modality
for (start_idx, end_idx), group in groupby(
_resolve_matches(prompt, mm_matches),
key=lambda x: (x.start_idx, x.end_idx),
):
matches = tuple(group)
assert len(matches) == 1
item_idx = next_idx_by_modality[modality]
if item_idx >= mm_item_counts.get(modality, 0):
continue
for match in matches:
modality = match.modality
start_idx = match.start_idx
end_idx = match.end_idx
item_idx = next_idx_by_modality[modality]
if item_idx >= mm_item_counts.get(modality, 0):
continue
repl_info = match.prompt_repl
replacement = repl_info.get_replacement(item_idx)
origin = match._origin
content = origin.get_content(item_idx)
mode = origin.mode
if isinstance(prompt, str):
repl_seq = replacement.full.text
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
else:
repl_seq = replacement.full.token_ids
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
if mode == UpdateMode.INSERT:
out_seqs.append(prompt[prev_end_idx:end_idx])
num_inserts = mm_item_counts.get(modality, 0)
elif mode == UpdateMode.REPLACE:
out_seqs.append(prompt[prev_end_idx:start_idx])
num_inserts = 1
else:
assert_never(mode)
prev_end_idx = end_idx
next_idx_by_modality[modality] += 1
for _ in range(num_inserts):
if item_idx >= mm_item_counts.get(modality, 0):
continue
if isinstance(prompt, str):
out_seqs.append(content.full.text)
else:
out_seqs.append(content.full.token_ids)
next_idx_by_modality[modality] += 1
prev_end_idx = end_idx
out_seqs.append(prompt[prev_end_idx:])
return out_seqs
return cast(list[_S], out_seqs)
def replace_token_matches(
def apply_token_matches(
prompt: list[int],
mm_matches: Mapping[str, Sequence[_PromptReplacementTokenMatch]],
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
mm_item_counts: Mapping[str, int],
) -> list[int]:
"""Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
if not mm_matches:
return prompt
token_id_seqs = _replace_matches(prompt, mm_matches, mm_item_counts)
token_id_seqs = _apply_matches(prompt, mm_matches, mm_item_counts)
return flatten_2d_lists(token_id_seqs)
def replace_text_matches(
def apply_text_matches(
prompt: str,
mm_matches: Mapping[str, Sequence[_PromptReplacementTextMatch]],
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
mm_item_counts: Mapping[str, int],
) -> str:
"""Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
if not mm_matches:
return prompt
texts = _replace_matches(prompt, mm_matches, mm_item_counts)
texts = _apply_matches(prompt, mm_matches, mm_item_counts)
return "".join(texts)
def _iter_placeholders(
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
prompt: list[int],
mm_item_counts: Mapping[str, int],
) -> Iterable[PlaceholderFeaturesInfo]:
@ -517,7 +630,7 @@ def _iter_placeholders(
Matches are exclusive even when multiple modalities share
the same placeholder tokens. In that case, the modality that
appears earlier in `mm_prompt_repls` takes priority.
appears earlier in `mm_prompt_updates` takes priority.
Note that empty matches are ignored.
"""
@ -528,37 +641,37 @@ def _iter_placeholders(
while start_idx < prompt_len:
found = False
for modality, modality_repls in mm_prompt_repls.items():
for modality, modality_updates in mm_prompt_updates.items():
item_idx = item_idx_by_modality[modality]
if item_idx >= mm_item_counts.get(modality, 0):
continue
for repl_info in modality_repls:
replacement = repl_info.get_replacement(item_idx)
repl_tokens_full = replacement.full.token_ids
repl_len_full = len(repl_tokens_full)
end_idx_full = start_idx + repl_len_full
for update_info in modality_updates:
content = update_info.get_content(item_idx)
content_tokens_full = content.full.token_ids
content_len_full = len(content_tokens_full)
end_idx_full = start_idx + content_len_full
if repl_len_full == 0 or end_idx_full > prompt_len:
if content_len_full == 0 or end_idx_full > prompt_len:
continue
if prompt[start_idx:end_idx_full] == repl_tokens_full:
repl_tokens_feat = replacement.features.token_ids
if prompt[start_idx:end_idx_full] == content_tokens_full:
content_tokens_feat = content.features.token_ids
try:
match = next(
iter_token_matches(repl_tokens_full,
repl_tokens_feat))
iter_token_matches(content_tokens_full,
content_tokens_feat))
yield PlaceholderFeaturesInfo(
modality=modality,
item_idx=item_idx,
start_idx=start_idx + match.start_idx,
tokens=repl_tokens_feat,
tokens=content_tokens_feat,
)
except StopIteration:
raise AssertionError(
f"{repl_tokens_feat=} should be a "
f"subsequence of {repl_tokens_full=}") from None
f"{content_tokens_feat=} should be a "
f"subsequence of {content_tokens_full=}") from None
# Exclude overlapping matches
start_idx = end_idx_full
@ -574,11 +687,11 @@ def _iter_placeholders(
def find_mm_placeholders(
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
prompt: list[int],
mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
it = _iter_placeholders(mm_prompt_updates, prompt, mm_item_counts)
return dict(full_groupby_modality(it))
@ -712,6 +825,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
if get_repls := getattr(self, "_get_prompt_replacements", None):
logger.warning_once("`_get_prompt_replacements` has been renamed "
"to `_get_prompt_updates`. The old name will "
"be removed in an upcoming release.")
self._get_prompt_updates = get_repls # type: ignore[method-assign]
super().__init__()
self.info = info
@ -770,34 +889,34 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
raise NotImplementedError
@abstractmethod
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> list[PromptUpdate]:
"""
Given the original multi-modal items for this modality
and HF-processed data, output the replacements to perform.
and HF-processed data, output the updates to perform.
Notes:
- You should not assume that HF processor always performs prompt
replacement: in :meth:`_apply_hf_processor_missing`, this method
updates: in :meth:`_apply_hf_processor_missing`, this method
is called on text-only and multimodal-only inputs separately,
instead of passing them in the same call.
- The replacement information returned by this method is also used
to determine the placeholder token positions for each multi-modal
- The update information returned by this method is also used to
determine the placeholder token positions for each multi-modal
item.
"""
raise NotImplementedError
def _find_mm_placeholders(
self,
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
new_token_ids: list[int],
mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
return find_mm_placeholders(mm_prompt_repls, new_token_ids,
return find_mm_placeholders(mm_prompt_updates, new_token_ids,
mm_item_counts)
def _get_hf_mm_data(
@ -831,14 +950,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_kwargs,
)
def _hf_processor_applies_repl(
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> bool:
"""
Return whether the HF processor applies prompt replacements.
Return whether the HF processor applies prompt updates.
For most HF processors, this should be :code:`True` when multi-modal
data items are passed, but :code:`False` when multi-modal embeddings
@ -858,7 +977,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Apply the HF processor on the prompt text and multi-modal data
together.
In addition, return whether prompt replacements have been applied.
In addition, return whether prompt updates have been applied.
"""
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
@ -876,13 +995,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
)
is_repl_applied = self._hf_processor_applies_repl(
is_update_applied = self._hf_processor_applies_updates(
prompt_text=prompt_text,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
return prompt_ids, mm_kwargs, is_repl_applied
return prompt_ids, mm_kwargs, is_update_applied
def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
"""
@ -948,21 +1067,21 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
*,
enable_hf_prompt_replacement: bool,
enable_hf_prompt_update: bool,
) -> tuple[list[int], MultiModalKwargs, bool]:
"""
Apply the HF processor on the prompt text and multi-modal data.
In addition, return whether prompt replacements have been applied
In addition, return whether prompt updates have been applied
(for most HF processors, this should be :code:`True`).
Note:
If :code:`enable_hf_prompt_replacement=False`, we use HF processor
to perform prompt replacement if available; HF processor requires
If :code:`enable_hf_prompt_update=False`, we use HF processor
to perform prompt updates if available; HF processor requires
that the prompt corresponds to multi-modal items.
"""
if isinstance(prompt, str):
if enable_hf_prompt_replacement:
if enable_hf_prompt_update:
return self._apply_hf_processor_text_mm(
prompt_text=prompt,
mm_items=mm_items,
@ -999,7 +1118,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt=prompt,
mm_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_replacement=True,
enable_hf_prompt_update=True,
)
mm_maybe_cached_kw_items = {
@ -1022,17 +1141,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_missing_data_items = self._to_mm_items(mm_missing_data)
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we can't apply prompt replacements until the new multimodal
# so we can't apply prompt updates until the new multimodal
# items are combined with the cached multimodal items
(
prompt_ids,
mm_missing_kwargs,
is_repl_applied,
is_update_applied,
) = self._apply_hf_processor_main(
prompt=prompt,
mm_items=mm_missing_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_replacement=False,
enable_hf_prompt_update=False,
)
mm_missing_next_idx = {
@ -1071,28 +1190,28 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
return prompt_ids, mm_kwargs, is_repl_applied
return prompt_ids, mm_kwargs, is_update_applied
def _bind_and_group_repls(
def _bind_and_group_updates(
self,
prompt_repls: list[PromptReplacement],
) -> dict[str, list[BoundPromptReplacement]]:
prompt_updates: list[PromptUpdate],
) -> dict[str, list[BoundPromptUpdate]]:
tokenizer = self.info.get_tokenizer()
it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
it = (update.bind(tokenizer) for update in prompt_updates)
return dict(full_groupby_modality(it))
def _apply_prompt_replacements(
def _apply_prompt_updates(
self,
token_ids: list[int],
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
mm_item_counts: Mapping[str, int],
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
tokenizer = self.info.get_tokenizer()
mm_token_matches = {
modality: find_token_matches(token_ids, prompt_repls)
for modality, prompt_repls in mm_prompt_repls.items()
modality: find_token_matches(token_ids, updates)
for modality, updates in mm_prompt_updates.items()
}
mm_match_counts = {
modality: len(matches)
@ -1107,31 +1226,31 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
# up a token, then the token ID of "foo" will not appear at all
# ----
# Since it is inefficient to search for all possible tokenizations
# of the search text in the prompt, we instead perform string
# replacement on the decoded token IDs, then encode them back.
# of the search text in the prompt, we instead perform string-based
# updates on the decoded token IDs, then encode them back.
if all(
mm_match_counts.get(modality, 0) >= item_count
for modality, item_count in mm_item_counts.items()
): # yapf: disable
token_ids = replace_token_matches(
token_ids = apply_token_matches(
token_ids,
mm_token_matches,
mm_item_counts,
)
text = decode_tokens(tokenizer, token_ids)
matched_repls = {
modality: [match.prompt_repl for match in token_matches]
matched_updates = {
modality: [match._origin for match in token_matches]
for modality, token_matches in mm_token_matches.items()
}
else:
text = decode_tokens(tokenizer, token_ids)
mm_text_matches = {
modality: find_text_matches(text, prompt_repls)
for modality, prompt_repls in mm_prompt_repls.items()
modality: find_text_matches(text, updates)
for modality, updates in mm_prompt_updates.items()
}
text = replace_text_matches(
text = apply_text_matches(
text,
mm_text_matches,
mm_item_counts,
@ -1140,13 +1259,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
token_ids = encode_tokens(tokenizer,
text,
add_special_tokens=False)
matched_repls = {
modality: [match.prompt_repl for match in token_matches]
matched_updates = {
modality: [match._origin for match in token_matches]
for modality, token_matches in mm_text_matches.items()
}
placeholders = self._find_mm_placeholders(
matched_repls,
matched_updates,
token_ids,
mm_item_counts,
)
@ -1184,14 +1303,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
if len(placeholders) != item_count:
raise RuntimeError(
f"Expected there to be {item_count} prompt replacements "
f"Expected there to be {item_count} prompt updates "
f"corresponding to {item_count} {modality} items, but "
f"instead found {len(placeholders)} prompt replacements! "
f"instead found {len(placeholders)} prompt updates! "
"Either the prompt text has missing/incorrect tokens for "
"multi-modal inputs, or there is a problem with your "
"implementation of merged multi-modal processor for this "
"model (usually arising from an inconsistency between "
"`_call_hf_processor` and `_get_prompt_replacements`).")
"`_call_hf_processor` and `_get_prompt_updates`).")
def apply(
self,
@ -1206,7 +1325,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
1. Apply HF Processor on prompt text and multi-modal data together,
outputting token IDs and processed tensors.
2. Find and replace sequences in the token IDs with placeholder tokens.
2. Find and update sequences in the token IDs with placeholder tokens.
The number of placeholder tokens equals the feature size of the
multi-modal data outputted by the multi-modal encoder.
3. Extract information about the placeholder tokens from the
@ -1235,26 +1354,27 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
(
prompt_ids,
mm_kwargs,
is_repl_applied,
is_update_applied,
) = self._cached_apply_hf_processor(
prompt,
mm_items,
hf_processor_mm_kwargs,
)
unbound_prompt_repls = self._get_prompt_replacements(
unbound_prompt_updates = self._get_prompt_updates(
mm_items,
hf_processor_mm_kwargs,
mm_kwargs,
)
mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls)
mm_prompt_updates = self._bind_and_group_updates(
unbound_prompt_updates)
mm_item_counts = mm_items.get_all_counts()
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
if is_repl_applied:
if is_update_applied:
mm_placeholders = self._find_mm_placeholders(
mm_prompt_repls,
mm_prompt_updates,
prompt_ids,
mm_item_counts,
)
@ -1267,9 +1387,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_ids,
prompt,
mm_placeholders,
) = self._apply_prompt_replacements(
) = self._apply_prompt_updates(
prompt_ids,
mm_prompt_repls,
mm_prompt_updates,
mm_item_counts,
)
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)