[VLM] Merged multi-modal processor for InternVL-based models (#12553)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Cyrus Leung
2025-02-04 16:44:52 +08:00
committed by GitHub
parent 96b23621c1
commit d1ca7df84d
34 changed files with 1469 additions and 1021 deletions

View File

@ -250,7 +250,11 @@ def get_max_image_tokens(self) -> int:
And thus, we can override the method as:
```python
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
```

View File

@ -726,7 +726,7 @@ See [this page](#generative-models) for more information on how to use generativ
* `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc.
*
* ✅︎
*
* \*
- * `Idefics3ForConditionalGeneration`
* Idefics3
* T + I
@ -799,7 +799,7 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
- * `NVLM_D_Model`
* NVLM-D 1.0
* T + I<sup>E+</sup>
* T + I<sup>+</sup>
* `nvidia/NVLM-D-72B`, etc.
*
* ✅︎
@ -859,7 +859,11 @@ See [this page](#generative-models) for more information on how to use generativ
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.
:::{note}
To use `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
To use DeepSeek-VL2 series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
:::
:::{note}
H2O-VL series models will be available in V1 once we support backends other than FlashAttention.
:::
:::{note}

View File

@ -1,131 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
import pytest
import torch
from PIL.Image import Image
from transformers import AutoConfig
# Import the functions to test
from vllm.model_executor.models.h2ovl import (calculate_num_blocks,
image_to_pixel_values_wrapper)
from vllm.multimodal.image import rescale_image_size
models = [
"h2oai/h2ovl-mississippi-800m", # Replace with your actual model names
"h2oai/h2ovl-mississippi-2b",
]
def run_preprocessing_test(
image: Image,
config,
max_dynamic_patch: Optional[int] = None,
) -> Tuple[torch.Tensor, int]:
"""Test the image preprocessing and calculate expected blocks."""
if max_dynamic_patch is None:
max_dynamic_patch = config.max_dynamic_patch
width, height = image.size
use_MSAC = config.use_msac
# Create the mapper function with the provided configuration
mapper = image_to_pixel_values_wrapper(config, max_dynamic_patch, use_MSAC)
pixel_values = mapper(image)
# Calculate the expected number of blocks
if use_MSAC:
# First pass
blocks1, _, _, aspect_ratio = calculate_num_blocks(
width,
height,
config.min_dynamic_patch,
max_dynamic_patch,
config.vision_config.image_size,
use_thumbnail=False, # Thumbnail is handled separately
prior_aspect_ratio=None,
)
# Second pass
blocks2, _, _, _ = calculate_num_blocks(
width,
height,
config.min_dynamic_patch,
max_dynamic_patch,
config.vision_config.image_size,
use_thumbnail=False,
prior_aspect_ratio=aspect_ratio,
)
# Add thumbnail if use_thumbnail is True and total_blocks > 1
if config.use_thumbnail:
blocks1 += 1 if blocks1 > 1 else 0
blocks2 += 1 if blocks2 > 1 else 0
# Total blocks is the sum of blocks from both passes minus overlapping
total_blocks = blocks1 + blocks2 - 1
expected_blocks = total_blocks
else:
blocks, _, _, _ = calculate_num_blocks(
width,
height,
config.min_dynamic_patch,
max_dynamic_patch,
config.vision_config.image_size,
use_thumbnail=False,
prior_aspect_ratio=None,
)
expected_blocks = blocks
if config.use_thumbnail and expected_blocks > 1:
expected_blocks += 1
return pixel_values, expected_blocks
@pytest.mark.parametrize("model_name", models)
@pytest.mark.parametrize(
"size_factors",
[
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("max_dynamic_patch", [None, 2, 4, 8])
def test_image_preprocessing(image_assets, model_name, size_factors,
max_dynamic_patch):
"""Test image preprocessing pipeline with different configurations."""
# Load the configuration from the model
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
for asset in image_assets:
image = asset.pil_image
for factor in size_factors:
scaled_image = rescale_image_size(image, factor)
# Test preprocessing and get expected number of blocks
pixel_values, expected_blocks = run_preprocessing_test(
scaled_image, config, max_dynamic_patch)
# Verify output shapes and properties
actual_blocks = pixel_values.shape[0]
assert actual_blocks == expected_blocks, (
f"Expected {expected_blocks} blocks, got {actual_blocks}")
# Check image dimensions
expected_size = (
3, # Number of channels (C, H, W)
config.vision_config.image_size,
config.vision_config.image_size,
)
for img in pixel_values:
assert img.shape == expected_size, (
f"Expected image size {expected_size}, got {img.shape}")

View File

@ -250,6 +250,7 @@ VLM_TEST_SETTINGS = {
max_model_len=8192,
dtype="bfloat16",
use_tokenizer_eos=True,
num_logprobs=10,
patch_hf_runner=model_utils.h2ovl_patch_hf_runner,
),
"idefics3": VLMTestInfo(
@ -282,7 +283,6 @@ VLM_TEST_SETTINGS = {
dtype="bfloat16",
use_tokenizer_eos=True,
patch_hf_runner=model_utils.internvl_patch_hf_runner,
marks=[large_gpu_mark(min_gb=32)],
),
"llava_next": VLMTestInfo(
models=["llava-hf/llava-v1.6-mistral-7b-hf"],

View File

@ -334,12 +334,12 @@ def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
def __init__(self, hf_runner: HfRunner):
self.num_image_token = hf_runner.model.num_image_token
self.tokenizer = hf_runner.tokenizer
self.dtype = hf_runner.model.dtype
self.config = AutoConfig.from_pretrained(hf_runner.model_name,
trust_remote_code=True)
self.vision_config = self.config.vision_config
self.use_thumbnail = self.config.use_thumbnail
self.use_msac = self.config.use_msac
self.min_num = self.config.min_dynamic_patch
self.max_num = self.config.max_dynamic_patch
self.image_size = self.vision_config.image_size
@ -348,18 +348,19 @@ def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
**kwargs):
# yapf: disable
from vllm.model_executor.models.h2ovl import (
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values_h2ovl)
# yapf: enable
images = [images] if isinstance(images, Image) else images
pixel_values = [
image_to_pixel_values(image,
self.image_size,
self.min_num,
self.max_num,
self.use_thumbnail,
use_MSAC=self.config.use_msac).to(
self.dtype) for image in images
image_to_pixel_values_h2ovl(
image,
input_size=self.image_size,
min_num=self.min_num,
max_num=self.max_num,
use_thumbnail=self.use_thumbnail,
use_msac=self.use_msac,
) for image in images
]
num_patches_list = [
pixel_value.shape[0] for pixel_value in pixel_values
@ -394,7 +395,6 @@ def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
def __init__(self, hf_runner: HfRunner):
self.num_image_token = hf_runner.model.num_image_token
self.tokenizer = hf_runner.tokenizer
self.dtype = hf_runner.model.dtype
self.config = AutoConfig.from_pretrained(hf_runner.model_name,
trust_remote_code=True)
@ -407,13 +407,17 @@ def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
def __call__(self, text: str, images: Union[Image, List[Image]],
**kwargs):
from vllm.model_executor.models.internvl import (
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
IMG_CONTEXT, IMG_END, IMG_START,
image_to_pixel_values_internvl)
images = [images] if isinstance(images, Image) else images
pixel_values = [
image_to_pixel_values(image, self.image_size, self.min_num,
self.max_num,
self.use_thumbnail).to(self.dtype)
for image in images
image_to_pixel_values_internvl(
image,
input_size=self.image_size,
min_num=self.min_num,
max_num=self.max_num,
use_thumbnail=self.use_thumbnail,
) for image in images
]
num_patches_list = [
pixel_value.shape[0] for pixel_value in pixel_values
@ -448,7 +452,8 @@ def _internvl_generate(
) -> torch.LongTensor:
"""Generate method for InternVL2 model without fixed use_cache."""
assert self.img_context_token_id is not None
vit_embeds = self.extract_feature(pixel_values)
target_dtype = next(self.parameters()).dtype
vit_embeds = self.extract_feature(pixel_values.to(target_dtype))
input_embeds = self.language_model.get_input_embeddings()(input_ids)
B, N, C = input_embeds.shape
input_embeds = input_embeds.reshape(B * N, C)

View File

@ -141,13 +141,14 @@ def _test_processing_correctness(
# yapf: disable
# True if the model supports multiple data items of the modality per request
@pytest.mark.parametrize("model_id", [
"rhymes-ai/Aria",
"Salesforce/blip2-opt-2.7b",
"facebook/chameleon-7b",
"deepseek-ai/deepseek-vl2-tiny",
"adept/fuyu-8b",
"h2oai/h2ovl-mississippi-800m",
"OpenGVLab/InternVL2-1B",
"llava-hf/llava-1.5-7b-hf",
"llava-hf/llava-v1.6-mistral-7b-hf",
"llava-hf/LLaVA-NeXT-Video-7B-hf",
@ -156,6 +157,7 @@ def _test_processing_correctness(
"mistral-community/pixtral-12b",
"openbmb/MiniCPM-o-2_6",
"openbmb/MiniCPM-V-2_6",
"nvidia/NVLM-D-72B",
"Qwen/Qwen-VL-Chat",
"Qwen/Qwen2-VL-2B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",

View File

@ -0,0 +1,142 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for H2OVL's multimodal preprocessing kwargs."""
from typing import Optional
import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.utils import cached_get_tokenizer
from ....conftest import _ImageAssets
from ...utils import build_model_context
@pytest.mark.parametrize("model_id", [
"h2oai/h2ovl-mississippi-800m",
"h2oai/h2ovl-mississippi-2b",
])
@pytest.mark.parametrize(
"size_factors",
[
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("max_dynamic_patch", [1, 2, 4, 8])
@pytest.mark.parametrize("dynamic_image_size", [True, False])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(
model_id: str,
image_assets: _ImageAssets,
size_factors: list[int],
max_dynamic_patch: int,
dynamic_image_size: Optional[bool],
num_imgs: int,
):
from vllm.model_executor.models.h2ovl import (calculate_h2ovl_targets,
get_h2ovl_target_ratios)
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
trust_remote_code=True,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,
)
config = processor.info.get_hf_config()
use_msac = config.use_msac
mm_processor_kwargs = {
"max_dynamic_patch": max_dynamic_patch,
}
if dynamic_image_size is not None:
mm_processor_kwargs["dynamic_image_size"] = dynamic_image_size
min_num = config.min_dynamic_patch
max_num = max_dynamic_patch if dynamic_image_size else 1
# Build the image str / prompt based on the number of images we pass
prompt = "<image>" * num_imgs
for asset in image_assets:
for factor in size_factors:
image = rescale_image_size(asset.pil_image, factor)
mm_data = {"image": [image] * num_imgs}
width, height = image.size
# Calculate the expected number of blocks
if num_imgs == 1 and use_msac:
# First pass
blocks1, _, _, aspect_ratio = calculate_h2ovl_targets(
orig_width=width,
orig_height=height,
target_ratios=get_h2ovl_target_ratios(
min_num,
max_num,
prior_aspect_ratio=None,
),
image_size=config.vision_config.image_size,
use_thumbnail=False, # Thumbnail is handled separately
)
# Second pass
blocks2, _, _, _ = calculate_h2ovl_targets(
orig_width=width,
orig_height=height,
target_ratios=get_h2ovl_target_ratios(
min_num,
max_num,
prior_aspect_ratio=aspect_ratio,
),
image_size=config.vision_config.image_size,
use_thumbnail=False,
)
# Add thumbnail if use_thumbnail is True and total_blocks > 1
if config.use_thumbnail:
blocks1 += 1 if blocks1 > 1 else 0
blocks2 += 1 if blocks2 > 1 else 0
# Total blocks is the sum of blocks from both passes minus
# overlapping
total_blocks = blocks1 + blocks2 - 1
expected_num_patches = total_blocks
else:
blocks, _, _, _ = calculate_h2ovl_targets(
orig_width=width,
orig_height=height,
target_ratios=get_h2ovl_target_ratios(
min_num,
max_num,
prior_aspect_ratio=None,
),
image_size=config.vision_config.image_size,
use_thumbnail=False,
)
expected_num_patches = blocks
if config.use_thumbnail and expected_num_patches != 1:
expected_num_patches += 1
processed_inputs = processor.apply(prompt, mm_data,
mm_processor_kwargs)
pixel_shape = (
processed_inputs["mm_kwargs"]["pixel_values_flat"].shape)
assert pixel_shape[0] == expected_num_patches * num_imgs

View File

@ -1,207 +1,64 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for InternVL's multimodal preprocessing kwargs."""
from typing import Callable, Optional
from typing import Optional
import pytest
from transformers import AutoTokenizer
from vllm.inputs import InputContext, token_inputs
from vllm.multimodal import MultiModalRegistry
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.utils import cached_get_tokenizer
from ....conftest import _ImageAssets
from ...utils import build_model_context
models = ["OpenGVLab/InternVL2-2B"]
# Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture()
def input_processor_for_internvl():
from vllm.model_executor.models.internvl import InternVLInputPipeline
pipeline = InternVLInputPipeline('<img>', '</img>', '<IMG_CONTEXT>')
return pipeline.input_processor
@pytest.fixture()
def dummy_data_for_internvl():
from vllm.model_executor.models.internvl import InternVLInputPipeline
pipeline = InternVLInputPipeline('<img>', '</img>', '<IMG_CONTEXT>')
return pipeline.dummy_data
@pytest.fixture()
def get_max_internvl_image_tokens():
from vllm.model_executor.models.internvl import (
get_max_internvl_image_tokens)
return get_max_internvl_image_tokens
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("model_id", ["OpenGVLab/InternVL2-2B"])
@pytest.mark.parametrize("max_dynamic_patch", [1, 4])
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
def test_input_mapper_override(
model: str,
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(
model_id: str,
image_assets: _ImageAssets,
max_dynamic_patch: int,
dynamic_image_size: Optional[bool],
num_imgs: int,
):
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
trust_remote_code=True,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,
)
mm_processor_kwargs = {
"max_dynamic_patch": max_dynamic_patch,
}
if dynamic_image_size is not None:
mm_processor_kwargs["dynamic_image_size"] = dynamic_image_size
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
if dynamic_image_size is False:
expected_num_patches = 1
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
)
mm_registry = MultiModalRegistry()
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
image = image_assets[0].pil_image.resize((448 * 2, 448 * 2))
vllm_result = mm_registry.map_input(
ctx.model_config,
{"image": image},
)
assert vllm_result["pixel_values"].size(1) == expected_num_patches
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("max_dynamic_patch", [1, 4, None])
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
def test_max_tokens_override(
get_max_internvl_image_tokens: Callable,
model: str,
max_dynamic_patch: Optional[int],
dynamic_image_size: Optional[bool],
):
"""Ensure get_max_internvl_image_tokens handles mm_processor_kwargs."""
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)
if max_dynamic_patch is None:
max_dynamic_patch = ctx.get_hf_config().max_dynamic_patch
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
if dynamic_image_size is False:
expected_num_patches = 1
expected_max_tokens = 256 * expected_num_patches
actual_max_tokens = get_max_internvl_image_tokens(
ctx=InputContext(ctx.model_config),
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
assert expected_max_tokens == actual_max_tokens
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("num_imgs", [1, 2])
@pytest.mark.parametrize("max_dynamic_patch", [1, 4, None])
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
def test_dummy_data_override(
dummy_data_for_internvl: Callable,
model: str,
num_imgs: int,
max_dynamic_patch: Optional[int],
dynamic_image_size: Optional[bool],
):
"""Ensure dummy_data_for_internvl handles kwargs properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the dummy data func.
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)
if max_dynamic_patch is None:
max_dynamic_patch = ctx.get_hf_config().max_dynamic_patch
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
if dynamic_image_size is False:
expected_num_patches = 1
expected_max_tokens = 256 * expected_num_patches
dummy_data = dummy_data_for_internvl(
ctx=ctx,
seq_len=8192, # Should be bigger than num_imgs * toks_per_img
mm_counts={"image": num_imgs},
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
sequence_data = dummy_data.seq_data
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
image_token_id = tokenizer.encode('<IMG_CONTEXT>',
add_special_tokens=False)[0]
# Ensure we have the right number of placeholders per size
img_tok_count = sequence_data.get_token_ids().count(image_token_id)
assert img_tok_count == expected_max_tokens * num_imgs
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("max_dynamic_patch", [1, 4])
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_input_processor_override(
input_processor_for_internvl: Callable,
image_assets: _ImageAssets,
model: str,
num_imgs: int,
max_dynamic_patch: int,
dynamic_image_size: Optional[bool],
):
"""Ensure input_processor_for_internvl handles kwargs properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the custom input processor.
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
if dynamic_image_size is False:
expected_num_patches = 1
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
trust_remote_code=True,
mm_processor_kwargs=None,
)
expected_toks_per_img = 256 * expected_num_patches
# Build the image str / prompt based on the number of images we pass
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
placeholders = "<image>" if num_imgs == 1 else "\n".join(
f"Image-{i}: <image>\n" for i in range(1, num_imgs + 1))
prompt = placeholders
images = [image_assets[0].pil_image.resize((448 * 2, 448 * 2))] * num_imgs
prompt = "<image>" * num_imgs
image = image_assets[0].pil_image.resize((448 * 2, 448 * 2))
mm_data = {"image": [image] * num_imgs}
inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
prompt=prompt,
multi_modal_data={"image": images})
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
if dynamic_image_size is False:
expected_num_patches = 1
processed_inputs = input_processor_for_internvl(
ctx,
inputs,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
# Ensure we have the right number of placeholders per num_crops size
image_token_id = tokenizer.encode('<IMG_CONTEXT>',
add_special_tokens=False)[0]
image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
assert img_tok_count == expected_toks_per_img * num_imgs
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape
assert img_tok_count == 256 * expected_num_patches * num_imgs
assert pixel_shape[0] == expected_num_patches * num_imgs

View File

@ -43,7 +43,10 @@ def test_processor_max_tokens(model_id):
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
tokenizer=cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
),
)
info = processor.info
@ -143,7 +146,10 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs):
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
tokenizer=cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
),
)
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
@ -173,7 +179,10 @@ def test_processor_prompt_replacements_all(model_id, num_imgs):
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
tokenizer=cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
),
)
seen_aspect_ratios = set[float]()

View File

@ -44,7 +44,10 @@ def test_processor_max_tokens(model_id):
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
tokenizer=cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
),
)
info = processor.info
@ -143,7 +146,10 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs):
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
tokenizer=cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
),
)
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
@ -174,7 +180,10 @@ def test_processor_prompt_replacements_all(model_id, num_imgs):
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
tokenizer=cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
),
)
seen_aspect_ratios = set[float]()

