mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[VLM] Fully dynamic prompt replacement in merged input processor (#11199)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -97,9 +97,6 @@ def run_phi3v(question: str, modality: str):
|
||||
# max_model_len (128k) for this model may cause OOM.
|
||||
# You may lower either to run this example on lower-end GPUs.
|
||||
|
||||
# In this example, we override max_num_seqs to 5 while
|
||||
# keeping the original context length of 128k.
|
||||
|
||||
# num_crops is an override kwarg to the multimodal image processor;
|
||||
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
|
||||
# to use 16 for single frame scenarios, and 4 for multi-frame.
|
||||
@ -113,7 +110,7 @@ def run_phi3v(question: str, modality: str):
|
||||
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
|
||||
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
|
||||
llm = LLM(
|
||||
model="microsoft/Phi-3-vision-128k-instruct",
|
||||
model="microsoft/Phi-3.5-vision-instruct",
|
||||
trust_remote_code=True,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
|
@ -16,8 +16,8 @@ models = ["microsoft/Phi-3.5-vision-instruct"]
|
||||
# Wrap lazy imports to avoid initializing CUDA during test collection
|
||||
@pytest.fixture()
|
||||
def processor_for_phi3v():
|
||||
from vllm.model_executor.models.phi3v import Phi3VProcessor
|
||||
return Phi3VProcessor
|
||||
from vllm.model_executor.models.phi3v import Phi3VMultiModalProcessor
|
||||
return Phi3VMultiModalProcessor
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
@ -1,11 +1,11 @@
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo,
|
||||
find_text_matches, find_token_matches,
|
||||
iter_placeholders, iter_token_matches,
|
||||
from vllm.multimodal.processing import (MultiModalDataItems, PromptReplacement,
|
||||
_PlaceholderInfo, find_text_matches,
|
||||
find_token_matches, iter_placeholders,
|
||||
iter_token_matches,
|
||||
replace_text_matches,
|
||||
replace_token_matches)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
@ -16,7 +16,7 @@ from vllm.utils import full_groupby
|
||||
@pytest.mark.parametrize(
|
||||
("token_ids", "match_ids", "expected"),
|
||||
[
|
||||
([], [], [{ "start_idx": 0, "end_idx": 0 }]),
|
||||
([], [], []),
|
||||
([], [32000], []),
|
||||
(
|
||||
[32000, 32000, 32000],
|
||||
@ -83,7 +83,7 @@ def test_iter_token_matches(token_ids, match_ids, expected):
|
||||
"pattern_2": [32000],
|
||||
},
|
||||
{
|
||||
"pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
|
||||
"pattern_1": [],
|
||||
"pattern_2": [],
|
||||
}
|
||||
),
|
||||
@ -136,7 +136,7 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key):
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
prompt_repls = [
|
||||
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
|
||||
PromptReplacement(key, target, []).bind(mock_tokenizer)
|
||||
for key, target in target_by_key.items()
|
||||
]
|
||||
result = find_token_matches(prompt, prompt_repls)
|
||||
@ -243,7 +243,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
prompt_repls = [
|
||||
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
|
||||
PromptReplacement(key, target, []).bind(mock_tokenizer)
|
||||
for key, target in target_by_key.items()
|
||||
]
|
||||
result = find_text_matches(prompt, prompt_repls)
|
||||
@ -276,12 +276,12 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
|
||||
"pattern_3": "!",
|
||||
},
|
||||
{
|
||||
# Test whether target is confused with repl_unit
|
||||
"pattern_1": ("<image><image>", 1),
|
||||
# Test empty repl_unit
|
||||
"pattern_2": ("", 1),
|
||||
# Test multiple repl_count
|
||||
"pattern_3": ("?", 2),
|
||||
# Test whether target is confused with replacement
|
||||
"pattern_1": "<image><image>",
|
||||
# Test empty replacement
|
||||
"pattern_2": "",
|
||||
# Test dynamic replacement (beyond the form of `unit * count`)
|
||||
"pattern_3": "?!?",
|
||||
},
|
||||
),
|
||||
]
|
||||
@ -290,8 +290,8 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
|
||||
("mm_count", "expected"),
|
||||
[
|
||||
(0, "Image:<image>Image:<image><image>!"),
|
||||
(1, "<image><image>Image:<image><image>??"),
|
||||
(2, "<image><image><image><image><image>??"),
|
||||
(1, "<image><image>Image:<image><image>?!?"),
|
||||
(2, "<image><image><image><image><image>?!?"),
|
||||
]
|
||||
)
|
||||
# yapf: enable
|
||||
@ -306,7 +306,7 @@ def test_find_replace_text(
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
prompt_repls = [
|
||||
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer)
|
||||
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
|
||||
for key, target in target_by_key.items()
|
||||
]
|
||||
matches = find_text_matches(prompt, prompt_repls)
|
||||
@ -314,9 +314,8 @@ def test_find_replace_text(
|
||||
result = replace_text_matches(
|
||||
prompt,
|
||||
matches,
|
||||
{key: list(range(mm_count))
|
||||
for key in repl_by_key},
|
||||
BatchFeature(),
|
||||
MultiModalDataItems({key: [None] * mm_count
|
||||
for key in repl_by_key}),
|
||||
)
|
||||
|
||||
# Only displayed on error
|
||||
@ -343,12 +342,12 @@ def test_find_replace_text(
|
||||
"pattern_3": [918],
|
||||
},
|
||||
{
|
||||
# Test whether target is confused with repl_unit
|
||||
"pattern_1": ([32000, 32000], 1),
|
||||
# Test empty repl_unit
|
||||
"pattern_2": ([], 1),
|
||||
# Test multiple repl_count
|
||||
"pattern_3": ([1550], 2),
|
||||
# Test whether target is confused with replacement
|
||||
"pattern_1": [32000, 32000],
|
||||
# Test empty replacement
|
||||
"pattern_2": [],
|
||||
# Test dynamic replacement (beyond the form of `unit * count`)
|
||||
"pattern_3": [1550, 918, 1550],
|
||||
},
|
||||
),
|
||||
]
|
||||
@ -357,8 +356,8 @@ def test_find_replace_text(
|
||||
("mm_count", "expected"),
|
||||
[
|
||||
(0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]),
|
||||
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550]),
|
||||
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 1550]),
|
||||
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550]),
|
||||
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550]),
|
||||
]
|
||||
)
|
||||
# yapf: enable
|
||||
@ -373,7 +372,7 @@ def test_find_replace_tokens(
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
prompt_repls = [
|
||||
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer)
|
||||
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
|
||||
for key, target in target_by_key.items()
|
||||
]
|
||||
matches = find_token_matches(prompt, prompt_repls)
|
||||
@ -381,9 +380,8 @@ def test_find_replace_tokens(
|
||||
result = replace_token_matches(
|
||||
prompt,
|
||||
matches,
|
||||
{key: list(range(mm_count))
|
||||
for key in repl_by_key},
|
||||
BatchFeature(),
|
||||
MultiModalDataItems({key: [None] * mm_count
|
||||
for key in repl_by_key}),
|
||||
)
|
||||
|
||||
# Only displayed on error
|
||||
@ -399,9 +397,9 @@ def test_find_replace_tokens(
|
||||
"repl_by_key",
|
||||
[
|
||||
{
|
||||
"pattern_1": ([32000, 32000], 1),
|
||||
"pattern_2": ([], 1),
|
||||
"pattern_3": ([1550], 2),
|
||||
"pattern_1": [32000, 32000],
|
||||
"pattern_2": [],
|
||||
"pattern_3": [1550, 918, 1550],
|
||||
},
|
||||
],
|
||||
)
|
||||
@ -414,48 +412,47 @@ def test_find_replace_tokens(
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
start_idx=6,
|
||||
unit=[32000, 32000],
|
||||
unit_count=1,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550],
|
||||
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
|
||||
[
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
start_idx=1,
|
||||
unit=[32000, 32000],
|
||||
unit_count=1,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
start_idx=5,
|
||||
unit=[32000, 32000],
|
||||
unit_count=1,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_3",
|
||||
start_idx=7,
|
||||
unit=[1550],
|
||||
unit_count=2,
|
||||
replacement=[1550, 918, 1550],
|
||||
),
|
||||
],
|
||||
),
|
||||
(
|
||||
[1, 32000, 32000, 32000, 32000, 32000, 1550, 1550],
|
||||
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
|
||||
[
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
start_idx=1,
|
||||
unit=[32000, 32000],
|
||||
unit_count=2,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_1",
|
||||
start_idx=3,
|
||||
replacement=[32000, 32000],
|
||||
),
|
||||
_PlaceholderInfo(
|
||||
modality="pattern_3",
|
||||
start_idx=6,
|
||||
unit=[1550],
|
||||
unit_count=2,
|
||||
replacement=[1550, 918, 1550],
|
||||
),
|
||||
],
|
||||
),
|
||||
@ -470,11 +467,17 @@ def test_iter_placeholders(
|
||||
mock_tokenizer = cast(AnyTokenizer, object())
|
||||
|
||||
prompt_repls = [
|
||||
PromptReplacement([], *repl).bind(key, mock_tokenizer)
|
||||
PromptReplacement(key, [], repl).bind(mock_tokenizer)
|
||||
for key, repl in repl_by_key.items()
|
||||
]
|
||||
|
||||
result = list(iter_placeholders(prompt_repls, prompt))
|
||||
result = list(
|
||||
iter_placeholders(
|
||||
prompt_repls,
|
||||
prompt,
|
||||
# Effectively match all occurrences in the prompt
|
||||
MultiModalDataItems({key: [None] * 3 for key in repl_by_key}),
|
||||
))
|
||||
|
||||
# Only displayed on error
|
||||
print("result:", result)
|
||||
|
@ -3,14 +3,14 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
|
||||
LlavaProcessor,
|
||||
LlavaMultiModalProcessor,
|
||||
get_max_llava_image_tokens)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor)
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
|
||||
class MyLlava(LlavaForConditionalGeneration):
|
||||
|
||||
def compute_logits(
|
||||
|
@ -2,7 +2,7 @@ import functools
|
||||
from collections import UserDict
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple,
|
||||
Optional, Protocol, Type, cast)
|
||||
Optional, Protocol, Type)
|
||||
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig, ProcessorMixin
|
||||
@ -47,7 +47,6 @@ class InputContext:
|
||||
Raises:
|
||||
TypeError: If the model is not of the specified type.
|
||||
"""
|
||||
|
||||
hf_config = self.model_config.hf_config
|
||||
if not isinstance(hf_config, hf_config_type):
|
||||
raise TypeError("Invalid type of HuggingFace config. "
|
||||
@ -60,21 +59,70 @@ class InputContext:
|
||||
"""
|
||||
Get the HuggingFace image processor configuration of the model.
|
||||
"""
|
||||
|
||||
return self.model_config.hf_image_processor_config
|
||||
|
||||
def get_mm_config(self):
|
||||
"""
|
||||
Get the multimodal config of the model.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the model is not a multimodal model.
|
||||
"""
|
||||
mm_config = self.model_config.multimodal_config
|
||||
if mm_config is None:
|
||||
raise RuntimeError("Not a multimodal model")
|
||||
|
||||
return mm_config
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
|
||||
base_kwargs = self.model_config.mm_processor_kwargs
|
||||
if base_kwargs is None:
|
||||
base_kwargs = {}
|
||||
|
||||
merged_kwargs = {**base_kwargs, **kwargs}
|
||||
|
||||
return cached_get_processor(
|
||||
self.model_config.model,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
**merged_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InputProcessingContext(InputContext):
|
||||
tokenizer: AnyTokenizer
|
||||
"""The tokenizer used to tokenize the inputs."""
|
||||
|
||||
def get_hf_processor(self, **kwargs) -> ProcessorMixin:
|
||||
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
|
||||
base_kwargs = self.model_config.mm_processor_kwargs
|
||||
if base_kwargs is None:
|
||||
base_kwargs = {}
|
||||
|
||||
merged_kwargs = {**base_kwargs, **kwargs}
|
||||
|
||||
return cached_get_processor(
|
||||
self.model_config.tokenizer,
|
||||
self.model_config.model,
|
||||
tokenizer=self.tokenizer, # Override the tokenizer with ours
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
**kwargs)
|
||||
**merged_kwargs,
|
||||
)
|
||||
|
||||
def resolve_hf_processor_call_kwargs(
|
||||
self,
|
||||
hf_processor: ProcessorMixin,
|
||||
inference_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, object]:
|
||||
assert callable(hf_processor)
|
||||
|
||||
base_kwargs = self.model_config.mm_processor_kwargs
|
||||
if base_kwargs is None:
|
||||
base_kwargs = {}
|
||||
|
||||
return resolve_mm_processor_kwargs(
|
||||
base_kwargs,
|
||||
inference_kwargs,
|
||||
hf_processor,
|
||||
)
|
||||
|
||||
|
||||
N = TypeVar("N", bound=Type[nn.Module])
|
||||
@ -171,7 +219,8 @@ class InputRegistry:
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if model_cls in self._dummy_factories_by_model_type:
|
||||
if self._dummy_factories_by_model_type.contains(model_cls,
|
||||
strict=True):
|
||||
logger.warning(
|
||||
"Model class %s already has dummy data "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
@ -195,7 +244,8 @@ class InputRegistry:
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if model_cls in self._dummy_encoder_factories_by_model_type:
|
||||
if self._dummy_encoder_factories_by_model_type.contains(
|
||||
model_cls, strict=True):
|
||||
logger.warning(
|
||||
"Model class %s already has dummy encoder data "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
@ -305,7 +355,8 @@ class InputRegistry:
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if model_cls in self._input_processors_by_model_type:
|
||||
if self._input_processors_by_model_type.contains(model_cls,
|
||||
strict=True):
|
||||
logger.warning(
|
||||
"Model class %s already has input processor "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
@ -357,7 +408,7 @@ class InputRegistry:
|
||||
# If it's empty, it'll fall back to the default kwarg values
|
||||
mm_processor_kwargs = resolve_mm_processor_kwargs(
|
||||
model_config.mm_processor_kwargs,
|
||||
cast(Dict[str, Any], inputs.get("mm_processor_kwargs")),
|
||||
inputs.get("mm_processor_kwargs", {}), # type: ignore
|
||||
processor,
|
||||
)
|
||||
|
||||
|
@ -5,10 +5,10 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL.Image import Image
|
||||
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
|
||||
PixtralVisionConfig, PretrainedConfig,
|
||||
ProcessorMixin, SiglipVisionConfig)
|
||||
from transformers.models.llava import LlavaProcessor
|
||||
from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
@ -21,11 +21,9 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
InputProcessingContext,
|
||||
ModalityProcessingMetadata,
|
||||
MultiModalProcessingMetadata,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@ -33,7 +31,8 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip,
|
||||
get_max_clip_image_tokens)
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
|
||||
get_max_pixtral_hf_image_tokens)
|
||||
get_max_pixtral_hf_image_tokens,
|
||||
get_pixtral_hf_image_feature_size)
|
||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||
get_max_siglip_image_tokens)
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
@ -115,62 +114,7 @@ def get_max_llava_image_tokens(ctx: InputContext):
|
||||
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||
|
||||
|
||||
def dummy_mm_kwargs_for_llava(ctx: InputProcessingContext,
|
||||
mm_counts: Mapping[str, int]):
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
data = dummy_image_for_clip(vision_config, num_images)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
data = dummy_image_for_siglip(vision_config, num_images)
|
||||
elif isinstance(vision_config, PixtralVisionConfig):
|
||||
data = dummy_image_for_pixtral_hf(vision_config, num_images)
|
||||
else:
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
hf_processor = ctx.get_hf_processor()
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
|
||||
is_pixtral = isinstance(hf_processor, PixtralProcessor)
|
||||
|
||||
return MultiModalKwargs(
|
||||
**hf_inputs,
|
||||
is_pixtral=torch.tensor(is_pixtral),
|
||||
)
|
||||
|
||||
|
||||
def create_metadata_for_llava(
|
||||
ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
def get_repl_count(
|
||||
mm_items: list[Image],
|
||||
hf_inputs: BatchFeature,
|
||||
item_idx: int,
|
||||
) -> int:
|
||||
return get_max_llava_image_tokens(ctx)
|
||||
|
||||
return {
|
||||
"image":
|
||||
ModalityProcessingMetadata(prompt_repls=[
|
||||
PromptReplacement(target=[image_token_id],
|
||||
repl_unit=[image_token_id],
|
||||
repl_count=get_repl_count),
|
||||
]),
|
||||
}
|
||||
|
||||
|
||||
class LlavaProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def __init__(self, ctx: InputProcessingContext) -> None:
|
||||
super().__init__(
|
||||
ctx=ctx,
|
||||
metadata=create_metadata_for_llava(ctx),
|
||||
)
|
||||
class LlavaMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
|
||||
if getattr(hf_processor, "__is_patched__", False):
|
||||
@ -188,18 +132,72 @@ class LlavaProcessor(BaseMultiModalProcessor):
|
||||
|
||||
hf_processor.__is_patched__ = True # type: ignore
|
||||
|
||||
def _get_hf_processor(self) -> ProcessorMixin:
|
||||
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
|
||||
hf_processor = self.ctx.get_hf_processor()
|
||||
assert isinstance(hf_processor, (LlavaProcessor, PixtralProcessor))
|
||||
|
||||
if isinstance(hf_processor, PixtralProcessor):
|
||||
self._patch_pixtral_processor(hf_processor)
|
||||
|
||||
return hf_processor
|
||||
|
||||
def _get_dummy_mm_kwargs(
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
processor = self._get_hf_processor()
|
||||
if isinstance(processor, PixtralProcessor):
|
||||
image_token = processor.image_token
|
||||
image_break_token = processor.image_break_token
|
||||
image_end_token = processor.image_end_token
|
||||
|
||||
vision_config = hf_config.vision_config
|
||||
assert isinstance(vision_config, PixtralVisionConfig)
|
||||
|
||||
def get_replacement_pixtral(item_idx: int):
|
||||
image_size = mm_items.get_image_size(item_idx)
|
||||
(
|
||||
num_width_tokens,
|
||||
num_height_tokens,
|
||||
) = get_pixtral_hf_image_feature_size(
|
||||
vision_config,
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
tokens = ([image_token] * num_width_tokens +
|
||||
[image_break_token]) * num_height_tokens
|
||||
tokens[-1] = image_end_token
|
||||
|
||||
return "".join(tokens)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=get_replacement_pixtral,
|
||||
),
|
||||
]
|
||||
|
||||
max_image_tokens = get_max_llava_image_tokens(self.ctx)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=[image_token_id] * max_image_tokens,
|
||||
)
|
||||
]
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalKwargs:
|
||||
) -> ProcessorInputs:
|
||||
hf_config = self.ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
num_images = mm_counts["image"]
|
||||
@ -215,11 +213,13 @@ class LlavaProcessor(BaseMultiModalProcessor):
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
hf_inputs = image_processor.preprocess(data['image'],
|
||||
return_tensors="pt")
|
||||
image_token = hf_processor.image_token
|
||||
|
||||
return MultiModalKwargs(**hf_inputs)
|
||||
return ProcessorInputs(
|
||||
prompt_text=image_token * num_images,
|
||||
mm_data=data,
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
|
||||
|
||||
class LlavaLikeConfig(Protocol):
|
||||
@ -303,7 +303,7 @@ def init_vision_tower_for_llava(
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor)
|
||||
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# BitandBytes specific attributes
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
@ -584,7 +584,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
class MantisProcessor(LlavaProcessor):
|
||||
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
|
||||
def _get_hf_processor(self) -> ProcessorMixin:
|
||||
try:
|
||||
@ -604,6 +604,6 @@ class MantisProcessor(LlavaProcessor):
|
||||
# To use this model, please use
|
||||
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(MantisProcessor)
|
||||
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor)
|
||||
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
pass
|
||||
|
@ -32,13 +32,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
InputProcessingContext,
|
||||
ModalityProcessingMetadata,
|
||||
MultiModalDataDict,
|
||||
MultiModalProcessingMetadata,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
@ -305,64 +302,17 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
return image_features_hd_newline
|
||||
|
||||
|
||||
def get_max_phi3v_image_tokens(ctx: InputContext,
|
||||
*,
|
||||
num_crops: Optional[int] = None):
|
||||
mm_processor_kwargs = {}
|
||||
if num_crops is not None:
|
||||
mm_processor_kwargs["num_crops"] = num_crops
|
||||
def get_max_phi3v_image_tokens(ctx: InputContext) -> int:
|
||||
processor = ctx.get_hf_processor()
|
||||
image_processor = processor.image_processor # type: ignore
|
||||
|
||||
model_config = ctx.model_config
|
||||
image_processor = cached_get_image_processor(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
**mm_processor_kwargs,
|
||||
)
|
||||
|
||||
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
|
||||
return image_processor.calc_num_image_tokens_from_image_size(
|
||||
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
)
|
||||
return num_tokens
|
||||
|
||||
|
||||
def dummy_mm_kwargs_for_phi3v(ctx: InputProcessingContext,
|
||||
mm_counts: Mapping[str, int]):
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
data = dummy_image_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
num_images,
|
||||
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
)
|
||||
|
||||
hf_processor = ctx.get_hf_processor()
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
|
||||
|
||||
return MultiModalKwargs(**hf_inputs)
|
||||
|
||||
|
||||
def create_metadata_for_phi3v(
|
||||
ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
|
||||
return {
|
||||
"image":
|
||||
ModalityProcessingMetadata(prompt_repls=[
|
||||
PromptReplacement(target=[_IMAGE_TOKEN_ID],
|
||||
repl_unit=[_IMAGE_TOKEN_ID],
|
||||
repl_count=get_max_phi3v_image_tokens(ctx)),
|
||||
]),
|
||||
}
|
||||
|
||||
|
||||
class Phi3VProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def __init__(self, ctx: InputProcessingContext) -> None:
|
||||
super().__init__(
|
||||
ctx=ctx,
|
||||
metadata=create_metadata_for_phi3v(ctx),
|
||||
)
|
||||
class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
@ -389,15 +339,61 @@ class Phi3VProcessor(BaseMultiModalProcessor):
|
||||
processed_outputs['input_ids'] = token_ids
|
||||
return processed_outputs
|
||||
|
||||
def _get_dummy_mm_kwargs(
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
|
||||
mm_config = self.ctx.get_mm_config()
|
||||
max_images = mm_config.limit_per_prompt.get("image", 1)
|
||||
|
||||
def get_replacement_phi3v(item_idx: int):
|
||||
image_size = mm_items.get_image_size(item_idx)
|
||||
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
|
||||
width=image_size.width,
|
||||
height=image_size.height,
|
||||
)
|
||||
|
||||
return [_IMAGE_TOKEN_ID] * num_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=image_token,
|
||||
replacement=get_replacement_phi3v,
|
||||
) for image_token in image_tokens[:max_images]
|
||||
]
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalKwargs:
|
||||
return dummy_mm_kwargs_for_phi3v(self.ctx, mm_counts)
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
data = dummy_image_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
num_images,
|
||||
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
)
|
||||
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="".join(image_tokens[:num_images]),
|
||||
mm_data=data,
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(Phi3VProcessor)
|
||||
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
|
||||
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
@ -72,7 +72,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
|
||||
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
|
||||
image_token_id = mm_encoder.special_ids.img
|
||||
|
||||
mm_config = ctx.model_config.multimodal_config
|
||||
mm_config = ctx.get_mm_config()
|
||||
num_images = mm_config.limit_per_prompt.get("image", 1)
|
||||
|
||||
# dummy size
|
||||
|
@ -99,7 +99,7 @@ class MultiModalPlugin(ABC):
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if model_cls in self._input_mappers:
|
||||
if self._input_mappers.contains(model_cls, strict=True):
|
||||
logger.warning(
|
||||
"Model class %s already has an input mapper "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
@ -194,7 +194,7 @@ class MultiModalPlugin(ABC):
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if model_cls in self._max_mm_tokens:
|
||||
if self._max_mm_tokens.contains(model_cls, strict=True):
|
||||
logger.warning(
|
||||
"Model class %s already calculates maximum number of "
|
||||
"tokens in %s. It is overwritten by the new one.",
|
||||
|
@ -1,116 +1,59 @@
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import UserDict
|
||||
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from typing import (Any, Dict, Generic, NamedTuple, Optional, Protocol,
|
||||
TypeVar, Union, cast)
|
||||
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from transformers import BatchFeature, ProcessorMixin
|
||||
from typing_extensions import TypeAlias, TypedDict
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.inputs import DummyData, InputProcessingContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.utils import (flatten_2d_lists, full_groupby, is_list_of,
|
||||
resolve_mm_processor_kwargs)
|
||||
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
|
||||
|
||||
from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
|
||||
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
|
||||
VideoItem)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
def bind_prompt_sequence(
|
||||
seq: Union[str, list[int]],
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> "_BoundPromptSequence":
|
||||
"""
|
||||
Bind a text or token sequence to a tokenizer so that it can be
|
||||
lazily converted into the other format on demand.
|
||||
"""
|
||||
return _BoundPromptSequence(
|
||||
tokenizer=tokenizer,
|
||||
_text=seq if isinstance(seq, str) else None,
|
||||
_token_ids=seq if isinstance(seq, list) else None,
|
||||
)
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_S = TypeVar("_S", str, list[int])
|
||||
_PromptSeq = Union[str, list[int]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptReplacement(Generic[_S, _T]):
|
||||
target: _S
|
||||
class PromptReplacement:
|
||||
modality: str
|
||||
"""The modality for which the replacement is made"""
|
||||
|
||||
target: _PromptSeq
|
||||
"""The text or token sequence to find and replace."""
|
||||
|
||||
repl_unit: _S
|
||||
replacement: Union[Callable[[int], _PromptSeq],
|
||||
_PromptSeq] = field(repr=False)
|
||||
"""
|
||||
The unit making up the replacement text or token sequence.
|
||||
|
||||
See :code:`repl_count` for more details.
|
||||
Given the index of the processed item within :attr:`modality`, output the
|
||||
replacement text or token sequence.
|
||||
|
||||
For convenience, you can pass in the replacement instead of a function
|
||||
if it does not depend on the input.
|
||||
"""
|
||||
|
||||
repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int]
|
||||
"""
|
||||
Given the original multi-modal items for this modality, HF-processed data,
|
||||
and index of the processed item, output the number of repetitions of
|
||||
:code:`repl_unit` to build up the replacement text or token sequence.
|
||||
|
||||
For convenience, you can pass in an integer if the number of repetitions is
|
||||
a constant.
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"{type(self).__name__}(target={self.target!r}, "
|
||||
f"repl_unit={self.repl_unit!r})")
|
||||
|
||||
def bind(
|
||||
self,
|
||||
modality: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> "_BoundPromptReplacement[_T]":
|
||||
def bind(self, tokenizer: AnyTokenizer) -> "_BoundPromptReplacement":
|
||||
return _BoundPromptReplacement(
|
||||
modality=modality,
|
||||
target=bind_prompt_sequence(self.target, tokenizer),
|
||||
repl_unit=bind_prompt_sequence(self.repl_unit, tokenizer),
|
||||
repl_count=self.repl_count,
|
||||
tokenizer=tokenizer,
|
||||
modality=self.modality,
|
||||
_target=self.target,
|
||||
_replacement=self.replacement,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModalityProcessingMetadata(Generic[_T]):
|
||||
prompt_repls: Sequence[Union[PromptReplacement[str, _T],
|
||||
PromptReplacement[list[int], _T]]]
|
||||
"""
|
||||
Defines each text or token sequence to replace in the HF-processed prompt.
|
||||
|
||||
This is skipped if the HF-processed prompt is found to already contain
|
||||
the replacement prompts.
|
||||
"""
|
||||
|
||||
|
||||
class MultiModalProcessingMetadataBuiltins(TypedDict, total=False):
|
||||
"""Type annotations for modality types predefined by vLLM."""
|
||||
|
||||
image: ModalityProcessingMetadata[ImageItem]
|
||||
video: ModalityProcessingMetadata[VideoItem]
|
||||
audio: ModalityProcessingMetadata[AudioItem]
|
||||
|
||||
|
||||
MultiModalProcessingMetadata: TypeAlias = \
|
||||
Mapping[str, ModalityProcessingMetadata[Any]]
|
||||
"""
|
||||
A dictionary containing an entry for each modality type to process.
|
||||
|
||||
Note:
|
||||
This dictionary also accepts modality keys defined outside
|
||||
:class:`MultiModalProcessingMetadataBuiltins` as long as a customized plugin
|
||||
is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
|
||||
Read more on that :ref:`here <adding_multimodal_plugin>`.
|
||||
"""
|
||||
|
||||
|
||||
def _encode(
|
||||
tokenizer: AnyTokenizer,
|
||||
text: str,
|
||||
@ -185,7 +128,8 @@ def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
|
||||
|
||||
@dataclass
|
||||
class _BoundPromptSequence:
|
||||
tokenizer: AnyTokenizer
|
||||
tokenizer: AnyTokenizer = field(repr=False)
|
||||
|
||||
_text: Optional[str]
|
||||
_token_ids: Optional[list[int]]
|
||||
|
||||
@ -210,38 +154,92 @@ class _BoundPromptSequence:
|
||||
|
||||
return self._token_ids
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"{type(self).__name__}(_text={self._text!r}, "
|
||||
f"_token_ids={self._token_ids!r})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BoundPromptReplacement(Generic[_T]):
|
||||
class _BoundPromptReplacement:
|
||||
tokenizer: AnyTokenizer = field(repr=False)
|
||||
modality: str
|
||||
target: _BoundPromptSequence
|
||||
repl_unit: _BoundPromptSequence
|
||||
repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int]
|
||||
|
||||
def get_count(
|
||||
self,
|
||||
mm_items: list[_T],
|
||||
hf_inputs: BatchFeature,
|
||||
item_idx: int,
|
||||
) -> int:
|
||||
repl_count = self.repl_count
|
||||
if isinstance(repl_count, int):
|
||||
return repl_count
|
||||
_target: _PromptSeq
|
||||
_replacement: Union[Callable[[int], _PromptSeq],
|
||||
_PromptSeq] = field(repr=False)
|
||||
|
||||
return repl_count(mm_items, hf_inputs, item_idx)
|
||||
def __post_init__(self) -> None:
|
||||
self._replacement_cache = dict[int, _BoundPromptSequence]()
|
||||
|
||||
@property
|
||||
def target(self) -> _BoundPromptSequence:
|
||||
target = self._target
|
||||
|
||||
return _BoundPromptSequence(
|
||||
tokenizer=self.tokenizer,
|
||||
_text=target if isinstance(target, str) else None,
|
||||
_token_ids=target if isinstance(target, list) else None,
|
||||
)
|
||||
|
||||
def get_replacement(self, item_idx: int) -> _BoundPromptSequence:
|
||||
replacement = self._replacement
|
||||
if callable(replacement):
|
||||
cache_key = item_idx
|
||||
if cache_key in self._replacement_cache:
|
||||
return self._replacement_cache[cache_key]
|
||||
|
||||
replacement = replacement(item_idx)
|
||||
else:
|
||||
cache_key = None
|
||||
|
||||
bound_replacement = _BoundPromptSequence(
|
||||
tokenizer=self.tokenizer,
|
||||
_text=replacement if isinstance(replacement, str) else None,
|
||||
_token_ids=replacement if isinstance(replacement, list) else None,
|
||||
)
|
||||
|
||||
if cache_key is not None:
|
||||
self._replacement_cache[cache_key] = bound_replacement
|
||||
|
||||
return bound_replacement
|
||||
|
||||
|
||||
def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]:
|
||||
class ImageSize(NamedTuple):
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
class MultiModalDataItems(UserDict[str, list[Any]]):
|
||||
"""
|
||||
Convert a :class:`MultiModalDataDict` containing single data items
|
||||
to a :class:`MultiModalMultiDataDict` containing multiple data items
|
||||
per entry.
|
||||
As :class:`MultiModalDataDict`, but normalized such that each entry
|
||||
corresponds to a list.
|
||||
"""
|
||||
multi_data = dict[str, list[Any]]()
|
||||
|
||||
@property
|
||||
def image(self) -> list[ImageItem]:
|
||||
return self["image"]
|
||||
|
||||
@property
|
||||
def video(self) -> list[VideoItem]:
|
||||
return self["video"]
|
||||
|
||||
@property
|
||||
def audio(self) -> list[AudioItem]:
|
||||
return self["audio"]
|
||||
|
||||
def get_image_size(self, item_idx: int) -> ImageSize:
|
||||
image = self.image[item_idx]
|
||||
|
||||
if isinstance(image, Image):
|
||||
return ImageSize(*image.size)
|
||||
if isinstance(image, (np.ndarray, torch.Tensor)):
|
||||
_, h, w = image.shape
|
||||
return ImageSize(w, h)
|
||||
|
||||
assert_never(image)
|
||||
|
||||
|
||||
def to_multi_format(data: MultiModalDataDict) -> MultiModalDataItems:
|
||||
"""
|
||||
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
|
||||
"""
|
||||
multi_data = MultiModalDataItems()
|
||||
|
||||
for k, v in data.items():
|
||||
# yapf: disable
|
||||
@ -266,22 +264,33 @@ def iter_token_matches(
|
||||
token_ids: list[int],
|
||||
match_ids: list[int],
|
||||
) -> Iterable[_TokenMatch]:
|
||||
"""Yield each occurrence of :code:`match_ids` in :code:`token_ids`."""
|
||||
"""
|
||||
Yield each occurrence of :code:`match_ids` in :code:`token_ids`.
|
||||
|
||||
Note that empty matches are ignored.
|
||||
"""
|
||||
prompt_len = len(token_ids)
|
||||
match_len = len(match_ids)
|
||||
|
||||
last_end_idx = 0
|
||||
for start_idx in range(len(token_ids) - match_len + 1):
|
||||
if start_idx < last_end_idx:
|
||||
continue # Exclude overlapping matches
|
||||
if match_len == 0:
|
||||
return
|
||||
|
||||
start_idx = 0
|
||||
while start_idx < prompt_len - match_len + 1:
|
||||
end_idx = start_idx + match_len
|
||||
|
||||
if token_ids[start_idx:end_idx] == match_ids:
|
||||
yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
|
||||
last_end_idx = end_idx
|
||||
|
||||
# Exclude overlapping matches
|
||||
start_idx = end_idx
|
||||
else:
|
||||
start_idx += 1
|
||||
|
||||
|
||||
class _PromptReplacementMatch(ABC, Generic[_T, _S]):
|
||||
prompt_repl: _BoundPromptReplacement[_T]
|
||||
@dataclass(repr=False)
|
||||
class _PromptReplacementMatch(ABC):
|
||||
prompt_repl: _BoundPromptReplacement
|
||||
|
||||
@property
|
||||
def modality(self) -> str:
|
||||
@ -297,19 +306,13 @@ class _PromptReplacementMatch(ABC, Generic[_T, _S]):
|
||||
def end_idx(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def repl_unit(self) -> _S:
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"{type(self).__name__}(modality={self.modality!r}, "
|
||||
f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class _PromptReplacementTokenMatch(_PromptReplacementMatch[_T, list[int]]):
|
||||
prompt_repl: _BoundPromptReplacement[_T]
|
||||
class _PromptReplacementTokenMatch(_PromptReplacementMatch):
|
||||
match: _TokenMatch
|
||||
|
||||
@property
|
||||
@ -320,14 +323,9 @@ class _PromptReplacementTokenMatch(_PromptReplacementMatch[_T, list[int]]):
|
||||
def end_idx(self) -> int:
|
||||
return self.match.end_idx
|
||||
|
||||
@property
|
||||
def repl_unit(self) -> list[int]:
|
||||
return self.prompt_repl.repl_unit.token_ids
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class _PromptReplacementTextMatch(_PromptReplacementMatch[_T, str]):
|
||||
prompt_repl: _BoundPromptReplacement[_T]
|
||||
class _PromptReplacementTextMatch(_PromptReplacementMatch):
|
||||
match: re.Match[str]
|
||||
|
||||
@property
|
||||
@ -338,20 +336,15 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch[_T, str]):
|
||||
def end_idx(self) -> int:
|
||||
return self.match.end()
|
||||
|
||||
@property
|
||||
def repl_unit(self) -> str:
|
||||
return self.prompt_repl.repl_unit.text
|
||||
|
||||
|
||||
class _PlaceholderInfo(NamedTuple):
|
||||
modality: str
|
||||
start_idx: int
|
||||
unit: list[int]
|
||||
unit_count: int
|
||||
replacement: list[int]
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
return len(self.unit) * self.unit_count
|
||||
return len(self.replacement)
|
||||
|
||||
def to_range(self) -> PlaceholderRange:
|
||||
return PlaceholderRange(
|
||||
@ -362,8 +355,8 @@ class _PlaceholderInfo(NamedTuple):
|
||||
|
||||
def find_token_matches(
|
||||
prompt: list[int],
|
||||
prompt_repls: Sequence[_BoundPromptReplacement[_T]],
|
||||
) -> list[_PromptReplacementTokenMatch[_T]]:
|
||||
prompt_repls: Sequence[_BoundPromptReplacement],
|
||||
) -> list[_PromptReplacementTokenMatch]:
|
||||
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
|
||||
return [
|
||||
_PromptReplacementTokenMatch(prompt_repl, match)
|
||||
@ -374,8 +367,8 @@ def find_token_matches(
|
||||
|
||||
def find_text_matches(
|
||||
prompt: str,
|
||||
prompt_repls: Sequence[_BoundPromptReplacement[_T]],
|
||||
) -> list[_PromptReplacementTextMatch[_T]]:
|
||||
prompt_repls: Sequence[_BoundPromptReplacement],
|
||||
) -> list[_PromptReplacementTextMatch]:
|
||||
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
|
||||
return [
|
||||
_PromptReplacementTextMatch(prompt_repl, match)
|
||||
@ -385,15 +378,15 @@ def find_text_matches(
|
||||
|
||||
|
||||
def _resolve_matches(
|
||||
prompt: _S,
|
||||
matches: Sequence[_PromptReplacementMatch[_T, _S]],
|
||||
) -> list[_PromptReplacementMatch[_T, _S]]:
|
||||
prompt: _PromptSeq,
|
||||
matches: Sequence[_PromptReplacementMatch],
|
||||
) -> list[_PromptReplacementMatch]:
|
||||
"""
|
||||
Resolve :code:`matches` to ensure that there are no overlapping matches,
|
||||
and sort them such that earlier matches take priority over later ones.
|
||||
"""
|
||||
seen_matches: list[Optional[_PromptReplacementMatch[_T, _S]]] \
|
||||
= [None] * len(prompt)
|
||||
seen_matches: list[Optional[_PromptReplacementMatch]] = [None
|
||||
] * len(prompt)
|
||||
|
||||
for match in matches:
|
||||
for idx in range(match.start_idx, match.end_idx):
|
||||
@ -409,30 +402,34 @@ def _resolve_matches(
|
||||
|
||||
def _replace_matches(
|
||||
prompt: _S,
|
||||
matches: Sequence[_PromptReplacementMatch[_T, _S]],
|
||||
mm_items_by_modality: Mapping[str, list[_T]],
|
||||
hf_inputs: BatchFeature,
|
||||
matches: Sequence[_PromptReplacementMatch],
|
||||
mm_items: MultiModalDataItems,
|
||||
) -> list[_S]:
|
||||
out_seqs = list[_S]()
|
||||
prev_end_idx = 0
|
||||
next_idx_by_modality = {modality: 0 for modality in mm_items_by_modality}
|
||||
next_idx_by_modality = {modality: 0 for modality in mm_items}
|
||||
|
||||
for match in _resolve_matches(prompt, matches):
|
||||
modality = match.modality
|
||||
mm_items = mm_items_by_modality[modality]
|
||||
modal_items = mm_items[modality]
|
||||
|
||||
item_idx = next_idx_by_modality[modality]
|
||||
if item_idx >= len(mm_items):
|
||||
if item_idx >= len(modal_items):
|
||||
continue
|
||||
|
||||
start_idx = match.start_idx
|
||||
end_idx = match.end_idx
|
||||
repl_unit = match.repl_unit
|
||||
repl_info = match.prompt_repl
|
||||
repl_count = repl_info.get_count(mm_items, hf_inputs, item_idx)
|
||||
|
||||
out_seqs.append(prompt[prev_end_idx:start_idx] +
|
||||
repl_unit * repl_count)
|
||||
repl_info = match.prompt_repl
|
||||
replacement = repl_info.get_replacement(item_idx)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
repl_seq = replacement.text
|
||||
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
|
||||
else:
|
||||
repl_seq = replacement.token_ids
|
||||
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
|
||||
|
||||
prev_end_idx = end_idx
|
||||
next_idx_by_modality[modality] += 1
|
||||
|
||||
@ -443,92 +440,104 @@ def _replace_matches(
|
||||
|
||||
def replace_token_matches(
|
||||
prompt: list[int],
|
||||
matches: Sequence[_PromptReplacementMatch[_T, list[int]]],
|
||||
mm_items_by_modality: Mapping[str, list[_T]],
|
||||
hf_inputs: BatchFeature,
|
||||
matches: Sequence[_PromptReplacementTokenMatch],
|
||||
mm_items: MultiModalDataItems,
|
||||
) -> list[int]:
|
||||
"""Apply :code:`prompt_repls` to :code:`prompt`."""
|
||||
if not matches:
|
||||
return prompt
|
||||
|
||||
token_id_seqs = _replace_matches(
|
||||
prompt,
|
||||
matches,
|
||||
mm_items_by_modality,
|
||||
hf_inputs,
|
||||
)
|
||||
token_id_seqs = _replace_matches(prompt, matches, mm_items)
|
||||
|
||||
return flatten_2d_lists(token_id_seqs)
|
||||
|
||||
|
||||
def replace_text_matches(
|
||||
prompt: str,
|
||||
matches: Sequence[_PromptReplacementMatch[_T, str]],
|
||||
mm_items_by_modality: Mapping[str, list[_T]],
|
||||
hf_inputs: BatchFeature,
|
||||
matches: Sequence[_PromptReplacementTextMatch],
|
||||
mm_items: MultiModalDataItems,
|
||||
) -> str:
|
||||
"""Apply :code:`prompt_repls` to :code:`prompt`."""
|
||||
if not matches:
|
||||
return prompt
|
||||
|
||||
texts = _replace_matches(
|
||||
prompt,
|
||||
matches,
|
||||
mm_items_by_modality,
|
||||
hf_inputs,
|
||||
)
|
||||
texts = _replace_matches(prompt, matches, mm_items)
|
||||
|
||||
return "".join(texts)
|
||||
|
||||
|
||||
def _merge_placeholder_matches(
|
||||
matches: Iterable[_PromptReplacementTokenMatch],
|
||||
) -> Iterable[_PromptReplacementTokenMatch]:
|
||||
current_match = None
|
||||
def _iter_modality_placeholders(
|
||||
prompt: list[int],
|
||||
modality: str,
|
||||
modality_repls: Sequence[_BoundPromptReplacement],
|
||||
modal_items: list[Any],
|
||||
) -> Iterable[_PlaceholderInfo]:
|
||||
if len(modal_items) == 0:
|
||||
return
|
||||
|
||||
for match in sorted(matches, key=lambda x: x.start_idx):
|
||||
if current_match is None:
|
||||
current_match = match
|
||||
elif (current_match.prompt_repl == match.prompt_repl
|
||||
and current_match.end_idx == match.start_idx):
|
||||
current_match = _PromptReplacementTokenMatch(
|
||||
current_match.prompt_repl,
|
||||
match=_TokenMatch(current_match.start_idx, match.end_idx),
|
||||
)
|
||||
else:
|
||||
yield current_match
|
||||
current_match = match
|
||||
prompt_len = len(prompt)
|
||||
item_index = 0
|
||||
|
||||
if current_match is not None:
|
||||
yield current_match
|
||||
start_idx = 0
|
||||
while start_idx < prompt_len:
|
||||
found = False
|
||||
|
||||
for repl_info in modality_repls:
|
||||
replacement = repl_info.get_replacement(item_index)
|
||||
repl_tokens = replacement.token_ids
|
||||
repl_len = len(repl_tokens)
|
||||
end_idx = start_idx + repl_len
|
||||
|
||||
if repl_len == 0 or end_idx > prompt_len:
|
||||
continue
|
||||
|
||||
if prompt[start_idx:end_idx] == repl_tokens:
|
||||
yield _PlaceholderInfo(
|
||||
modality=modality,
|
||||
start_idx=start_idx,
|
||||
replacement=repl_tokens,
|
||||
)
|
||||
|
||||
item_index += 1
|
||||
if item_index >= len(modal_items):
|
||||
return
|
||||
|
||||
# Exclude overlapping matches
|
||||
start_idx = end_idx
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
start_idx += 1
|
||||
|
||||
|
||||
def iter_placeholders(
|
||||
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
|
||||
prompt_repls: Sequence[_BoundPromptReplacement],
|
||||
prompt: list[int],
|
||||
*,
|
||||
min_unit_count: int = 1,
|
||||
mm_items: MultiModalDataItems,
|
||||
) -> Iterable[_PlaceholderInfo]:
|
||||
"""Yield each set of placeholder tokens found in :code:`token_ids`."""
|
||||
if min_unit_count <= 0:
|
||||
raise ValueError("`min_unit_count` must be a positive integer")
|
||||
"""
|
||||
Yield each set of placeholder tokens found in :code:`prompt`.
|
||||
|
||||
matches = (_PromptReplacementTokenMatch(prompt_repl, match)
|
||||
for prompt_repl in prompt_repls
|
||||
if len(repl_unit := prompt_repl.repl_unit.token_ids) > 0
|
||||
for match in iter_token_matches(prompt, repl_unit))
|
||||
Note that empty matches are ignored.
|
||||
"""
|
||||
repls_by_modality = dict(full_groupby_modality(prompt_repls))
|
||||
|
||||
for match in _merge_placeholder_matches(matches):
|
||||
unit = match.repl_unit
|
||||
placeholder = _PlaceholderInfo(
|
||||
modality=match.modality,
|
||||
start_idx=match.start_idx,
|
||||
unit=unit,
|
||||
unit_count=(match.end_idx - match.start_idx) // len(unit),
|
||||
)
|
||||
for modality, modal_items in mm_items.items():
|
||||
if modality in repls_by_modality:
|
||||
yield from _iter_modality_placeholders(
|
||||
prompt,
|
||||
modality,
|
||||
repls_by_modality[modality],
|
||||
modal_items,
|
||||
)
|
||||
|
||||
if placeholder.unit_count >= min_unit_count:
|
||||
yield placeholder
|
||||
|
||||
class ProcessorInputs(NamedTuple):
|
||||
"""Keyword arguments to :meth:`BaseMultiModalProcessor`"""
|
||||
prompt_text: str
|
||||
mm_data: MultiModalDataDict
|
||||
mm_processor_kwargs: Mapping[str, object]
|
||||
|
||||
|
||||
class BaseMultiModalProcessor(ABC):
|
||||
@ -536,27 +545,10 @@ class BaseMultiModalProcessor(ABC):
|
||||
Abstract base class to process multi-modal inputs to be used in vLLM.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ctx: InputProcessingContext,
|
||||
metadata: MultiModalProcessingMetadata,
|
||||
) -> None:
|
||||
def __init__(self, ctx: InputProcessingContext) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.ctx = ctx
|
||||
self.metadata = metadata
|
||||
self.init_mm_processor_kwargs = (ctx.model_config.mm_processor_kwargs
|
||||
or {})
|
||||
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
**mm_processor_kwargs: Mapping[str, object],
|
||||
) -> ProcessorMixin:
|
||||
# by default, we won't pass any kwargs to the processor initialization
|
||||
return self.ctx.get_hf_processor()
|
||||
|
||||
def _get_tokenizer(self) -> AnyTokenizer:
|
||||
return self.ctx.tokenizer
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -566,22 +558,42 @@ class BaseMultiModalProcessor(ABC):
|
||||
) -> MultiModalInputsV2:
|
||||
return self.apply(prompt, mm_data, mm_processor_kwargs)
|
||||
|
||||
def _get_hf_processor(self) -> ProcessorMixin:
|
||||
"""
|
||||
Subclasses can add keyword arguments to this method to accept
|
||||
additional kwargs from model config or user inputs.
|
||||
"""
|
||||
return self.ctx.get_hf_processor()
|
||||
|
||||
def _get_tokenizer(self) -> AnyTokenizer:
|
||||
return self.ctx.tokenizer
|
||||
|
||||
@abstractmethod
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> list[PromptReplacement]:
|
||||
"""
|
||||
Given the original multi-modal items for this modality
|
||||
and HF-processed data, output the replacements to perform.
|
||||
|
||||
Note:
|
||||
Even when the HF processor already performs replacement for us,
|
||||
we still use this replacement information to determine
|
||||
the placeholder token positions for each multi-modal item.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _find_placeholders(
|
||||
self,
|
||||
all_prompt_repls: Sequence[_BoundPromptReplacement[Any]],
|
||||
all_prompt_repls: Sequence[_BoundPromptReplacement],
|
||||
new_token_ids: list[int],
|
||||
*,
|
||||
# To avoid false positives from multi-input when detecting
|
||||
# whether placeholder tokens have been inserted, in case
|
||||
# the target sequence is a subset of the replacement tokens
|
||||
min_unit_count: int = 16,
|
||||
mm_items: MultiModalDataItems,
|
||||
) -> list[_PlaceholderInfo]:
|
||||
return list(
|
||||
iter_placeholders(
|
||||
all_prompt_repls,
|
||||
new_token_ids,
|
||||
min_unit_count=min_unit_count,
|
||||
))
|
||||
iter_placeholders(all_prompt_repls, new_token_ids, mm_items))
|
||||
|
||||
def _apply_hf_processor(
|
||||
self,
|
||||
@ -589,13 +601,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
# some mm_processor_kwargs may be used in processor initialization
|
||||
# instead of processor call
|
||||
processor_init_kwargs = {
|
||||
**self.init_mm_processor_kwargs,
|
||||
**mm_processor_kwargs,
|
||||
}
|
||||
hf_processor = self._get_hf_processor(**processor_init_kwargs)
|
||||
hf_processor = self._get_hf_processor(**mm_processor_kwargs)
|
||||
|
||||
processor_data = dict[str, Any]()
|
||||
passthrough_data = dict[str, Any]()
|
||||
@ -615,11 +621,10 @@ class BaseMultiModalProcessor(ABC):
|
||||
else:
|
||||
processor_data[k] = v
|
||||
|
||||
# filter mm_processor_kwargs used in processor call
|
||||
mm_processor_kwargs = resolve_mm_processor_kwargs(
|
||||
self.init_mm_processor_kwargs,
|
||||
cast(Dict[str, Any], mm_processor_kwargs),
|
||||
assert callable(hf_processor)
|
||||
mm_processor_kwargs = self.ctx.resolve_hf_processor_call_kwargs(
|
||||
hf_processor,
|
||||
mm_processor_kwargs,
|
||||
)
|
||||
|
||||
try:
|
||||
@ -642,26 +647,21 @@ class BaseMultiModalProcessor(ABC):
|
||||
|
||||
def _bind_prompt_replacements(
|
||||
self,
|
||||
mm_data: MultiModalDataDict,
|
||||
) -> list[_BoundPromptReplacement[Any]]:
|
||||
prompt_repls: list[PromptReplacement],
|
||||
) -> list[_BoundPromptReplacement]:
|
||||
tokenizer = self._get_tokenizer()
|
||||
|
||||
return [
|
||||
prompt_repl.bind(modality, tokenizer)
|
||||
for modality, metadata in self.metadata.items()
|
||||
if modality in mm_data for prompt_repl in metadata.prompt_repls
|
||||
]
|
||||
return [prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls]
|
||||
|
||||
def _apply_prompt_replacements(
|
||||
self,
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_inputs: BatchFeature,
|
||||
token_ids: list[int],
|
||||
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
|
||||
prompt_repls: Sequence[_BoundPromptReplacement],
|
||||
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
|
||||
tokenizer = self._get_tokenizer()
|
||||
|
||||
mm_items = to_multi_format(mm_data)
|
||||
token_matches = find_token_matches(token_ids, prompt_repls)
|
||||
|
||||
# If the search text does not represent a special token,
|
||||
@ -682,7 +682,6 @@ class BaseMultiModalProcessor(ABC):
|
||||
token_ids,
|
||||
token_matches,
|
||||
mm_items,
|
||||
hf_inputs,
|
||||
)
|
||||
|
||||
text = _decode(tokenizer, token_ids)
|
||||
@ -695,13 +694,13 @@ class BaseMultiModalProcessor(ABC):
|
||||
text,
|
||||
text_matches,
|
||||
mm_items,
|
||||
hf_inputs,
|
||||
)
|
||||
|
||||
token_ids = _encode(tokenizer, text)
|
||||
matched_repls = [match.prompt_repl for match in text_matches]
|
||||
|
||||
placeholders = self._find_placeholders(matched_repls, token_ids)
|
||||
placeholders = self._find_placeholders(matched_repls, token_ids,
|
||||
mm_items)
|
||||
|
||||
return token_ids, text, placeholders
|
||||
|
||||
@ -731,12 +730,16 @@ class BaseMultiModalProcessor(ABC):
|
||||
prompt_ids, = hf_inputs.pop("input_ids").tolist()
|
||||
mm_kwargs = MultiModalKwargs(hf_inputs)
|
||||
|
||||
all_prompt_repls = self._bind_prompt_replacements(mm_data)
|
||||
mm_items = to_multi_format(mm_data)
|
||||
prompt_repls = self._get_prompt_replacements(mm_items, hf_inputs,
|
||||
mm_processor_kwargs)
|
||||
all_prompt_repls = self._bind_prompt_replacements(prompt_repls)
|
||||
|
||||
# If HF processor already inserts placeholder tokens,
|
||||
# there is no need for us to insert them
|
||||
all_placeholders = self._find_placeholders(all_prompt_repls,
|
||||
prompt_ids)
|
||||
prompt_ids, mm_items)
|
||||
|
||||
if all_placeholders:
|
||||
prompt_text = _decode(tokenizer, prompt_ids)
|
||||
else:
|
||||
@ -745,7 +748,7 @@ class BaseMultiModalProcessor(ABC):
|
||||
prompt_text,
|
||||
all_placeholders,
|
||||
) = self._apply_prompt_replacements(
|
||||
mm_data,
|
||||
mm_items,
|
||||
hf_inputs,
|
||||
prompt_ids,
|
||||
all_prompt_repls,
|
||||
@ -765,13 +768,13 @@ class BaseMultiModalProcessor(ABC):
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _get_dummy_mm_kwargs(
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalKwargs:
|
||||
) -> ProcessorInputs:
|
||||
"""
|
||||
Build the input that corresponds to `mm_max_tokens` in
|
||||
:meth:`get_dummy_data`.
|
||||
Build the multi-modal portion of the input which, after processing,
|
||||
results in `mm_max_tokens` in :meth:`get_dummy_data`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -784,38 +787,41 @@ class BaseMultiModalProcessor(ABC):
|
||||
# Avoid circular import
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
tokenizer = self._get_tokenizer()
|
||||
processor_inputs = self._get_dummy_mm_inputs(mm_counts)
|
||||
mm_inputs = self.apply(*processor_inputs)
|
||||
|
||||
mm_placeholders = dict[str, _PlaceholderInfo]()
|
||||
offset = 0
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||
|
||||
for modality, max_tokens in mm_max_tokens.items():
|
||||
if max_tokens == 0:
|
||||
continue
|
||||
total_placeholders_by_modality = dict[str, int]()
|
||||
for modality, placeholders in placeholders_by_modality.items():
|
||||
num_placeholders = sum(item["length"] for item in placeholders)
|
||||
max_tokens = mm_max_tokens[modality]
|
||||
|
||||
metadata = self.metadata[modality]
|
||||
repl = metadata.prompt_repls[0].bind(modality, tokenizer)
|
||||
repl_token_ids = repl.repl_unit.token_ids
|
||||
if num_placeholders != max_tokens:
|
||||
logger.warning(
|
||||
"The processed dummy data has a total of %d placeholder "
|
||||
"tokens for the '%s' modality, which is not the expected "
|
||||
"%d tokens.", num_placeholders, modality, max_tokens)
|
||||
|
||||
placeholders = _PlaceholderInfo(
|
||||
modality=modality,
|
||||
start_idx=offset,
|
||||
unit=repl_token_ids,
|
||||
unit_count=max_tokens // len(repl_token_ids),
|
||||
)
|
||||
total_placeholders_by_modality[modality] = num_placeholders
|
||||
|
||||
mm_placeholders[modality] = placeholders
|
||||
offset += placeholders.length
|
||||
total_len = len(prompt_token_ids)
|
||||
if total_len > seq_len:
|
||||
logger.warning(
|
||||
"The context length (%d) of the model is too short "
|
||||
"to hold the multi-modal embeddings in the worst case "
|
||||
"(%d tokens in total, out of which %s are reserved for "
|
||||
"multi-modal embeddings). This may cause certain multi-modal "
|
||||
"inputs to fail during inference, even when the input text is "
|
||||
"short. To avoid this, you should increase `max_model_len`, "
|
||||
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
|
||||
total_len, total_placeholders_by_modality)
|
||||
|
||||
prompt_token_ids = flatten_2d_lists(
|
||||
[p.unit * p.unit_count for p in mm_placeholders.values()])
|
||||
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
|
||||
|
||||
return DummyData(
|
||||
seq_data=SequenceData.from_seqs(prompt_token_ids),
|
||||
multi_modal_data=self._get_dummy_mm_kwargs(mm_counts),
|
||||
multi_modal_placeholders={
|
||||
modality: [p.to_range()]
|
||||
for modality, p in mm_placeholders.items()
|
||||
},
|
||||
multi_modal_data=mm_inputs["mm_kwargs"],
|
||||
multi_modal_placeholders=placeholders_by_modality,
|
||||
)
|
||||
|
@ -299,9 +299,9 @@ class MultiModalRegistry:
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if model_cls in self._processor_factories:
|
||||
if self._processor_factories.contains(model_cls, strict=True):
|
||||
logger.warning(
|
||||
"Model class %s already has an input mapper "
|
||||
"Model class %s already has a multi-modal processor "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls, self)
|
||||
|
||||
|
@ -1370,8 +1370,8 @@ def supports_kw(
|
||||
|
||||
|
||||
def resolve_mm_processor_kwargs(
|
||||
init_kwargs: Optional[Dict[str, Any]],
|
||||
inference_kwargs: Optional[Dict[str, Any]],
|
||||
init_kwargs: Optional[Mapping[str, object]],
|
||||
inference_kwargs: Optional[Mapping[str, object]],
|
||||
callable: Callable[..., object],
|
||||
allow_var_kwargs: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
@ -1405,7 +1405,7 @@ def resolve_mm_processor_kwargs(
|
||||
|
||||
def get_allowed_kwarg_only_overrides(
|
||||
callable: Callable[..., object],
|
||||
overrides: Optional[Dict[str, Any]],
|
||||
overrides: Optional[Mapping[str, object]],
|
||||
allow_var_kwargs: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@ -1524,9 +1524,15 @@ class ClassRegistry(UserDict[Type[T], _V]):
|
||||
raise KeyError(key)
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
return self.contains(key)
|
||||
|
||||
def contains(self, key: object, *, strict: bool = False) -> bool:
|
||||
if not isinstance(key, type):
|
||||
return False
|
||||
|
||||
if strict:
|
||||
return key in self.data
|
||||
|
||||
return any(cls in self.data for cls in key.mro())
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user