[VLM] Fully dynamic prompt replacement in merged input processor (#11199)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-15 01:52:18 +08:00
committed by GitHub
parent 9c3dadd1c9
commit 93abf23a64
12 changed files with 565 additions and 506 deletions

View File

@ -97,9 +97,6 @@ def run_phi3v(question: str, modality: str):
# max_model_len (128k) for this model may cause OOM.
# You may lower either to run this example on lower-end GPUs.
# In this example, we override max_num_seqs to 5 while
# keeping the original context length of 128k.
# num_crops is an override kwarg to the multimodal image processor;
# For some models, e.g., Phi-3.5-vision-instruct, it is recommended
# to use 16 for single frame scenarios, and 4 for multi-frame.
@ -113,7 +110,7 @@ def run_phi3v(question: str, modality: str):
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally
# https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194
llm = LLM(
model="microsoft/Phi-3-vision-128k-instruct",
model="microsoft/Phi-3.5-vision-instruct",
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,

View File

@ -16,8 +16,8 @@ models = ["microsoft/Phi-3.5-vision-instruct"]
# Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_phi3v():
from vllm.model_executor.models.phi3v import Phi3VProcessor
return Phi3VProcessor
from vllm.model_executor.models.phi3v import Phi3VMultiModalProcessor
return Phi3VMultiModalProcessor
@pytest.fixture()

View File

@ -1,11 +1,11 @@
from typing import cast
import pytest
from transformers import BatchFeature
from vllm.multimodal.processing import (PromptReplacement, _PlaceholderInfo,
find_text_matches, find_token_matches,
iter_placeholders, iter_token_matches,
from vllm.multimodal.processing import (MultiModalDataItems, PromptReplacement,
_PlaceholderInfo, find_text_matches,
find_token_matches, iter_placeholders,
iter_token_matches,
replace_text_matches,
replace_token_matches)
from vllm.transformers_utils.tokenizer import AnyTokenizer
@ -16,7 +16,7 @@ from vllm.utils import full_groupby
@pytest.mark.parametrize(
("token_ids", "match_ids", "expected"),
[
([], [], [{ "start_idx": 0, "end_idx": 0 }]),
([], [], []),
([], [32000], []),
(
[32000, 32000, 32000],
@ -83,7 +83,7 @@ def test_iter_token_matches(token_ids, match_ids, expected):
"pattern_2": [32000],
},
{
"pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
"pattern_1": [],
"pattern_2": [],
}
),
@ -136,7 +136,7 @@ def test_find_token_matches(prompt, target_by_key, expected_by_key):
mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
PromptReplacement(key, target, []).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
result = find_token_matches(prompt, prompt_repls)
@ -243,7 +243,7 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [
PromptReplacement(target, [], 0).bind(key, mock_tokenizer)
PromptReplacement(key, target, []).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
result = find_text_matches(prompt, prompt_repls)
@ -276,12 +276,12 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
"pattern_3": "!",
},
{
# Test whether target is confused with repl_unit
"pattern_1": ("<image><image>", 1),
# Test empty repl_unit
"pattern_2": ("", 1),
# Test multiple repl_count
"pattern_3": ("?", 2),
# Test whether target is confused with replacement
"pattern_1": "<image><image>",
# Test empty replacement
"pattern_2": "",
# Test dynamic replacement (beyond the form of `unit * count`)
"pattern_3": "?!?",
},
),
]
@ -290,8 +290,8 @@ def test_find_text_matches(prompt, target_by_key, expected_by_key):
("mm_count", "expected"),
[
(0, "Image:<image>Image:<image><image>!"),
(1, "<image><image>Image:<image><image>??"),
(2, "<image><image><image><image><image>??"),
(1, "<image><image>Image:<image><image>?!?"),
(2, "<image><image><image><image><image>?!?"),
]
)
# yapf: enable
@ -306,7 +306,7 @@ def test_find_replace_text(
mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer)
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
matches = find_text_matches(prompt, prompt_repls)
@ -314,9 +314,8 @@ def test_find_replace_text(
result = replace_text_matches(
prompt,
matches,
{key: list(range(mm_count))
for key in repl_by_key},
BatchFeature(),
MultiModalDataItems({key: [None] * mm_count
for key in repl_by_key}),
)
# Only displayed on error
@ -343,12 +342,12 @@ def test_find_replace_text(
"pattern_3": [918],
},
{
# Test whether target is confused with repl_unit
"pattern_1": ([32000, 32000], 1),
# Test empty repl_unit
"pattern_2": ([], 1),
# Test multiple repl_count
"pattern_3": ([1550], 2),
# Test whether target is confused with replacement
"pattern_1": [32000, 32000],
# Test empty replacement
"pattern_2": [],
# Test dynamic replacement (beyond the form of `unit * count`)
"pattern_3": [1550, 918, 1550],
},
),
]
@ -357,8 +356,8 @@ def test_find_replace_text(
("mm_count", "expected"),
[
(0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]),
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550]),
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 1550]),
(1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550]),
(2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550]),
]
)
# yapf: enable
@ -373,7 +372,7 @@ def test_find_replace_tokens(
mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [
PromptReplacement(target, *repl_by_key[key]).bind(key, mock_tokenizer)
PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
for key, target in target_by_key.items()
]
matches = find_token_matches(prompt, prompt_repls)
@ -381,9 +380,8 @@ def test_find_replace_tokens(
result = replace_token_matches(
prompt,
matches,
{key: list(range(mm_count))
for key in repl_by_key},
BatchFeature(),
MultiModalDataItems({key: [None] * mm_count
for key in repl_by_key}),
)
# Only displayed on error
@ -399,9 +397,9 @@ def test_find_replace_tokens(
"repl_by_key",
[
{
"pattern_1": ([32000, 32000], 1),
"pattern_2": ([], 1),
"pattern_3": ([1550], 2),
"pattern_1": [32000, 32000],
"pattern_2": [],
"pattern_3": [1550, 918, 1550],
},
],
)
@ -414,48 +412,47 @@ def test_find_replace_tokens(
_PlaceholderInfo(
modality="pattern_1",
start_idx=6,
unit=[32000, 32000],
unit_count=1,
replacement=[32000, 32000],
),
],
),
(
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 1550],
[1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
[
_PlaceholderInfo(
modality="pattern_1",
start_idx=1,
unit=[32000, 32000],
unit_count=1,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_1",
start_idx=5,
unit=[32000, 32000],
unit_count=1,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_3",
start_idx=7,
unit=[1550],
unit_count=2,
replacement=[1550, 918, 1550],
),
],
),
(
[1, 32000, 32000, 32000, 32000, 32000, 1550, 1550],
[1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
[
_PlaceholderInfo(
modality="pattern_1",
start_idx=1,
unit=[32000, 32000],
unit_count=2,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_1",
start_idx=3,
replacement=[32000, 32000],
),
_PlaceholderInfo(
modality="pattern_3",
start_idx=6,
unit=[1550],
unit_count=2,
replacement=[1550, 918, 1550],
),
],
),
@ -470,11 +467,17 @@ def test_iter_placeholders(
mock_tokenizer = cast(AnyTokenizer, object())
prompt_repls = [
PromptReplacement([], *repl).bind(key, mock_tokenizer)
PromptReplacement(key, [], repl).bind(mock_tokenizer)
for key, repl in repl_by_key.items()
]
result = list(iter_placeholders(prompt_repls, prompt))
result = list(
iter_placeholders(
prompt_repls,
prompt,
# Effectively match all occurrences in the prompt
MultiModalDataItems({key: [None] * 3 for key in repl_by_key}),
))
# Only displayed on error
print("result:", result)

View File

@ -3,14 +3,14 @@ from typing import Optional
import torch
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration,
LlavaProcessor,
LlavaMultiModalProcessor,
get_max_llava_image_tokens)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor)
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
class MyLlava(LlavaForConditionalGeneration):
def compute_logits(

View File

@ -2,7 +2,7 @@ import functools
from collections import UserDict
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, NamedTuple,
Optional, Protocol, Type, cast)
Optional, Protocol, Type)
from torch import nn
from transformers import PretrainedConfig, ProcessorMixin
@ -47,7 +47,6 @@ class InputContext:
Raises:
TypeError: If the model is not of the specified type.
"""
hf_config = self.model_config.hf_config
if not isinstance(hf_config, hf_config_type):
raise TypeError("Invalid type of HuggingFace config. "
@ -60,21 +59,70 @@ class InputContext:
"""
Get the HuggingFace image processor configuration of the model.
"""
return self.model_config.hf_image_processor_config
def get_mm_config(self):
"""
Get the multimodal config of the model.
Raises:
RuntimeError: If the model is not a multimodal model.
"""
mm_config = self.model_config.multimodal_config
if mm_config is None:
raise RuntimeError("Not a multimodal model")
return mm_config
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs}
return cached_get_processor(
self.model_config.model,
trust_remote_code=self.model_config.trust_remote_code,
**merged_kwargs,
)
@dataclass(frozen=True)
class InputProcessingContext(InputContext):
tokenizer: AnyTokenizer
"""The tokenizer used to tokenize the inputs."""
def get_hf_processor(self, **kwargs) -> ProcessorMixin:
def get_hf_processor(self, **kwargs: object) -> ProcessorMixin:
base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
merged_kwargs = {**base_kwargs, **kwargs}
return cached_get_processor(
self.model_config.tokenizer,
self.model_config.model,
tokenizer=self.tokenizer, # Override the tokenizer with ours
trust_remote_code=self.model_config.trust_remote_code,
**kwargs)
**merged_kwargs,
)
def resolve_hf_processor_call_kwargs(
self,
hf_processor: ProcessorMixin,
inference_kwargs: Mapping[str, object],
) -> Mapping[str, object]:
assert callable(hf_processor)
base_kwargs = self.model_config.mm_processor_kwargs
if base_kwargs is None:
base_kwargs = {}
return resolve_mm_processor_kwargs(
base_kwargs,
inference_kwargs,
hf_processor,
)
N = TypeVar("N", bound=Type[nn.Module])
@ -171,7 +219,8 @@ class InputRegistry:
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._dummy_factories_by_model_type:
if self._dummy_factories_by_model_type.contains(model_cls,
strict=True):
logger.warning(
"Model class %s already has dummy data "
"registered to %s. It is overwritten by the new one.",
@ -195,7 +244,8 @@ class InputRegistry:
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._dummy_encoder_factories_by_model_type:
if self._dummy_encoder_factories_by_model_type.contains(
model_cls, strict=True):
logger.warning(
"Model class %s already has dummy encoder data "
"registered to %s. It is overwritten by the new one.",
@ -305,7 +355,8 @@ class InputRegistry:
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._input_processors_by_model_type:
if self._input_processors_by_model_type.contains(model_cls,
strict=True):
logger.warning(
"Model class %s already has input processor "
"registered to %s. It is overwritten by the new one.",
@ -357,7 +408,7 @@ class InputRegistry:
# If it's empty, it'll fall back to the default kwarg values
mm_processor_kwargs = resolve_mm_processor_kwargs(
model_config.mm_processor_kwargs,
cast(Dict[str, Any], inputs.get("mm_processor_kwargs")),
inputs.get("mm_processor_kwargs", {}), # type: ignore
processor,
)

View File

@ -5,10 +5,10 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
import torch
import torch.nn as nn
from PIL.Image import Image
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
PixtralVisionConfig, PretrainedConfig,
ProcessorMixin, SiglipVisionConfig)
from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor
from vllm.attention import AttentionMetadata
@ -21,11 +21,9 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext,
ModalityProcessingMetadata,
MultiModalProcessingMetadata,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
@ -33,7 +31,8 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip,
get_max_clip_image_tokens)
from .interfaces import SupportsMultiModal, SupportsPP
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
get_max_pixtral_hf_image_tokens)
get_max_pixtral_hf_image_tokens,
get_pixtral_hf_image_feature_size)
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
get_max_siglip_image_tokens)
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
@ -115,62 +114,7 @@ def get_max_llava_image_tokens(ctx: InputContext):
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def dummy_mm_kwargs_for_llava(ctx: InputProcessingContext,
mm_counts: Mapping[str, int]):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
if isinstance(vision_config, CLIPVisionConfig):
data = dummy_image_for_clip(vision_config, num_images)
elif isinstance(vision_config, SiglipVisionConfig):
data = dummy_image_for_siglip(vision_config, num_images)
elif isinstance(vision_config, PixtralVisionConfig):
data = dummy_image_for_pixtral_hf(vision_config, num_images)
else:
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
hf_processor = ctx.get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
is_pixtral = isinstance(hf_processor, PixtralProcessor)
return MultiModalKwargs(
**hf_inputs,
is_pixtral=torch.tensor(is_pixtral),
)
def create_metadata_for_llava(
ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
hf_config = ctx.get_hf_config(LlavaConfig)
image_token_id = hf_config.image_token_index
def get_repl_count(
mm_items: list[Image],
hf_inputs: BatchFeature,
item_idx: int,
) -> int:
return get_max_llava_image_tokens(ctx)
return {
"image":
ModalityProcessingMetadata(prompt_repls=[
PromptReplacement(target=[image_token_id],
repl_unit=[image_token_id],
repl_count=get_repl_count),
]),
}
class LlavaProcessor(BaseMultiModalProcessor):
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__(
ctx=ctx,
metadata=create_metadata_for_llava(ctx),
)
class LlavaMultiModalProcessor(BaseMultiModalProcessor):
def _patch_pixtral_processor(self, hf_processor: PixtralProcessor):
if getattr(hf_processor, "__is_patched__", False):
@ -188,18 +132,72 @@ class LlavaProcessor(BaseMultiModalProcessor):
hf_processor.__is_patched__ = True # type: ignore
def _get_hf_processor(self) -> ProcessorMixin:
def _get_hf_processor(self) -> Union[LlavaProcessor, PixtralProcessor]:
hf_processor = self.ctx.get_hf_processor()
assert isinstance(hf_processor, (LlavaProcessor, PixtralProcessor))
if isinstance(hf_processor, PixtralProcessor):
self._patch_pixtral_processor(hf_processor)
return hf_processor
def _get_dummy_mm_kwargs(
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(LlavaConfig)
image_token_id = hf_config.image_token_index
processor = self._get_hf_processor()
if isinstance(processor, PixtralProcessor):
image_token = processor.image_token
image_break_token = processor.image_break_token
image_end_token = processor.image_end_token
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
def get_replacement_pixtral(item_idx: int):
image_size = mm_items.get_image_size(item_idx)
(
num_width_tokens,
num_height_tokens,
) = get_pixtral_hf_image_feature_size(
vision_config,
image_width=image_size.width,
image_height=image_size.height,
)
tokens = ([image_token] * num_width_tokens +
[image_break_token]) * num_height_tokens
tokens[-1] = image_end_token
return "".join(tokens)
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=get_replacement_pixtral,
),
]
max_image_tokens = get_max_llava_image_tokens(self.ctx)
return [
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=[image_token_id] * max_image_tokens,
)
]
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> MultiModalKwargs:
) -> ProcessorInputs:
hf_config = self.ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
num_images = mm_counts["image"]
@ -215,11 +213,13 @@ class LlavaProcessor(BaseMultiModalProcessor):
raise NotImplementedError(msg)
hf_processor = self._get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore
hf_inputs = image_processor.preprocess(data['image'],
return_tensors="pt")
image_token = hf_processor.image_token
return MultiModalKwargs(**hf_inputs)
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=data,
mm_processor_kwargs={},
)
class LlavaLikeConfig(Protocol):
@ -303,7 +303,7 @@ def init_vision_tower_for_llava(
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(LlavaProcessor)
@MULTIMODAL_REGISTRY.register_processor(LlavaMultiModalProcessor)
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
# BitandBytes specific attributes
bitsandbytes_stacked_params_mapping = {
@ -584,7 +584,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return loader.load_weights(weights)
class MantisProcessor(LlavaProcessor):
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
def _get_hf_processor(self) -> ProcessorMixin:
try:
@ -604,6 +604,6 @@ class MantisProcessor(LlavaProcessor):
# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(MantisProcessor)
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor)
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
pass

View File

@ -32,13 +32,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext,
ModalityProcessingMetadata,
MultiModalDataDict,
MultiModalProcessingMetadata,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
@ -305,64 +302,17 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline
def get_max_phi3v_image_tokens(ctx: InputContext,
*,
num_crops: Optional[int] = None):
mm_processor_kwargs = {}
if num_crops is not None:
mm_processor_kwargs["num_crops"] = num_crops
def get_max_phi3v_image_tokens(ctx: InputContext) -> int:
processor = ctx.get_hf_processor()
image_processor = processor.image_processor # type: ignore
model_config = ctx.model_config
image_processor = cached_get_image_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs,
)
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
return image_processor.calc_num_image_tokens_from_image_size(
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
return num_tokens
def dummy_mm_kwargs_for_phi3v(ctx: InputProcessingContext,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
hf_processor = ctx.get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
return MultiModalKwargs(**hf_inputs)
def create_metadata_for_phi3v(
ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
return {
"image":
ModalityProcessingMetadata(prompt_repls=[
PromptReplacement(target=[_IMAGE_TOKEN_ID],
repl_unit=[_IMAGE_TOKEN_ID],
repl_count=get_max_phi3v_image_tokens(ctx)),
]),
}
class Phi3VProcessor(BaseMultiModalProcessor):
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__(
ctx=ctx,
metadata=create_metadata_for_phi3v(ctx),
)
class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
def _get_hf_processor(
self,
@ -389,15 +339,61 @@ class Phi3VProcessor(BaseMultiModalProcessor):
processed_outputs['input_ids'] = token_ids
return processed_outputs
def _get_dummy_mm_kwargs(
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
image_processor = hf_processor.image_processor # type: ignore
mm_config = self.ctx.get_mm_config()
max_images = mm_config.limit_per_prompt.get("image", 1)
def get_replacement_phi3v(item_idx: int):
image_size = mm_items.get_image_size(item_idx)
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
width=image_size.width,
height=image_size.height,
)
return [_IMAGE_TOKEN_ID] * num_tokens
return [
PromptReplacement(
modality="image",
target=image_token,
replacement=get_replacement_phi3v,
) for image_token in image_tokens[:max_images]
]
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> MultiModalKwargs:
return dummy_mm_kwargs_for_phi3v(self.ctx, mm_counts)
) -> ProcessorInputs:
num_images = mm_counts["image"]
data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
hf_processor = self._get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
return ProcessorInputs(
prompt_text="".join(image_tokens[:num_images]),
mm_data=data,
mm_processor_kwargs={},
)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(Phi3VProcessor)
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@ -72,7 +72,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder
image_token_id = mm_encoder.special_ids.img
mm_config = ctx.model_config.multimodal_config
mm_config = ctx.get_mm_config()
num_images = mm_config.limit_per_prompt.get("image", 1)
# dummy size

View File

@ -99,7 +99,7 @@ class MultiModalPlugin(ABC):
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._input_mappers:
if self._input_mappers.contains(model_cls, strict=True):
logger.warning(
"Model class %s already has an input mapper "
"registered to %s. It is overwritten by the new one.",
@ -194,7 +194,7 @@ class MultiModalPlugin(ABC):
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._max_mm_tokens:
if self._max_mm_tokens.contains(model_cls, strict=True):
logger.warning(
"Model class %s already calculates maximum number of "
"tokens in %s. It is overwritten by the new one.",

View File

@ -1,116 +1,59 @@
import re
from abc import ABC, abstractmethod
from collections import UserDict
from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import lru_cache
from typing import (Any, Dict, Generic, NamedTuple, Optional, Protocol,
TypeVar, Union, cast)
from typing import Any, NamedTuple, Optional, Protocol, TypeVar, Union
import numpy as np
import torch
from PIL.Image import Image
from transformers import BatchFeature, ProcessorMixin
from typing_extensions import TypeAlias, TypedDict
from typing_extensions import assert_never
from vllm.inputs import DummyData, InputProcessingContext
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import (flatten_2d_lists, full_groupby, is_list_of,
resolve_mm_processor_kwargs)
from vllm.utils import flatten_2d_lists, full_groupby, is_list_of
from .inputs import (AudioItem, ImageItem, MultiModalDataDict,
MultiModalInputsV2, MultiModalKwargs, PlaceholderRange,
VideoItem)
logger = init_logger(__name__)
def bind_prompt_sequence(
seq: Union[str, list[int]],
tokenizer: AnyTokenizer,
) -> "_BoundPromptSequence":
"""
Bind a text or token sequence to a tokenizer so that it can be
lazily converted into the other format on demand.
"""
return _BoundPromptSequence(
tokenizer=tokenizer,
_text=seq if isinstance(seq, str) else None,
_token_ids=seq if isinstance(seq, list) else None,
)
_T = TypeVar("_T")
_S = TypeVar("_S", str, list[int])
_PromptSeq = Union[str, list[int]]
@dataclass
class PromptReplacement(Generic[_S, _T]):
target: _S
class PromptReplacement:
modality: str
"""The modality for which the replacement is made"""
target: _PromptSeq
"""The text or token sequence to find and replace."""
repl_unit: _S
replacement: Union[Callable[[int], _PromptSeq],
_PromptSeq] = field(repr=False)
"""
The unit making up the replacement text or token sequence.
See :code:`repl_count` for more details.
Given the index of the processed item within :attr:`modality`, output the
replacement text or token sequence.
For convenience, you can pass in the replacement instead of a function
if it does not depend on the input.
"""
repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int]
"""
Given the original multi-modal items for this modality, HF-processed data,
and index of the processed item, output the number of repetitions of
:code:`repl_unit` to build up the replacement text or token sequence.
For convenience, you can pass in an integer if the number of repetitions is
a constant.
"""
def __repr__(self) -> str:
return (f"{type(self).__name__}(target={self.target!r}, "
f"repl_unit={self.repl_unit!r})")
def bind(
self,
modality: str,
tokenizer: AnyTokenizer,
) -> "_BoundPromptReplacement[_T]":
def bind(self, tokenizer: AnyTokenizer) -> "_BoundPromptReplacement":
return _BoundPromptReplacement(
modality=modality,
target=bind_prompt_sequence(self.target, tokenizer),
repl_unit=bind_prompt_sequence(self.repl_unit, tokenizer),
repl_count=self.repl_count,
tokenizer=tokenizer,
modality=self.modality,
_target=self.target,
_replacement=self.replacement,
)
@dataclass
class ModalityProcessingMetadata(Generic[_T]):
prompt_repls: Sequence[Union[PromptReplacement[str, _T],
PromptReplacement[list[int], _T]]]
"""
Defines each text or token sequence to replace in the HF-processed prompt.
This is skipped if the HF-processed prompt is found to already contain
the replacement prompts.
"""
class MultiModalProcessingMetadataBuiltins(TypedDict, total=False):
"""Type annotations for modality types predefined by vLLM."""
image: ModalityProcessingMetadata[ImageItem]
video: ModalityProcessingMetadata[VideoItem]
audio: ModalityProcessingMetadata[AudioItem]
MultiModalProcessingMetadata: TypeAlias = \
Mapping[str, ModalityProcessingMetadata[Any]]
"""
A dictionary containing an entry for each modality type to process.
Note:
This dictionary also accepts modality keys defined outside
:class:`MultiModalProcessingMetadataBuiltins` as long as a customized plugin
is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`.
Read more on that :ref:`here <adding_multimodal_plugin>`.
"""
def _encode(
tokenizer: AnyTokenizer,
text: str,
@ -185,7 +128,8 @@ def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]:
@dataclass
class _BoundPromptSequence:
tokenizer: AnyTokenizer
tokenizer: AnyTokenizer = field(repr=False)
_text: Optional[str]
_token_ids: Optional[list[int]]
@ -210,38 +154,92 @@ class _BoundPromptSequence:
return self._token_ids
def __repr__(self) -> str:
return (f"{type(self).__name__}(_text={self._text!r}, "
f"_token_ids={self._token_ids!r})")
@dataclass
class _BoundPromptReplacement(Generic[_T]):
class _BoundPromptReplacement:
tokenizer: AnyTokenizer = field(repr=False)
modality: str
target: _BoundPromptSequence
repl_unit: _BoundPromptSequence
repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int]
def get_count(
self,
mm_items: list[_T],
hf_inputs: BatchFeature,
item_idx: int,
) -> int:
repl_count = self.repl_count
if isinstance(repl_count, int):
return repl_count
_target: _PromptSeq
_replacement: Union[Callable[[int], _PromptSeq],
_PromptSeq] = field(repr=False)
return repl_count(mm_items, hf_inputs, item_idx)
def __post_init__(self) -> None:
self._replacement_cache = dict[int, _BoundPromptSequence]()
@property
def target(self) -> _BoundPromptSequence:
target = self._target
return _BoundPromptSequence(
tokenizer=self.tokenizer,
_text=target if isinstance(target, str) else None,
_token_ids=target if isinstance(target, list) else None,
)
def get_replacement(self, item_idx: int) -> _BoundPromptSequence:
replacement = self._replacement
if callable(replacement):
cache_key = item_idx
if cache_key in self._replacement_cache:
return self._replacement_cache[cache_key]
replacement = replacement(item_idx)
else:
cache_key = None
bound_replacement = _BoundPromptSequence(
tokenizer=self.tokenizer,
_text=replacement if isinstance(replacement, str) else None,
_token_ids=replacement if isinstance(replacement, list) else None,
)
if cache_key is not None:
self._replacement_cache[cache_key] = bound_replacement
return bound_replacement
def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]:
class ImageSize(NamedTuple):
width: int
height: int
class MultiModalDataItems(UserDict[str, list[Any]]):
"""
Convert a :class:`MultiModalDataDict` containing single data items
to a :class:`MultiModalMultiDataDict` containing multiple data items
per entry.
As :class:`MultiModalDataDict`, but normalized such that each entry
corresponds to a list.
"""
multi_data = dict[str, list[Any]]()
@property
def image(self) -> list[ImageItem]:
return self["image"]
@property
def video(self) -> list[VideoItem]:
return self["video"]
@property
def audio(self) -> list[AudioItem]:
return self["audio"]
def get_image_size(self, item_idx: int) -> ImageSize:
image = self.image[item_idx]
if isinstance(image, Image):
return ImageSize(*image.size)
if isinstance(image, (np.ndarray, torch.Tensor)):
_, h, w = image.shape
return ImageSize(w, h)
assert_never(image)
def to_multi_format(data: MultiModalDataDict) -> MultiModalDataItems:
"""
Normalize :class:`MultiModalDataDict` to :class:`MultiModalDataItems`.
"""
multi_data = MultiModalDataItems()
for k, v in data.items():
# yapf: disable
@ -266,22 +264,33 @@ def iter_token_matches(
token_ids: list[int],
match_ids: list[int],
) -> Iterable[_TokenMatch]:
"""Yield each occurrence of :code:`match_ids` in :code:`token_ids`."""
"""
Yield each occurrence of :code:`match_ids` in :code:`token_ids`.
Note that empty matches are ignored.
"""
prompt_len = len(token_ids)
match_len = len(match_ids)
last_end_idx = 0
for start_idx in range(len(token_ids) - match_len + 1):
if start_idx < last_end_idx:
continue # Exclude overlapping matches
if match_len == 0:
return
start_idx = 0
while start_idx < prompt_len - match_len + 1:
end_idx = start_idx + match_len
if token_ids[start_idx:end_idx] == match_ids:
yield _TokenMatch(start_idx=start_idx, end_idx=end_idx)
last_end_idx = end_idx
# Exclude overlapping matches
start_idx = end_idx
else:
start_idx += 1
class _PromptReplacementMatch(ABC, Generic[_T, _S]):
prompt_repl: _BoundPromptReplacement[_T]
@dataclass(repr=False)
class _PromptReplacementMatch(ABC):
prompt_repl: _BoundPromptReplacement
@property
def modality(self) -> str:
@ -297,19 +306,13 @@ class _PromptReplacementMatch(ABC, Generic[_T, _S]):
def end_idx(self) -> int:
raise NotImplementedError
@property
@abstractmethod
def repl_unit(self) -> _S:
raise NotImplementedError
def __repr__(self) -> str:
return (f"{type(self).__name__}(modality={self.modality!r}, "
f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})")
@dataclass(repr=False)
class _PromptReplacementTokenMatch(_PromptReplacementMatch[_T, list[int]]):
prompt_repl: _BoundPromptReplacement[_T]
class _PromptReplacementTokenMatch(_PromptReplacementMatch):
match: _TokenMatch
@property
@ -320,14 +323,9 @@ class _PromptReplacementTokenMatch(_PromptReplacementMatch[_T, list[int]]):
def end_idx(self) -> int:
return self.match.end_idx
@property
def repl_unit(self) -> list[int]:
return self.prompt_repl.repl_unit.token_ids
@dataclass(repr=False)
class _PromptReplacementTextMatch(_PromptReplacementMatch[_T, str]):
prompt_repl: _BoundPromptReplacement[_T]
class _PromptReplacementTextMatch(_PromptReplacementMatch):
match: re.Match[str]
@property
@ -338,20 +336,15 @@ class _PromptReplacementTextMatch(_PromptReplacementMatch[_T, str]):
def end_idx(self) -> int:
return self.match.end()
@property
def repl_unit(self) -> str:
return self.prompt_repl.repl_unit.text
class _PlaceholderInfo(NamedTuple):
modality: str
start_idx: int
unit: list[int]
unit_count: int
replacement: list[int]
@property
def length(self) -> int:
return len(self.unit) * self.unit_count
return len(self.replacement)
def to_range(self) -> PlaceholderRange:
return PlaceholderRange(
@ -362,8 +355,8 @@ class _PlaceholderInfo(NamedTuple):
def find_token_matches(
prompt: list[int],
prompt_repls: Sequence[_BoundPromptReplacement[_T]],
) -> list[_PromptReplacementTokenMatch[_T]]:
prompt_repls: Sequence[_BoundPromptReplacement],
) -> list[_PromptReplacementTokenMatch]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return [
_PromptReplacementTokenMatch(prompt_repl, match)
@ -374,8 +367,8 @@ def find_token_matches(
def find_text_matches(
prompt: str,
prompt_repls: Sequence[_BoundPromptReplacement[_T]],
) -> list[_PromptReplacementTextMatch[_T]]:
prompt_repls: Sequence[_BoundPromptReplacement],
) -> list[_PromptReplacementTextMatch]:
"""Return each target of :code:`prompt_repls` found in :code:`prompt`."""
return [
_PromptReplacementTextMatch(prompt_repl, match)
@ -385,15 +378,15 @@ def find_text_matches(
def _resolve_matches(
prompt: _S,
matches: Sequence[_PromptReplacementMatch[_T, _S]],
) -> list[_PromptReplacementMatch[_T, _S]]:
prompt: _PromptSeq,
matches: Sequence[_PromptReplacementMatch],
) -> list[_PromptReplacementMatch]:
"""
Resolve :code:`matches` to ensure that there are no overlapping matches,
and sort them such that earlier matches take priority over later ones.
"""
seen_matches: list[Optional[_PromptReplacementMatch[_T, _S]]] \
= [None] * len(prompt)
seen_matches: list[Optional[_PromptReplacementMatch]] = [None
] * len(prompt)
for match in matches:
for idx in range(match.start_idx, match.end_idx):
@ -409,30 +402,34 @@ def _resolve_matches(
def _replace_matches(
prompt: _S,
matches: Sequence[_PromptReplacementMatch[_T, _S]],
mm_items_by_modality: Mapping[str, list[_T]],
hf_inputs: BatchFeature,
matches: Sequence[_PromptReplacementMatch],
mm_items: MultiModalDataItems,
) -> list[_S]:
out_seqs = list[_S]()
prev_end_idx = 0
next_idx_by_modality = {modality: 0 for modality in mm_items_by_modality}
next_idx_by_modality = {modality: 0 for modality in mm_items}
for match in _resolve_matches(prompt, matches):
modality = match.modality
mm_items = mm_items_by_modality[modality]
modal_items = mm_items[modality]
item_idx = next_idx_by_modality[modality]
if item_idx >= len(mm_items):
if item_idx >= len(modal_items):
continue
start_idx = match.start_idx
end_idx = match.end_idx
repl_unit = match.repl_unit
repl_info = match.prompt_repl
repl_count = repl_info.get_count(mm_items, hf_inputs, item_idx)
out_seqs.append(prompt[prev_end_idx:start_idx] +
repl_unit * repl_count)
repl_info = match.prompt_repl
replacement = repl_info.get_replacement(item_idx)
if isinstance(prompt, str):
repl_seq = replacement.text
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
else:
repl_seq = replacement.token_ids
out_seqs.append(prompt[prev_end_idx:start_idx] + repl_seq)
prev_end_idx = end_idx
next_idx_by_modality[modality] += 1
@ -443,92 +440,104 @@ def _replace_matches(
def replace_token_matches(
prompt: list[int],
matches: Sequence[_PromptReplacementMatch[_T, list[int]]],
mm_items_by_modality: Mapping[str, list[_T]],
hf_inputs: BatchFeature,
matches: Sequence[_PromptReplacementTokenMatch],
mm_items: MultiModalDataItems,
) -> list[int]:
"""Apply :code:`prompt_repls` to :code:`prompt`."""
if not matches:
return prompt
token_id_seqs = _replace_matches(
prompt,
matches,
mm_items_by_modality,
hf_inputs,
)
token_id_seqs = _replace_matches(prompt, matches, mm_items)
return flatten_2d_lists(token_id_seqs)
def replace_text_matches(
prompt: str,
matches: Sequence[_PromptReplacementMatch[_T, str]],
mm_items_by_modality: Mapping[str, list[_T]],
hf_inputs: BatchFeature,
matches: Sequence[_PromptReplacementTextMatch],
mm_items: MultiModalDataItems,
) -> str:
"""Apply :code:`prompt_repls` to :code:`prompt`."""
if not matches:
return prompt
texts = _replace_matches(
prompt,
matches,
mm_items_by_modality,
hf_inputs,
)
texts = _replace_matches(prompt, matches, mm_items)
return "".join(texts)
def _merge_placeholder_matches(
matches: Iterable[_PromptReplacementTokenMatch],
) -> Iterable[_PromptReplacementTokenMatch]:
current_match = None
def _iter_modality_placeholders(
prompt: list[int],
modality: str,
modality_repls: Sequence[_BoundPromptReplacement],
modal_items: list[Any],
) -> Iterable[_PlaceholderInfo]:
if len(modal_items) == 0:
return
for match in sorted(matches, key=lambda x: x.start_idx):
if current_match is None:
current_match = match
elif (current_match.prompt_repl == match.prompt_repl
and current_match.end_idx == match.start_idx):
current_match = _PromptReplacementTokenMatch(
current_match.prompt_repl,
match=_TokenMatch(current_match.start_idx, match.end_idx),
)
else:
yield current_match
current_match = match
prompt_len = len(prompt)
item_index = 0
if current_match is not None:
yield current_match
start_idx = 0
while start_idx < prompt_len:
found = False
for repl_info in modality_repls:
replacement = repl_info.get_replacement(item_index)
repl_tokens = replacement.token_ids
repl_len = len(repl_tokens)
end_idx = start_idx + repl_len
if repl_len == 0 or end_idx > prompt_len:
continue
if prompt[start_idx:end_idx] == repl_tokens:
yield _PlaceholderInfo(
modality=modality,
start_idx=start_idx,
replacement=repl_tokens,
)
item_index += 1
if item_index >= len(modal_items):
return
# Exclude overlapping matches
start_idx = end_idx
found = True
break
if not found:
start_idx += 1
def iter_placeholders(
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
prompt_repls: Sequence[_BoundPromptReplacement],
prompt: list[int],
*,
min_unit_count: int = 1,
mm_items: MultiModalDataItems,
) -> Iterable[_PlaceholderInfo]:
"""Yield each set of placeholder tokens found in :code:`token_ids`."""
if min_unit_count <= 0:
raise ValueError("`min_unit_count` must be a positive integer")
"""
Yield each set of placeholder tokens found in :code:`prompt`.
matches = (_PromptReplacementTokenMatch(prompt_repl, match)
for prompt_repl in prompt_repls
if len(repl_unit := prompt_repl.repl_unit.token_ids) > 0
for match in iter_token_matches(prompt, repl_unit))
Note that empty matches are ignored.
"""
repls_by_modality = dict(full_groupby_modality(prompt_repls))
for match in _merge_placeholder_matches(matches):
unit = match.repl_unit
placeholder = _PlaceholderInfo(
modality=match.modality,
start_idx=match.start_idx,
unit=unit,
unit_count=(match.end_idx - match.start_idx) // len(unit),
)
for modality, modal_items in mm_items.items():
if modality in repls_by_modality:
yield from _iter_modality_placeholders(
prompt,
modality,
repls_by_modality[modality],
modal_items,
)
if placeholder.unit_count >= min_unit_count:
yield placeholder
class ProcessorInputs(NamedTuple):
"""Keyword arguments to :meth:`BaseMultiModalProcessor`"""
prompt_text: str
mm_data: MultiModalDataDict
mm_processor_kwargs: Mapping[str, object]
class BaseMultiModalProcessor(ABC):
@ -536,27 +545,10 @@ class BaseMultiModalProcessor(ABC):
Abstract base class to process multi-modal inputs to be used in vLLM.
"""
def __init__(
self,
ctx: InputProcessingContext,
metadata: MultiModalProcessingMetadata,
) -> None:
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__()
self.ctx = ctx
self.metadata = metadata
self.init_mm_processor_kwargs = (ctx.model_config.mm_processor_kwargs
or {})
def _get_hf_processor(
self,
**mm_processor_kwargs: Mapping[str, object],
) -> ProcessorMixin:
# by default, we won't pass any kwargs to the processor initialization
return self.ctx.get_hf_processor()
def _get_tokenizer(self) -> AnyTokenizer:
return self.ctx.tokenizer
def __call__(
self,
@ -566,22 +558,42 @@ class BaseMultiModalProcessor(ABC):
) -> MultiModalInputsV2:
return self.apply(prompt, mm_data, mm_processor_kwargs)
def _get_hf_processor(self) -> ProcessorMixin:
"""
Subclasses can add keyword arguments to this method to accept
additional kwargs from model config or user inputs.
"""
return self.ctx.get_hf_processor()
def _get_tokenizer(self) -> AnyTokenizer:
return self.ctx.tokenizer
@abstractmethod
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
mm_processor_kwargs: Mapping[str, object],
) -> list[PromptReplacement]:
"""
Given the original multi-modal items for this modality
and HF-processed data, output the replacements to perform.
Note:
Even when the HF processor already performs replacement for us,
we still use this replacement information to determine
the placeholder token positions for each multi-modal item.
"""
raise NotImplementedError
def _find_placeholders(
self,
all_prompt_repls: Sequence[_BoundPromptReplacement[Any]],
all_prompt_repls: Sequence[_BoundPromptReplacement],
new_token_ids: list[int],
*,
# To avoid false positives from multi-input when detecting
# whether placeholder tokens have been inserted, in case
# the target sequence is a subset of the replacement tokens
min_unit_count: int = 16,
mm_items: MultiModalDataItems,
) -> list[_PlaceholderInfo]:
return list(
iter_placeholders(
all_prompt_repls,
new_token_ids,
min_unit_count=min_unit_count,
))
iter_placeholders(all_prompt_repls, new_token_ids, mm_items))
def _apply_hf_processor(
self,
@ -589,13 +601,7 @@ class BaseMultiModalProcessor(ABC):
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
# some mm_processor_kwargs may be used in processor initialization
# instead of processor call
processor_init_kwargs = {
**self.init_mm_processor_kwargs,
**mm_processor_kwargs,
}
hf_processor = self._get_hf_processor(**processor_init_kwargs)
hf_processor = self._get_hf_processor(**mm_processor_kwargs)
processor_data = dict[str, Any]()
passthrough_data = dict[str, Any]()
@ -615,11 +621,10 @@ class BaseMultiModalProcessor(ABC):
else:
processor_data[k] = v
# filter mm_processor_kwargs used in processor call
mm_processor_kwargs = resolve_mm_processor_kwargs(
self.init_mm_processor_kwargs,
cast(Dict[str, Any], mm_processor_kwargs),
assert callable(hf_processor)
mm_processor_kwargs = self.ctx.resolve_hf_processor_call_kwargs(
hf_processor,
mm_processor_kwargs,
)
try:
@ -642,26 +647,21 @@ class BaseMultiModalProcessor(ABC):
def _bind_prompt_replacements(
self,
mm_data: MultiModalDataDict,
) -> list[_BoundPromptReplacement[Any]]:
prompt_repls: list[PromptReplacement],
) -> list[_BoundPromptReplacement]:
tokenizer = self._get_tokenizer()
return [
prompt_repl.bind(modality, tokenizer)
for modality, metadata in self.metadata.items()
if modality in mm_data for prompt_repl in metadata.prompt_repls
]
return [prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls]
def _apply_prompt_replacements(
self,
mm_data: MultiModalDataDict,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
token_ids: list[int],
prompt_repls: Sequence[_BoundPromptReplacement[Any]],
prompt_repls: Sequence[_BoundPromptReplacement],
) -> tuple[list[int], str, list[_PlaceholderInfo]]:
tokenizer = self._get_tokenizer()
mm_items = to_multi_format(mm_data)
token_matches = find_token_matches(token_ids, prompt_repls)
# If the search text does not represent a special token,
@ -682,7 +682,6 @@ class BaseMultiModalProcessor(ABC):
token_ids,
token_matches,
mm_items,
hf_inputs,
)
text = _decode(tokenizer, token_ids)
@ -695,13 +694,13 @@ class BaseMultiModalProcessor(ABC):
text,
text_matches,
mm_items,
hf_inputs,
)
token_ids = _encode(tokenizer, text)
matched_repls = [match.prompt_repl for match in text_matches]
placeholders = self._find_placeholders(matched_repls, token_ids)
placeholders = self._find_placeholders(matched_repls, token_ids,
mm_items)
return token_ids, text, placeholders
@ -731,12 +730,16 @@ class BaseMultiModalProcessor(ABC):
prompt_ids, = hf_inputs.pop("input_ids").tolist()
mm_kwargs = MultiModalKwargs(hf_inputs)
all_prompt_repls = self._bind_prompt_replacements(mm_data)
mm_items = to_multi_format(mm_data)
prompt_repls = self._get_prompt_replacements(mm_items, hf_inputs,
mm_processor_kwargs)
all_prompt_repls = self._bind_prompt_replacements(prompt_repls)
# If HF processor already inserts placeholder tokens,
# there is no need for us to insert them
all_placeholders = self._find_placeholders(all_prompt_repls,
prompt_ids)
prompt_ids, mm_items)
if all_placeholders:
prompt_text = _decode(tokenizer, prompt_ids)
else:
@ -745,7 +748,7 @@ class BaseMultiModalProcessor(ABC):
prompt_text,
all_placeholders,
) = self._apply_prompt_replacements(
mm_data,
mm_items,
hf_inputs,
prompt_ids,
all_prompt_repls,
@ -765,13 +768,13 @@ class BaseMultiModalProcessor(ABC):
)
@abstractmethod
def _get_dummy_mm_kwargs(
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> MultiModalKwargs:
) -> ProcessorInputs:
"""
Build the input that corresponds to `mm_max_tokens` in
:meth:`get_dummy_data`.
Build the multi-modal portion of the input which, after processing,
results in `mm_max_tokens` in :meth:`get_dummy_data`.
"""
raise NotImplementedError
@ -784,38 +787,41 @@ class BaseMultiModalProcessor(ABC):
# Avoid circular import
from vllm.sequence import SequenceData
tokenizer = self._get_tokenizer()
processor_inputs = self._get_dummy_mm_inputs(mm_counts)
mm_inputs = self.apply(*processor_inputs)
mm_placeholders = dict[str, _PlaceholderInfo]()
offset = 0
prompt_token_ids = mm_inputs["prompt_token_ids"]
placeholders_by_modality = mm_inputs["mm_placeholders"]
for modality, max_tokens in mm_max_tokens.items():
if max_tokens == 0:
continue
total_placeholders_by_modality = dict[str, int]()
for modality, placeholders in placeholders_by_modality.items():
num_placeholders = sum(item["length"] for item in placeholders)
max_tokens = mm_max_tokens[modality]
metadata = self.metadata[modality]
repl = metadata.prompt_repls[0].bind(modality, tokenizer)
repl_token_ids = repl.repl_unit.token_ids
if num_placeholders != max_tokens:
logger.warning(
"The processed dummy data has a total of %d placeholder "
"tokens for the '%s' modality, which is not the expected "
"%d tokens.", num_placeholders, modality, max_tokens)
placeholders = _PlaceholderInfo(
modality=modality,
start_idx=offset,
unit=repl_token_ids,
unit_count=max_tokens // len(repl_token_ids),
)
total_placeholders_by_modality[modality] = num_placeholders
mm_placeholders[modality] = placeholders
offset += placeholders.length
total_len = len(prompt_token_ids)
if total_len > seq_len:
logger.warning(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
"(%d tokens in total, out of which %s are reserved for "
"multi-modal embeddings). This may cause certain multi-modal "
"inputs to fail during inference, even when the input text is "
"short. To avoid this, you should increase `max_model_len`, "
"reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len,
total_len, total_placeholders_by_modality)
prompt_token_ids = flatten_2d_lists(
[p.unit * p.unit_count for p in mm_placeholders.values()])
prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids)))
return DummyData(
seq_data=SequenceData.from_seqs(prompt_token_ids),
multi_modal_data=self._get_dummy_mm_kwargs(mm_counts),
multi_modal_placeholders={
modality: [p.to_range()]
for modality, p in mm_placeholders.items()
},
multi_modal_data=mm_inputs["mm_kwargs"],
multi_modal_placeholders=placeholders_by_modality,
)

View File

@ -299,9 +299,9 @@ class MultiModalRegistry:
"""
def wrapper(model_cls: N) -> N:
if model_cls in self._processor_factories:
if self._processor_factories.contains(model_cls, strict=True):
logger.warning(
"Model class %s already has an input mapper "
"Model class %s already has a multi-modal processor "
"registered to %s. It is overwritten by the new one.",
model_cls, self)

View File

@ -1370,8 +1370,8 @@ def supports_kw(
def resolve_mm_processor_kwargs(
init_kwargs: Optional[Dict[str, Any]],
inference_kwargs: Optional[Dict[str, Any]],
init_kwargs: Optional[Mapping[str, object]],
inference_kwargs: Optional[Mapping[str, object]],
callable: Callable[..., object],
allow_var_kwargs: bool = False,
) -> Dict[str, Any]:
@ -1405,7 +1405,7 @@ def resolve_mm_processor_kwargs(
def get_allowed_kwarg_only_overrides(
callable: Callable[..., object],
overrides: Optional[Dict[str, Any]],
overrides: Optional[Mapping[str, object]],
allow_var_kwargs: bool = False,
) -> Dict[str, Any]:
"""
@ -1524,9 +1524,15 @@ class ClassRegistry(UserDict[Type[T], _V]):
raise KeyError(key)
def __contains__(self, key: object) -> bool:
return self.contains(key)
def contains(self, key: object, *, strict: bool = False) -> bool:
if not isinstance(key, type):
return False
if strict:
return key in self.data
return any(cls in self.data for cls in key.mro())