View File

@ -38,7 +38,10 @@ def test_processor_override(
trust_remote_code=True,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,

View File

@ -33,7 +33,10 @@ def test_processor_override(
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
tokenizer = cached_get_tokenizer(
ctx.model_config.tokenizer,
trust_remote_code=ctx.model_config.trust_remote_code,
)
processor = MULTIMODAL_REGISTRY.create_processor(
ctx.model_config,
tokenizer=tokenizer,

View File

@ -399,7 +399,11 @@ class AriaProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:

View File

@ -407,7 +407,11 @@ class Blip2ProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:

View File

@ -64,7 +64,11 @@ class ChameleonProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:

View File

@ -165,7 +165,11 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
image_width=x[1], image_height=x[0]))
return ImageSize(width=width, height=height)
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
max_image_size = self.get_image_size_with_most_features()
max_image_tokens = self.get_num_image_tokens(
image_height=max_image_size.height,

View File

@ -80,7 +80,11 @@ class FuyuProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_ncols, max_nrows = self.get_image_feature_grid_size(

View File

@ -7,43 +7,55 @@
# Copyright (c) 2024 H2O.AI
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
from functools import partial
from typing import List, Optional, Tuple
from typing import Mapping, Optional
import torch
from PIL import Image
from transformers import PretrainedConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.utils import is_list_of
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems)
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
PromptReplacementDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .intern_vit import InternVisionModel
from .internvl import (IMG_CONTEXT, IMG_END, IMG_START, InternVLChatModel,
InternVLInputPipeline, build_transform,
find_closest_aspect_ratio, get_internvl_num_patches)
from .internvl import (IMG_CONTEXT, IMG_END, IMG_START,
BaseInternVLProcessingInfo, BaseInternVLProcessor,
InternVLChatModel, InternVLDummyInputsBuilder,
InternVLMultiModalProcessor, build_transform,
find_closest_aspect_ratio, get_internvl_target_ratios)
logger = init_logger(__name__)
# modified to include blocks generated in second pass
def calculate_num_blocks(
orig_width: int,
orig_height: int,
def resolve_h2ovl_min_max_num(
*,
min_dynamic_patch: int,
max_dynamic_patch: int,
dynamic_image_size: bool,
use_thumbnail: bool,
) -> tuple[int, int]:
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
if use_thumbnail and max_dynamic_patch != 1:
max_dynamic_patch += 1
return min_dynamic_patch, max_dynamic_patch
def get_h2ovl_target_ratios(
min_num: int,
max_num: int,
image_size: int,
use_thumbnail: bool,
prior_aspect_ratio=None,
) -> Tuple[int, int, int, Tuple[int, int]]:
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set((i, j) for n in range(min_num, max_num + 1)
for i in range(1, n + 1) for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
*,
prior_aspect_ratio: Optional[tuple[int, int]],
) -> list[tuple[int, int]]:
target_ratios = get_internvl_target_ratios(min_num, max_num)
# if prior_aspect_ratio is provided, filter the target ratios
if prior_aspect_ratio is not None:
@ -52,44 +64,66 @@ def calculate_num_blocks(
ratio[0] != 0 and prior_aspect_ratio[1] % ratio[1] != 0
]
return target_ratios
# modified to include blocks generated in second pass
def calculate_h2ovl_targets(
*,
orig_width: int,
orig_height: int,
target_ratios: list[tuple[int, int]],
image_size: int,
use_thumbnail: bool,
) -> tuple[int, int, int, tuple[int, int]]:
aspect_ratio = orig_width / orig_height
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
target_ratios, orig_width,
orig_height, image_size)
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio,
target_ratios,
width=orig_width,
height=orig_height,
image_size=image_size,
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# add thumbnail image if num_blocks > 1
if use_thumbnail and blocks > 1:
# add thumbnail image if num_blocks != 1
if use_thumbnail and blocks != 1:
blocks += 1
return blocks, target_width, target_height, target_aspect_ratio
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
# refactored to handle prior_aspect_ratio as optional
def dynamic_preprocess(
# refactored to handle prior_aspect_ratio
def dynamic_preprocess_h2ovl(
image: Image.Image,
min_num: int,
max_num: int,
*,
target_ratios: list[tuple[int, int]],
image_size: int,
use_thumbnail: bool,
prior_aspect_ratio: Optional[Tuple[int, int]] = None,
) -> Tuple[List[Image.Image], Tuple[int, int]]:
) -> tuple[list[Image.Image], tuple[int, int]]:
orig_width, orig_height = image.size
# calculate the number of blocks based on prior aspect ratio if available
blocks, target_width, target_height, target_aspect_ratio = (
calculate_num_blocks(
orig_width,
orig_height,
min_num,
max_num,
image_size,
use_thumbnail=False,
prior_aspect_ratio=prior_aspect_ratio,
))
# calculate the number of blocks without thumbnail
(
blocks,
target_width,
target_height,
target_aspect_ratio,
) = calculate_h2ovl_targets(
orig_width=orig_width,
orig_height=orig_height,
target_ratios=target_ratios,
image_size=image_size,
use_thumbnail=False,
)
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
@ -103,276 +137,393 @@ def dynamic_preprocess(
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images, target_aspect_ratio
def load_image(
image: Image.Image,
input_size=448,
min_num=1,
max_num=6,
use_thumbnail=True,
prior_aspect_ratio: Optional[Tuple[int, int]] = None,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
transform = build_transform(input_size=input_size)
images, target_aspect_ratio = dynamic_preprocess(
image,
image_size=input_size,
use_thumbnail=use_thumbnail,
min_num=min_num,
max_num=max_num,
prior_aspect_ratio=prior_aspect_ratio,
)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values, target_aspect_ratio
# refactored to use the combined load_image function
def image_to_pixel_values(
def _preprocess_image(
image: Image.Image,
*,
input_size: int,
min_num: int,
max_num: int,
use_thumbnail: bool,
use_MSAC: bool,
prior_aspect_ratio: Optional[tuple[int, int]],
) -> tuple[torch.Tensor, tuple[int, int]]:
target_ratios = get_h2ovl_target_ratios(
min_num,
max_num,
prior_aspect_ratio=prior_aspect_ratio,
)
transform = build_transform(input_size=input_size)
images, target_aspect_ratio = dynamic_preprocess_h2ovl(
image,
image_size=input_size,
use_thumbnail=use_thumbnail,
target_ratios=target_ratios,
)
pixel_values = torch.stack([transform(image) for image in images])
return pixel_values, target_aspect_ratio
# refactored to use the _preprocess_image function
def image_to_pixel_values_h2ovl(
image: Image.Image,
*,
input_size: int,
min_num: int,
max_num: int,
use_thumbnail: bool,
use_msac: bool,
) -> torch.Tensor:
# when MSAC is turned on, we need to process the image twice
if use_MSAC:
if use_msac:
# first pass
pixel_values, target_aspect_ratio = load_image(
pixel_values1, aspect_ratio1 = _preprocess_image(
image,
input_size=input_size,
min_num=min_num,
max_num=max_num,
use_thumbnail=True,
prior_aspect_ratio=None,
)
# second pass
pixel_values2, _ = load_image(
pixel_values2, _ = _preprocess_image(
image,
input_size=input_size,
min_num=min_num,
min_num=3, # Hardcoded value
max_num=max_num,
prior_aspect_ratio=target_aspect_ratio,
use_thumbnail=True,
prior_aspect_ratio=aspect_ratio1,
)
# combine pixel values
pixel_values = torch.cat(
[pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], 0)
[pixel_values2[:-1], pixel_values1[:-1], pixel_values2[-1:]], 0)
else:
pixel_values, _ = load_image(
pixel_values, _ = _preprocess_image(
image,
input_size=input_size,
min_num=min_num,
max_num=max_num,
use_thumbnail=use_thumbnail,
prior_aspect_ratio=None,
)
return pixel_values
def image_to_pixel_values_wrapper(hf_config: PretrainedConfig,
max_dynamic_patch: Optional[int] = None,
use_MSAC: Optional[bool] = None):
image_size = hf_config.vision_config.image_size
min_num = hf_config.min_dynamic_patch
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
if use_MSAC is None:
use_MSAC = hf_config.use_msac
use_thumbnail = hf_config.use_thumbnail
return partial(
image_to_pixel_values,
input_size=image_size,
min_num=min_num,
max_num=max_dynamic_patch,
use_thumbnail=use_thumbnail,
use_MSAC=use_MSAC,
)
class H2OVLProcessor(BaseInternVLProcessor):
def get_max_internvl_image_tokens(ctx: InputContext,
*,
max_dynamic_patch: Optional[int] = None):
"""
Calculate the maximum number of tokens with/without MSAC and thumbnail
"""
hf_config = ctx.get_hf_config()
use_thumbnail = hf_config.use_thumbnail
use_MSAC = hf_config.use_msac
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
num_patches = get_internvl_num_patches(hf_config)
coefficient = 2 if use_MSAC else 1
num_blocks = coefficient * max_dynamic_patch + (1 if use_thumbnail else 0)
return num_blocks * num_patches
class H2OVLInputPipeline(InternVLInputPipeline):
"""
Input pipeline for processing image and text data for the H2OVL model.
"""
def input_processor(
def __init__(
self,
ctx: InputContext,
inputs: DecoderOnlyInputs,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
*,
max_dynamic_patch: Optional[int] = None,
) -> DecoderOnlyInputs:
# get multi_modal_data
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
model_config = ctx.model_config
hf_config = ctx.get_hf_config()
use_MSAC = hf_config.use_msac
image_data = multi_modal_data["image"]
num_patches = get_internvl_num_patches(hf_config)
image_pixel_values_mapper = image_to_pixel_values_wrapper(
hf_config, max_dynamic_patch=max_dynamic_patch)
# single image
if isinstance(image_data, Image.Image):
pixel_values = image_pixel_values_mapper(image_data,
use_MSAC=use_MSAC)
num_blocks = pixel_values.shape[0]
image_feature_sizes = [num_blocks * num_patches]
pixel_values = pixel_values.unsqueeze(0)
# multi images
elif is_list_of(image_data, Image.Image):
# Do not use MSAC for multi images
image_feature_sizes = []
pixel_values = [
image_pixel_values_mapper(image, use_MSAC=False)
for image in image_data
]
for pixel_value in pixel_values:
num_blocks = pixel_value.shape[0]
image_feature_sizes.append(num_blocks * num_patches)
# image embeddings as input
elif isinstance(image_data, torch.Tensor):
_, image_feature_size, _ = image_data.shape
image_feature_sizes = [image_feature_size]
pixel_values = None
# multi-image image embeddings
elif is_list_of(image_data, torch.Tensor):
image_feature_sizes = []
for image_embed in image_data:
_, image_feature_size, _ = image_embed.shape
image_feature_sizes.append(image_feature_size)
pixel_values = None
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
dynamic_image_size: Optional[bool] = None,
use_msac: Optional[bool] = None,
) -> None:
super().__init__(
config,
tokenizer,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
prompt = inputs.get("prompt")
prompt_token_ids = inputs["prompt_token_ids"]
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
if use_msac is None:
use_msac = config.use_msac
assert isinstance(use_msac, bool)
new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
num_patches)
new_prompt_token_ids = tokenizer.encode(new_prompt)
self.use_msac = use_msac
# Wrap image processing in input_processor to avoid duplication
image_token_id = tokenizer.encode(
self.img_context_token,
add_special_tokens=False,
return_tensors="pt",
)[0]
@property
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_CONTEXT]
# Update multi_modal_data to return
if pixel_values is not None:
multi_modal_data = {
"image": {
"pixel_values": pixel_values,
"image_token_id": image_token_id,
}
}
else:
multi_modal_data = {"image": {"image_embeds": image_data}}
return token_inputs(
prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data,
)
def input_mapper(
def get_image_repl_features(
self,
feature_size: int,
num_patches: Optional[int],
) -> str:
return IMG_CONTEXT * feature_size
def get_image_repl_full(
self,
feature_size: int,
num_patches: Optional[int],
) -> str:
features = self.get_image_repl_features(feature_size, num_patches)
return IMG_START + features + IMG_END
def resolve_min_max_num(
self,
ctx: InputContext,
data: object,
*,
max_dynamic_patch: Optional[int] = None,
) -> MultiModalKwargs:
dynamic_image_size: Optional[bool] = None,
use_thumbnail: Optional[bool] = None,
) -> tuple[int, int]:
min_dynamic_patch = self.min_dynamic_patch
max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
is None else max_dynamic_patch)
dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
is None else dynamic_image_size)
use_thumbnail = (self.use_thumbnail
if use_thumbnail is None else use_thumbnail)
# NOTE: Preprocessing for the image data is done in the
# 'input_processor' function during actual inference.
if isinstance(data, dict):
return MultiModalKwargs(data)
# The section below is only used with dummy data during
# memory profiling.
hf_config = ctx.get_hf_config()
image_pixel_values_mapper = image_to_pixel_values_wrapper(
hf_config, max_dynamic_patch)
if isinstance(data, Image.Image):
pixel_values = image_pixel_values_mapper(data)
pixel_values = pixel_values.unsqueeze(0)
elif is_list_of(data, Image.Image):
hf_config.use_msac = False
pixel_values = [image_pixel_values_mapper(img) for img in data]
else:
return MultiModalKwargs({"image_embeds": data})
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
return resolve_h2ovl_min_max_num(
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
use_thumbnail=use_thumbnail,
)
image_token_id = tokenizer.encode(
self.img_context_token,
add_special_tokens=False,
return_tensors="pt",
)[0]
return MultiModalKwargs({
"pixel_values": pixel_values,
"image_token_id": image_token_id
})
def resolve_target_ratios(
self,
*,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
use_thumbnail: Optional[bool] = None,
prior_aspect_ratio: Optional[tuple[int, int]] = None,
) -> list[tuple[int, int]]:
min_num, max_num = self.resolve_min_max_num(
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
use_thumbnail=use_thumbnail,
)
if prior_aspect_ratio: # hardcoded value for second pass of use_msac
min_num = 3
return get_h2ovl_target_ratios(
min_num,
max_num,
prior_aspect_ratio=prior_aspect_ratio,
)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
use_msac: Optional[bool] = None,
) -> int:
use_msac = (self.use_msac if use_msac is None else use_msac)
use_thumbnail = self.use_thumbnail
if use_msac:
target_ratios_1 = self.resolve_target_ratios(
use_thumbnail=False, # Applied in calculate_targets
)
num_patches_1, _, _, aspect_ratio_1 = calculate_h2ovl_targets(
orig_width=image_width,
orig_height=image_height,
image_size=self.image_size,
target_ratios=target_ratios_1,
use_thumbnail=True,
)
target_ratios_2 = self.resolve_target_ratios(
use_thumbnail=False, # Applied in calculate_targets
prior_aspect_ratio=aspect_ratio_1,
)
num_patches_2, _, _, _ = calculate_h2ovl_targets(
orig_width=image_width,
orig_height=image_height,
image_size=self.image_size,
target_ratios=target_ratios_2,
use_thumbnail=True,
)
num_patches = num_patches_1 + num_patches_2 - 1
else:
target_ratios = self.resolve_target_ratios(
use_thumbnail=False, # Applied in calculate_targets
)
num_patches, _, _, _ = calculate_h2ovl_targets(
orig_width=image_width,
orig_height=image_height,
image_size=self.image_size,
target_ratios=target_ratios,
use_thumbnail=use_thumbnail,
)
return num_patches * self.num_image_token
def _images_to_pixel_values_lst(
self,
images: list[Image.Image],
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
) -> list[torch.Tensor]:
use_msac = self.use_msac if len(images) == 1 else False
min_num, max_num = self.resolve_min_max_num(
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
use_thumbnail=False, # Applied in image_to_pixel_values
)
return [
image_to_pixel_values_h2ovl(
image,
input_size=self.image_size,
min_num=min_num,
max_num=max_num,
use_thumbnail=self.use_thumbnail,
use_msac=use_msac,
) for image in images
]
input_pipeline = H2OVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
def get_hf_processor(
self,
*,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
) -> H2OVLProcessor:
return H2OVLProcessor(
self.get_hf_config(),
self.get_tokenizer(),
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
max_tokens_one_image = self.get_max_image_tokens(use_msac=None)
if mm_counts.get("image", 0) <= 1:
max_tokens_per_image = max_tokens_one_image
else:
max_tokens_per_image = self.get_max_image_tokens(use_msac=False)
return {"image": max_tokens_per_image}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional[H2OVLProcessor],
use_msac: Optional[bool] = None,
) -> int:
if processor is None:
processor = self.get_hf_processor()
return processor.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
use_msac=use_msac,
)
def get_max_image_tokens(self, use_msac: Optional[bool] = None) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
use_msac=use_msac,
)
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
):
def __init__(self,
info: H2OVLProcessingInfo,
dummy_inputs: "BaseDummyInputsBuilder[H2OVLProcessingInfo]",
*,
cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True) -> None:
super().__init__(
info,
dummy_inputs,
cache=cache,
enable_sanity_checks=enable_sanity_checks,
)
if self.cache is not None:
# The processor output depends on the number of images passed,
# making it incompatible with processing cache which is supposed
# to be invariant of how many images are passed per prompt
self.cache = None
logger.warning_once(
f"{type(self).__name__} does not support processing cache.")
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if "image_num_patches" in out_mm_kwargs:
image_num_patches = out_mm_kwargs["image_num_patches"]
assert isinstance(image_num_patches, torch.Tensor)
image_num_patches = image_num_patches.tolist()
elif "image_embeds" in out_mm_kwargs:
# TODO: Use image size information in dictionary embedding inputs
# to compute num_patches (similar to Qwen2-VL)
image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
else:
image_num_patches = []
num_images = len(image_num_patches)
def get_replacement_internvl(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
if isinstance(images, ImageEmbeddingItems):
feature_size = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
feature_size = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
use_msac=None if num_images == 1 else False,
)
num_patches = image_num_patches[item_idx]
if num_patches is not None:
assert isinstance(num_patches, int)
return PromptReplacementDetails(
full=hf_processor.get_image_repl_full(feature_size,
num_patches),
features=hf_processor.get_image_repl_features(
feature_size, num_patches),
)
return [
PromptReplacement(
modality="image",
target="<image>",
replacement=get_replacement_internvl,
)
]
@MULTIMODAL_REGISTRY.register_processor(
H2OVLMultiModalProcessor,
info=H2OVLProcessingInfo,
dummy_inputs=InternVLDummyInputsBuilder)
class H2OVLChatModel(InternVLChatModel):
def _init_vision_model(

View File

@ -6,35 +6,37 @@
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import re
from functools import cached_property, partial
from abc import ABC, abstractmethod
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
TypedDict, TypeVar, Union)
import torch
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from transformers import PretrainedConfig
from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches)
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
@ -75,22 +77,27 @@ InternVLImageInputs = Union[InternVLImagePixelInputs,
InternVLImageEmbeddingInputs]
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size):
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def build_transform(input_size: int):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
return T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size),
interpolation=T.InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
image_size):
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def find_closest_aspect_ratio(
aspect_ratio: float,
target_ratios: list[tuple[int, int]],
*,
width: int,
height: int,
image_size: int,
) -> tuple[int, int]:
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
@ -106,67 +113,82 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
return best_ratio
def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
max_num: int, image_size: int,
use_thumbnail: bool) -> Tuple[int, int, int]:
def resolve_internvl_min_max_num(
*,
min_dynamic_patch: int,
max_dynamic_patch: int,
dynamic_image_size: bool,
use_thumbnail: bool,
) -> tuple[int, int]:
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
if use_thumbnail and max_dynamic_patch != 1:
max_dynamic_patch += 1
return min_dynamic_patch, max_dynamic_patch
def get_internvl_target_ratios(
min_num: int,
max_num: int,
) -> list[tuple[int, int]]:
target_ratios = {(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1) if min_num <= i * j <= max_num}
return sorted(target_ratios, key=lambda x: x[0] * x[1])
def calculate_internvl_targets(
*,
orig_width: int,
orig_height: int,
target_ratios: list[tuple[int, int]],
image_size: int,
use_thumbnail: bool,
) -> tuple[int, int, int]:
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set((i, j) for n in range(min_num, max_num + 1)
for i in range(1, n + 1) for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
target_ratios, orig_width,
orig_height, image_size)
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio,
target_ratios,
width=orig_width,
height=orig_height,
image_size=image_size,
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# add thumbnail image if num_blocks > 1
if use_thumbnail and blocks > 1:
# add thumbnail image if num_blocks != 1
if use_thumbnail and blocks != 1:
blocks += 1
return blocks, target_width, target_height
def calculate_num_blocks_wrapper(
hf_config: PretrainedConfig,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
):
if dynamic_image_size is None:
dynamic_image_size = hf_config.dynamic_image_size
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
min_num = hf_config.min_dynamic_patch
image_size = hf_config.vision_config.image_size
use_thumbnail = hf_config.use_thumbnail
return partial(calculate_num_blocks,
min_num=min_num,
max_num=max_dynamic_patch,
image_size=image_size,
use_thumbnail=use_thumbnail)
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
image_size: int,
use_thumbnail: bool) -> List[Image.Image]:
def dynamic_preprocess_internvl(
image: Image.Image,
*,
target_ratios: list[tuple[int, int]],
image_size: int,
use_thumbnail: bool,
) -> list[Image.Image]:
orig_width, orig_height = image.size
# calculate the number of blocks without thumbnail
blocks, target_width, target_height = calculate_num_blocks(
orig_width,
orig_height,
min_num,
max_num,
image_size,
use_thumbnail=False)
blocks, target_width, target_height = calculate_internvl_targets(
orig_width=orig_width,
orig_height=orig_height,
target_ratios=target_ratios,
image_size=image_size,
use_thumbnail=False,
)
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
@ -178,301 +200,463 @@ def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
max_num: int, use_thumbnail: bool) -> torch.Tensor:
def image_to_pixel_values_internvl(
image: Image.Image,
*,
input_size: int,
min_num: int,
max_num: int,
use_thumbnail: bool,
) -> torch.Tensor:
target_ratios = get_internvl_target_ratios(min_num, max_num)
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image,
min_num=min_num,
max_num=max_num,
image_size=input_size,
use_thumbnail=use_thumbnail)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
images = dynamic_preprocess_internvl(
image,
target_ratios=target_ratios,
image_size=input_size,
use_thumbnail=use_thumbnail,
)
pixel_values = torch.stack([transform(image) for image in images])
return pixel_values
def image_to_pixel_values_wrapper(
hf_config: PretrainedConfig,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
):
image_size = hf_config.vision_config.image_size
min_num = hf_config.min_dynamic_patch
if dynamic_image_size is None:
dynamic_image_size = hf_config.dynamic_image_size
class BaseInternVLProcessor(ABC):
"""
This model doesn't define its own HF processor,
so we implement our own one here.
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
use_thumbnail = hf_config.use_thumbnail
return partial(image_to_pixel_values,
input_size=image_size,
min_num=min_num,
max_num=max_dynamic_patch,
use_thumbnail=use_thumbnail)
def get_internvl_num_patches(hf_config: PretrainedConfig):
vision_config = hf_config.vision_config
downsample_ratio = hf_config.downsample_ratio
image_size = vision_config.image_size
patch_size = vision_config.patch_size
return int(
get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
(downsample_ratio**2))
def get_max_internvl_image_tokens(
ctx: InputContext,
*,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
):
hf_config = ctx.get_hf_config()
if dynamic_image_size is None:
dynamic_image_size = hf_config.dynamic_image_size
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
use_thumbnail = hf_config.use_thumbnail
if use_thumbnail and max_dynamic_patch > 1:
max_dynamic_patch += 1
num_patches = get_internvl_num_patches(hf_config)
return num_patches * max_dynamic_patch
def get_max_internvl_image_size(
ctx: InputContext,
*,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
):
hf_config = ctx.get_hf_config()
image_size = hf_config.vision_config.image_size
if dynamic_image_size is None:
dynamic_image_size = hf_config.dynamic_image_size
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
if max_dynamic_patch is None:
max_dynamic_patch = hf_config.max_dynamic_patch
use_thumbnail = hf_config.use_thumbnail
if use_thumbnail and max_dynamic_patch > 1:
max_dynamic_patch += 1
width = image_size * max_dynamic_patch
height = image_size
return width, height
class InternVLInputPipeline:
The code to insert image tokens is based on:
https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252
"""
def __init__(
self,
img_start_token: str,
img_end_token: str,
img_context_token: str,
config: PretrainedConfig,
tokenizer: AnyTokenizer,
*,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
) -> None:
super().__init__()
self.img_start_token = img_start_token
self.img_end_token = img_end_token
self.img_context_token = img_context_token
self.config = config
self.tokenizer = tokenizer
def _create_image_prompt(self, feature_size: int, num_patches: int) -> str:
return (self.img_start_token + self.img_context_token * feature_size +
self.img_end_token)
image_size: int = config.vision_config.image_size
patch_size: int = config.vision_config.patch_size
def _expand_image_prompt(
if dynamic_image_size is None:
dynamic_image_size = config.dynamic_image_size
assert isinstance(dynamic_image_size, bool)
if max_dynamic_patch is None:
max_dynamic_patch = config.max_dynamic_patch
assert isinstance(max_dynamic_patch, int)
self.num_image_token = int(
(image_size // patch_size)**2 * (config.downsample_ratio**2))
self.image_size = image_size
self.min_dynamic_patch: int = config.min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail: bool = config.use_thumbnail
@property
@abstractmethod
def image_token_id(self) -> int:
raise NotImplementedError
@abstractmethod
def get_image_repl_features(
self,
prompt: str,
feature_sizes: List[int],
num_patches: int,
feature_size: int,
num_patches: Optional[int],
) -> str:
image_idx = sorted(
map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))
raise NotImplementedError
new_prompt = prompt
for idx, feature_size in enumerate(feature_sizes, start=1):
image_prompt = self._create_image_prompt(feature_size, num_patches)
if not image_idx:
image_prompt = f"Image-{idx}: {image_prompt}"
new_prompt = new_prompt.replace('<image>', image_prompt, 1)
return new_prompt
def input_processor(
@abstractmethod
def get_image_repl_full(
self,
feature_size: int,
num_patches: Optional[int],
) -> str:
raise NotImplementedError
def resolve_min_max_num(
self,
ctx: InputContext,
inputs: DecoderOnlyInputs,
*,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
) -> DecoderOnlyInputs:
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
use_thumbnail: Optional[bool] = None,
) -> tuple[int, int]:
min_dynamic_patch = self.min_dynamic_patch
max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
is None else max_dynamic_patch)
dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
is None else dynamic_image_size)
use_thumbnail = (self.use_thumbnail
if use_thumbnail is None else use_thumbnail)
model_config = ctx.model_config
hf_config = ctx.get_hf_config()
return resolve_internvl_min_max_num(
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
use_thumbnail=use_thumbnail,
)
image_data = multi_modal_data["image"]
num_patches = get_internvl_num_patches(hf_config)
num_blocks_calculator = calculate_num_blocks_wrapper(
hf_config, max_dynamic_patch, dynamic_image_size)
if isinstance(image_data, Image.Image):
width, height = image_data.size
num_blocks, _, _ = num_blocks_calculator(width, height)
image_feature_sizes = [num_blocks * num_patches]
elif is_list_of(image_data, Image.Image):
image_feature_sizes = []
for image in image_data:
width, height = image.size
num_blocks, _, _ = num_blocks_calculator(width, height)
image_feature_sizes.append(num_blocks * num_patches)
elif isinstance(image_data, torch.Tensor):
num_images, image_feature_size, hidden_size = image_data.shape
image_feature_sizes = [image_feature_size]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
prompt = inputs.get("prompt")
prompt_token_ids = inputs["prompt_token_ids"]
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
num_patches)
new_prompt_token_ids = tokenizer.encode(new_prompt)
img_context_token_id = tokenizer.encode(self.img_context_token,
add_special_tokens=False)
assert len(img_context_token_id) == 1, \
(f"Invalid image token '{self.img_context_token}': A valid image "
f"token encodes to a single token ID, got {img_context_token_id}.")
img_context_token_id = img_context_token_id[0]
# Get precise tracking of placeholder positions
token_idx = image_idx = 0
placeholder_ranges = []
while token_idx < len(new_prompt_token_ids):
if new_prompt_token_ids[token_idx] == img_context_token_id:
curr_image_featue_size = image_feature_sizes[image_idx]
placeholder_ranges.append(
PlaceholderRange(offset=token_idx,
length=curr_image_featue_size))
image_idx += 1
token_idx += curr_image_featue_size
else:
token_idx += 1
return token_inputs(
prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})
def input_mapper(
def resolve_target_ratios(
self,
ctx: InputContext,
data: object,
*,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
):
hf_config = ctx.get_hf_config()
use_thumbnail: Optional[bool] = None,
) -> list[tuple[int, int]]:
min_num, max_num = self.resolve_min_max_num(
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
use_thumbnail=use_thumbnail,
)
image_pixel_values_mapper = image_to_pixel_values_wrapper(
hf_config, max_dynamic_patch, dynamic_image_size)
if isinstance(data, Image.Image):
data = image_pixel_values_mapper(data)
# Add an N dimension for number of images per prompt (currently 1).
data = data.unsqueeze(0)
elif is_list_of(data, Image.Image):
# we can't stack here because images may have different num_patches
data = [image_pixel_values_mapper(img) for img in data]
else:
return MultiModalKwargs({"image_embeds": data})
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
image_token_id = tokenizer.encode(self.img_context_token,
add_special_tokens=False,
return_tensors="pt")[0]
return get_internvl_target_ratios(min_num, max_num)
return MultiModalKwargs({
"pixel_values": data,
"image_token_id": image_token_id
})
def dummy_data(
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
) -> int:
target_ratios = self.resolve_target_ratios(
use_thumbnail=False, # Applied in calculate_targets
)
num_patches, _, _ = calculate_internvl_targets(
orig_width=image_width,
orig_height=image_height,
image_size=self.image_size,
target_ratios=target_ratios,
use_thumbnail=self.use_thumbnail,
)
return num_patches * self.num_image_token
def _images_to_pixel_values_lst(
self,
images: list[Image.Image],
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
) -> list[torch.Tensor]:
min_num, max_num = self.resolve_min_max_num(
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
use_thumbnail=False, # Applied in image_to_pixel_values
)
return [
image_to_pixel_values_internvl(
image,
input_size=self.image_size,
min_num=min_num,
max_num=max_num,
use_thumbnail=self.use_thumbnail,
) for image in images
]
def __call__(
self,
text: Optional[Union[str, list[str]]] = None,
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchFeature:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
if len(images) == 0:
image_inputs = {}
else:
pixel_values_lst = self._images_to_pixel_values_lst(
images,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
image_inputs = {
"pixel_values_flat": torch.cat(pixel_values_lst),
"image_num_patches": list(map(len, pixel_values_lst)),
}
for pixel_values in pixel_values_lst:
num_patches = pixel_values.shape[0]
feature_size = num_patches * self.num_image_token
image_repl = self.get_image_repl_full(feature_size,
num_patches)
text = [t.replace('<image>', image_repl, 1) for t in text]
text_inputs = self.tokenizer(text)
return BatchFeature(
{
**text_inputs,
**image_inputs,
},
tensor_type=return_tensors,
)
class InternVLProcessor(BaseInternVLProcessor):
@property
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_CONTEXT]
def get_image_repl_features(
self,
feature_size: int,
num_patches: Optional[int],
) -> str:
return IMG_CONTEXT * feature_size
def get_image_repl_full(
self,
feature_size: int,
num_patches: Optional[int],
) -> str:
features = self.get_image_repl_features(feature_size, num_patches)
return IMG_START + features + IMG_END
class BaseInternVLProcessingInfo(BaseProcessingInfo):
@abstractmethod
def get_hf_processor(
self,
*,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
) -> BaseInternVLProcessor:
raise NotImplementedError
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(
self,
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional[BaseInternVLProcessor],
) -> int:
if processor is None:
processor = self.get_hf_processor()
return processor.get_num_image_tokens(
image_width=image_width,
image_height=image_height,
)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
base_size = processor.image_size
target_ratios = processor.resolve_target_ratios()
largest_feature_size, largest_feature_pinpoint = 0, None
for wr, hr in target_ratios:
width, height = base_size * wr, base_size * hr
feat_size = self.get_num_image_tokens(
image_width=width,
image_height=height,
processor=processor,
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
largest_feature_pinpoint = ImageSize(width=width,
height=height)
if largest_feature_size == 0 or largest_feature_pinpoint is None:
raise ValueError("Cannot have a largest feature size of 0!")
return largest_feature_pinpoint
_I = TypeVar("_I", bound=BaseInternVLProcessingInfo)
class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="<image>" * num_images,
mm_data=mm_data,
)
class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
image_token_id = self.info.get_hf_processor(**mm_kwargs).image_token_id
image_data = mm_data.get("images", [])
assert isinstance(image_data, list)
# Since there may be extra tokens in the feature placeholders,
# we need to pass the image token ID to the model to select the
# tokens to merge from the vision encoder outputs
processed_outputs["image_token_id"] = [image_token_id
] * len(image_data)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
return dict(
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if "image_num_patches" in out_mm_kwargs:
image_num_patches = out_mm_kwargs["image_num_patches"]
assert isinstance(image_num_patches, torch.Tensor)
image_num_patches = image_num_patches.tolist()
elif "image_embeds" in out_mm_kwargs:
# TODO: Use image size information in dictionary embedding inputs
# to compute num_patches (similar to Qwen2-VL)
image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
else:
image_num_patches = []
def get_replacement_internvl(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
if isinstance(images, ImageEmbeddingItems):
feature_size = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
feature_size = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
)
num_patches = image_num_patches[item_idx]
if num_patches is not None:
assert isinstance(num_patches, int)
return PromptReplacementDetails(
full=hf_processor.get_image_repl_full(feature_size,
num_patches),
features=hf_processor.get_image_repl_features(
feature_size, num_patches),
)
return [
PromptReplacement(
modality="image",
target="<image>",
replacement=get_replacement_internvl,
)
]
class InternVLProcessingInfo(BaseInternVLProcessingInfo):
def get_hf_processor(
self,
*,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
):
num_images = mm_counts["image"]
hf_config = ctx.get_hf_config()
image_feature_size = get_max_internvl_image_tokens(
ctx,
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
seq_data, ranges = dummy_seq_data_for_clip(
hf_config.vision_config,
seq_len,
num_images,
image_token_id=tokenizer.encode(self.img_context_token,
add_special_tokens=False)[0],
image_feature_size_override=image_feature_size,
)
max_image_width, max_image_height = get_max_internvl_image_size(
ctx,
) -> InternVLProcessor:
return InternVLProcessor(
self.get_hf_config(),
self.get_tokenizer(),
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
mm_data = dummy_image_for_clip(
hf_config.vision_config,
num_images,
image_width_override=max_image_width,
image_height_override=max_image_height,
)
return DummyData(seq_data, mm_data, ranges)
input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
@MULTIMODAL_REGISTRY.register_processor(
InternVLMultiModalProcessor,
info=InternVLProcessingInfo,
dummy_inputs=InternVLDummyInputsBuilder)
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
@ -621,11 +805,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[InternVLImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_token_id = kwargs.pop("image_token_id", None)
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
image_num_patches = kwargs.pop("image_num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None and image_embeds is None:
if pixel_values_flat is None and image_embeds is None:
return None
if image_embeds is not None:
@ -638,31 +822,30 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
data=flatten_bn(image_embeds),
)
self.img_context_token_id = image_token_id[0]
image_token_id = kwargs["image_token_id"]
assert isinstance(image_token_id, torch.Tensor)
self.img_context_token_id = image_token_id.flatten().unique().item()
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
if pixel_values_flat is not None:
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
f"Got type: {type(pixel_values_flat)}")
assert isinstance(image_num_patches, (torch.Tensor, list))
patches_per_image = []
for request_pixel_values in pixel_values:
for image_pixel_values in request_pixel_values:
patches_per_image.append(image_pixel_values.shape[0])
# We need to flatten (B, N, P) to (B*N*P),
# so we call flatten_bn twice.
return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(flatten_bn(pixel_values), concat=True)),
patches_per_image=patches_per_image)
flatten_bn(pixel_values_flat, concat=True)),
patches_per_image=flatten_bn(image_num_patches,
concat=True).tolist())
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self,
image_input: InternVLImageInputs,
) -> Tuple[torch.Tensor]:
) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds":
return image_input["data"]
@ -689,7 +872,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
image_embeds = image_embeds.split(image_feature_sizes)
return image_embeds
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
if self.is_mono:
self.visual_token_mask = (
input_ids == self.img_context_token_id).reshape(-1, 1)

