[VLM] Merged multi-modal processor for InternVL-based models (#12553)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@ -250,7 +250,11 @@ def get_max_image_tokens(self) -> int:
|
||||
And thus, we can override the method as:
|
||||
|
||||
```python
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {"image": self.get_max_image_tokens()}
|
||||
```
|
||||
|
||||
|
@ -726,7 +726,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
* `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc.
|
||||
*
|
||||
* ✅︎
|
||||
*
|
||||
* \*
|
||||
- * `Idefics3ForConditionalGeneration`
|
||||
* Idefics3
|
||||
* T + I
|
||||
@ -799,7 +799,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
* ✅︎
|
||||
- * `NVLM_D_Model`
|
||||
* NVLM-D 1.0
|
||||
* T + I<sup>E+</sup>
|
||||
* T + I<sup>+</sup>
|
||||
* `nvidia/NVLM-D-72B`, etc.
|
||||
*
|
||||
* ✅︎
|
||||
@ -859,7 +859,11 @@ See [this page](#generative-models) for more information on how to use generativ
|
||||
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.
|
||||
|
||||
:::{note}
|
||||
To use `DeepSeek-VL2` series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
|
||||
To use DeepSeek-VL2 series models, you have to pass `--hf_overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'` when running vLLM.
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
H2O-VL series models will be available in V1 once we support backends other than FlashAttention.
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
|
@ -1,131 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from transformers import AutoConfig
|
||||
|
||||
# Import the functions to test
|
||||
from vllm.model_executor.models.h2ovl import (calculate_num_blocks,
|
||||
image_to_pixel_values_wrapper)
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
|
||||
models = [
|
||||
"h2oai/h2ovl-mississippi-800m", # Replace with your actual model names
|
||||
"h2oai/h2ovl-mississippi-2b",
|
||||
]
|
||||
|
||||
|
||||
def run_preprocessing_test(
|
||||
image: Image,
|
||||
config,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, int]:
|
||||
"""Test the image preprocessing and calculate expected blocks."""
|
||||
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = config.max_dynamic_patch
|
||||
|
||||
width, height = image.size
|
||||
use_MSAC = config.use_msac
|
||||
|
||||
# Create the mapper function with the provided configuration
|
||||
mapper = image_to_pixel_values_wrapper(config, max_dynamic_patch, use_MSAC)
|
||||
pixel_values = mapper(image)
|
||||
|
||||
# Calculate the expected number of blocks
|
||||
if use_MSAC:
|
||||
# First pass
|
||||
blocks1, _, _, aspect_ratio = calculate_num_blocks(
|
||||
width,
|
||||
height,
|
||||
config.min_dynamic_patch,
|
||||
max_dynamic_patch,
|
||||
config.vision_config.image_size,
|
||||
use_thumbnail=False, # Thumbnail is handled separately
|
||||
prior_aspect_ratio=None,
|
||||
)
|
||||
|
||||
# Second pass
|
||||
blocks2, _, _, _ = calculate_num_blocks(
|
||||
width,
|
||||
height,
|
||||
config.min_dynamic_patch,
|
||||
max_dynamic_patch,
|
||||
config.vision_config.image_size,
|
||||
use_thumbnail=False,
|
||||
prior_aspect_ratio=aspect_ratio,
|
||||
)
|
||||
|
||||
# Add thumbnail if use_thumbnail is True and total_blocks > 1
|
||||
if config.use_thumbnail:
|
||||
blocks1 += 1 if blocks1 > 1 else 0
|
||||
blocks2 += 1 if blocks2 > 1 else 0
|
||||
|
||||
# Total blocks is the sum of blocks from both passes minus overlapping
|
||||
total_blocks = blocks1 + blocks2 - 1
|
||||
|
||||
expected_blocks = total_blocks
|
||||
|
||||
else:
|
||||
blocks, _, _, _ = calculate_num_blocks(
|
||||
width,
|
||||
height,
|
||||
config.min_dynamic_patch,
|
||||
max_dynamic_patch,
|
||||
config.vision_config.image_size,
|
||||
use_thumbnail=False,
|
||||
prior_aspect_ratio=None,
|
||||
)
|
||||
expected_blocks = blocks
|
||||
|
||||
if config.use_thumbnail and expected_blocks > 1:
|
||||
expected_blocks += 1
|
||||
|
||||
return pixel_values, expected_blocks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", models)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
# Single-scale
|
||||
[1.0],
|
||||
# Single-scale, batched
|
||||
[1.0, 1.0, 1.0],
|
||||
# Multi-scale
|
||||
[0.25, 0.5, 1.0],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("max_dynamic_patch", [None, 2, 4, 8])
|
||||
def test_image_preprocessing(image_assets, model_name, size_factors,
|
||||
max_dynamic_patch):
|
||||
"""Test image preprocessing pipeline with different configurations."""
|
||||
# Load the configuration from the model
|
||||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
for asset in image_assets:
|
||||
image = asset.pil_image
|
||||
for factor in size_factors:
|
||||
scaled_image = rescale_image_size(image, factor)
|
||||
|
||||
# Test preprocessing and get expected number of blocks
|
||||
pixel_values, expected_blocks = run_preprocessing_test(
|
||||
scaled_image, config, max_dynamic_patch)
|
||||
|
||||
# Verify output shapes and properties
|
||||
actual_blocks = pixel_values.shape[0]
|
||||
assert actual_blocks == expected_blocks, (
|
||||
f"Expected {expected_blocks} blocks, got {actual_blocks}")
|
||||
|
||||
# Check image dimensions
|
||||
expected_size = (
|
||||
3, # Number of channels (C, H, W)
|
||||
config.vision_config.image_size,
|
||||
config.vision_config.image_size,
|
||||
)
|
||||
for img in pixel_values:
|
||||
assert img.shape == expected_size, (
|
||||
f"Expected image size {expected_size}, got {img.shape}")
|
@ -250,6 +250,7 @@ VLM_TEST_SETTINGS = {
|
||||
max_model_len=8192,
|
||||
dtype="bfloat16",
|
||||
use_tokenizer_eos=True,
|
||||
num_logprobs=10,
|
||||
patch_hf_runner=model_utils.h2ovl_patch_hf_runner,
|
||||
),
|
||||
"idefics3": VLMTestInfo(
|
||||
@ -282,7 +283,6 @@ VLM_TEST_SETTINGS = {
|
||||
dtype="bfloat16",
|
||||
use_tokenizer_eos=True,
|
||||
patch_hf_runner=model_utils.internvl_patch_hf_runner,
|
||||
marks=[large_gpu_mark(min_gb=32)],
|
||||
),
|
||||
"llava_next": VLMTestInfo(
|
||||
models=["llava-hf/llava-v1.6-mistral-7b-hf"],
|
||||
|
@ -334,12 +334,12 @@ def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
def __init__(self, hf_runner: HfRunner):
|
||||
self.num_image_token = hf_runner.model.num_image_token
|
||||
self.tokenizer = hf_runner.tokenizer
|
||||
self.dtype = hf_runner.model.dtype
|
||||
|
||||
self.config = AutoConfig.from_pretrained(hf_runner.model_name,
|
||||
trust_remote_code=True)
|
||||
self.vision_config = self.config.vision_config
|
||||
self.use_thumbnail = self.config.use_thumbnail
|
||||
self.use_msac = self.config.use_msac
|
||||
self.min_num = self.config.min_dynamic_patch
|
||||
self.max_num = self.config.max_dynamic_patch
|
||||
self.image_size = self.vision_config.image_size
|
||||
@ -348,18 +348,19 @@ def h2ovl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
**kwargs):
|
||||
# yapf: disable
|
||||
from vllm.model_executor.models.h2ovl import (
|
||||
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
|
||||
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values_h2ovl)
|
||||
|
||||
# yapf: enable
|
||||
images = [images] if isinstance(images, Image) else images
|
||||
pixel_values = [
|
||||
image_to_pixel_values(image,
|
||||
self.image_size,
|
||||
self.min_num,
|
||||
self.max_num,
|
||||
self.use_thumbnail,
|
||||
use_MSAC=self.config.use_msac).to(
|
||||
self.dtype) for image in images
|
||||
image_to_pixel_values_h2ovl(
|
||||
image,
|
||||
input_size=self.image_size,
|
||||
min_num=self.min_num,
|
||||
max_num=self.max_num,
|
||||
use_thumbnail=self.use_thumbnail,
|
||||
use_msac=self.use_msac,
|
||||
) for image in images
|
||||
]
|
||||
num_patches_list = [
|
||||
pixel_value.shape[0] for pixel_value in pixel_values
|
||||
@ -394,7 +395,6 @@ def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
def __init__(self, hf_runner: HfRunner):
|
||||
self.num_image_token = hf_runner.model.num_image_token
|
||||
self.tokenizer = hf_runner.tokenizer
|
||||
self.dtype = hf_runner.model.dtype
|
||||
|
||||
self.config = AutoConfig.from_pretrained(hf_runner.model_name,
|
||||
trust_remote_code=True)
|
||||
@ -407,13 +407,17 @@ def internvl_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
|
||||
def __call__(self, text: str, images: Union[Image, List[Image]],
|
||||
**kwargs):
|
||||
from vllm.model_executor.models.internvl import (
|
||||
IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values)
|
||||
IMG_CONTEXT, IMG_END, IMG_START,
|
||||
image_to_pixel_values_internvl)
|
||||
images = [images] if isinstance(images, Image) else images
|
||||
pixel_values = [
|
||||
image_to_pixel_values(image, self.image_size, self.min_num,
|
||||
self.max_num,
|
||||
self.use_thumbnail).to(self.dtype)
|
||||
for image in images
|
||||
image_to_pixel_values_internvl(
|
||||
image,
|
||||
input_size=self.image_size,
|
||||
min_num=self.min_num,
|
||||
max_num=self.max_num,
|
||||
use_thumbnail=self.use_thumbnail,
|
||||
) for image in images
|
||||
]
|
||||
num_patches_list = [
|
||||
pixel_value.shape[0] for pixel_value in pixel_values
|
||||
@ -448,7 +452,8 @@ def _internvl_generate(
|
||||
) -> torch.LongTensor:
|
||||
"""Generate method for InternVL2 model without fixed use_cache."""
|
||||
assert self.img_context_token_id is not None
|
||||
vit_embeds = self.extract_feature(pixel_values)
|
||||
target_dtype = next(self.parameters()).dtype
|
||||
vit_embeds = self.extract_feature(pixel_values.to(target_dtype))
|
||||
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
B, N, C = input_embeds.shape
|
||||
input_embeds = input_embeds.reshape(B * N, C)
|
||||
|
@ -141,13 +141,14 @@ def _test_processing_correctness(
|
||||
|
||||
|
||||
# yapf: disable
|
||||
# True if the model supports multiple data items of the modality per request
|
||||
@pytest.mark.parametrize("model_id", [
|
||||
"rhymes-ai/Aria",
|
||||
"Salesforce/blip2-opt-2.7b",
|
||||
"facebook/chameleon-7b",
|
||||
"deepseek-ai/deepseek-vl2-tiny",
|
||||
"adept/fuyu-8b",
|
||||
"h2oai/h2ovl-mississippi-800m",
|
||||
"OpenGVLab/InternVL2-1B",
|
||||
"llava-hf/llava-1.5-7b-hf",
|
||||
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||
"llava-hf/LLaVA-NeXT-Video-7B-hf",
|
||||
@ -156,6 +157,7 @@ def _test_processing_correctness(
|
||||
"mistral-community/pixtral-12b",
|
||||
"openbmb/MiniCPM-o-2_6",
|
||||
"openbmb/MiniCPM-V-2_6",
|
||||
"nvidia/NVLM-D-72B",
|
||||
"Qwen/Qwen-VL-Chat",
|
||||
"Qwen/Qwen2-VL-2B-Instruct",
|
||||
"Qwen/Qwen2-Audio-7B-Instruct",
|
||||
|
142
tests/models/multimodal/processing/test_h2ovl.py
Normal file
142
tests/models/multimodal/processing/test_h2ovl.py
Normal file
@ -0,0 +1,142 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Tests for H2OVL's multimodal preprocessing kwargs."""
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
|
||||
from ....conftest import _ImageAssets
|
||||
from ...utils import build_model_context
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", [
|
||||
"h2oai/h2ovl-mississippi-800m",
|
||||
"h2oai/h2ovl-mississippi-2b",
|
||||
])
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
# Single-scale
|
||||
[1.0],
|
||||
# Single-scale, batched
|
||||
[1.0, 1.0, 1.0],
|
||||
# Multi-scale
|
||||
[0.25, 0.5, 1.0],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("max_dynamic_patch", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("dynamic_image_size", [True, False])
|
||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||
def test_processor_override(
|
||||
model_id: str,
|
||||
image_assets: _ImageAssets,
|
||||
size_factors: list[int],
|
||||
max_dynamic_patch: int,
|
||||
dynamic_image_size: Optional[bool],
|
||||
num_imgs: int,
|
||||
):
|
||||
from vllm.model_executor.models.h2ovl import (calculate_h2ovl_targets,
|
||||
get_h2ovl_target_ratios)
|
||||
|
||||
ctx = build_model_context(
|
||||
model_name=model_id,
|
||||
tokenizer_name=model_id,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=None,
|
||||
limit_mm_per_prompt={"image": num_imgs},
|
||||
)
|
||||
tokenizer = cached_get_tokenizer(
|
||||
ctx.model_config.tokenizer,
|
||||
trust_remote_code=ctx.model_config.trust_remote_code,
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||
ctx.model_config,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
config = processor.info.get_hf_config()
|
||||
use_msac = config.use_msac
|
||||
|
||||
mm_processor_kwargs = {
|
||||
"max_dynamic_patch": max_dynamic_patch,
|
||||
}
|
||||
if dynamic_image_size is not None:
|
||||
mm_processor_kwargs["dynamic_image_size"] = dynamic_image_size
|
||||
|
||||
min_num = config.min_dynamic_patch
|
||||
max_num = max_dynamic_patch if dynamic_image_size else 1
|
||||
|
||||
# Build the image str / prompt based on the number of images we pass
|
||||
prompt = "<image>" * num_imgs
|
||||
|
||||
for asset in image_assets:
|
||||
for factor in size_factors:
|
||||
image = rescale_image_size(asset.pil_image, factor)
|
||||
mm_data = {"image": [image] * num_imgs}
|
||||
|
||||
width, height = image.size
|
||||
|
||||
# Calculate the expected number of blocks
|
||||
if num_imgs == 1 and use_msac:
|
||||
# First pass
|
||||
blocks1, _, _, aspect_ratio = calculate_h2ovl_targets(
|
||||
orig_width=width,
|
||||
orig_height=height,
|
||||
target_ratios=get_h2ovl_target_ratios(
|
||||
min_num,
|
||||
max_num,
|
||||
prior_aspect_ratio=None,
|
||||
),
|
||||
image_size=config.vision_config.image_size,
|
||||
use_thumbnail=False, # Thumbnail is handled separately
|
||||
)
|
||||
|
||||
# Second pass
|
||||
blocks2, _, _, _ = calculate_h2ovl_targets(
|
||||
orig_width=width,
|
||||
orig_height=height,
|
||||
target_ratios=get_h2ovl_target_ratios(
|
||||
min_num,
|
||||
max_num,
|
||||
prior_aspect_ratio=aspect_ratio,
|
||||
),
|
||||
image_size=config.vision_config.image_size,
|
||||
use_thumbnail=False,
|
||||
)
|
||||
|
||||
# Add thumbnail if use_thumbnail is True and total_blocks > 1
|
||||
if config.use_thumbnail:
|
||||
blocks1 += 1 if blocks1 > 1 else 0
|
||||
blocks2 += 1 if blocks2 > 1 else 0
|
||||
|
||||
# Total blocks is the sum of blocks from both passes minus
|
||||
# overlapping
|
||||
total_blocks = blocks1 + blocks2 - 1
|
||||
|
||||
expected_num_patches = total_blocks
|
||||
else:
|
||||
blocks, _, _, _ = calculate_h2ovl_targets(
|
||||
orig_width=width,
|
||||
orig_height=height,
|
||||
target_ratios=get_h2ovl_target_ratios(
|
||||
min_num,
|
||||
max_num,
|
||||
prior_aspect_ratio=None,
|
||||
),
|
||||
image_size=config.vision_config.image_size,
|
||||
use_thumbnail=False,
|
||||
)
|
||||
expected_num_patches = blocks
|
||||
|
||||
if config.use_thumbnail and expected_num_patches != 1:
|
||||
expected_num_patches += 1
|
||||
|
||||
processed_inputs = processor.apply(prompt, mm_data,
|
||||
mm_processor_kwargs)
|
||||
pixel_shape = (
|
||||
processed_inputs["mm_kwargs"]["pixel_values_flat"].shape)
|
||||
|
||||
assert pixel_shape[0] == expected_num_patches * num_imgs
|
@ -1,207 +1,64 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Tests for InternVL's multimodal preprocessing kwargs."""
|
||||
from typing import Callable, Optional
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.inputs import InputContext, token_inputs
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
|
||||
from ....conftest import _ImageAssets
|
||||
from ...utils import build_model_context
|
||||
|
||||
models = ["OpenGVLab/InternVL2-2B"]
|
||||
|
||||
|
||||
# Wrap lazy imports to avoid initializing CUDA during test collection
|
||||
@pytest.fixture()
|
||||
def input_processor_for_internvl():
|
||||
from vllm.model_executor.models.internvl import InternVLInputPipeline
|
||||
|
||||
pipeline = InternVLInputPipeline('<img>', '</img>', '<IMG_CONTEXT>')
|
||||
return pipeline.input_processor
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def dummy_data_for_internvl():
|
||||
from vllm.model_executor.models.internvl import InternVLInputPipeline
|
||||
|
||||
pipeline = InternVLInputPipeline('<img>', '</img>', '<IMG_CONTEXT>')
|
||||
return pipeline.dummy_data
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def get_max_internvl_image_tokens():
|
||||
from vllm.model_executor.models.internvl import (
|
||||
get_max_internvl_image_tokens)
|
||||
return get_max_internvl_image_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("model_id", ["OpenGVLab/InternVL2-2B"])
|
||||
@pytest.mark.parametrize("max_dynamic_patch", [1, 4])
|
||||
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
|
||||
def test_input_mapper_override(
|
||||
model: str,
|
||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||
def test_processor_override(
|
||||
model_id: str,
|
||||
image_assets: _ImageAssets,
|
||||
max_dynamic_patch: int,
|
||||
dynamic_image_size: Optional[bool],
|
||||
num_imgs: int,
|
||||
):
|
||||
ctx = build_model_context(
|
||||
model_name=model_id,
|
||||
tokenizer_name=model_id,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=None,
|
||||
limit_mm_per_prompt={"image": num_imgs},
|
||||
)
|
||||
tokenizer = cached_get_tokenizer(
|
||||
ctx.model_config.tokenizer,
|
||||
trust_remote_code=ctx.model_config.trust_remote_code,
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||
ctx.model_config,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
mm_processor_kwargs = {
|
||||
"max_dynamic_patch": max_dynamic_patch,
|
||||
}
|
||||
if dynamic_image_size is not None:
|
||||
mm_processor_kwargs["dynamic_image_size"] = dynamic_image_size
|
||||
|
||||
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
|
||||
if dynamic_image_size is False:
|
||||
expected_num_patches = 1
|
||||
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
mm_registry = MultiModalRegistry()
|
||||
mm_registry.init_mm_limits_per_prompt(ctx.model_config)
|
||||
|
||||
image = image_assets[0].pil_image.resize((448 * 2, 448 * 2))
|
||||
vllm_result = mm_registry.map_input(
|
||||
ctx.model_config,
|
||||
{"image": image},
|
||||
)
|
||||
assert vllm_result["pixel_values"].size(1) == expected_num_patches
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("max_dynamic_patch", [1, 4, None])
|
||||
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
|
||||
def test_max_tokens_override(
|
||||
get_max_internvl_image_tokens: Callable,
|
||||
model: str,
|
||||
max_dynamic_patch: Optional[int],
|
||||
dynamic_image_size: Optional[bool],
|
||||
):
|
||||
"""Ensure get_max_internvl_image_tokens handles mm_processor_kwargs."""
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=None,
|
||||
)
|
||||
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = ctx.get_hf_config().max_dynamic_patch
|
||||
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
|
||||
if dynamic_image_size is False:
|
||||
expected_num_patches = 1
|
||||
expected_max_tokens = 256 * expected_num_patches
|
||||
|
||||
actual_max_tokens = get_max_internvl_image_tokens(
|
||||
ctx=InputContext(ctx.model_config),
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
assert expected_max_tokens == actual_max_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||
@pytest.mark.parametrize("max_dynamic_patch", [1, 4, None])
|
||||
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
|
||||
def test_dummy_data_override(
|
||||
dummy_data_for_internvl: Callable,
|
||||
model: str,
|
||||
num_imgs: int,
|
||||
max_dynamic_patch: Optional[int],
|
||||
dynamic_image_size: Optional[bool],
|
||||
):
|
||||
"""Ensure dummy_data_for_internvl handles kwargs properly."""
|
||||
# Same as the previous test - don't initialize mm_processor_kwargs
|
||||
# in this test and assume that the kwargs will be correctly expanded by
|
||||
# the partial when calling the dummy data func.
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=None,
|
||||
)
|
||||
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = ctx.get_hf_config().max_dynamic_patch
|
||||
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
|
||||
if dynamic_image_size is False:
|
||||
expected_num_patches = 1
|
||||
expected_max_tokens = 256 * expected_num_patches
|
||||
|
||||
dummy_data = dummy_data_for_internvl(
|
||||
ctx=ctx,
|
||||
seq_len=8192, # Should be bigger than num_imgs * toks_per_img
|
||||
mm_counts={"image": num_imgs},
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
sequence_data = dummy_data.seq_data
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
|
||||
image_token_id = tokenizer.encode('<IMG_CONTEXT>',
|
||||
add_special_tokens=False)[0]
|
||||
|
||||
# Ensure we have the right number of placeholders per size
|
||||
img_tok_count = sequence_data.get_token_ids().count(image_token_id)
|
||||
assert img_tok_count == expected_max_tokens * num_imgs
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("max_dynamic_patch", [1, 4])
|
||||
@pytest.mark.parametrize("dynamic_image_size", [True, False, None])
|
||||
@pytest.mark.parametrize("num_imgs", [1, 2])
|
||||
def test_input_processor_override(
|
||||
input_processor_for_internvl: Callable,
|
||||
image_assets: _ImageAssets,
|
||||
model: str,
|
||||
num_imgs: int,
|
||||
max_dynamic_patch: int,
|
||||
dynamic_image_size: Optional[bool],
|
||||
):
|
||||
"""Ensure input_processor_for_internvl handles kwargs properly."""
|
||||
# Same as the previous test - don't initialize mm_processor_kwargs
|
||||
# in this test and assume that the kwargs will be correctly expanded by
|
||||
# the partial when calling the custom input processor.
|
||||
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
|
||||
if dynamic_image_size is False:
|
||||
expected_num_patches = 1
|
||||
|
||||
ctx = build_model_context(
|
||||
model_name=model,
|
||||
tokenizer_name=model,
|
||||
trust_remote_code=True,
|
||||
mm_processor_kwargs=None,
|
||||
)
|
||||
expected_toks_per_img = 256 * expected_num_patches
|
||||
|
||||
# Build the image str / prompt based on the number of images we pass
|
||||
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
|
||||
placeholders = "<image>" if num_imgs == 1 else "\n".join(
|
||||
f"Image-{i}: <image>\n" for i in range(1, num_imgs + 1))
|
||||
prompt = placeholders
|
||||
images = [image_assets[0].pil_image.resize((448 * 2, 448 * 2))] * num_imgs
|
||||
prompt = "<image>" * num_imgs
|
||||
image = image_assets[0].pil_image.resize((448 * 2, 448 * 2))
|
||||
mm_data = {"image": [image] * num_imgs}
|
||||
|
||||
inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt),
|
||||
prompt=prompt,
|
||||
multi_modal_data={"image": images})
|
||||
expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1
|
||||
if dynamic_image_size is False:
|
||||
expected_num_patches = 1
|
||||
|
||||
processed_inputs = input_processor_for_internvl(
|
||||
ctx,
|
||||
inputs,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
|
||||
|
||||
# Ensure we have the right number of placeholders per num_crops size
|
||||
image_token_id = tokenizer.encode('<IMG_CONTEXT>',
|
||||
add_special_tokens=False)[0]
|
||||
image_token_id = tokenizer.convert_tokens_to_ids("<IMG_CONTEXT>")
|
||||
img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id)
|
||||
assert img_tok_count == expected_toks_per_img * num_imgs
|
||||
pixel_shape = processed_inputs["mm_kwargs"]["pixel_values_flat"].shape
|
||||
|
||||
assert img_tok_count == 256 * expected_num_patches * num_imgs
|
||||
assert pixel_shape[0] == expected_num_patches * num_imgs
|
||||
|
@ -43,7 +43,10 @@ def test_processor_max_tokens(model_id):
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||
ctx.model_config,
|
||||
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
|
||||
tokenizer=cached_get_tokenizer(
|
||||
ctx.model_config.tokenizer,
|
||||
trust_remote_code=ctx.model_config.trust_remote_code,
|
||||
),
|
||||
)
|
||||
info = processor.info
|
||||
|
||||
@ -143,7 +146,10 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs):
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||
ctx.model_config,
|
||||
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
|
||||
tokenizer=cached_get_tokenizer(
|
||||
ctx.model_config.tokenizer,
|
||||
trust_remote_code=ctx.model_config.trust_remote_code,
|
||||
),
|
||||
)
|
||||
|
||||
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
|
||||
@ -173,7 +179,10 @@ def test_processor_prompt_replacements_all(model_id, num_imgs):
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||
ctx.model_config,
|
||||
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
|
||||
tokenizer=cached_get_tokenizer(
|
||||
ctx.model_config.tokenizer,
|
||||
trust_remote_code=ctx.model_config.trust_remote_code,
|
||||
),
|
||||
)
|
||||
|
||||
seen_aspect_ratios = set[float]()
|
||||
|
@ -44,7 +44,10 @@ def test_processor_max_tokens(model_id):
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||
ctx.model_config,
|
||||
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
|
||||
tokenizer=cached_get_tokenizer(
|
||||
ctx.model_config.tokenizer,
|
||||
trust_remote_code=ctx.model_config.trust_remote_code,
|
||||
),
|
||||
)
|
||||
info = processor.info
|
||||
|
||||
@ -143,7 +146,10 @@ def test_processor_prompt_replacements_regression(model_id, num_imgs):
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||
ctx.model_config,
|
||||
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
|
||||
tokenizer=cached_get_tokenizer(
|
||||
ctx.model_config.tokenizer,
|
||||
trust_remote_code=ctx.model_config.trust_remote_code,
|
||||
),
|
||||
)
|
||||
|
||||
image_ratios = [(171, 152), (184, 161), (198, 176), (333, 296), (369, 328),
|
||||
@ -174,7 +180,10 @@ def test_processor_prompt_replacements_all(model_id, num_imgs):
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||
ctx.model_config,
|
||||
tokenizer=cached_get_tokenizer(ctx.model_config.tokenizer),
|
||||
tokenizer=cached_get_tokenizer(
|
||||
ctx.model_config.tokenizer,
|
||||
trust_remote_code=ctx.model_config.trust_remote_code,
|
||||
),
|
||||
)
|
||||
|
||||
seen_aspect_ratios = set[float]()
|
||||
|
@ -38,7 +38,10 @@ def test_processor_override(
|
||||
trust_remote_code=True,
|
||||
limit_mm_per_prompt={"image": num_imgs},
|
||||
)
|
||||
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
|
||||
tokenizer = cached_get_tokenizer(
|
||||
ctx.model_config.tokenizer,
|
||||
trust_remote_code=ctx.model_config.trust_remote_code,
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||
ctx.model_config,
|
||||
tokenizer=tokenizer,
|
||||
|
@ -33,7 +33,10 @@ def test_processor_override(
|
||||
mm_processor_kwargs=None,
|
||||
limit_mm_per_prompt={"image": num_imgs},
|
||||
)
|
||||
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
|
||||
tokenizer = cached_get_tokenizer(
|
||||
ctx.model_config.tokenizer,
|
||||
trust_remote_code=ctx.model_config.trust_remote_code,
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(
|
||||
ctx.model_config,
|
||||
tokenizer=tokenizer,
|
||||
|
@ -399,7 +399,11 @@ class AriaProcessingInfo(BaseProcessingInfo):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {"image": self.get_num_image_tokens()}
|
||||
|
||||
def get_num_image_tokens(self) -> int:
|
||||
|
@ -407,7 +407,11 @@ class Blip2ProcessingInfo(BaseProcessingInfo):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {"image": self.get_num_image_tokens()}
|
||||
|
||||
def get_num_image_tokens(self) -> int:
|
||||
|
@ -64,7 +64,11 @@ class ChameleonProcessingInfo(BaseProcessingInfo):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {"image": self.get_num_image_tokens()}
|
||||
|
||||
def get_num_image_tokens(self) -> int:
|
||||
|
@ -165,7 +165,11 @@ class DeepseekVL2ProcessingInfo(BaseProcessingInfo):
|
||||
image_width=x[1], image_height=x[0]))
|
||||
return ImageSize(width=width, height=height)
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
max_image_size = self.get_image_size_with_most_features()
|
||||
max_image_tokens = self.get_num_image_tokens(
|
||||
image_height=max_image_size.height,
|
||||
|
@ -80,7 +80,11 @@ class FuyuProcessingInfo(BaseProcessingInfo):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": 1}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
max_ncols, max_nrows = self.get_image_feature_grid_size(
|
||||
|
@ -7,43 +7,55 @@
|
||||
# Copyright (c) 2024 H2O.AI
|
||||
# Licensed under Apache 2.0 License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from functools import partial
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Mapping, Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (ProcessingCache, PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from .intern_vit import InternVisionModel
|
||||
from .internvl import (IMG_CONTEXT, IMG_END, IMG_START, InternVLChatModel,
|
||||
InternVLInputPipeline, build_transform,
|
||||
find_closest_aspect_ratio, get_internvl_num_patches)
|
||||
from .internvl import (IMG_CONTEXT, IMG_END, IMG_START,
|
||||
BaseInternVLProcessingInfo, BaseInternVLProcessor,
|
||||
InternVLChatModel, InternVLDummyInputsBuilder,
|
||||
InternVLMultiModalProcessor, build_transform,
|
||||
find_closest_aspect_ratio, get_internvl_target_ratios)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# modified to include blocks generated in second pass
|
||||
def calculate_num_blocks(
|
||||
orig_width: int,
|
||||
orig_height: int,
|
||||
def resolve_h2ovl_min_max_num(
|
||||
*,
|
||||
min_dynamic_patch: int,
|
||||
max_dynamic_patch: int,
|
||||
dynamic_image_size: bool,
|
||||
use_thumbnail: bool,
|
||||
) -> tuple[int, int]:
|
||||
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
|
||||
|
||||
if use_thumbnail and max_dynamic_patch != 1:
|
||||
max_dynamic_patch += 1
|
||||
|
||||
return min_dynamic_patch, max_dynamic_patch
|
||||
|
||||
|
||||
def get_h2ovl_target_ratios(
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
image_size: int,
|
||||
use_thumbnail: bool,
|
||||
prior_aspect_ratio=None,
|
||||
) -> Tuple[int, int, int, Tuple[int, int]]:
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = set((i, j) for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1) for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num)
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
*,
|
||||
prior_aspect_ratio: Optional[tuple[int, int]],
|
||||
) -> list[tuple[int, int]]:
|
||||
target_ratios = get_internvl_target_ratios(min_num, max_num)
|
||||
|
||||
# if prior_aspect_ratio is provided, filter the target ratios
|
||||
if prior_aspect_ratio is not None:
|
||||
@ -52,44 +64,66 @@ def calculate_num_blocks(
|
||||
ratio[0] != 0 and prior_aspect_ratio[1] % ratio[1] != 0
|
||||
]
|
||||
|
||||
return target_ratios
|
||||
|
||||
|
||||
# modified to include blocks generated in second pass
|
||||
def calculate_h2ovl_targets(
|
||||
*,
|
||||
orig_width: int,
|
||||
orig_height: int,
|
||||
target_ratios: list[tuple[int, int]],
|
||||
image_size: int,
|
||||
use_thumbnail: bool,
|
||||
) -> tuple[int, int, int, tuple[int, int]]:
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
|
||||
target_ratios, orig_width,
|
||||
orig_height, image_size)
|
||||
target_aspect_ratio = find_closest_aspect_ratio(
|
||||
aspect_ratio,
|
||||
target_ratios,
|
||||
width=orig_width,
|
||||
height=orig_height,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
# add thumbnail image if num_blocks > 1
|
||||
if use_thumbnail and blocks > 1:
|
||||
|
||||
# add thumbnail image if num_blocks != 1
|
||||
if use_thumbnail and blocks != 1:
|
||||
blocks += 1
|
||||
|
||||
return blocks, target_width, target_height, target_aspect_ratio
|
||||
|
||||
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
# refactored to handle prior_aspect_ratio as optional
|
||||
def dynamic_preprocess(
|
||||
# refactored to handle prior_aspect_ratio
|
||||
def dynamic_preprocess_h2ovl(
|
||||
image: Image.Image,
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
*,
|
||||
target_ratios: list[tuple[int, int]],
|
||||
image_size: int,
|
||||
use_thumbnail: bool,
|
||||
prior_aspect_ratio: Optional[Tuple[int, int]] = None,
|
||||
) -> Tuple[List[Image.Image], Tuple[int, int]]:
|
||||
) -> tuple[list[Image.Image], tuple[int, int]]:
|
||||
orig_width, orig_height = image.size
|
||||
|
||||
# calculate the number of blocks based on prior aspect ratio if available
|
||||
blocks, target_width, target_height, target_aspect_ratio = (
|
||||
calculate_num_blocks(
|
||||
orig_width,
|
||||
orig_height,
|
||||
min_num,
|
||||
max_num,
|
||||
image_size,
|
||||
use_thumbnail=False,
|
||||
prior_aspect_ratio=prior_aspect_ratio,
|
||||
))
|
||||
# calculate the number of blocks without thumbnail
|
||||
(
|
||||
blocks,
|
||||
target_width,
|
||||
target_height,
|
||||
target_aspect_ratio,
|
||||
) = calculate_h2ovl_targets(
|
||||
orig_width=orig_width,
|
||||
orig_height=orig_height,
|
||||
target_ratios=target_ratios,
|
||||
image_size=image_size,
|
||||
use_thumbnail=False,
|
||||
)
|
||||
|
||||
# resize the image
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
@ -103,276 +137,393 @@ def dynamic_preprocess(
|
||||
# split the image
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
|
||||
assert len(processed_images) == blocks
|
||||
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = image.resize((image_size, image_size))
|
||||
processed_images.append(thumbnail_img)
|
||||
|
||||
return processed_images, target_aspect_ratio
|
||||
|
||||
|
||||
def load_image(
|
||||
image: Image.Image,
|
||||
input_size=448,
|
||||
min_num=1,
|
||||
max_num=6,
|
||||
use_thumbnail=True,
|
||||
prior_aspect_ratio: Optional[Tuple[int, int]] = None,
|
||||
) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
||||
transform = build_transform(input_size=input_size)
|
||||
images, target_aspect_ratio = dynamic_preprocess(
|
||||
image,
|
||||
image_size=input_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
min_num=min_num,
|
||||
max_num=max_num,
|
||||
prior_aspect_ratio=prior_aspect_ratio,
|
||||
)
|
||||
pixel_values = [transform(image) for image in images]
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
return pixel_values, target_aspect_ratio
|
||||
|
||||
|
||||
# refactored to use the combined load_image function
|
||||
def image_to_pixel_values(
|
||||
def _preprocess_image(
|
||||
image: Image.Image,
|
||||
*,
|
||||
input_size: int,
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
use_thumbnail: bool,
|
||||
use_MSAC: bool,
|
||||
prior_aspect_ratio: Optional[tuple[int, int]],
|
||||
) -> tuple[torch.Tensor, tuple[int, int]]:
|
||||
target_ratios = get_h2ovl_target_ratios(
|
||||
min_num,
|
||||
max_num,
|
||||
prior_aspect_ratio=prior_aspect_ratio,
|
||||
)
|
||||
|
||||
transform = build_transform(input_size=input_size)
|
||||
images, target_aspect_ratio = dynamic_preprocess_h2ovl(
|
||||
image,
|
||||
image_size=input_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
target_ratios=target_ratios,
|
||||
)
|
||||
|
||||
pixel_values = torch.stack([transform(image) for image in images])
|
||||
return pixel_values, target_aspect_ratio
|
||||
|
||||
|
||||
# refactored to use the _preprocess_image function
|
||||
def image_to_pixel_values_h2ovl(
|
||||
image: Image.Image,
|
||||
*,
|
||||
input_size: int,
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
use_thumbnail: bool,
|
||||
use_msac: bool,
|
||||
) -> torch.Tensor:
|
||||
# when MSAC is turned on, we need to process the image twice
|
||||
if use_MSAC:
|
||||
if use_msac:
|
||||
# first pass
|
||||
pixel_values, target_aspect_ratio = load_image(
|
||||
pixel_values1, aspect_ratio1 = _preprocess_image(
|
||||
image,
|
||||
input_size=input_size,
|
||||
min_num=min_num,
|
||||
max_num=max_num,
|
||||
use_thumbnail=True,
|
||||
prior_aspect_ratio=None,
|
||||
)
|
||||
# second pass
|
||||
pixel_values2, _ = load_image(
|
||||
pixel_values2, _ = _preprocess_image(
|
||||
image,
|
||||
input_size=input_size,
|
||||
min_num=min_num,
|
||||
min_num=3, # Hardcoded value
|
||||
max_num=max_num,
|
||||
prior_aspect_ratio=target_aspect_ratio,
|
||||
use_thumbnail=True,
|
||||
prior_aspect_ratio=aspect_ratio1,
|
||||
)
|
||||
# combine pixel values
|
||||
pixel_values = torch.cat(
|
||||
[pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], 0)
|
||||
[pixel_values2[:-1], pixel_values1[:-1], pixel_values2[-1:]], 0)
|
||||
|
||||
else:
|
||||
pixel_values, _ = load_image(
|
||||
pixel_values, _ = _preprocess_image(
|
||||
image,
|
||||
input_size=input_size,
|
||||
min_num=min_num,
|
||||
max_num=max_num,
|
||||
use_thumbnail=use_thumbnail,
|
||||
prior_aspect_ratio=None,
|
||||
)
|
||||
|
||||
return pixel_values
|
||||
|
||||
|
||||
def image_to_pixel_values_wrapper(hf_config: PretrainedConfig,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
use_MSAC: Optional[bool] = None):
|
||||
image_size = hf_config.vision_config.image_size
|
||||
min_num = hf_config.min_dynamic_patch
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||
if use_MSAC is None:
|
||||
use_MSAC = hf_config.use_msac
|
||||
use_thumbnail = hf_config.use_thumbnail
|
||||
return partial(
|
||||
image_to_pixel_values,
|
||||
input_size=image_size,
|
||||
min_num=min_num,
|
||||
max_num=max_dynamic_patch,
|
||||
use_thumbnail=use_thumbnail,
|
||||
use_MSAC=use_MSAC,
|
||||
)
|
||||
class H2OVLProcessor(BaseInternVLProcessor):
|
||||
|
||||
|
||||
def get_max_internvl_image_tokens(ctx: InputContext,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None):
|
||||
"""
|
||||
Calculate the maximum number of tokens with/without MSAC and thumbnail
|
||||
"""
|
||||
hf_config = ctx.get_hf_config()
|
||||
use_thumbnail = hf_config.use_thumbnail
|
||||
use_MSAC = hf_config.use_msac
|
||||
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||
|
||||
num_patches = get_internvl_num_patches(hf_config)
|
||||
|
||||
coefficient = 2 if use_MSAC else 1
|
||||
num_blocks = coefficient * max_dynamic_patch + (1 if use_thumbnail else 0)
|
||||
|
||||
return num_blocks * num_patches
|
||||
|
||||
|
||||
class H2OVLInputPipeline(InternVLInputPipeline):
|
||||
"""
|
||||
Input pipeline for processing image and text data for the H2OVL model.
|
||||
"""
|
||||
|
||||
def input_processor(
|
||||
def __init__(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
) -> DecoderOnlyInputs:
|
||||
# get multi_modal_data
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config()
|
||||
use_MSAC = hf_config.use_msac
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
num_patches = get_internvl_num_patches(hf_config)
|
||||
|
||||
image_pixel_values_mapper = image_to_pixel_values_wrapper(
|
||||
hf_config, max_dynamic_patch=max_dynamic_patch)
|
||||
|
||||
# single image
|
||||
if isinstance(image_data, Image.Image):
|
||||
pixel_values = image_pixel_values_mapper(image_data,
|
||||
use_MSAC=use_MSAC)
|
||||
num_blocks = pixel_values.shape[0]
|
||||
image_feature_sizes = [num_blocks * num_patches]
|
||||
pixel_values = pixel_values.unsqueeze(0)
|
||||
|
||||
# multi images
|
||||
elif is_list_of(image_data, Image.Image):
|
||||
# Do not use MSAC for multi images
|
||||
image_feature_sizes = []
|
||||
pixel_values = [
|
||||
image_pixel_values_mapper(image, use_MSAC=False)
|
||||
for image in image_data
|
||||
]
|
||||
for pixel_value in pixel_values:
|
||||
num_blocks = pixel_value.shape[0]
|
||||
image_feature_sizes.append(num_blocks * num_patches)
|
||||
|
||||
# image embeddings as input
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
_, image_feature_size, _ = image_data.shape
|
||||
image_feature_sizes = [image_feature_size]
|
||||
pixel_values = None
|
||||
|
||||
# multi-image image embeddings
|
||||
elif is_list_of(image_data, torch.Tensor):
|
||||
|
||||
image_feature_sizes = []
|
||||
for image_embed in image_data:
|
||||
_, image_feature_size, _ = image_embed.shape
|
||||
image_feature_sizes.append(image_feature_size)
|
||||
pixel_values = None
|
||||
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
use_msac: Optional[bool] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
config,
|
||||
tokenizer,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
|
||||
prompt = inputs.get("prompt")
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
if prompt is None:
|
||||
prompt = tokenizer.decode(prompt_token_ids)
|
||||
if use_msac is None:
|
||||
use_msac = config.use_msac
|
||||
assert isinstance(use_msac, bool)
|
||||
|
||||
new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
|
||||
num_patches)
|
||||
new_prompt_token_ids = tokenizer.encode(new_prompt)
|
||||
self.use_msac = use_msac
|
||||
|
||||
# Wrap image processing in input_processor to avoid duplication
|
||||
image_token_id = tokenizer.encode(
|
||||
self.img_context_token,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
)[0]
|
||||
@property
|
||||
def image_token_id(self) -> int:
|
||||
return self.tokenizer.get_vocab()[IMG_CONTEXT]
|
||||
|
||||
# Update multi_modal_data to return
|
||||
if pixel_values is not None:
|
||||
multi_modal_data = {
|
||||
"image": {
|
||||
"pixel_values": pixel_values,
|
||||
"image_token_id": image_token_id,
|
||||
}
|
||||
}
|
||||
else:
|
||||
multi_modal_data = {"image": {"image_embeds": image_data}}
|
||||
|
||||
return token_inputs(
|
||||
prompt=prompt,
|
||||
prompt_token_ids=new_prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
def input_mapper(
|
||||
def get_image_repl_features(
|
||||
self,
|
||||
feature_size: int,
|
||||
num_patches: Optional[int],
|
||||
) -> str:
|
||||
return IMG_CONTEXT * feature_size
|
||||
|
||||
def get_image_repl_full(
|
||||
self,
|
||||
feature_size: int,
|
||||
num_patches: Optional[int],
|
||||
) -> str:
|
||||
features = self.get_image_repl_features(feature_size, num_patches)
|
||||
return IMG_START + features + IMG_END
|
||||
|
||||
def resolve_min_max_num(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: object,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
) -> MultiModalKwargs:
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
use_thumbnail: Optional[bool] = None,
|
||||
) -> tuple[int, int]:
|
||||
min_dynamic_patch = self.min_dynamic_patch
|
||||
max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
|
||||
is None else max_dynamic_patch)
|
||||
dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
|
||||
is None else dynamic_image_size)
|
||||
use_thumbnail = (self.use_thumbnail
|
||||
if use_thumbnail is None else use_thumbnail)
|
||||
|
||||
# NOTE: Preprocessing for the image data is done in the
|
||||
# 'input_processor' function during actual inference.
|
||||
if isinstance(data, dict):
|
||||
return MultiModalKwargs(data)
|
||||
|
||||
# The section below is only used with dummy data during
|
||||
# memory profiling.
|
||||
hf_config = ctx.get_hf_config()
|
||||
|
||||
image_pixel_values_mapper = image_to_pixel_values_wrapper(
|
||||
hf_config, max_dynamic_patch)
|
||||
|
||||
if isinstance(data, Image.Image):
|
||||
pixel_values = image_pixel_values_mapper(data)
|
||||
pixel_values = pixel_values.unsqueeze(0)
|
||||
|
||||
elif is_list_of(data, Image.Image):
|
||||
hf_config.use_msac = False
|
||||
pixel_values = [image_pixel_values_mapper(img) for img in data]
|
||||
|
||||
else:
|
||||
return MultiModalKwargs({"image_embeds": data})
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
return resolve_h2ovl_min_max_num(
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
)
|
||||
image_token_id = tokenizer.encode(
|
||||
self.img_context_token,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
)[0]
|
||||
|
||||
return MultiModalKwargs({
|
||||
"pixel_values": pixel_values,
|
||||
"image_token_id": image_token_id
|
||||
})
|
||||
def resolve_target_ratios(
|
||||
self,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
use_thumbnail: Optional[bool] = None,
|
||||
prior_aspect_ratio: Optional[tuple[int, int]] = None,
|
||||
) -> list[tuple[int, int]]:
|
||||
min_num, max_num = self.resolve_min_max_num(
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
)
|
||||
if prior_aspect_ratio: # hardcoded value for second pass of use_msac
|
||||
min_num = 3
|
||||
|
||||
return get_h2ovl_target_ratios(
|
||||
min_num,
|
||||
max_num,
|
||||
prior_aspect_ratio=prior_aspect_ratio,
|
||||
)
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
use_msac: Optional[bool] = None,
|
||||
) -> int:
|
||||
use_msac = (self.use_msac if use_msac is None else use_msac)
|
||||
|
||||
use_thumbnail = self.use_thumbnail
|
||||
|
||||
if use_msac:
|
||||
target_ratios_1 = self.resolve_target_ratios(
|
||||
use_thumbnail=False, # Applied in calculate_targets
|
||||
)
|
||||
num_patches_1, _, _, aspect_ratio_1 = calculate_h2ovl_targets(
|
||||
orig_width=image_width,
|
||||
orig_height=image_height,
|
||||
image_size=self.image_size,
|
||||
target_ratios=target_ratios_1,
|
||||
use_thumbnail=True,
|
||||
)
|
||||
|
||||
target_ratios_2 = self.resolve_target_ratios(
|
||||
use_thumbnail=False, # Applied in calculate_targets
|
||||
prior_aspect_ratio=aspect_ratio_1,
|
||||
)
|
||||
num_patches_2, _, _, _ = calculate_h2ovl_targets(
|
||||
orig_width=image_width,
|
||||
orig_height=image_height,
|
||||
image_size=self.image_size,
|
||||
target_ratios=target_ratios_2,
|
||||
use_thumbnail=True,
|
||||
)
|
||||
|
||||
num_patches = num_patches_1 + num_patches_2 - 1
|
||||
else:
|
||||
target_ratios = self.resolve_target_ratios(
|
||||
use_thumbnail=False, # Applied in calculate_targets
|
||||
)
|
||||
num_patches, _, _, _ = calculate_h2ovl_targets(
|
||||
orig_width=image_width,
|
||||
orig_height=image_height,
|
||||
image_size=self.image_size,
|
||||
target_ratios=target_ratios,
|
||||
use_thumbnail=use_thumbnail,
|
||||
)
|
||||
|
||||
return num_patches * self.num_image_token
|
||||
|
||||
def _images_to_pixel_values_lst(
|
||||
self,
|
||||
images: list[Image.Image],
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
) -> list[torch.Tensor]:
|
||||
use_msac = self.use_msac if len(images) == 1 else False
|
||||
|
||||
min_num, max_num = self.resolve_min_max_num(
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
use_thumbnail=False, # Applied in image_to_pixel_values
|
||||
)
|
||||
|
||||
return [
|
||||
image_to_pixel_values_h2ovl(
|
||||
image,
|
||||
input_size=self.image_size,
|
||||
min_num=min_num,
|
||||
max_num=max_num,
|
||||
use_thumbnail=self.use_thumbnail,
|
||||
use_msac=use_msac,
|
||||
) for image in images
|
||||
]
|
||||
|
||||
|
||||
input_pipeline = H2OVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
|
||||
class H2OVLProcessingInfo(BaseInternVLProcessingInfo):
|
||||
|
||||
def get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
) -> H2OVLProcessor:
|
||||
return H2OVLProcessor(
|
||||
self.get_hf_config(),
|
||||
self.get_tokenizer(),
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
max_tokens_one_image = self.get_max_image_tokens(use_msac=None)
|
||||
if mm_counts.get("image", 0) <= 1:
|
||||
max_tokens_per_image = max_tokens_one_image
|
||||
else:
|
||||
max_tokens_per_image = self.get_max_image_tokens(use_msac=False)
|
||||
|
||||
return {"image": max_tokens_per_image}
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: Optional[H2OVLProcessor],
|
||||
use_msac: Optional[bool] = None,
|
||||
) -> int:
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
return processor.get_num_image_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
use_msac=use_msac,
|
||||
)
|
||||
|
||||
def get_max_image_tokens(self, use_msac: Optional[bool] = None) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
return self.get_num_image_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
processor=None,
|
||||
use_msac=use_msac,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
|
||||
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
|
||||
class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
|
||||
):
|
||||
|
||||
def __init__(self,
|
||||
info: H2OVLProcessingInfo,
|
||||
dummy_inputs: "BaseDummyInputsBuilder[H2OVLProcessingInfo]",
|
||||
*,
|
||||
cache: Optional[ProcessingCache] = None,
|
||||
enable_sanity_checks: bool = True) -> None:
|
||||
super().__init__(
|
||||
info,
|
||||
dummy_inputs,
|
||||
cache=cache,
|
||||
enable_sanity_checks=enable_sanity_checks,
|
||||
)
|
||||
|
||||
if self.cache is not None:
|
||||
# The processor output depends on the number of images passed,
|
||||
# making it incompatible with processing cache which is supposed
|
||||
# to be invariant of how many images are passed per prompt
|
||||
self.cache = None
|
||||
logger.warning_once(
|
||||
f"{type(self).__name__} does not support processing cache.")
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
if "image_num_patches" in out_mm_kwargs:
|
||||
image_num_patches = out_mm_kwargs["image_num_patches"]
|
||||
assert isinstance(image_num_patches, torch.Tensor)
|
||||
image_num_patches = image_num_patches.tolist()
|
||||
elif "image_embeds" in out_mm_kwargs:
|
||||
# TODO: Use image size information in dictionary embedding inputs
|
||||
# to compute num_patches (similar to Qwen2-VL)
|
||||
image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
|
||||
else:
|
||||
image_num_patches = []
|
||||
|
||||
num_images = len(image_num_patches)
|
||||
|
||||
def get_replacement_internvl(item_idx: int):
|
||||
images = mm_items.get_items(
|
||||
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
||||
|
||||
if isinstance(images, ImageEmbeddingItems):
|
||||
feature_size = images.get_feature_size(item_idx)
|
||||
else:
|
||||
image_size = images.get_image_size(item_idx)
|
||||
feature_size = self.info.get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
processor=hf_processor,
|
||||
use_msac=None if num_images == 1 else False,
|
||||
)
|
||||
|
||||
num_patches = image_num_patches[item_idx]
|
||||
if num_patches is not None:
|
||||
assert isinstance(num_patches, int)
|
||||
|
||||
return PromptReplacementDetails(
|
||||
full=hf_processor.get_image_repl_full(feature_size,
|
||||
num_patches),
|
||||
features=hf_processor.get_image_repl_features(
|
||||
feature_size, num_patches),
|
||||
)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target="<image>",
|
||||
replacement=get_replacement_internvl,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
H2OVLMultiModalProcessor,
|
||||
info=H2OVLProcessingInfo,
|
||||
dummy_inputs=InternVLDummyInputsBuilder)
|
||||
class H2OVLChatModel(InternVLChatModel):
|
||||
|
||||
def _init_vision_model(
|
||||
|
@ -6,35 +6,37 @@
|
||||
# Copyright (c) 2023 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
import re
|
||||
from functools import cached_property, partial
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
TypedDict, TypeVar, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from transformers import PretrainedConfig
|
||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.models.intern_vit import (InternVisionModel,
|
||||
InternVisionPatchModel)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
get_clip_num_patches)
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
@ -75,22 +77,27 @@ InternVLImageInputs = Union[InternVLImagePixelInputs,
|
||||
InternVLImageEmbeddingInputs]
|
||||
|
||||
|
||||
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def build_transform(input_size):
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def build_transform(input_size: int):
|
||||
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
||||
transform = T.Compose([
|
||||
return T.Compose([
|
||||
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
||||
T.Resize((input_size, input_size),
|
||||
interpolation=T.InterpolationMode.BICUBIC),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=MEAN, std=STD)
|
||||
])
|
||||
return transform
|
||||
|
||||
|
||||
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
|
||||
image_size):
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def find_closest_aspect_ratio(
|
||||
aspect_ratio: float,
|
||||
target_ratios: list[tuple[int, int]],
|
||||
*,
|
||||
width: int,
|
||||
height: int,
|
||||
image_size: int,
|
||||
) -> tuple[int, int]:
|
||||
best_ratio_diff = float('inf')
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
@ -106,67 +113,82 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
|
||||
return best_ratio
|
||||
|
||||
|
||||
def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
|
||||
max_num: int, image_size: int,
|
||||
use_thumbnail: bool) -> Tuple[int, int, int]:
|
||||
def resolve_internvl_min_max_num(
|
||||
*,
|
||||
min_dynamic_patch: int,
|
||||
max_dynamic_patch: int,
|
||||
dynamic_image_size: bool,
|
||||
use_thumbnail: bool,
|
||||
) -> tuple[int, int]:
|
||||
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
|
||||
|
||||
if use_thumbnail and max_dynamic_patch != 1:
|
||||
max_dynamic_patch += 1
|
||||
|
||||
return min_dynamic_patch, max_dynamic_patch
|
||||
|
||||
|
||||
def get_internvl_target_ratios(
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
) -> list[tuple[int, int]]:
|
||||
target_ratios = {(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1) if min_num <= i * j <= max_num}
|
||||
return sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
|
||||
def calculate_internvl_targets(
|
||||
*,
|
||||
orig_width: int,
|
||||
orig_height: int,
|
||||
target_ratios: list[tuple[int, int]],
|
||||
image_size: int,
|
||||
use_thumbnail: bool,
|
||||
) -> tuple[int, int, int]:
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = set((i, j) for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1) for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num)
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
|
||||
target_ratios, orig_width,
|
||||
orig_height, image_size)
|
||||
target_aspect_ratio = find_closest_aspect_ratio(
|
||||
aspect_ratio,
|
||||
target_ratios,
|
||||
width=orig_width,
|
||||
height=orig_height,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
# add thumbnail image if num_blocks > 1
|
||||
if use_thumbnail and blocks > 1:
|
||||
|
||||
# add thumbnail image if num_blocks != 1
|
||||
if use_thumbnail and blocks != 1:
|
||||
blocks += 1
|
||||
|
||||
return blocks, target_width, target_height
|
||||
|
||||
|
||||
def calculate_num_blocks_wrapper(
|
||||
hf_config: PretrainedConfig,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
):
|
||||
if dynamic_image_size is None:
|
||||
dynamic_image_size = hf_config.dynamic_image_size
|
||||
|
||||
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||
min_num = hf_config.min_dynamic_patch
|
||||
image_size = hf_config.vision_config.image_size
|
||||
use_thumbnail = hf_config.use_thumbnail
|
||||
return partial(calculate_num_blocks,
|
||||
min_num=min_num,
|
||||
max_num=max_dynamic_patch,
|
||||
image_size=image_size,
|
||||
use_thumbnail=use_thumbnail)
|
||||
|
||||
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
|
||||
image_size: int,
|
||||
use_thumbnail: bool) -> List[Image.Image]:
|
||||
def dynamic_preprocess_internvl(
|
||||
image: Image.Image,
|
||||
*,
|
||||
target_ratios: list[tuple[int, int]],
|
||||
image_size: int,
|
||||
use_thumbnail: bool,
|
||||
) -> list[Image.Image]:
|
||||
orig_width, orig_height = image.size
|
||||
|
||||
# calculate the number of blocks without thumbnail
|
||||
blocks, target_width, target_height = calculate_num_blocks(
|
||||
orig_width,
|
||||
orig_height,
|
||||
min_num,
|
||||
max_num,
|
||||
image_size,
|
||||
use_thumbnail=False)
|
||||
blocks, target_width, target_height = calculate_internvl_targets(
|
||||
orig_width=orig_width,
|
||||
orig_height=orig_height,
|
||||
target_ratios=target_ratios,
|
||||
image_size=image_size,
|
||||
use_thumbnail=False,
|
||||
)
|
||||
|
||||
# resize the image
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
@ -178,301 +200,463 @@ def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
|
||||
# split the image
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
|
||||
assert len(processed_images) == blocks
|
||||
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = image.resize((image_size, image_size))
|
||||
processed_images.append(thumbnail_img)
|
||||
|
||||
return processed_images
|
||||
|
||||
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
|
||||
max_num: int, use_thumbnail: bool) -> torch.Tensor:
|
||||
def image_to_pixel_values_internvl(
|
||||
image: Image.Image,
|
||||
*,
|
||||
input_size: int,
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
use_thumbnail: bool,
|
||||
) -> torch.Tensor:
|
||||
target_ratios = get_internvl_target_ratios(min_num, max_num)
|
||||
|
||||
transform = build_transform(input_size=input_size)
|
||||
images = dynamic_preprocess(image,
|
||||
min_num=min_num,
|
||||
max_num=max_num,
|
||||
image_size=input_size,
|
||||
use_thumbnail=use_thumbnail)
|
||||
pixel_values = [transform(image) for image in images]
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
images = dynamic_preprocess_internvl(
|
||||
image,
|
||||
target_ratios=target_ratios,
|
||||
image_size=input_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
)
|
||||
|
||||
pixel_values = torch.stack([transform(image) for image in images])
|
||||
return pixel_values
|
||||
|
||||
|
||||
def image_to_pixel_values_wrapper(
|
||||
hf_config: PretrainedConfig,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
):
|
||||
image_size = hf_config.vision_config.image_size
|
||||
min_num = hf_config.min_dynamic_patch
|
||||
if dynamic_image_size is None:
|
||||
dynamic_image_size = hf_config.dynamic_image_size
|
||||
class BaseInternVLProcessor(ABC):
|
||||
"""
|
||||
This model doesn't define its own HF processor,
|
||||
so we implement our own one here.
|
||||
|
||||
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||
use_thumbnail = hf_config.use_thumbnail
|
||||
return partial(image_to_pixel_values,
|
||||
input_size=image_size,
|
||||
min_num=min_num,
|
||||
max_num=max_dynamic_patch,
|
||||
use_thumbnail=use_thumbnail)
|
||||
|
||||
|
||||
def get_internvl_num_patches(hf_config: PretrainedConfig):
|
||||
vision_config = hf_config.vision_config
|
||||
downsample_ratio = hf_config.downsample_ratio
|
||||
image_size = vision_config.image_size
|
||||
patch_size = vision_config.patch_size
|
||||
return int(
|
||||
get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
|
||||
(downsample_ratio**2))
|
||||
|
||||
|
||||
def get_max_internvl_image_tokens(
|
||||
ctx: InputContext,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
):
|
||||
hf_config = ctx.get_hf_config()
|
||||
if dynamic_image_size is None:
|
||||
dynamic_image_size = hf_config.dynamic_image_size
|
||||
|
||||
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||
use_thumbnail = hf_config.use_thumbnail
|
||||
if use_thumbnail and max_dynamic_patch > 1:
|
||||
max_dynamic_patch += 1
|
||||
|
||||
num_patches = get_internvl_num_patches(hf_config)
|
||||
return num_patches * max_dynamic_patch
|
||||
|
||||
|
||||
def get_max_internvl_image_size(
|
||||
ctx: InputContext,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
):
|
||||
hf_config = ctx.get_hf_config()
|
||||
image_size = hf_config.vision_config.image_size
|
||||
if dynamic_image_size is None:
|
||||
dynamic_image_size = hf_config.dynamic_image_size
|
||||
|
||||
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||
use_thumbnail = hf_config.use_thumbnail
|
||||
if use_thumbnail and max_dynamic_patch > 1:
|
||||
max_dynamic_patch += 1
|
||||
width = image_size * max_dynamic_patch
|
||||
height = image_size
|
||||
return width, height
|
||||
|
||||
|
||||
class InternVLInputPipeline:
|
||||
The code to insert image tokens is based on:
|
||||
https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_start_token: str,
|
||||
img_end_token: str,
|
||||
img_context_token: str,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.img_start_token = img_start_token
|
||||
self.img_end_token = img_end_token
|
||||
self.img_context_token = img_context_token
|
||||
self.config = config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def _create_image_prompt(self, feature_size: int, num_patches: int) -> str:
|
||||
return (self.img_start_token + self.img_context_token * feature_size +
|
||||
self.img_end_token)
|
||||
image_size: int = config.vision_config.image_size
|
||||
patch_size: int = config.vision_config.patch_size
|
||||
|
||||
def _expand_image_prompt(
|
||||
if dynamic_image_size is None:
|
||||
dynamic_image_size = config.dynamic_image_size
|
||||
assert isinstance(dynamic_image_size, bool)
|
||||
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = config.max_dynamic_patch
|
||||
assert isinstance(max_dynamic_patch, int)
|
||||
|
||||
self.num_image_token = int(
|
||||
(image_size // patch_size)**2 * (config.downsample_ratio**2))
|
||||
self.image_size = image_size
|
||||
self.min_dynamic_patch: int = config.min_dynamic_patch
|
||||
self.max_dynamic_patch = max_dynamic_patch
|
||||
self.dynamic_image_size = dynamic_image_size
|
||||
self.use_thumbnail: bool = config.use_thumbnail
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def image_token_id(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_image_repl_features(
|
||||
self,
|
||||
prompt: str,
|
||||
feature_sizes: List[int],
|
||||
num_patches: int,
|
||||
feature_size: int,
|
||||
num_patches: Optional[int],
|
||||
) -> str:
|
||||
image_idx = sorted(
|
||||
map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))
|
||||
raise NotImplementedError
|
||||
|
||||
new_prompt = prompt
|
||||
for idx, feature_size in enumerate(feature_sizes, start=1):
|
||||
image_prompt = self._create_image_prompt(feature_size, num_patches)
|
||||
if not image_idx:
|
||||
image_prompt = f"Image-{idx}: {image_prompt}"
|
||||
|
||||
new_prompt = new_prompt.replace('<image>', image_prompt, 1)
|
||||
|
||||
return new_prompt
|
||||
|
||||
def input_processor(
|
||||
@abstractmethod
|
||||
def get_image_repl_full(
|
||||
self,
|
||||
feature_size: int,
|
||||
num_patches: Optional[int],
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def resolve_min_max_num(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
) -> DecoderOnlyInputs:
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
use_thumbnail: Optional[bool] = None,
|
||||
) -> tuple[int, int]:
|
||||
min_dynamic_patch = self.min_dynamic_patch
|
||||
max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
|
||||
is None else max_dynamic_patch)
|
||||
dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
|
||||
is None else dynamic_image_size)
|
||||
use_thumbnail = (self.use_thumbnail
|
||||
if use_thumbnail is None else use_thumbnail)
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config()
|
||||
return resolve_internvl_min_max_num(
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
)
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
num_patches = get_internvl_num_patches(hf_config)
|
||||
num_blocks_calculator = calculate_num_blocks_wrapper(
|
||||
hf_config, max_dynamic_patch, dynamic_image_size)
|
||||
if isinstance(image_data, Image.Image):
|
||||
width, height = image_data.size
|
||||
num_blocks, _, _ = num_blocks_calculator(width, height)
|
||||
image_feature_sizes = [num_blocks * num_patches]
|
||||
elif is_list_of(image_data, Image.Image):
|
||||
image_feature_sizes = []
|
||||
for image in image_data:
|
||||
width, height = image.size
|
||||
num_blocks, _, _ = num_blocks_calculator(width, height)
|
||||
image_feature_sizes.append(num_blocks * num_patches)
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
num_images, image_feature_size, hidden_size = image_data.shape
|
||||
image_feature_sizes = [image_feature_size]
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
prompt = inputs.get("prompt")
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
if prompt is None:
|
||||
prompt = tokenizer.decode(prompt_token_ids)
|
||||
|
||||
new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
|
||||
num_patches)
|
||||
new_prompt_token_ids = tokenizer.encode(new_prompt)
|
||||
img_context_token_id = tokenizer.encode(self.img_context_token,
|
||||
add_special_tokens=False)
|
||||
assert len(img_context_token_id) == 1, \
|
||||
(f"Invalid image token '{self.img_context_token}': A valid image "
|
||||
f"token encodes to a single token ID, got {img_context_token_id}.")
|
||||
img_context_token_id = img_context_token_id[0]
|
||||
|
||||
# Get precise tracking of placeholder positions
|
||||
token_idx = image_idx = 0
|
||||
placeholder_ranges = []
|
||||
while token_idx < len(new_prompt_token_ids):
|
||||
if new_prompt_token_ids[token_idx] == img_context_token_id:
|
||||
curr_image_featue_size = image_feature_sizes[image_idx]
|
||||
placeholder_ranges.append(
|
||||
PlaceholderRange(offset=token_idx,
|
||||
length=curr_image_featue_size))
|
||||
image_idx += 1
|
||||
token_idx += curr_image_featue_size
|
||||
else:
|
||||
token_idx += 1
|
||||
|
||||
return token_inputs(
|
||||
prompt=prompt,
|
||||
prompt_token_ids=new_prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
multi_modal_placeholders={"image": placeholder_ranges})
|
||||
|
||||
def input_mapper(
|
||||
def resolve_target_ratios(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: object,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
):
|
||||
hf_config = ctx.get_hf_config()
|
||||
use_thumbnail: Optional[bool] = None,
|
||||
) -> list[tuple[int, int]]:
|
||||
min_num, max_num = self.resolve_min_max_num(
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
)
|
||||
|
||||
image_pixel_values_mapper = image_to_pixel_values_wrapper(
|
||||
hf_config, max_dynamic_patch, dynamic_image_size)
|
||||
if isinstance(data, Image.Image):
|
||||
data = image_pixel_values_mapper(data)
|
||||
# Add an N dimension for number of images per prompt (currently 1).
|
||||
data = data.unsqueeze(0)
|
||||
elif is_list_of(data, Image.Image):
|
||||
# we can't stack here because images may have different num_patches
|
||||
data = [image_pixel_values_mapper(img) for img in data]
|
||||
else:
|
||||
return MultiModalKwargs({"image_embeds": data})
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
image_token_id = tokenizer.encode(self.img_context_token,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt")[0]
|
||||
return get_internvl_target_ratios(min_num, max_num)
|
||||
|
||||
return MultiModalKwargs({
|
||||
"pixel_values": data,
|
||||
"image_token_id": image_token_id
|
||||
})
|
||||
|
||||
def dummy_data(
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
target_ratios = self.resolve_target_ratios(
|
||||
use_thumbnail=False, # Applied in calculate_targets
|
||||
)
|
||||
|
||||
num_patches, _, _ = calculate_internvl_targets(
|
||||
orig_width=image_width,
|
||||
orig_height=image_height,
|
||||
image_size=self.image_size,
|
||||
target_ratios=target_ratios,
|
||||
use_thumbnail=self.use_thumbnail,
|
||||
)
|
||||
|
||||
return num_patches * self.num_image_token
|
||||
|
||||
def _images_to_pixel_values_lst(
|
||||
self,
|
||||
images: list[Image.Image],
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
) -> list[torch.Tensor]:
|
||||
min_num, max_num = self.resolve_min_max_num(
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
use_thumbnail=False, # Applied in image_to_pixel_values
|
||||
)
|
||||
|
||||
return [
|
||||
image_to_pixel_values_internvl(
|
||||
image,
|
||||
input_size=self.image_size,
|
||||
min_num=min_num,
|
||||
max_num=max_num,
|
||||
use_thumbnail=self.use_thumbnail,
|
||||
) for image in images
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[Union[str, list[str]]] = None,
|
||||
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
) -> BatchFeature:
|
||||
if text is None:
|
||||
text = []
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
if images is None:
|
||||
images = []
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
|
||||
if len(images) == 0:
|
||||
image_inputs = {}
|
||||
else:
|
||||
pixel_values_lst = self._images_to_pixel_values_lst(
|
||||
images,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
image_inputs = {
|
||||
"pixel_values_flat": torch.cat(pixel_values_lst),
|
||||
"image_num_patches": list(map(len, pixel_values_lst)),
|
||||
}
|
||||
|
||||
for pixel_values in pixel_values_lst:
|
||||
num_patches = pixel_values.shape[0]
|
||||
feature_size = num_patches * self.num_image_token
|
||||
|
||||
image_repl = self.get_image_repl_full(feature_size,
|
||||
num_patches)
|
||||
text = [t.replace('<image>', image_repl, 1) for t in text]
|
||||
|
||||
text_inputs = self.tokenizer(text)
|
||||
|
||||
return BatchFeature(
|
||||
{
|
||||
**text_inputs,
|
||||
**image_inputs,
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
|
||||
class InternVLProcessor(BaseInternVLProcessor):
|
||||
|
||||
@property
|
||||
def image_token_id(self) -> int:
|
||||
return self.tokenizer.get_vocab()[IMG_CONTEXT]
|
||||
|
||||
def get_image_repl_features(
|
||||
self,
|
||||
feature_size: int,
|
||||
num_patches: Optional[int],
|
||||
) -> str:
|
||||
return IMG_CONTEXT * feature_size
|
||||
|
||||
def get_image_repl_full(
|
||||
self,
|
||||
feature_size: int,
|
||||
num_patches: Optional[int],
|
||||
) -> str:
|
||||
features = self.get_image_repl_features(feature_size, num_patches)
|
||||
return IMG_START + features + IMG_END
|
||||
|
||||
|
||||
class BaseInternVLProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
@abstractmethod
|
||||
def get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
) -> BaseInternVLProcessor:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {"image": self.get_max_image_tokens()}
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: Optional[BaseInternVLProcessor],
|
||||
) -> int:
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
return processor.get_num_image_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
)
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
return self.get_num_image_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
processor=None,
|
||||
)
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
base_size = processor.image_size
|
||||
target_ratios = processor.resolve_target_ratios()
|
||||
|
||||
largest_feature_size, largest_feature_pinpoint = 0, None
|
||||
for wr, hr in target_ratios:
|
||||
width, height = base_size * wr, base_size * hr
|
||||
|
||||
feat_size = self.get_num_image_tokens(
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
processor=processor,
|
||||
)
|
||||
if feat_size > largest_feature_size:
|
||||
largest_feature_size = feat_size
|
||||
largest_feature_pinpoint = ImageSize(width=width,
|
||||
height=height)
|
||||
|
||||
if largest_feature_size == 0 or largest_feature_pinpoint is None:
|
||||
raise ValueError("Cannot have a largest feature size of 0!")
|
||||
|
||||
return largest_feature_pinpoint
|
||||
|
||||
|
||||
_I = TypeVar("_I", bound=BaseInternVLProcessingInfo)
|
||||
|
||||
|
||||
class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="<image>" * num_images,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
image_token_id = self.info.get_hf_processor(**mm_kwargs).image_token_id
|
||||
image_data = mm_data.get("images", [])
|
||||
assert isinstance(image_data, list)
|
||||
|
||||
# Since there may be extra tokens in the feature placeholders,
|
||||
# we need to pass the image token ID to the model to select the
|
||||
# tokens to merge from the vision encoder outputs
|
||||
processed_outputs["image_token_id"] = [image_token_id
|
||||
] * len(image_data)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
||||
|
||||
return dict(
|
||||
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_num_patches),
|
||||
image_num_patches=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
image_token_id=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
if "image_num_patches" in out_mm_kwargs:
|
||||
image_num_patches = out_mm_kwargs["image_num_patches"]
|
||||
assert isinstance(image_num_patches, torch.Tensor)
|
||||
image_num_patches = image_num_patches.tolist()
|
||||
elif "image_embeds" in out_mm_kwargs:
|
||||
# TODO: Use image size information in dictionary embedding inputs
|
||||
# to compute num_patches (similar to Qwen2-VL)
|
||||
image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
|
||||
else:
|
||||
image_num_patches = []
|
||||
|
||||
def get_replacement_internvl(item_idx: int):
|
||||
images = mm_items.get_items(
|
||||
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
||||
|
||||
if isinstance(images, ImageEmbeddingItems):
|
||||
feature_size = images.get_feature_size(item_idx)
|
||||
else:
|
||||
image_size = images.get_image_size(item_idx)
|
||||
feature_size = self.info.get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
processor=hf_processor,
|
||||
)
|
||||
|
||||
num_patches = image_num_patches[item_idx]
|
||||
if num_patches is not None:
|
||||
assert isinstance(num_patches, int)
|
||||
|
||||
return PromptReplacementDetails(
|
||||
full=hf_processor.get_image_repl_full(feature_size,
|
||||
num_patches),
|
||||
features=hf_processor.get_image_repl_features(
|
||||
feature_size, num_patches),
|
||||
)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target="<image>",
|
||||
replacement=get_replacement_internvl,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class InternVLProcessingInfo(BaseInternVLProcessingInfo):
|
||||
|
||||
def get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
):
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
hf_config = ctx.get_hf_config()
|
||||
|
||||
image_feature_size = get_max_internvl_image_tokens(
|
||||
ctx,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
seq_data, ranges = dummy_seq_data_for_clip(
|
||||
hf_config.vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=tokenizer.encode(self.img_context_token,
|
||||
add_special_tokens=False)[0],
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
max_image_width, max_image_height = get_max_internvl_image_size(
|
||||
ctx,
|
||||
) -> InternVLProcessor:
|
||||
return InternVLProcessor(
|
||||
self.get_hf_config(),
|
||||
self.get_tokenizer(),
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_clip(
|
||||
hf_config.vision_config,
|
||||
num_images,
|
||||
image_width_override=max_image_width,
|
||||
image_height_override=max_image_height,
|
||||
)
|
||||
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
|
||||
|
||||
input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
|
||||
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
InternVLMultiModalProcessor,
|
||||
info=InternVLProcessingInfo,
|
||||
dummy_inputs=InternVLDummyInputsBuilder)
|
||||
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
@ -621,11 +805,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[InternVLImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_token_id = kwargs.pop("image_token_id", None)
|
||||
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
|
||||
image_num_patches = kwargs.pop("image_num_patches", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is None and image_embeds is None:
|
||||
if pixel_values_flat is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if image_embeds is not None:
|
||||
@ -638,31 +822,30 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
data=flatten_bn(image_embeds),
|
||||
)
|
||||
|
||||
self.img_context_token_id = image_token_id[0]
|
||||
image_token_id = kwargs["image_token_id"]
|
||||
assert isinstance(image_token_id, torch.Tensor)
|
||||
self.img_context_token_id = image_token_id.flatten().unique().item()
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
if pixel_values_flat is not None:
|
||||
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
f"Got type: {type(pixel_values_flat)}")
|
||||
|
||||
assert isinstance(image_num_patches, (torch.Tensor, list))
|
||||
|
||||
patches_per_image = []
|
||||
for request_pixel_values in pixel_values:
|
||||
for image_pixel_values in request_pixel_values:
|
||||
patches_per_image.append(image_pixel_values.shape[0])
|
||||
# We need to flatten (B, N, P) to (B*N*P),
|
||||
# so we call flatten_bn twice.
|
||||
return InternVLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(
|
||||
flatten_bn(flatten_bn(pixel_values), concat=True)),
|
||||
patches_per_image=patches_per_image)
|
||||
flatten_bn(pixel_values_flat, concat=True)),
|
||||
patches_per_image=flatten_bn(image_num_patches,
|
||||
concat=True).tolist())
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: InternVLImageInputs,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
@ -689,7 +872,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
image_embeds = image_embeds.split(image_feature_sizes)
|
||||
return image_embeds
|
||||
|
||||
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
|
||||
if self.is_mono:
|
||||
self.visual_token_mask = (
|
||||
input_ids == self.img_context_token_id).reshape(-1, 1)
|
||||
|
@ -125,7 +125,11 @@ class BaseLlavaProcessingInfo(BaseProcessingInfo):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {"image": self.get_max_image_tokens()}
|
||||
|
||||
def _apply_feature_select_strategy(
|
||||
|
@ -62,7 +62,11 @@ class LlavaNextVideoProcessingInfo(BaseProcessingInfo):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"video": 1}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
max_video_tokens = self.get_num_video_tokens(
|
||||
|
@ -103,7 +103,11 @@ class LlavaOnevisionProcessingInfo(LlavaNextProcessingInfo):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None, "video": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {
|
||||
"image": self.get_max_image_tokens(),
|
||||
"video": self.get_max_video_tokens(seq_len),
|
||||
|
@ -23,7 +23,6 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
|
||||
from functools import partial
|
||||
from itertools import accumulate
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
@ -138,11 +137,15 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None, "video": None, "audio": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {
|
||||
"image": self.get_max_image_tokens(),
|
||||
"audio": self.get_max_audio_tokens(),
|
||||
"video": self.get_max_video_tokens(seq_len)
|
||||
"video": self.get_max_video_tokens(seq_len),
|
||||
}
|
||||
|
||||
def get_default_audio_pool_step(self) -> int:
|
||||
@ -369,23 +372,18 @@ class MiniCPMOMultiModalProcessor(
|
||||
hf_inputs,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0))
|
||||
|
||||
def get_slices(num_slices: List[int]) -> List[int]:
|
||||
slice_indices = [0] + list(accumulate(num_slices))
|
||||
slices = [(slice_indices[i], slice_indices[i + 1])
|
||||
for i in range(len(num_slices))]
|
||||
return [slice(*slice_item) for slice_item in slices]
|
||||
|
||||
audio_slices = get_slices(
|
||||
hf_inputs.get("audio_num_slices", torch.empty(0)))
|
||||
return dict(
|
||||
**super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs),
|
||||
audio_features=MultiModalFieldConfig.flat("audio", audio_slices),
|
||||
audio_feature_lens=MultiModalFieldConfig.flat(
|
||||
"audio", audio_slices),
|
||||
audio_features=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", audio_num_slices),
|
||||
audio_feature_lens=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", audio_num_slices),
|
||||
audio_num_slices=MultiModalFieldConfig.batched("audio"),
|
||||
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
|
||||
audio_embeds=MultiModalFieldConfig.flat("audio", audio_slices))
|
||||
audio_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", audio_num_slices))
|
||||
|
||||
|
||||
class MultiModalProjector(nn.Module):
|
||||
|
@ -26,7 +26,6 @@ import math
|
||||
import re
|
||||
from collections import Counter
|
||||
from functools import cached_property, partial
|
||||
from itertools import accumulate
|
||||
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
|
||||
Optional, Set, Tuple, TypedDict, Union)
|
||||
|
||||
@ -365,7 +364,11 @@ class MiniCPMVProcessingInfo(BaseProcessingInfo):
|
||||
else:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
mm_max_tokens = {"image": self.get_max_image_tokens()}
|
||||
if self.get_model_version() == (2, 6):
|
||||
mm_max_tokens["video"] = self.get_max_video_tokens(seq_len)
|
||||
@ -761,30 +764,25 @@ class MiniCPMVMultiModalProcessor(
|
||||
hf_inputs,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
image_num_slices = hf_inputs.get("image_num_slices", torch.empty(0))
|
||||
video_num_slices = hf_inputs.get("video_num_slices", torch.empty(0))
|
||||
|
||||
def get_slices(num_slices: List[int]) -> List[int]:
|
||||
slice_indices = [0] + list(accumulate(num_slices))
|
||||
slices = [(slice_indices[i], slice_indices[i + 1])
|
||||
for i in range(len(num_slices))]
|
||||
return [slice(*slice_item) for slice_item in slices]
|
||||
|
||||
image_slices = get_slices(
|
||||
hf_inputs.get("image_num_slices", torch.empty(0)))
|
||||
video_slices = get_slices(
|
||||
hf_inputs.get("video_num_slices", torch.empty(0)))
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat("image", image_slices),
|
||||
image_sizes=MultiModalFieldConfig.batched("image"),
|
||||
tgt_sizes=MultiModalFieldConfig.flat("image", image_slices),
|
||||
image_num_slices=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.flat("image", image_slices),
|
||||
video_pixel_values=MultiModalFieldConfig.flat(
|
||||
"video", video_slices),
|
||||
video_image_sizes=MultiModalFieldConfig.batched("video"),
|
||||
video_tgt_sizes=MultiModalFieldConfig.flat("video", video_slices),
|
||||
video_embeds=MultiModalFieldConfig.flat("video", video_slices),
|
||||
video_num_slices=MultiModalFieldConfig.batched("video"))
|
||||
return dict(pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_num_slices),
|
||||
image_sizes=MultiModalFieldConfig.batched("image"),
|
||||
tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_num_slices),
|
||||
image_num_slices=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_num_slices),
|
||||
video_pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_num_slices),
|
||||
video_image_sizes=MultiModalFieldConfig.batched("video"),
|
||||
video_tgt_sizes=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_num_slices),
|
||||
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_num_slices),
|
||||
video_num_slices=MultiModalFieldConfig.batched("video"))
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
@ -6,44 +6,190 @@
|
||||
# Copyright (c) 2024 NVIDIA
|
||||
# Licensed under Apache 2.0 License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from typing import Optional
|
||||
from typing import Mapping, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
from vllm.multimodal.profiling import ProcessorInputs
|
||||
|
||||
from .intern_vit import InternVisionModel
|
||||
from .internvl import (InternVLChatModel, InternVLInputPipeline,
|
||||
get_max_internvl_image_tokens)
|
||||
from .internvl import (BaseInternVLProcessingInfo, BaseInternVLProcessor,
|
||||
InternVLChatModel, InternVLDummyInputsBuilder,
|
||||
InternVLMultiModalProcessor)
|
||||
|
||||
IMG_START = '<|vision_start|>'
|
||||
IMG_END = '<|vision_end|>'
|
||||
IMG_CONTEXT = '<|vision_pad|>'
|
||||
IMG_PAD = "<|vision_pad|>"
|
||||
|
||||
|
||||
class NVLMInputPipeline(InternVLInputPipeline):
|
||||
class NVLMProcessor(BaseInternVLProcessor):
|
||||
|
||||
@property
|
||||
def image_token_id(self) -> int:
|
||||
return self.tokenizer.get_vocab()[IMG_PAD]
|
||||
|
||||
def get_image_repl_features(
|
||||
self,
|
||||
feature_size: int,
|
||||
num_patches: Optional[int],
|
||||
) -> str:
|
||||
if num_patches is None:
|
||||
raise NotImplementedError("Embedding inputs are not supported")
|
||||
|
||||
tile_pos_identifiers = [f"<tile_{i}>" for i in range(1, num_patches)]
|
||||
if self.use_thumbnail and num_patches != 1:
|
||||
tile_pos_identifiers += ["<tile_global_thumbnail>"]
|
||||
|
||||
def _create_image_prompt(self, feature_size: int, num_patches: int) -> str:
|
||||
tile_pos_identifiers = ([f"<tile_{i}>"
|
||||
for i in range(1, num_patches)] +
|
||||
["<tile_global_thumbnail>"])
|
||||
context_size = feature_size // num_patches
|
||||
features = "".join(identifier + IMG_PAD * context_size
|
||||
for identifier in tile_pos_identifiers)
|
||||
|
||||
return '<Image>' + ''.join(
|
||||
tile_pos_identifier + self.img_context_token * context_size
|
||||
for tile_pos_identifier in tile_pos_identifiers) + '</Image>'
|
||||
# We include the start and end as well because "<Image><tile" is
|
||||
# tokenized as ["<Image", "><", "tile"], resulting in assertion error
|
||||
# when trying to find "<tile" as a subsequence of "<Image><tile"
|
||||
return "<Image>" + features + "</Image>"
|
||||
|
||||
def get_image_repl_full(
|
||||
self,
|
||||
feature_size: int,
|
||||
num_patches: Optional[int],
|
||||
) -> str:
|
||||
return self.get_image_repl_features(feature_size, num_patches)
|
||||
|
||||
|
||||
input_pipeline = NVLMInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
|
||||
class NVLMProcessingInfo(BaseInternVLProcessingInfo):
|
||||
|
||||
def get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
) -> NVLMProcessor:
|
||||
return NVLMProcessor(
|
||||
self.get_hf_config(),
|
||||
self.get_tokenizer(),
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
hf_processor = self.get_hf_processor()
|
||||
tokenizer = hf_processor.tokenizer
|
||||
|
||||
max_num_patches = hf_processor.max_dynamic_patch
|
||||
# we need +1 here because max_dynamic_patch in config doesn't
|
||||
# include the thumbnail patch
|
||||
tile_pos_identifiers = [
|
||||
f"<tile_{i+1}>" for i in range(max_num_patches)
|
||||
]
|
||||
if hf_processor.use_thumbnail and max_num_patches != 1:
|
||||
tile_pos_identifiers += ["<tile_global_thumbnail>"]
|
||||
|
||||
# "<Image><tile" is tokenized as ["<Image", "><", "tile"]
|
||||
# so we include <tile_1> in the start_str
|
||||
start_str = "<Image>" + tile_pos_identifiers.pop(0)
|
||||
end_str = "</Image>"
|
||||
start_token_len = len(tokenizer.encode(start_str))
|
||||
end_token_len = len(tokenizer.encode(end_str))
|
||||
tile_token_len = sum(
|
||||
len(tokenizer.encode(identifier))
|
||||
for identifier in tile_pos_identifiers)
|
||||
non_image_tokens_num = start_token_len + end_token_len + tile_token_len
|
||||
return super().get_max_image_tokens() + non_image_tokens_num
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
|
||||
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
|
||||
class NVLMDummyInputsBuilder(InternVLDummyInputsBuilder[NVLMProcessingInfo]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
# The newline is necessary to separate ">" of the current item
|
||||
# and "<" of the next item
|
||||
prompt_text="<image>\n" * num_images,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class NVLMMultiModalProcessor(InternVLMultiModalProcessor[NVLMProcessingInfo]):
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
if "image_num_patches" in out_mm_kwargs:
|
||||
image_num_patches = out_mm_kwargs["image_num_patches"]
|
||||
assert isinstance(image_num_patches, torch.Tensor)
|
||||
image_num_patches = image_num_patches.tolist()
|
||||
elif "image_embeds" in out_mm_kwargs:
|
||||
# TODO: Use image size information in dictionary embedding inputs
|
||||
# to compute num_patches (similar to Qwen2-VL)
|
||||
image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
|
||||
else:
|
||||
image_num_patches = []
|
||||
|
||||
def get_replacement_nvlm(item_idx: int):
|
||||
images = mm_items.get_items(
|
||||
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
||||
|
||||
if isinstance(images, ImageEmbeddingItems):
|
||||
feature_size = images.get_feature_size(item_idx)
|
||||
else:
|
||||
image_size = images.get_image_size(item_idx)
|
||||
feature_size = self.info.get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
processor=hf_processor,
|
||||
)
|
||||
|
||||
num_patches = image_num_patches[item_idx]
|
||||
if num_patches is not None:
|
||||
assert isinstance(num_patches, int)
|
||||
|
||||
return PromptReplacementDetails(
|
||||
full=hf_processor.get_image_repl_full(feature_size,
|
||||
num_patches) + "\n",
|
||||
features=hf_processor.get_image_repl_features(
|
||||
feature_size, num_patches) + "\n",
|
||||
)
|
||||
|
||||
# See note in dummy data regarding why we have the extra newline
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target="<image>\n",
|
||||
replacement=get_replacement_nvlm,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(NVLMMultiModalProcessor,
|
||||
info=NVLMProcessingInfo,
|
||||
dummy_inputs=NVLMDummyInputsBuilder)
|
||||
class NVLM_D_Model(InternVLChatModel):
|
||||
|
||||
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
|
||||
|
@ -322,7 +322,11 @@ class Phi3VProcessingInfo(BaseProcessingInfo):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
max_image_tokens = self.get_num_image_tokens(
|
||||
|
@ -779,7 +779,11 @@ class QWenVLProcessingInfo(BaseProcessingInfo):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {"image": self.get_num_image_tokens()}
|
||||
|
||||
def get_num_image_tokens(self) -> int:
|
||||
@ -799,13 +803,13 @@ class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]):
|
||||
|
||||
vision_config = hf_config.visual
|
||||
|
||||
max_image_size = vision_config["image_size"]
|
||||
target_width = target_height = vision_config["image_size"]
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=max_image_size,
|
||||
height=max_image_size,
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
|
@ -110,7 +110,11 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"audio": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
hf_config = self.get_hf_config()
|
||||
max_source_positions = hf_config.audio_config.max_source_positions
|
||||
max_output_lengths = (max_source_positions - 2) // 2 + 1
|
||||
|
@ -758,7 +758,11 @@ class Qwen2VLProcessingInfo(BaseProcessingInfo):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None, "video": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {
|
||||
"image": self.get_max_image_tokens(),
|
||||
"video": self.get_max_video_tokens(seq_len),
|
||||
@ -989,26 +993,21 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
|
||||
image_slice_idxs = [0] + image_grid_thw.prod(-1).cumsum_(0).tolist()
|
||||
image_slices = [
|
||||
slice(image_slice_idxs[i], image_slice_idxs[i + 1])
|
||||
for i in range(len(image_grid_thw))
|
||||
]
|
||||
image_grid_sizes = image_grid_thw.prod(-1)
|
||||
|
||||
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
|
||||
video_slice_idxs = [0] + video_grid_thw.prod(-1).cumsum_(0).tolist()
|
||||
video_slices = [
|
||||
slice(video_slice_idxs[i], video_slice_idxs[i + 1])
|
||||
for i in range(len(video_grid_thw))
|
||||
]
|
||||
video_grid_sizes = video_grid_thw.prod(-1)
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat("image", image_slices),
|
||||
image_embeds=MultiModalFieldConfig.flat("image", image_slices),
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_grid_sizes),
|
||||
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_grid_sizes),
|
||||
image_grid_thw=MultiModalFieldConfig.batched("image"),
|
||||
pixel_values_videos=MultiModalFieldConfig.flat(
|
||||
"video", video_slices),
|
||||
video_embeds=MultiModalFieldConfig.flat("video", video_slices),
|
||||
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_grid_sizes),
|
||||
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"video", video_grid_sizes),
|
||||
video_grid_thw=MultiModalFieldConfig.batched("video"),
|
||||
)
|
||||
|
||||
|
@ -92,7 +92,11 @@ class UltravoxProcessingInfo(BaseProcessingInfo):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"audio": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
max_audio_tokens = math.ceil(feature_extractor.chunk_length *
|
||||
_AUDIO_TOKENS_PER_SECOND)
|
||||
|
@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
|
||||
Union, cast, final)
|
||||
|
||||
@ -258,6 +259,16 @@ class MultiModalFieldConfig:
|
||||
slices=slices,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def flat_from_sizes(modality: str, size_per_item: torch.Tensor):
|
||||
slice_idxs = [0, *accumulate(size_per_item)]
|
||||
slices = [
|
||||
slice(slice_idxs[i], slice_idxs[i + 1])
|
||||
for i in range(len(size_per_item))
|
||||
]
|
||||
|
||||
return MultiModalFieldConfig.flat(modality, slices)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
field_cls: type[BaseMultiModalField],
|
||||
|
@ -680,7 +680,11 @@ class BaseProcessingInfo:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Get the maximum possible number of tokens per data item
|
||||
for each modality.
|
||||
|
@ -151,7 +151,8 @@ class MultiModalProfiler(Generic[_I]):
|
||||
mm_counts = self.get_mm_limits()
|
||||
|
||||
info = self.processing_info
|
||||
mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(seq_len)
|
||||
mm_max_tokens_per_item = info.get_mm_max_tokens_per_item(
|
||||
seq_len, mm_counts)
|
||||
|
||||
if mm_counts.keys() != mm_max_tokens_per_item.keys():
|
||||
raise AssertionError(
|
||||
|
@ -264,7 +264,9 @@ class MultiModalRegistry:
|
||||
)
|
||||
processor = self.create_processor(model_config, tokenizer)
|
||||
seq_len = model_config.max_model_len
|
||||
return processor.info.get_mm_max_tokens_per_item(seq_len)
|
||||
mm_limits = self.get_mm_limits_per_prompt(model_config)
|
||||
return processor.info.get_mm_max_tokens_per_item(
|
||||
seq_len, mm_limits)
|
||||
|
||||
return {
|
||||
key: plugin.get_max_multimodal_tokens(model_config)
|
||||
|
Reference in New Issue
Block a user