mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[VLM] Generalized prompt updates for multi-modal processor (#13964)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -720,13 +720,13 @@ def _get_mm_fields_config(
|
||||
|
||||
:::::
|
||||
|
||||
### Prompt replacements
|
||||
### Prompt updates
|
||||
|
||||
Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` to
|
||||
return a list of {class}`~vllm.multimodal.processing.PromptReplacement` instances.
|
||||
Override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates` to
|
||||
return a list of {class}`~vllm.multimodal.processing.PromptUpdate` instances.
|
||||
|
||||
Each {class}`~vllm.multimodal.processing.PromptReplacement` instance specifies a find-and-replace
|
||||
operation performed by the HF processor.
|
||||
Each {class}`~vllm.multimodal.processing.PromptUpdate` instance specifies an update operation
|
||||
(e.g.: insertion, replacement) performed by the HF processor.
|
||||
|
||||
::::{tab-set}
|
||||
:::{tab-item} Basic example: LLaVA
|
||||
@ -743,15 +743,15 @@ for sample in text:
|
||||
```
|
||||
|
||||
It simply repeats each input `image_token` a number of times equal to the number of placeholder feature tokens (`num_image_tokens`).
|
||||
Based on this, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements` as follows:
|
||||
Based on this, we override {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates` as follows:
|
||||
|
||||
```python
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
@ -859,7 +859,7 @@ prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
|
||||
)
|
||||
```
|
||||
|
||||
To accommodate this, instead of a string you can return an instance of `PromptReplacementDetails`
|
||||
To accommodate this, instead of a string you can return an instance of `PromptUpdateDetails`
|
||||
with different `full` and `feature` attributes:
|
||||
|
||||
```python
|
||||
@ -878,7 +878,7 @@ def get_replacement_fuyu(item_idx: int):
|
||||
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
||||
[_NEWLINE_TOKEN_ID]) * nrows
|
||||
|
||||
return PromptReplacementDetails(
|
||||
return PromptUpdateDetails(
|
||||
full=image_tokens + [bos_token_id],
|
||||
features=image_tokens,
|
||||
)
|
||||
@ -888,12 +888,12 @@ Finally, noticing that the HF processor removes the `|ENDOFTEXT|` token from the
|
||||
we can search for it to conduct the replacement at the start of the string:
|
||||
|
||||
```python
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
bos_token_id = hf_config.bos_token_id
|
||||
assert isinstance(bos_token_id, int)
|
||||
@ -913,7 +913,7 @@ def _get_prompt_replacements(
|
||||
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
||||
[_NEWLINE_TOKEN_ID]) * nrows
|
||||
|
||||
return PromptReplacementDetails(
|
||||
return PromptUpdateDetails(
|
||||
full=image_tokens + [bos_token_id],
|
||||
features=image_tokens,
|
||||
)
|
||||
|
@ -6,11 +6,16 @@ To enable various optimizations in vLLM such as [chunked prefill](#chunked-prefi
|
||||
|
||||
Here are the main features of {class}`~vllm.multimodal.processing.BaseMultiModalProcessor`:
|
||||
|
||||
## Prompt Replacement Detection
|
||||
## Prompt Update Detection
|
||||
|
||||
One of the main responsibilies of HF processor is to replace input placeholder tokens (e.g. `<image>` for a single image) with feature placeholder tokens (e.g. `<image><image>...<image>`, the number of which equals to the feature size). The information about which tokens have been replaced is key to finding the correspondence between placeholder feature tokens and multi-modal inputs.
|
||||
One of the main responsibilies of HF processor is to update the prompt with placeholder tokens. For example:
|
||||
|
||||
In vLLM, this information is specified using {class}`~vllm.multimodal.processing.PromptReplacement` in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements`. Given this specification, we can automatically detect whether HF has replaced the input placeholder tokens by checking whether the feature placeholder tokens exist in the prompt.
|
||||
- Insert feature placeholder tokens (e.g. `<image><image>...<image>`, the number of which equals to the feature size) at the start of the string.
|
||||
- Replace existing input placeholder tokens (e.g. `<image>` for a single image) with feature placeholder tokens (e.g. `<image><image>...<image>`, the number of which equals to the feature size).
|
||||
|
||||
The information about which tokens have been updated is key to finding the correspondence between placeholder feature tokens and multi-modal inputs.
|
||||
|
||||
In vLLM, this information is specified using {class}`~vllm.multimodal.processing.PromptUpdate` in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates`. We can automatically detect whether HF has updated the prompt by checking the existence of the updated tokens.
|
||||
|
||||
## Tokenized Prompt Inputs
|
||||
|
||||
@ -22,7 +27,7 @@ Consider that HF processors follow these main steps:
|
||||
|
||||
1. Tokenize the text
|
||||
2. Process multi-modal inputs
|
||||
3. Perform prompt replacement
|
||||
3. Perform prompt updates
|
||||
|
||||
And we require that:
|
||||
|
||||
@ -44,16 +49,16 @@ Moreover, since the tokenized text has not passed through the HF processor, we h
|
||||
|
||||
We work around the first issue by requiring each model to define how to generate dummy text based on the number of multi-modal inputs, via {meth}`~vllm.multimodal.profiling.BaseDummyInputsBuilder.get_dummy_processor_inputs`. This lets us generate dummy text corresponding to the multi-modal inputs and input them together to obtain the processed multi-modal data.
|
||||
|
||||
(mm-automatic-prompt-replacement)=
|
||||
(mm-automatic-prompt-updating)=
|
||||
|
||||
### Automatic prompt replacement
|
||||
### Automatic prompt updating
|
||||
|
||||
We address the second issue by implementing model-agnostic code in
|
||||
{meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_prompt_replacements` to automatically replace input placeholder tokens with feature placeholder tokens based on the specification outputted by {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_replacements`.
|
||||
{meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_prompt_updates` to automatically update the prompt with feature placeholder tokens based on the specification outputted by {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates`.
|
||||
|
||||
### Summary
|
||||
|
||||
With the help of dummy text and automatic prompt replacement, our multi-modal processor can finally accept both text and token prompts with multi-modal data. The detailed logic is shown in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_main`.
|
||||
With the help of dummy text and automatic prompt updating, our multi-modal processor can finally accept both text and token prompts with multi-modal data. The detailed logic is shown in {meth}`~vllm.multimodal.processing.BaseMultiModalProcessor._apply_hf_processor_main`.
|
||||
|
||||
## Processor Output Caching
|
||||
|
||||
@ -61,4 +66,4 @@ Some HF processors, such as the one for Qwen2-VL, are [very slow](gh-issue:9238)
|
||||
|
||||
When new data is passed in, we first check which items are in the cache, and which ones are missing. The missing items are passed into the HF processor in a single batch and cached, before being merged with the existing items in the cache.
|
||||
|
||||
Since we only process the missing multi-modal data items, the number of input placeholder tokens no longer corresponds to the number of the multi-modal inputs, so they can't be passed alongside the text prompt to HF processor. Therefore, we process the text and multi-modal inputs separately, using [dummy text](#mm-dummy-text) to avoid HF errors. Since this skips HF's prompt replacement code, we apply [automatic prompt replacement](#mm-automatic-prompt-replacement) afterwards to keep the output tokens and multi-modal data consistent with each other.
|
||||
Since we only process the missing multi-modal data items, the number of input placeholder tokens no longer corresponds to the number of the multi-modal inputs, so they can't be passed alongside the text prompt to HF processor. Therefore, we process the text and multi-modal inputs separately, using [dummy text](#mm-dummy-text) to avoid HF errors. Since this skips HF's prompt updating code, we apply [automatic prompt updating](#mm-automatic-prompt-updating) afterwards to keep the output tokens and multi-modal data consistent with each other.
|
||||
|
@ -14,12 +14,12 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
|
||||
PromptReplacement,
|
||||
PromptInsertion, PromptReplacement,
|
||||
apply_text_matches,
|
||||
apply_token_matches,
|
||||
find_mm_placeholders,
|
||||
find_text_matches, find_token_matches,
|
||||
iter_token_matches,
|
||||
replace_text_matches,
|
||||
replace_token_matches)
|
||||
iter_token_matches)
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import MultiModalProfiler
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
|
||||
@ -102,7 +102,7 @@ def test_iter_token_matches(token_ids, match_ids, expected):
|
||||
{
|
||||
"pattern_1": [],
|
||||
"pattern_2": [],
|
||||
}
|
||||
},
|
||||
),
|
||||
(
|
||||
[32000, 32000, 32000, 32000],
|
||||
@ -147,16 +147,22 @@ def test_iter_token_matches(token_ids, match_ids, expected):
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
|
||||
# yapf: enable
|
||||
def test_find_token_matches(prompt, target_by_key, expected_by_key):
|
||||
def test_find_token_matches(
|
||||
prompt,
|
||||
target_by_key,
|
||||
expected_by_key,
|
||||
update_type,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to token IDs
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
prompt_repls = [
|
||||
PromptReplacement(key, target, []).bind(mock_tokenizer)
|
||||
prompt_updates = [
|
||||
update_type(key, target, []).bind(mock_tokenizer)
|
||||
for key, target in target_by_key.items()
|
||||
]
|
||||
result = find_token_matches(prompt, prompt_repls)
|
||||
result = find_token_matches(prompt, prompt_updates)
|
||||
|
||||
# Only displayed on error
|
||||
print("result:", result)
|
||||
@ -254,16 +260,22 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key):
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
|
||||
# yapf: enable
|
||||
def test_find_text_matches(prompt, target_by_key, expected_by_key):
|
||||
def test_find_text_matches(
|
||||
prompt,
|
||||
target_by_key,
|
||||
expected_by_key,
|
||||
update_type,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to text
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
prompt_repls = [
|
||||
PromptReplacement(key, target, []).bind(mock_tokenizer)
|
||||
prompt_updates = [
|
||||
update_type(key, target, []).bind(mock_tokenizer)
|
||||
for key, target in target_by_key.items()
|
||||
]
|
||||
result = find_text_matches(prompt, prompt_repls)
|
||||
result = find_text_matches(prompt, prompt_updates)
|
||||
|
||||
# Only displayed on error
|
||||
print("result:", result)
|
||||
@ -281,7 +293,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("prompt", "target_by_key", "repl_by_key"),
|
||||
("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501
|
||||
[
|
||||
(
|
||||
"Image:<image>Image:<image><image>!",
|
||||
@ -300,58 +312,66 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
|
||||
# Test dynamic replacement (beyond the form of `unit * count`)
|
||||
"pattern_3": "?!?",
|
||||
},
|
||||
{
|
||||
PromptInsertion: {
|
||||
0: "Image:<image>Image:<image><image>!",
|
||||
1: "Image:<image><image><image>Image:<image><image>!?!?",
|
||||
2: "Image:<image><image><image><image><image>Image:<image><image>!?!??!?", # noqa: E501
|
||||
},
|
||||
PromptReplacement: {
|
||||
0: "Image:<image>Image:<image><image>!",
|
||||
1: "<image><image>Image:<image><image>?!?",
|
||||
2: "<image><image><image><image><image>?!?",
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("mm_count", "expected"),
|
||||
[
|
||||
(0, "Image:<image>Image:<image><image>!"),
|
||||
(1, "<image><image>Image:<image><image>?!?"),
|
||||
(2, "<image><image><image><image><image>?!?"),
|
||||
]
|
||||
)
|
||||
# yapf: enable
|
||||
def test_find_replace_text(
|
||||
def test_find_update_text(
|
||||
prompt,
|
||||
target_by_key,
|
||||
repl_by_key,
|
||||
mm_count,
|
||||
expected,
|
||||
expected_by_update_type_mm_count,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to text
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
mm_prompt_repls = {
|
||||
key: [
|
||||
PromptReplacement(key, target,
|
||||
repl_by_key[key]).bind(mock_tokenizer)
|
||||
]
|
||||
for key, target in target_by_key.items()
|
||||
}
|
||||
mm_matches = {
|
||||
key: find_text_matches(prompt, prompt_repls)
|
||||
for key, prompt_repls in mm_prompt_repls.items()
|
||||
}
|
||||
for (
|
||||
update_type,
|
||||
expected_by_mm_count,
|
||||
) in expected_by_update_type_mm_count.items():
|
||||
mm_prompt_updates = {
|
||||
key:
|
||||
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
|
||||
for key, target in target_by_key.items()
|
||||
}
|
||||
mm_matches = {
|
||||
key: find_text_matches(prompt, updates)
|
||||
for key, updates in mm_prompt_updates.items()
|
||||
}
|
||||
|
||||
result = replace_text_matches(
|
||||
prompt,
|
||||
mm_matches,
|
||||
{key: mm_count
|
||||
for key in repl_by_key},
|
||||
)
|
||||
for mm_count, expected in expected_by_mm_count.items():
|
||||
result = apply_text_matches(
|
||||
prompt,
|
||||
mm_matches,
|
||||
{key: mm_count
|
||||
for key in repl_by_key},
|
||||
)
|
||||
|
||||
# Only displayed on error
|
||||
print("mm_matches:", mm_matches)
|
||||
print("result:", result)
|
||||
# Only displayed on error
|
||||
print("update_type:", update_type)
|
||||
print("mm_count:", mm_count)
|
||||
print("mm_matches:", mm_matches)
|
||||
print("result:", result)
|
||||
|
||||
# Manually constructed results
|
||||
assert result == expected
|
||||
# Manually constructed results
|
||||
assert result == expected
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("prompt", "target_by_key", "repl_by_key"),
|
||||
("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"), # noqa: E501
|
||||
[
|
||||
# Tokenized test cases of `test_find_replace_text`
|
||||
# using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
|
||||
@ -372,53 +392,61 @@ def test_find_replace_text(
|
||||
# Test dynamic replacement (beyond the form of `unit * count`)
|
||||
"pattern_3": [1550, 918, 1550],
|
||||
},
|
||||
{
|
||||
PromptInsertion: {
|
||||
0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
1: [1, 9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550], # noqa: E501
|
||||
2: [1, 9833, 28747, 32000, 32000, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918, 1550, 918, 1550, 1550, 918, 1550], # noqa: E501
|
||||
},
|
||||
PromptReplacement: {
|
||||
0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
|
||||
1: [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550], # noqa: E501
|
||||
2: [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
("mm_count", "expected"),
|
||||
[
|
||||
(0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]),
|
||||
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550]),
|
||||
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550]),
|
||||
]
|
||||
)
|
||||
# yapf: enable
|
||||
def test_find_replace_tokens(
|
||||
def test_find_update_tokens(
|
||||
prompt,
|
||||
target_by_key,
|
||||
repl_by_key,
|
||||
mm_count,
|
||||
expected,
|
||||
expected_by_update_type_mm_count,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to tokens
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
mm_prompt_repls = {
|
||||
key: [
|
||||
PromptReplacement(key, target,
|
||||
repl_by_key[key]).bind(mock_tokenizer)
|
||||
]
|
||||
for key, target in target_by_key.items()
|
||||
}
|
||||
mm_matches = {
|
||||
key: find_token_matches(prompt, prompt_repls)
|
||||
for key, prompt_repls in mm_prompt_repls.items()
|
||||
}
|
||||
for (
|
||||
update_type,
|
||||
expected_by_mm_count,
|
||||
) in expected_by_update_type_mm_count.items():
|
||||
mm_prompt_updates = {
|
||||
key:
|
||||
[update_type(key, target, repl_by_key[key]).bind(mock_tokenizer)]
|
||||
for key, target in target_by_key.items()
|
||||
}
|
||||
mm_matches = {
|
||||
key: find_token_matches(prompt, updates)
|
||||
for key, updates in mm_prompt_updates.items()
|
||||
}
|
||||
|
||||
result = replace_token_matches(
|
||||
prompt,
|
||||
mm_matches,
|
||||
{key: mm_count
|
||||
for key in repl_by_key},
|
||||
)
|
||||
for mm_count, expected in expected_by_mm_count.items():
|
||||
result = apply_token_matches(
|
||||
prompt,
|
||||
mm_matches,
|
||||
{key: mm_count
|
||||
for key in repl_by_key},
|
||||
)
|
||||
|
||||
# Only displayed on error
|
||||
print("mm_matches:", mm_matches)
|
||||
print("result:", result)
|
||||
# Only displayed on error
|
||||
print("update_type:", update_type)
|
||||
print("mm_count:", mm_count)
|
||||
print("mm_matches:", mm_matches)
|
||||
print("result:", result)
|
||||
|
||||
# Manually constructed results
|
||||
assert result == expected
|
||||
# Manually constructed results
|
||||
assert result == expected
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@ -524,22 +552,24 @@ def test_find_replace_tokens(
|
||||
),
|
||||
]
|
||||
)
|
||||
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
|
||||
# yapf: enable
|
||||
def test_find_mm_placeholders(
|
||||
repl_by_key,
|
||||
prompt,
|
||||
expected,
|
||||
update_type,
|
||||
):
|
||||
# Should not be used since there is nothing to convert to tokens
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
mm_prompt_repls = {
|
||||
key: [PromptReplacement(key, [], repl).bind(mock_tokenizer)]
|
||||
mm_prompt_updates = {
|
||||
key: [update_type(key, [], repl).bind(mock_tokenizer)]
|
||||
for key, repl in repl_by_key.items()
|
||||
}
|
||||
|
||||
result = find_mm_placeholders(
|
||||
mm_prompt_repls,
|
||||
mm_prompt_updates,
|
||||
prompt,
|
||||
# Effectively match all occurrences in the prompt
|
||||
{key: 3
|
||||
|
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import List, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -26,7 +25,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement)
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -457,12 +457,12 @@ class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]):
|
||||
pixel_mask=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
from typing import Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -19,8 +19,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
BaseProcessingInfo, PromptInsertion,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -474,30 +474,24 @@ class Blip2MultiModalProcessor(BaseMultiModalProcessor[Blip2ProcessingInfo]):
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
bos_token_id = tokenizer.bos_token_id
|
||||
assert isinstance(bos_token_id, int)
|
||||
|
||||
image_token_id = vocab["<image>"]
|
||||
num_image_tokens = self.info.get_num_image_tokens()
|
||||
image_tokens = [image_token_id] * num_image_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
PromptInsertion(
|
||||
modality="image",
|
||||
target=[bos_token_id],
|
||||
replacement=PromptReplacementDetails(
|
||||
full=image_tokens + [bos_token_id],
|
||||
features=image_tokens,
|
||||
),
|
||||
target="",
|
||||
insertion=image_tokens,
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import (Any, Dict, Iterable, Literal, Mapping, Optional, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
from typing import Any, Dict, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -35,7 +35,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -141,12 +141,12 @@ class ChameleonMultiModalProcessor(
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
@ -162,7 +162,7 @@ class ChameleonMultiModalProcessor(
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=PromptReplacementDetails(
|
||||
replacement=PromptUpdateDetails(
|
||||
full=([image_start_id] + image_tokens + [image_end_id]),
|
||||
features=image_tokens,
|
||||
),
|
||||
@ -371,7 +371,7 @@ class ChameleonDecoderLayer(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
|
@ -3,9 +3,9 @@
|
||||
# adapted from https://github.com/deepseek-ai/DeepSeek-VL2/blob/faf18023f24b962b32d9f0a2d89e402a8d383a78/deepseek_vl2/models/modeling_deepseek_vl_v2.py
|
||||
"""Inference-only Deepseek-VL2 model compatible with HuggingFace weights."""
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -26,7 +26,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, ProcessingCache,
|
||||
PromptReplacement)
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
|
||||
@ -281,12 +281,12 @@ class DeepseekVL2MultiModalProcessor(
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
image_token_id = hf_processor.image_token_id
|
||||
|
@ -1,9 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, OrderedDict,
|
||||
Set, Tuple, TypedDict, Union)
|
||||
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -24,8 +25,7 @@ from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
||||
from vllm.multimodal.parse import MultiModalDataDict, MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseProcessingInfo,
|
||||
EncDecMultiModalProcessor,
|
||||
PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
PromptInsertion, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -803,7 +803,7 @@ class Florence2DummyInputsBuilder(
|
||||
class Florence2MultiModalProcessor(
|
||||
EncDecMultiModalProcessor[Florence2ProcessingInfo]):
|
||||
|
||||
def _hf_processor_applies_repl(
|
||||
def _hf_processor_applies_updates(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_items: MultiModalDataItems,
|
||||
@ -850,26 +850,22 @@ class Florence2MultiModalProcessor(
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
pad_token_id = hf_config.pad_token_id
|
||||
bos_token_id = hf_config.bos_token_id
|
||||
num_image_tokens = self.info.get_max_image_tokens()
|
||||
image_tokens = [pad_token_id] * num_image_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
PromptInsertion(
|
||||
modality="image",
|
||||
target=[bos_token_id],
|
||||
replacement=PromptReplacementDetails(
|
||||
full=image_tokens + [bos_token_id],
|
||||
features=image_tokens,
|
||||
),
|
||||
target="",
|
||||
insertion=image_tokens,
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -17,8 +17,8 @@
|
||||
# limitations under the License.
|
||||
""" PyTorch Fuyu model."""
|
||||
import math
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict)
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import List, Literal, Optional, Set, Tuple, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -37,7 +37,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -203,12 +203,12 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(image_patches=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
bos_token_id = hf_config.bos_token_id
|
||||
assert isinstance(bos_token_id, int)
|
||||
@ -228,7 +228,7 @@ class FuyuMultiModalProcessor(BaseMultiModalProcessor[FuyuProcessingInfo]):
|
||||
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
||||
[_NEWLINE_TOKEN_ID]) * nrows
|
||||
|
||||
return PromptReplacementDetails(
|
||||
return PromptUpdateDetails(
|
||||
full=image_tokens + [bos_token_id],
|
||||
features=image_tokens,
|
||||
)
|
||||
|
@ -4,7 +4,8 @@
|
||||
# https://github.com/THUDM/CogAgent
|
||||
"""Inference-only CogAgent model compatible with THUDM weights."""
|
||||
from argparse import Namespace
|
||||
from typing import Literal, Mapping, Optional, TypedDict, Union
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Literal, Optional, TypedDict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -32,7 +33,7 @@ from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, BatchFeature,
|
||||
MultiModalFieldConfig,
|
||||
PromptReplacement)
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
@ -480,7 +481,7 @@ class GLM4VDummyInputsBuilder(BaseDummyInputsBuilder[GLM4VProcessingInfo]):
|
||||
|
||||
class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
|
||||
|
||||
def _hf_processor_applies_repl(
|
||||
def _hf_processor_applies_updates(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_items: MultiModalDataItems,
|
||||
@ -495,12 +496,12 @@ class GLM4VMultiModalProcessor(BaseMultiModalProcessor[GLM4VProcessingInfo]):
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
|
||||
boi_token_id = hf_config.boi_token_id
|
||||
|
@ -7,7 +7,8 @@
|
||||
# Copyright (c) 2024 H2O.AI
|
||||
# Licensed under Apache 2.0 License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from typing import Mapping, Optional
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
@ -20,7 +21,7 @@ from vllm.multimodal.inputs import MultiModalKwargs
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
@ -487,12 +488,12 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
|
||||
f"{type(self).__name__} does not support processing cache with "
|
||||
"multi-image support enabled.")
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
if "image_num_patches" in out_mm_kwargs:
|
||||
@ -527,7 +528,7 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
|
||||
if num_patches is not None:
|
||||
assert isinstance(num_patches, int)
|
||||
|
||||
return PromptReplacementDetails(
|
||||
return PromptUpdateDetails(
|
||||
full=hf_processor.get_image_repl_full(feature_size,
|
||||
num_patches),
|
||||
features=hf_processor.get_image_repl_features(
|
||||
|
@ -16,8 +16,8 @@
|
||||
"""Inference-only Idefics3 model compatible with HuggingFace weights."""
|
||||
|
||||
import math
|
||||
from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Dict, List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -41,7 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
MultiModalDataItems,
|
||||
MultiModalFieldConfig,
|
||||
PromptReplacement)
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -274,12 +274,12 @@ class Idefics3MultimodalProcessor(
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
image_token = hf_processor.image_token.content
|
||||
|
@ -7,9 +7,10 @@
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, TypeVar, Union)
|
||||
from typing import (List, Literal, Optional, Set, Tuple, TypedDict, TypeVar,
|
||||
Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -31,7 +32,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
@ -599,12 +600,12 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
if "image_num_patches" in out_mm_kwargs:
|
||||
@ -636,7 +637,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
if num_patches is not None:
|
||||
assert isinstance(num_patches, int)
|
||||
|
||||
return PromptReplacementDetails(
|
||||
return PromptUpdateDetails(
|
||||
full=hf_processor.get_image_repl_full(feature_size,
|
||||
num_patches),
|
||||
features=hf_processor.get_image_repl_features(
|
||||
|
@ -1,9 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
||||
Protocol, Set, Tuple, TypedDict, TypeVar, Union)
|
||||
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
|
||||
TypedDict, TypeVar, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -31,7 +32,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, ProcessingCache,
|
||||
PromptReplacement)
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -222,12 +223,12 @@ class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
@ -328,12 +329,12 @@ class PixtralHFMultiModalProcessor(
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
hf_config = self.info.get_hf_config()
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
@ -789,7 +790,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
"</Image>)", # 3 tokens
|
||||
])
|
||||
|
||||
mantis_mm_repls = self._bind_and_group_repls([
|
||||
mantis_mm_repls = self._bind_and_group_updates([
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id] * num_image_tokens,
|
||||
@ -797,18 +798,18 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
)
|
||||
])
|
||||
|
||||
prompt_ids, prompt, _ = self._apply_prompt_replacements(
|
||||
prompt_ids, prompt, _ = self._apply_prompt_updates(
|
||||
result["prompt_token_ids"],
|
||||
mantis_mm_repls,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
unbound_orig_repls = self._get_prompt_replacements(
|
||||
unbound_orig_repls = self._get_prompt_updates(
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
mm_kwargs,
|
||||
)
|
||||
orig_repls = self._bind_and_group_repls(unbound_orig_repls)
|
||||
orig_repls = self._bind_and_group_updates(unbound_orig_repls)
|
||||
|
||||
mm_placeholders = self._find_mm_placeholders(
|
||||
orig_repls,
|
||||
|
@ -1,9 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -21,7 +21,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
|
||||
VideoEmbeddingItems, VideoProcessorItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement)
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
@ -183,12 +184,12 @@ class LlavaNextVideoMultiModalProcessor(
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(pixel_values_videos=MultiModalFieldConfig.batched("video"))
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
video_token_id = hf_config.video_token_index
|
||||
|
||||
|
@ -1,9 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
||||
Protocol, Set, Tuple, TypedDict, Union)
|
||||
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -22,7 +23,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
|
||||
VideoEmbeddingItems, VideoProcessorItems)
|
||||
from vllm.multimodal.processing import PromptReplacement
|
||||
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
||||
from vllm.multimodal.profiling import ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
@ -347,13 +348,13 @@ class LlavaOnevisionMultiModalProcessor(
|
||||
)
|
||||
return BatchFeature(combined_outputs)
|
||||
|
||||
def _hf_processor_applies_repl(
|
||||
def _hf_processor_applies_updates(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> bool:
|
||||
base_result = super()._hf_processor_applies_repl(
|
||||
base_result = super()._hf_processor_applies_updates(
|
||||
prompt_text=prompt_text,
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
@ -361,13 +362,13 @@ class LlavaOnevisionMultiModalProcessor(
|
||||
|
||||
return base_result and mm_items.get_count("video", strict=False) == 0
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
image_repls = super()._get_prompt_replacements(
|
||||
) -> Sequence[PromptUpdate]:
|
||||
image_repls = super()._get_prompt_updates(
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
out_mm_kwargs=out_mm_kwargs,
|
||||
@ -392,7 +393,8 @@ class LlavaOnevisionMultiModalProcessor(
|
||||
|
||||
return [video_token_id] * num_video_tokens
|
||||
|
||||
return image_repls + [
|
||||
return [
|
||||
*image_repls,
|
||||
PromptReplacement(
|
||||
modality="video",
|
||||
target=[video_token_id],
|
||||
|
@ -22,9 +22,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import partial
|
||||
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
|
||||
Optional, Set, Tuple, TypedDict, Union)
|
||||
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -356,10 +357,10 @@ class MiniCPMOMultiModalProcessor(
|
||||
inputs["audio"]["audio_lens"][index])
|
||||
return super().get_prompt_texts_by_modality(inputs, modality, index)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self, mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||
out_mm_kwargs: MultiModalKwargs) -> List[PromptReplacement]:
|
||||
out_mm_kwargs: MultiModalKwargs) -> Sequence[PromptReplacement]:
|
||||
placeholder = {
|
||||
"image": self.info.image_pattern,
|
||||
"video": self.info.video_pattern,
|
||||
|
@ -25,9 +25,10 @@
|
||||
import math
|
||||
import re
|
||||
from collections import Counter
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property, partial
|
||||
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
|
||||
Optional, Set, Tuple, TypedDict, Union)
|
||||
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -732,7 +733,7 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
}
|
||||
}
|
||||
|
||||
def _hf_processor_applies_repl(
|
||||
def _hf_processor_applies_updates(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_items: MultiModalDataItems,
|
||||
@ -740,10 +741,10 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self, mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||
out_mm_kwargs: MultiModalKwargs) -> List[PromptReplacement]:
|
||||
out_mm_kwargs: MultiModalKwargs) -> Sequence[PromptReplacement]:
|
||||
placeholder = {
|
||||
"image": self.info.image_pattern,
|
||||
"video": self.info.video_pattern,
|
||||
|
@ -15,8 +15,8 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch Mllama model."""
|
||||
import math
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -59,7 +59,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataDict, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseProcessingInfo,
|
||||
EncDecMultiModalProcessor,
|
||||
PromptReplacement)
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
|
||||
from .clip import CLIPMLP
|
||||
@ -243,12 +243,12 @@ class MllamaMultiModalProcessor(EncDecMultiModalProcessor[MllamaProcessingInfo]
|
||||
image_token_id = self.info.get_hf_config().image_token_index
|
||||
return [image_token_id] * num_images
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
token_per_chunk = self.info.get_token_per_chunk_from_config()
|
||||
image_token_id = self.info.get_hf_config().image_token_index
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property, partial
|
||||
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
||||
Union, cast)
|
||||
from typing import List, Optional, Set, Tuple, TypedDict, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -46,8 +46,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
BaseProcessingInfo, PromptInsertion,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import JSONTree, json_map_leaves
|
||||
@ -1190,6 +1190,8 @@ class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
return MolmoProcessorWrapper(processor)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
# TODO: Investigate different `embed_is_patch` between cache/no-cache
|
||||
# in multi-image case
|
||||
return {"image": 1}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
@ -1328,25 +1330,18 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
img_patch_id=MultiModalFieldConfig.shared("image", num_images),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
image_token_length_w = processor.image_token_length_w
|
||||
image_token_length_h = processor.image_token_length_h
|
||||
pooling_size = processor.pooling_size
|
||||
|
||||
user_str = "User:"
|
||||
if processor.always_start_with_space:
|
||||
user_str = " " + user_str
|
||||
|
||||
user_tokens = tokenizer.encode(user_str, add_special_tokens=False)
|
||||
|
||||
img_patch_id = processor.image_patch_id
|
||||
img_col_id = processor.im_col_id
|
||||
img_start_id = processor.im_start_id
|
||||
@ -1356,7 +1351,7 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
extra_joint = ([img_start_id] + extra_row * image_token_length_h +
|
||||
[img_end_id])
|
||||
|
||||
def get_replacement_molmo(item_idx: int):
|
||||
def get_insertion_molmo(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
@ -1371,17 +1366,13 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
((nrows + 1) // pooling_size) + [img_end_id])
|
||||
|
||||
image_tokens = extra_joint + joint
|
||||
|
||||
return PromptReplacementDetails(
|
||||
full=image_tokens + user_tokens,
|
||||
features=image_tokens,
|
||||
)
|
||||
return image_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
PromptInsertion(
|
||||
modality="image",
|
||||
target=user_str,
|
||||
replacement=get_replacement_molmo,
|
||||
target="<|endoftext|>",
|
||||
insertion=get_insertion_molmo,
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -6,7 +6,8 @@
|
||||
# Copyright (c) 2024 NVIDIA
|
||||
# Licensed under Apache 2.0 License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from typing import Mapping, Optional
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -17,8 +18,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import ProcessorInputs
|
||||
|
||||
from .intern_vit import InternVisionModel
|
||||
@ -142,12 +143,12 @@ class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
|
||||
|
||||
class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
if "image_num_patches" in out_mm_kwargs:
|
||||
@ -179,7 +180,7 @@ class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
|
||||
if num_patches is not None:
|
||||
assert isinstance(num_patches, int)
|
||||
|
||||
return PromptReplacementDetails(
|
||||
return PromptUpdateDetails(
|
||||
full=hf_processor.get_image_repl_full(feature_size,
|
||||
num_patches) + "\n",
|
||||
features=hf_processor.get_image_repl_features(
|
||||
|
@ -38,11 +38,10 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
BoundPromptReplacement,
|
||||
BaseProcessingInfo, BoundPromptUpdate,
|
||||
PlaceholderFeaturesInfo,
|
||||
PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -420,12 +419,12 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
|
||||
|
||||
@ -449,7 +448,7 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
|
||||
|
||||
image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens
|
||||
|
||||
return PromptReplacementDetails(
|
||||
return PromptUpdateDetails(
|
||||
full=image_tokens + [bos_token_id],
|
||||
features=image_tokens,
|
||||
)
|
||||
@ -464,15 +463,15 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
|
||||
) for image_token in image_tokens[:num_images]
|
||||
]
|
||||
|
||||
def _apply_prompt_replacements(
|
||||
def _apply_prompt_updates(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
||||
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
|
||||
token_ids, text, placeholders = super()._apply_prompt_replacements(
|
||||
token_ids, text, placeholders = super()._apply_prompt_updates(
|
||||
token_ids=token_ids,
|
||||
mm_prompt_repls=mm_prompt_repls,
|
||||
mm_prompt_updates=mm_prompt_updates,
|
||||
mm_item_counts=mm_item_counts,
|
||||
)
|
||||
|
||||
|
@ -15,7 +15,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only IBM/NASA Prithvi Geospatial model."""
|
||||
from typing import Iterable, Mapping, Optional, Set, Tuple, Union
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -32,7 +33,7 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargs)
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement)
|
||||
BaseProcessingInfo, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import (IntermediateTensors, PoolerOutput,
|
||||
PoolingSequenceGroupOutput)
|
||||
@ -44,7 +45,7 @@ class PrithviGeoSpatialMAEProcessingInfo(BaseProcessingInfo):
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
pass
|
||||
return {"image": 0}
|
||||
|
||||
|
||||
class PrithviGeoSpatialMAEInputBuilder(
|
||||
@ -78,20 +79,13 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||
location_coords=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
pass
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
pass
|
||||
) -> Sequence[PromptUpdate]:
|
||||
return []
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -120,7 +114,7 @@ class PrithviGeoSpatialMAEMultiModalProcessor(BaseMultiModalProcessor):
|
||||
class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
""" Prithvi Masked Autoencoder"""
|
||||
|
||||
def _instantiate_model(self, config: dict) -> nn.Module | None:
|
||||
def _instantiate_model(self, config: dict) -> Optional[nn.Module]:
|
||||
|
||||
# We might be able/need to support different tasks with this same model
|
||||
if config["task_args"]["task"] == "SemanticSegmentationTask":
|
||||
@ -158,7 +152,7 @@ class PrithviGeoSpatialMAE(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
"by PrithviGeospatialMAE.")
|
||||
|
||||
def _parse_and_validate_multimodal_data(
|
||||
self, **kwargs) -> Tuple[torch.Tensor, torch.Tensor | None]:
|
||||
self, **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
|
@ -21,9 +21,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import (Any, Iterable, Mapping, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
from typing import Any, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -43,7 +43,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -188,12 +188,12 @@ class Qwen2AudioMultiModalProcessor(
|
||||
feature_attention_mask=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
@ -230,7 +230,7 @@ class Qwen2AudioMultiModalProcessor(
|
||||
|
||||
audio_tokens = [audio_token_id] * num_features
|
||||
|
||||
return PromptReplacementDetails(
|
||||
return PromptUpdateDetails(
|
||||
full=[audio_bos_id] + audio_tokens + [audio_eos_id],
|
||||
features=audio_tokens,
|
||||
)
|
||||
|
@ -23,9 +23,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property, partial
|
||||
from typing import (Any, Callable, Iterable, Literal, Mapping, Optional, Set,
|
||||
Tuple, Type, TypedDict, Union)
|
||||
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -61,7 +62,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageSize,
|
||||
ModalityDataItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement)
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.platforms import _Backend
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -169,7 +171,7 @@ class Qwen2VisionMLP(nn.Module):
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int,
|
||||
act_layer: Type[nn.Module] = QuickGELU,
|
||||
act_layer: type[nn.Module] = QuickGELU,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
@ -383,7 +385,7 @@ class Qwen2VisionBlock(nn.Module):
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float,
|
||||
act_layer: Type[nn.Module] = QuickGELU,
|
||||
act_layer: type[nn.Module] = QuickGELU,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
@ -987,12 +989,12 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
|
||||
self.info._get_image_processor_kwargs(**mm_kwargs),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
image_processor = self.info.get_image_processor(
|
||||
**hf_processor_mm_kwargs)
|
||||
|
@ -9,9 +9,10 @@ import copy
|
||||
import math
|
||||
import re
|
||||
import unicodedata
|
||||
from collections.abc import Collection, Mapping, Sequence
|
||||
from collections.abc import Set as AbstractSet
|
||||
from functools import lru_cache, partial
|
||||
from typing import (AbstractSet, Callable, Collection, List, Literal, Mapping,
|
||||
Optional, TypedDict, Union)
|
||||
from typing import Callable, List, Literal, Optional, TypedDict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -36,7 +37,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
from vllm.multimodal.parse import MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -606,7 +607,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
def _hf_processor_applies_repl(
|
||||
def _hf_processor_applies_updates(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_items: MultiModalDataItems,
|
||||
@ -624,12 +625,12 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
special_tokens: dict[str,
|
||||
int] = tokenizer.special_tokens # type: ignore
|
||||
@ -646,7 +647,7 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[img_start_id, img_end_id],
|
||||
replacement=PromptReplacementDetails(
|
||||
replacement=PromptUpdateDetails(
|
||||
full=[img_start_id] + image_tokens + [img_end_id],
|
||||
features=image_tokens,
|
||||
),
|
||||
|
@ -3,9 +3,9 @@
|
||||
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
|
||||
"""PyTorch Ultravox model."""
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import (Any, Iterable, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
@ -29,7 +29,8 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement)
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||
@ -197,12 +198,12 @@ class UltravoxMultiModalProcessor(
|
||||
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
@ -1,8 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import math
|
||||
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import List, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -31,7 +31,7 @@ from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (BaseProcessingInfo,
|
||||
EncDecMultiModalProcessor,
|
||||
PromptReplacement)
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsTranscription
|
||||
@ -623,12 +623,12 @@ class WhisperMultiModalProcessor(
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(input_features=MultiModalFieldConfig.batched("audio"))
|
||||
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> Sequence[PromptUpdate]:
|
||||
num_tokens = self.info.get_max_audio_tokens()
|
||||
return [
|
||||
PromptReplacement(
|
||||
|
@ -6,11 +6,14 @@ from collections import defaultdict
|
||||
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
|
||||
Sequence)
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from itertools import groupby
|
||||
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
|
||||
TypeVar, Union)
|
||||
TypeVar, Union, cast)
|
||||
|
||||
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.inputs import InputProcessingContext
|
||||
@ -38,35 +41,129 @@ PromptSeq = Union[str, list[int]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptReplacementDetails:
|
||||
"""Details about the replacement token sequence or text."""
|
||||
class PromptUpdateDetails:
|
||||
"""Details about the token sequence or text that are part of the update."""
|
||||
|
||||
full: PromptSeq
|
||||
"""The full replacement."""
|
||||
"""The full content."""
|
||||
|
||||
features: PromptSeq
|
||||
"""
|
||||
The part of the replacement that corresponds to feature placeholders;
|
||||
The part of the content that corresponds to feature placeholders;
|
||||
this will be replaced by the output of the vision encoder during model
|
||||
inference.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_seq(seq: PromptSeq) -> "PromptReplacementDetails":
|
||||
return PromptReplacementDetails(full=seq, features=seq)
|
||||
def from_seq(seq: PromptSeq) -> "PromptUpdateDetails":
|
||||
return PromptUpdateDetails(full=seq, features=seq)
|
||||
|
||||
|
||||
PromptRepl = Union[PromptSeq, PromptReplacementDetails]
|
||||
PromptUpdateInfo = Union[PromptSeq, PromptUpdateDetails]
|
||||
"""
|
||||
The replacement token sequence or text.
|
||||
The token sequence or text that are part of the update.
|
||||
|
||||
If only part of the replacement corresponds to feature placeholders, you can
|
||||
use :class:`PromptReplacementDetails` to specify which part.
|
||||
If only part of the content corresponds to feature placeholders, you can
|
||||
use :class:`PromptUpdateDetails` to specify which part.
|
||||
"""
|
||||
|
||||
PromptUpdateContent = Union[Callable[[int], PromptUpdateInfo],
|
||||
PromptUpdateInfo]
|
||||
"""
|
||||
Given the index of the processed item within :attr:`modality`,
|
||||
output the corresponding token sequence (or text).
|
||||
|
||||
For convenience, you can directly pass in the token sequence (or text)
|
||||
instead of a function if it does not depend on the input.
|
||||
"""
|
||||
|
||||
|
||||
class UpdateMode(str, Enum):
|
||||
INSERT = "insert"
|
||||
REPLACE = "replace"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptReplacement:
|
||||
class PromptUpdate:
|
||||
"""
|
||||
Defines how to update a prompt with placeholder tokens.
|
||||
"""
|
||||
|
||||
modality: str
|
||||
"""The modality for which the update is made."""
|
||||
|
||||
target: PromptSeq
|
||||
"""The token sequence (or text) to update."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def content(self) -> PromptUpdateContent:
|
||||
"""The placeholder tokens that are part of the update."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def mode(self) -> UpdateMode:
|
||||
"""Defines how to update the prompt."""
|
||||
raise NotImplementedError
|
||||
|
||||
def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptUpdate":
|
||||
return BoundPromptUpdate(
|
||||
_origin=self,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptInsertion(PromptUpdate):
|
||||
"""
|
||||
Defines how to insert placeholder tokens into a prompt.
|
||||
|
||||
Example:
|
||||
|
||||
For each image, insert a number of ``<image>`` feature placeholders
|
||||
equal to the feature size of the vision encoder at the start of the
|
||||
prompt:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
PromptInsertion(
|
||||
modality="image",
|
||||
target="",
|
||||
insertion="<image>" * image_feature_size,
|
||||
)
|
||||
|
||||
As above, but insert after the ``<s>`` token:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
PromptInsertion(
|
||||
modality="image",
|
||||
target="<s>",
|
||||
insertion="<image>" * image_feature_size,
|
||||
)
|
||||
"""
|
||||
|
||||
insertion: PromptUpdateContent = field(repr=False)
|
||||
"""
|
||||
Given the index of the processed item within :attr:`modality`,
|
||||
output the token sequence (or text) to insert right after :attr:`target`.
|
||||
|
||||
For convenience, you can directly pass in the token sequence (or text)
|
||||
instead of a function if it does not depend on the input.
|
||||
"""
|
||||
|
||||
@property
|
||||
def content(self) -> PromptUpdateContent:
|
||||
return self.insertion
|
||||
|
||||
@property
|
||||
def mode(self) -> UpdateMode:
|
||||
return UpdateMode.INSERT
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptReplacement(PromptUpdate):
|
||||
"""
|
||||
Defines how to replace portions of an input prompt with placeholder tokens.
|
||||
|
||||
@ -93,7 +190,7 @@ class PromptReplacement:
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target="<image>",
|
||||
replacement=PromptReplacementDetails(
|
||||
replacement=PromptUpdateDetails(
|
||||
full="".join([
|
||||
"<image_bos>",
|
||||
"<image>" * image_feature_size,
|
||||
@ -111,7 +208,7 @@ class PromptReplacement:
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=PromptReplacementDetails(
|
||||
replacement=PromptUpdateDetails(
|
||||
full=([image_bos_id] + [image_token_id] * image_feature_size
|
||||
+ [image_eos_id]),
|
||||
features=[image_token_id] * image_feature_size,
|
||||
@ -119,29 +216,22 @@ class PromptReplacement:
|
||||
)
|
||||
"""
|
||||
|
||||
modality: str
|
||||
"""The modality for which the replacement is made."""
|
||||
|
||||
target: PromptSeq
|
||||
"""The token sequence (or text) to find and replace."""
|
||||
|
||||
replacement: Union[Callable[[int], PromptRepl],
|
||||
PromptRepl] = field(repr=False)
|
||||
replacement: PromptUpdateContent = field(repr=False)
|
||||
"""
|
||||
Given the index of the processed item within :attr:`modality`,
|
||||
output the replacement token sequence (or text).
|
||||
output the token sequence (or text) to replace :attr:`target`.
|
||||
|
||||
For convenience, you can directly pass in the replacement token sequence
|
||||
(or text) instead of a function if it does not depend on the input.
|
||||
For convenience, you can directly pass in the token sequence (or text)
|
||||
instead of a function if it does not depend on the input.
|
||||
"""
|
||||
|
||||
def bind(self, tokenizer: AnyTokenizer) -> "BoundPromptReplacement":
|
||||
return BoundPromptReplacement(
|
||||
tokenizer=tokenizer,
|
||||
modality=self.modality,
|
||||
_target=self.target,
|
||||
_replacement=self.replacement,
|
||||
)
|
||||
@property
|
||||
def content(self) -> PromptUpdateContent:
|
||||
return self.replacement
|
||||
|
||||
@property
|
||||
def mode(self) -> UpdateMode:
|
||||
return UpdateMode.REPLACE
|
||||
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
@ -232,64 +322,73 @@ class _BoundPromptSequence:
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BoundPromptReplacementGroup:
|
||||
class _BoundPromptContent:
|
||||
full: _BoundPromptSequence
|
||||
features: _BoundPromptSequence
|
||||
|
||||
|
||||
@dataclass
|
||||
class BoundPromptReplacement:
|
||||
class BoundPromptUpdate:
|
||||
"""
|
||||
A :class:`PromptReplacement` bound to a tokenizer to automatically
|
||||
convert :attr:`target` and the result of :meth:`get_replacement` between
|
||||
A :class:`PromptUpdate` bound to a tokenizer to automatically convert
|
||||
:attr:`target` and the result of :meth:`get_content` between
|
||||
token sequence and text representations.
|
||||
"""
|
||||
_origin: PromptUpdate
|
||||
tokenizer: AnyTokenizer = field(repr=False)
|
||||
modality: str
|
||||
|
||||
_target: PromptSeq
|
||||
_replacement: Union[Callable[[int], PromptRepl],
|
||||
PromptRepl] = field(repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._replacement_cache = dict[int, _BoundPromptReplacementGroup]()
|
||||
self._content_cache = dict[int, _BoundPromptContent]()
|
||||
|
||||
@property
|
||||
def modality(self) -> str:
|
||||
return self._origin.modality
|
||||
|
||||
@property
|
||||
def target(self) -> _BoundPromptSequence:
|
||||
"""The token sequence (or text) to find and replace."""
|
||||
return _BoundPromptSequence.from_seq(self.tokenizer, self._target)
|
||||
"""The token sequence (or text) to update."""
|
||||
return _BoundPromptSequence.from_seq(self.tokenizer,
|
||||
self._origin.target)
|
||||
|
||||
def get_replacement(self, item_idx: int) -> _BoundPromptReplacementGroup:
|
||||
@property
|
||||
def content(self) -> PromptUpdateContent:
|
||||
"""The placeholder tokens that are part of the update."""
|
||||
return self._origin.content
|
||||
|
||||
@property
|
||||
def mode(self) -> UpdateMode:
|
||||
"""Defines how to update the prompt."""
|
||||
return self._origin.mode
|
||||
|
||||
def get_content(self, item_idx: int) -> _BoundPromptContent:
|
||||
"""
|
||||
Given the index of the processed item within :attr:`modality`,
|
||||
output the replacement token sequence (or text).
|
||||
output the token sequence (or text) to update.
|
||||
"""
|
||||
replacement = self._replacement
|
||||
if callable(replacement):
|
||||
content = self.content
|
||||
if callable(content):
|
||||
cache_key = item_idx
|
||||
if cache_key in self._replacement_cache:
|
||||
return self._replacement_cache[cache_key]
|
||||
if cache_key in self._content_cache:
|
||||
return self._content_cache[cache_key]
|
||||
|
||||
replacement = replacement(item_idx)
|
||||
content = content(item_idx)
|
||||
else:
|
||||
cache_key = None
|
||||
|
||||
if not isinstance(replacement, PromptReplacementDetails):
|
||||
replacement = PromptReplacementDetails.from_seq(replacement)
|
||||
if not isinstance(content, PromptUpdateDetails):
|
||||
content = PromptUpdateDetails.from_seq(content)
|
||||
|
||||
bound_full = _BoundPromptSequence.from_seq(self.tokenizer,
|
||||
replacement.full)
|
||||
content.full)
|
||||
bound_features = _BoundPromptSequence.from_seq(self.tokenizer,
|
||||
replacement.features)
|
||||
bound_replacement = _BoundPromptReplacementGroup(
|
||||
full=bound_full,
|
||||
features=bound_features,
|
||||
)
|
||||
content.features)
|
||||
bound_content = _BoundPromptContent(full=bound_full,
|
||||
features=bound_features)
|
||||
|
||||
if cache_key is not None:
|
||||
self._replacement_cache[cache_key] = bound_replacement
|
||||
self._content_cache[cache_key] = bound_content
|
||||
|
||||
return bound_replacement
|
||||
return bound_content
|
||||
|
||||
|
||||
class _TokenMatch(NamedTuple):
|
||||
@ -326,12 +425,12 @@ def iter_token_matches(
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class _PromptReplacementMatch(ABC):
|
||||
prompt_repl: BoundPromptReplacement
|
||||
class _PromptTargetMatch(ABC):
|
||||
_origin: BoundPromptUpdate
|
||||
|
||||
@property
|
||||
def modality(self) -> str:
|
||||
return self.prompt_repl.modality
|
||||
return self._origin.modality
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@ -349,7 +448,7 @@ class _PromptReplacementMatch(ABC):
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class _PromptReplacementTokenMatch(_PromptReplacementMatch):
|
||||
class _PromptTargetTokenMatch(_PromptTargetMatch):
|
||||
match: _TokenMatch
|
||||
|
||||
@property
|
||||
@ -362,7 +461,7 @@ class _PromptReplacementTokenMatch(_PromptReplacementMatch):
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class _PromptReplacementTextMatch(_PromptReplacementMatch):
|
||||
class _PromptTargetTextMatch(_PromptTargetMatch):
|
||||
match: re.Match[str]
|
||||
|
||||
@property
|
||||
@ -394,40 +493,37 @@ class PlaceholderFeaturesInfo:
|
||||
|
||||
def find_token_matches(
|
||||
prompt: list[int],
|
||||
prompt_repls: Sequence[BoundPromptReplacement],
|
||||
) -> list[_PromptReplacementTokenMatch]:
|
||||
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
|
||||
prompt_updates: Sequence[BoundPromptUpdate],
|
||||
) -> Sequence[_PromptTargetMatch]:
|
||||
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
|
||||
return [
|
||||
_PromptReplacementTokenMatch(prompt_repl, match)
|
||||
for prompt_repl in prompt_repls
|
||||
for match in iter_token_matches(prompt, prompt_repl.target.token_ids)
|
||||
_PromptTargetTokenMatch(update, match) for update in prompt_updates
|
||||
for match in iter_token_matches(prompt, update.target.token_ids)
|
||||
]
|
||||
|
||||
|
||||
def find_text_matches(
|
||||
prompt: str,
|
||||
prompt_repls: Sequence[BoundPromptReplacement],
|
||||
) -> list[_PromptReplacementTextMatch]:
|
||||
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
|
||||
prompt_updates: Sequence[BoundPromptUpdate],
|
||||
) -> Sequence[_PromptTargetMatch]:
|
||||
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
|
||||
return [
|
||||
_PromptReplacementTextMatch(prompt_repl, match)
|
||||
for prompt_repl in prompt_repls
|
||||
for match in re.finditer(re.escape(prompt_repl.target.text), prompt)
|
||||
_PromptTargetTextMatch(update, match) for update in prompt_updates
|
||||
for match in re.finditer(re.escape(update.target.text), prompt)
|
||||
]
|
||||
|
||||
|
||||
def _resolve_matches(
|
||||
prompt: PromptSeq,
|
||||
mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
|
||||
) -> list[_PromptReplacementMatch]:
|
||||
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
|
||||
) -> list[_PromptTargetMatch]:
|
||||
"""
|
||||
Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
|
||||
and sort them such that earlier matches take priority over later ones.
|
||||
"""
|
||||
matches = [m for matches in mm_matches.values() for m in matches]
|
||||
|
||||
seen_matches: list[Optional[_PromptReplacementMatch]] = [None
|
||||
] * len(prompt)
|
||||
seen_matches: list[Optional[_PromptTargetMatch]] = [None] * len(prompt)
|
||||
|
||||
for match in matches:
|
||||
for idx in range(match.start_idx, match.end_idx):
|
||||
@ -441,74 +537,91 @@ def _resolve_matches(
|
||||
return sorted(matches, key=lambda x: x.start_idx)
|
||||
|
||||
|
||||
def _replace_matches(
|
||||
def _apply_matches(
|
||||
prompt: _S,
|
||||
mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]],
|
||||
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> list[_S]:
|
||||
"""Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
|
||||
out_seqs = list[_S]()
|
||||
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
|
||||
out_seqs = list[Union[str, list[int]]]()
|
||||
prev_end_idx = 0
|
||||
next_idx_by_modality = defaultdict[str, int](lambda: 0)
|
||||
|
||||
for match in _resolve_matches(prompt, mm_matches):
|
||||
modality = match.modality
|
||||
for (start_idx, end_idx), group in groupby(
|
||||
_resolve_matches(prompt, mm_matches),
|
||||
key=lambda x: (x.start_idx, x.end_idx),
|
||||
):
|
||||
matches = tuple(group)
|
||||
assert len(matches) == 1
|
||||
|
||||
item_idx = next_idx_by_modality[modality]
|
||||
if item_idx >= mm_item_counts.get(modality, 0):
|
||||
continue
|
||||
for match in matches:
|
||||
modality = match.modality
|
||||
|
||||
start_idx = match.start_idx
|
||||
end_idx = match.end_idx
|
||||
item_idx = next_idx_by_modality[modality]
|
||||
if item_idx >= mm_item_counts.get(modality, 0):
|
||||
continue
|
||||
|
||||
repl_info = match.prompt_repl
|
||||
replacement = repl_info.get_replacement(item_idx)
|
||||
origin = match._origin
|
||||
content = origin.get_content(item_idx)
|
||||
mode = origin.mode
|
||||
|
||||
if isinstance(prompt, str):
|
||||
repl_seq = replacement.full.text
|
||||
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
|
||||
else:
|
||||
repl_seq = replacement.full.token_ids
|
||||
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
|
||||
if mode == UpdateMode.INSERT:
|
||||
out_seqs.append(prompt[prev_end_idx:end_idx])
|
||||
num_inserts = mm_item_counts.get(modality, 0)
|
||||
elif mode == UpdateMode.REPLACE:
|
||||
out_seqs.append(prompt[prev_end_idx:start_idx])
|
||||
num_inserts = 1
|
||||
else:
|
||||
assert_never(mode)
|
||||
|
||||
prev_end_idx = end_idx
|
||||
next_idx_by_modality[modality] += 1
|
||||
for _ in range(num_inserts):
|
||||
if item_idx >= mm_item_counts.get(modality, 0):
|
||||
continue
|
||||
|
||||
if isinstance(prompt, str):
|
||||
out_seqs.append(content.full.text)
|
||||
else:
|
||||
out_seqs.append(content.full.token_ids)
|
||||
|
||||
next_idx_by_modality[modality] += 1
|
||||
|
||||
prev_end_idx = end_idx
|
||||
|
||||
out_seqs.append(prompt[prev_end_idx:])
|
||||
|
||||
return out_seqs
|
||||
return cast(list[_S], out_seqs)
|
||||
|
||||
|
||||
def replace_token_matches(
|
||||
def apply_token_matches(
|
||||
prompt: list[int],
|
||||
mm_matches: Mapping[str, Sequence[_PromptReplacementTokenMatch]],
|
||||
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> list[int]:
|
||||
"""Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
|
||||
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
|
||||
if not mm_matches:
|
||||
return prompt
|
||||
|
||||
token_id_seqs = _replace_matches(prompt, mm_matches, mm_item_counts)
|
||||
token_id_seqs = _apply_matches(prompt, mm_matches, mm_item_counts)
|
||||
|
||||
return flatten_2d_lists(token_id_seqs)
|
||||
|
||||
|
||||
def replace_text_matches(
|
||||
def apply_text_matches(
|
||||
prompt: str,
|
||||
mm_matches: Mapping[str, Sequence[_PromptReplacementTextMatch]],
|
||||
mm_matches: Mapping[str, Sequence[_PromptTargetMatch]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> str:
|
||||
"""Apply the replacements in :code:`mm_matches` to :code:`prompt`."""
|
||||
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
|
||||
if not mm_matches:
|
||||
return prompt
|
||||
|
||||
texts = _replace_matches(prompt, mm_matches, mm_item_counts)
|
||||
texts = _apply_matches(prompt, mm_matches, mm_item_counts)
|
||||
|
||||
return "".join(texts)
|
||||
|
||||
|
||||
def _iter_placeholders(
|
||||
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
||||
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
|
||||
prompt: list[int],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> Iterable[PlaceholderFeaturesInfo]:
|
||||
@ -517,7 +630,7 @@ def _iter_placeholders(
|
||||
|
||||
Matches are exclusive even when multiple modalities share
|
||||
the same placeholder tokens. In that case, the modality that
|
||||
appears earlier in `mm_prompt_repls` takes priority.
|
||||
appears earlier in `mm_prompt_updates` takes priority.
|
||||
|
||||
Note that empty matches are ignored.
|
||||
"""
|
||||
@ -528,37 +641,37 @@ def _iter_placeholders(
|
||||
while start_idx < prompt_len:
|
||||
found = False
|
||||
|
||||
for modality, modality_repls in mm_prompt_repls.items():
|
||||
for modality, modality_updates in mm_prompt_updates.items():
|
||||
item_idx = item_idx_by_modality[modality]
|
||||
if item_idx >= mm_item_counts.get(modality, 0):
|
||||
continue
|
||||
|
||||
for repl_info in modality_repls:
|
||||
replacement = repl_info.get_replacement(item_idx)
|
||||
repl_tokens_full = replacement.full.token_ids
|
||||
repl_len_full = len(repl_tokens_full)
|
||||
end_idx_full = start_idx + repl_len_full
|
||||
for update_info in modality_updates:
|
||||
content = update_info.get_content(item_idx)
|
||||
content_tokens_full = content.full.token_ids
|
||||
content_len_full = len(content_tokens_full)
|
||||
end_idx_full = start_idx + content_len_full
|
||||
|
||||
if repl_len_full == 0 or end_idx_full > prompt_len:
|
||||
if content_len_full == 0 or end_idx_full > prompt_len:
|
||||
continue
|
||||
|
||||
if prompt[start_idx:end_idx_full] == repl_tokens_full:
|
||||
repl_tokens_feat = replacement.features.token_ids
|
||||
if prompt[start_idx:end_idx_full] == content_tokens_full:
|
||||
content_tokens_feat = content.features.token_ids
|
||||
|
||||
try:
|
||||
match = next(
|
||||
iter_token_matches(repl_tokens_full,
|
||||
repl_tokens_feat))
|
||||
iter_token_matches(content_tokens_full,
|
||||
content_tokens_feat))
|
||||
yield PlaceholderFeaturesInfo(
|
||||
modality=modality,
|
||||
item_idx=item_idx,
|
||||
start_idx=start_idx + match.start_idx,
|
||||
tokens=repl_tokens_feat,
|
||||
tokens=content_tokens_feat,
|
||||
)
|
||||
except StopIteration:
|
||||
raise AssertionError(
|
||||
f"{repl_tokens_feat=} should be a "
|
||||
f"subsequence of {repl_tokens_full=}") from None
|
||||
f"{content_tokens_feat=} should be a "
|
||||
f"subsequence of {content_tokens_full=}") from None
|
||||
|
||||
# Exclude overlapping matches
|
||||
start_idx = end_idx_full
|
||||
@ -574,11 +687,11 @@ def _iter_placeholders(
|
||||
|
||||
|
||||
def find_mm_placeholders(
|
||||
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
||||
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
|
||||
prompt: list[int],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
||||
it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts)
|
||||
it = _iter_placeholders(mm_prompt_updates, prompt, mm_item_counts)
|
||||
return dict(full_groupby_modality(it))
|
||||
|
||||
|
||||
@ -712,6 +825,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
enable_sanity_checks: bool = True) -> None:
|
||||
if get_repls := getattr(self, "_get_prompt_replacements", None):
|
||||
logger.warning_once("`_get_prompt_replacements` has been renamed "
|
||||
"to `_get_prompt_updates`. The old name will "
|
||||
"be removed in an upcoming release.")
|
||||
self._get_prompt_updates = get_repls # type: ignore[method-assign]
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.info = info
|
||||
@ -770,34 +889,34 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _get_prompt_replacements(
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
) -> list[PromptUpdate]:
|
||||
"""
|
||||
Given the original multi-modal items for this modality
|
||||
and HF-processed data, output the replacements to perform.
|
||||
and HF-processed data, output the updates to perform.
|
||||
|
||||
Notes:
|
||||
- You should not assume that HF processor always performs prompt
|
||||
replacement: in :meth:`_apply_hf_processor_missing`, this method
|
||||
updates: in :meth:`_apply_hf_processor_missing`, this method
|
||||
is called on text-only and multimodal-only inputs separately,
|
||||
instead of passing them in the same call.
|
||||
- The replacement information returned by this method is also used
|
||||
to determine the placeholder token positions for each multi-modal
|
||||
- The update information returned by this method is also used to
|
||||
determine the placeholder token positions for each multi-modal
|
||||
item.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _find_mm_placeholders(
|
||||
self,
|
||||
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
||||
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
|
||||
new_token_ids: list[int],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
||||
return find_mm_placeholders(mm_prompt_repls, new_token_ids,
|
||||
return find_mm_placeholders(mm_prompt_updates, new_token_ids,
|
||||
mm_item_counts)
|
||||
|
||||
def _get_hf_mm_data(
|
||||
@ -831,14 +950,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_kwargs,
|
||||
)
|
||||
|
||||
def _hf_processor_applies_repl(
|
||||
def _hf_processor_applies_updates(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> bool:
|
||||
"""
|
||||
Return whether the HF processor applies prompt replacements.
|
||||
Return whether the HF processor applies prompt updates.
|
||||
|
||||
For most HF processors, this should be :code:`True` when multi-modal
|
||||
data items are passed, but :code:`False` when multi-modal embeddings
|
||||
@ -858,7 +977,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
Apply the HF processor on the prompt text and multi-modal data
|
||||
together.
|
||||
|
||||
In addition, return whether prompt replacements have been applied.
|
||||
In addition, return whether prompt updates have been applied.
|
||||
"""
|
||||
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
|
||||
|
||||
@ -876,13 +995,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
|
||||
)
|
||||
|
||||
is_repl_applied = self._hf_processor_applies_repl(
|
||||
is_update_applied = self._hf_processor_applies_updates(
|
||||
prompt_text=prompt_text,
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
return prompt_ids, mm_kwargs, is_repl_applied
|
||||
return prompt_ids, mm_kwargs, is_update_applied
|
||||
|
||||
def _apply_hf_processor_text_only(self, prompt_text: str) -> list[int]:
|
||||
"""
|
||||
@ -948,21 +1067,21 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
*,
|
||||
enable_hf_prompt_replacement: bool,
|
||||
enable_hf_prompt_update: bool,
|
||||
) -> tuple[list[int], MultiModalKwargs, bool]:
|
||||
"""
|
||||
Apply the HF processor on the prompt text and multi-modal data.
|
||||
|
||||
In addition, return whether prompt replacements have been applied
|
||||
In addition, return whether prompt updates have been applied
|
||||
(for most HF processors, this should be :code:`True`).
|
||||
|
||||
Note:
|
||||
If :code:`enable_hf_prompt_replacement=False`, we use HF processor
|
||||
to perform prompt replacement if available; HF processor requires
|
||||
If :code:`enable_hf_prompt_update=False`, we use HF processor
|
||||
to perform prompt updates if available; HF processor requires
|
||||
that the prompt corresponds to multi-modal items.
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
if enable_hf_prompt_replacement:
|
||||
if enable_hf_prompt_update:
|
||||
return self._apply_hf_processor_text_mm(
|
||||
prompt_text=prompt,
|
||||
mm_items=mm_items,
|
||||
@ -999,7 +1118,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
prompt=prompt,
|
||||
mm_items=mm_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
enable_hf_prompt_replacement=True,
|
||||
enable_hf_prompt_update=True,
|
||||
)
|
||||
|
||||
mm_maybe_cached_kw_items = {
|
||||
@ -1022,17 +1141,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
mm_missing_data_items = self._to_mm_items(mm_missing_data)
|
||||
|
||||
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
|
||||
# so we can't apply prompt replacements until the new multimodal
|
||||
# so we can't apply prompt updates until the new multimodal
|
||||
# items are combined with the cached multimodal items
|
||||
(
|
||||
prompt_ids,
|
||||
mm_missing_kwargs,
|
||||
is_repl_applied,
|
||||
is_update_applied,
|
||||
) = self._apply_hf_processor_main(
|
||||
prompt=prompt,
|
||||
mm_items=mm_missing_data_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
enable_hf_prompt_replacement=False,
|
||||
enable_hf_prompt_update=False,
|
||||
)
|
||||
|
||||
mm_missing_next_idx = {
|
||||
@ -1071,28 +1190,28 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
|
||||
mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
|
||||
|
||||
return prompt_ids, mm_kwargs, is_repl_applied
|
||||
return prompt_ids, mm_kwargs, is_update_applied
|
||||
|
||||
def _bind_and_group_repls(
|
||||
def _bind_and_group_updates(
|
||||
self,
|
||||
prompt_repls: list[PromptReplacement],
|
||||
) -> dict[str, list[BoundPromptReplacement]]:
|
||||
prompt_updates: list[PromptUpdate],
|
||||
) -> dict[str, list[BoundPromptUpdate]]:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls)
|
||||
it = (update.bind(tokenizer) for update in prompt_updates)
|
||||
return dict(full_groupby_modality(it))
|
||||
|
||||
def _apply_prompt_replacements(
|
||||
def _apply_prompt_updates(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
|
||||
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
mm_token_matches = {
|
||||
modality: find_token_matches(token_ids, prompt_repls)
|
||||
for modality, prompt_repls in mm_prompt_repls.items()
|
||||
modality: find_token_matches(token_ids, updates)
|
||||
for modality, updates in mm_prompt_updates.items()
|
||||
}
|
||||
mm_match_counts = {
|
||||
modality: len(matches)
|
||||
@ -1107,31 +1226,31 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
# up a token, then the token ID of "foo" will not appear at all
|
||||
# ----
|
||||
# Since it is inefficient to search for all possible tokenizations
|
||||
# of the search text in the prompt, we instead perform string
|
||||
# replacement on the decoded token IDs, then encode them back.
|
||||
# of the search text in the prompt, we instead perform string-based
|
||||
# updates on the decoded token IDs, then encode them back.
|
||||
if all(
|
||||
mm_match_counts.get(modality, 0) >= item_count
|
||||
for modality, item_count in mm_item_counts.items()
|
||||
): # yapf: disable
|
||||
token_ids = replace_token_matches(
|
||||
token_ids = apply_token_matches(
|
||||
token_ids,
|
||||
mm_token_matches,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
text = decode_tokens(tokenizer, token_ids)
|
||||
matched_repls = {
|
||||
modality: [match.prompt_repl for match in token_matches]
|
||||
matched_updates = {
|
||||
modality: [match._origin for match in token_matches]
|
||||
for modality, token_matches in mm_token_matches.items()
|
||||
}
|
||||
else:
|
||||
text = decode_tokens(tokenizer, token_ids)
|
||||
|
||||
mm_text_matches = {
|
||||
modality: find_text_matches(text, prompt_repls)
|
||||
for modality, prompt_repls in mm_prompt_repls.items()
|
||||
modality: find_text_matches(text, updates)
|
||||
for modality, updates in mm_prompt_updates.items()
|
||||
}
|
||||
text = replace_text_matches(
|
||||
text = apply_text_matches(
|
||||
text,
|
||||
mm_text_matches,
|
||||
mm_item_counts,
|
||||
@ -1140,13 +1259,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
token_ids = encode_tokens(tokenizer,
|
||||
text,
|
||||
add_special_tokens=False)
|
||||
matched_repls = {
|
||||
modality: [match.prompt_repl for match in token_matches]
|
||||
matched_updates = {
|
||||
modality: [match._origin for match in token_matches]
|
||||
for modality, token_matches in mm_text_matches.items()
|
||||
}
|
||||
|
||||
placeholders = self._find_mm_placeholders(
|
||||
matched_repls,
|
||||
matched_updates,
|
||||
token_ids,
|
||||
mm_item_counts,
|
||||
)
|
||||
@ -1184,14 +1303,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
|
||||
if len(placeholders) != item_count:
|
||||
raise RuntimeError(
|
||||
f"Expected there to be {item_count} prompt replacements "
|
||||
f"Expected there to be {item_count} prompt updates "
|
||||
f"corresponding to {item_count} {modality} items, but "
|
||||
f"instead found {len(placeholders)} prompt replacements! "
|
||||
f"instead found {len(placeholders)} prompt updates! "
|
||||
"Either the prompt text has missing/incorrect tokens for "
|
||||
"multi-modal inputs, or there is a problem with your "
|
||||
"implementation of merged multi-modal processor for this "
|
||||
"model (usually arising from an inconsistency between "
|
||||
"`_call_hf_processor` and `_get_prompt_replacements`).")
|
||||
"`_call_hf_processor` and `_get_prompt_updates`).")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -1206,7 +1325,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
|
||||
1. Apply HF Processor on prompt text and multi-modal data together,
|
||||
outputting token IDs and processed tensors.
|
||||
2. Find and replace sequences in the token IDs with placeholder tokens.
|
||||
2. Find and update sequences in the token IDs with placeholder tokens.
|
||||
The number of placeholder tokens equals the feature size of the
|
||||
multi-modal data outputted by the multi-modal encoder.
|
||||
3. Extract information about the placeholder tokens from the
|
||||
@ -1235,26 +1354,27 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
(
|
||||
prompt_ids,
|
||||
mm_kwargs,
|
||||
is_repl_applied,
|
||||
is_update_applied,
|
||||
) = self._cached_apply_hf_processor(
|
||||
prompt,
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
)
|
||||
|
||||
unbound_prompt_repls = self._get_prompt_replacements(
|
||||
unbound_prompt_updates = self._get_prompt_updates(
|
||||
mm_items,
|
||||
hf_processor_mm_kwargs,
|
||||
mm_kwargs,
|
||||
)
|
||||
mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls)
|
||||
mm_prompt_updates = self._bind_and_group_updates(
|
||||
unbound_prompt_updates)
|
||||
|
||||
mm_item_counts = mm_items.get_all_counts()
|
||||
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
|
||||
|
||||
if is_repl_applied:
|
||||
if is_update_applied:
|
||||
mm_placeholders = self._find_mm_placeholders(
|
||||
mm_prompt_repls,
|
||||
mm_prompt_updates,
|
||||
prompt_ids,
|
||||
mm_item_counts,
|
||||
)
|
||||
@ -1267,9 +1387,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
prompt_ids,
|
||||
prompt,
|
||||
mm_placeholders,
|
||||
) = self._apply_prompt_replacements(
|
||||
) = self._apply_prompt_updates(
|
||||
prompt_ids,
|
||||
mm_prompt_repls,
|
||||
mm_prompt_updates,
|
||||
mm_item_counts,
|
||||
)
|
||||
self._validate_mm_placeholders(mm_placeholders, mm_item_counts)
|
||||
|
Reference in New Issue
Block a user