View File

@ -125,7 +125,11 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_max_image_tokens()}
def _apply_feature_select_strategy(

View File

@ -62,7 +62,11 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"video": 1}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_video_tokens = self.get_num_video_tokens(

View File

@ -103,7 +103,11 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len),

View File

@ -23,7 +23,6 @@
# limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from functools import partial
from itertools import accumulate
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
@ -138,11 +137,15 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None, "audio": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"audio": self.get_max_audio_tokens(),
"video": self.get_max_video_tokens(seq_len)
"video": self.get_max_video_tokens(seq_len),
}
def get_default_audio_pool_step(self) -> int:
@ -369,23 +372,18 @@ class MiniCPMOMultiModalProcessor(
hf_inputs,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0))
def get_slices(num_slices: List[int]) -> List[int]:
slice_indices = [0] + list(accumulate(num_slices))
slices = [(slice_indices[i], slice_indices[i + 1])
for i in range(len(num_slices))]
return [slice(*slice_item) for slice_item in slices]
audio_slices = get_slices(
hf_inputs.get("audio_num_slices", torch.empty(0)))
return dict(
**super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs),
audio_features=MultiModalFieldConfig.flat("audio", audio_slices),
audio_feature_lens=MultiModalFieldConfig.flat(
"audio", audio_slices),
audio_features=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_feature_lens=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices),
audio_num_slices=MultiModalFieldConfig.batched("audio"),
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.flat("audio", audio_slices))
audio_embeds=MultiModalFieldConfig.flat_from_sizes(
"audio", audio_num_slices))
class MultiModalProjector(nn.Module):

View File

@ -26,7 +26,6 @@ import math
import re
from collections import Counter
from functools import cached_property, partial
from itertools import accumulate
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Set, Tuple, TypedDict, Union)
@ -365,7 +364,11 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
else:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
mm_max_tokens = {"image": self.get_max_image_tokens()}
if self.get_model_version() == (2, 6):
mm_max_tokens["video"] = self.get_max_video_tokens(seq_len)
@ -761,30 +764,25 @@ class MiniCPMVMultiModalProcessor(
hf_inputs,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_slices = hf_inputs.get("image_num_slices", torch.empty(0))
video_num_slices = hf_inputs.get("video_num_slices", torch.empty(0))
def get_slices(num_slices: List[int]) -> List[int]:
slice_indices = [0] + list(accumulate(num_slices))
slices = [(slice_indices[i], slice_indices[i + 1])
for i in range(len(num_slices))]
return [slice(*slice_item) for slice_item in slices]
image_slices = get_slices(
hf_inputs.get("image_num_slices", torch.empty(0)))
video_slices = get_slices(
hf_inputs.get("video_num_slices", torch.empty(0)))
return dict(
pixel_values=MultiModalFieldConfig.flat("image", image_slices),
image_sizes=MultiModalFieldConfig.batched("image"),
tgt_sizes=MultiModalFieldConfig.flat("image", image_slices),
image_num_slices=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.flat("image", image_slices),
video_pixel_values=MultiModalFieldConfig.flat(
"video", video_slices),
video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.flat("video", video_slices),
video_embeds=MultiModalFieldConfig.flat("video", video_slices),
video_num_slices=MultiModalFieldConfig.batched("video"))
return dict(pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
image_sizes=MultiModalFieldConfig.batched("image"),
tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
image_num_slices=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_slices),
video_pixel_values=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_image_sizes=MultiModalFieldConfig.batched("video"),
video_tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_slices),
video_num_slices=MultiModalFieldConfig.batched("video"))
def apply(
self,

View File

@ -6,44 +6,190 @@
# Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
from typing import Optional
from typing import Mapping, Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems)
from vllm.multimodal.processing import (PromptReplacement,
PromptReplacementDetails)
from vllm.multimodal.profiling import ProcessorInputs
from .intern_vit import InternVisionModel
from .internvl import (InternVLChatModel, InternVLInputPipeline,
get_max_internvl_image_tokens)
from .internvl import (BaseInternVLProcessingInfo, BaseInternVLProcessor,
InternVLChatModel, InternVLDummyInputsBuilder,
InternVLMultiModalProcessor)
IMG_START = '<|vision_start|>'
IMG_END = '<|vision_end|>'
IMG_CONTEXT = '<|vision_pad|>'
IMG_PAD = "<|vision_pad|>"
class NVLMInputPipeline(InternVLInputPipeline):
class NVLMProcessor(BaseInternVLProcessor):
@property
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[IMG_PAD]
def get_image_repl_features(
self,
feature_size: int,
num_patches: Optional[int],
) -> str:
if num_patches is None:
raise NotImplementedError("Embedding inputs are not supported")
tile_pos_identifiers = [f"<tile_{i}>" for i in range(1, num_patches)]
if self.use_thumbnail and num_patches != 1:
tile_pos_identifiers += ["<tile_global_thumbnail>"]
def _create_image_prompt(self, feature_size: int, num_patches: int) -> str:
tile_pos_identifiers = ([f"<tile_{i}>"
for i in range(1, num_patches)] +
["<tile_global_thumbnail>"])
context_size = feature_size // num_patches
features = "".join(identifier + IMG_PAD * context_size
for identifier in tile_pos_identifiers)
return '<Image>' + ''.join(
tile_pos_identifier + self.img_context_token * context_size
for tile_pos_identifier in tile_pos_identifiers) + '</Image>'
# We include the start and end as well because "<Image><tile" is
# tokenized as ["<Image", "><", "tile"], resulting in assertion error
# when trying to find "<tile" as a subsequence of "<Image><tile"
return "<Image>" + features + "</Image>"
def get_image_repl_full(
self,
feature_size: int,
num_patches: Optional[int],
) -> str:
return self.get_image_repl_features(feature_size, num_patches)
input_pipeline = NVLMInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
class NVLMProcessingInfo(BaseInternVLProcessingInfo):
def get_hf_processor(
self,
*,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
) -> NVLMProcessor:
return NVLMProcessor(
self.get_hf_config(),
self.get_tokenizer(),
max_dynamic_patch=max_dynamic_patch,
dynamic_image_size=dynamic_image_size,
)
def get_max_image_tokens(self) -> int:
hf_processor = self.get_hf_processor()
tokenizer = hf_processor.tokenizer
max_num_patches = hf_processor.max_dynamic_patch
# we need +1 here because max_dynamic_patch in config doesn't
# include the thumbnail patch
tile_pos_identifiers = [
f"<tile_{i+1}>" for i in range(max_num_patches)
]
if hf_processor.use_thumbnail and max_num_patches != 1:
tile_pos_identifiers += ["<tile_global_thumbnail>"]
# "<Image><tile" is tokenized as ["<Image", "><", "tile"]
# so we include <tile_1> in the start_str
start_str = "<Image>" + tile_pos_identifiers.pop(0)
end_str = "</Image>"
start_token_len = len(tokenizer.encode(start_str))
end_token_len = len(tokenizer.encode(end_str))
tile_token_len = sum(
len(tokenizer.encode(identifier))
for identifier in tile_pos_identifiers)
non_image_tokens_num = start_token_len + end_token_len + tile_token_len
return super().get_max_image_tokens() + non_image_tokens_num
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}
return ProcessorInputs(
# The newline is necessary to separate ">" of the current item
# and "<" of the next item
prompt_text="<image>\n" * num_images,
mm_data=mm_data,
)
class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
if "image_num_patches" in out_mm_kwargs:
image_num_patches = out_mm_kwargs["image_num_patches"]
assert isinstance(image_num_patches, torch.Tensor)
image_num_patches = image_num_patches.tolist()
elif "image_embeds" in out_mm_kwargs:
# TODO: Use image size information in dictionary embedding inputs
# to compute num_patches (similar to Qwen2-VL)
image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
else:
image_num_patches = []
def get_replacement_nvlm(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems))
if isinstance(images, ImageEmbeddingItems):
feature_size = images.get_feature_size(item_idx)
else:
image_size = images.get_image_size(item_idx)
feature_size = self.info.get_num_image_tokens(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
)
num_patches = image_num_patches[item_idx]
if num_patches is not None:
assert isinstance(num_patches, int)
return PromptReplacementDetails(
full=hf_processor.get_image_repl_full(feature_size,
num_patches) + "\n",
features=hf_processor.get_image_repl_features(
feature_size, num_patches) + "\n",
)
# See note in dummy data regarding why we have the extra newline
return [
PromptReplacement(
modality="image",
target="<image>\n",
replacement=get_replacement_nvlm,
)
]
@MULTIMODAL_REGISTRY.register_processor(NVLMMultiModalProcessor,
info=NVLMProcessingInfo,
dummy_inputs=NVLMDummyInputsBuilder)
class NVLM_D_Model(InternVLChatModel):
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:

View File

@ -322,7 +322,11 @@ class Phi3VProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
target_width, target_height = self.get_image_size_with_most_features()
max_image_tokens = self.get_num_image_tokens(

View File

@ -779,7 +779,11 @@ class QWenVLProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
@ -799,13 +803,13 @@ class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]):
vision_config = hf_config.visual
max_image_size = vision_config["image_size"]
target_width = target_height = vision_config["image_size"]
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
}

View File

@ -110,7 +110,11 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
hf_config = self.get_hf_config()
max_source_positions = hf_config.audio_config.max_source_positions
max_output_lengths = (max_source_positions - 2) // 2 + 1

View File

@ -758,7 +758,11 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"video": self.get_max_video_tokens(seq_len),
@ -989,26 +993,21 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_slice_idxs = [0] + image_grid_thw.prod(-1).cumsum_(0).tolist()
image_slices = [
slice(image_slice_idxs[i], image_slice_idxs[i + 1])
for i in range(len(image_grid_thw))
]
image_grid_sizes = image_grid_thw.prod(-1)
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_slice_idxs = [0] + video_grid_thw.prod(-1).cumsum_(0).tolist()
video_slices = [
slice(video_slice_idxs[i], video_slice_idxs[i + 1])
for i in range(len(video_grid_thw))
]
video_grid_sizes = video_grid_thw.prod(-1)
return dict(
pixel_values=MultiModalFieldConfig.flat("image", image_slices),
image_embeds=MultiModalFieldConfig.flat("image", image_slices),
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_embeds=MultiModalFieldConfig.flat_from_sizes(
"image", image_grid_sizes),
image_grid_thw=MultiModalFieldConfig.batched("image"),
pixel_values_videos=MultiModalFieldConfig.flat(
"video", video_slices),
video_embeds=MultiModalFieldConfig.flat("video", video_slices),
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_embeds=MultiModalFieldConfig.flat_from_sizes(
"video", video_grid_sizes),
video_grid_thw=MultiModalFieldConfig.batched("video"),
)

View File

@ -92,7 +92,11 @@ class UltravoxProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
feature_extractor = self.get_feature_extractor()
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
_AUDIO_TOKENS_PER_SECOND)

View File

@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from itertools import accumulate
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
Union, cast, final)
@ -258,6 +259,16 @@ class MultiModalFieldConfig:
slices=slices,
)
@staticmethod
def flat_from_sizes(modality: str, size_per_item: torch.Tensor):
slice_idxs = [0, *accumulate(size_per_item)]
slices = [
slice(slice_idxs[i], slice_idxs[i + 1])
for i in range(len(size_per_item))
]
return MultiModalFieldConfig.flat(modality, slices)
def __init__(
self,
field_cls: type[BaseMultiModalField],

View File

@ -680,7 +680,11 @@ class BaseProcessingInfo:
raise NotImplementedError
@abstractmethod
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
"""
Get the maximum possible number of tokens per data item
for each modality.

View File

@ -151,7 +151,8 @@ class MultiModalProfiler(Generic[_I]):
mm_counts = self.get_mm_limits()
info = self.processing_info
mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(seq_len)
mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(
seq_len, mm_counts)
if mm_counts.keys() != mm_max_tokens_per_item.keys():
raise AssertionError(

View File

@ -264,7 +264,9 @@ class MultiModalRegistry:
)
processor = self.create_processor(model_config, tokenizer)
seq_len = model_config.max_model_len
return processor.info.get_mm_max_tokens_per_item(seq_len)
mm_limits = self.get_mm_limits_per_prompt(model_config)
return processor.info.get_mm_max_tokens_per_item(
seq_len, mm_limits)
return {
key: plugin.get_max_multimodal_tokens(model_config)