mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Bugfix] Fix Nemotron VL image processing (#22739)
Co-authored-by: ducviet00-h2 <viet.d.hoang@h2corporation.jp>
This commit is contained in:
@ -23,15 +23,15 @@ def _get_expected_num_patches(
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
):
|
||||
from vllm.model_executor.models.internvl import (
|
||||
calculate_internvl_targets, get_internvl_target_ratios)
|
||||
from vllm.model_executor.models.nemotron_vl import (
|
||||
calculate_nemotron_vl_targets, get_nemotron_vl_target_ratios)
|
||||
|
||||
width, height = image.size
|
||||
|
||||
blocks, _, _ = calculate_internvl_targets(
|
||||
blocks, _, _ = calculate_nemotron_vl_targets(
|
||||
orig_width=width,
|
||||
orig_height=height,
|
||||
target_ratios=get_internvl_target_ratios(
|
||||
target_ratios=get_nemotron_vl_target_ratios(
|
||||
min_num,
|
||||
max_num,
|
||||
),
|
||||
|
@ -13,6 +13,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from transformers import AutoModel, PretrainedConfig
|
||||
from transformers.image_processing_utils_fast import BaseImageProcessorFast
|
||||
@ -27,6 +28,7 @@ from vllm.model_executor.models.internvl import (
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.processing import PromptUpdateDetails
|
||||
from vllm.sequence import IntermediateTensors
|
||||
@ -44,6 +46,146 @@ IMG_END = '</img>'
|
||||
IMG_CONTEXT = '<image>'
|
||||
|
||||
|
||||
def build_transform(input_size: int):
|
||||
return T.Compose([
|
||||
T.Lambda(lambda img: convert_image_mode(img, 'RGB')),
|
||||
T.Resize((input_size, input_size),
|
||||
interpolation=T.InterpolationMode.BICUBIC),
|
||||
T.ToTensor(),
|
||||
])
|
||||
|
||||
|
||||
# adapted from https://huggingface.co/nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1
|
||||
def find_closest_aspect_ratio(
|
||||
aspect_ratio: float,
|
||||
target_ratios: list[tuple[int, int]],
|
||||
*,
|
||||
width: int,
|
||||
height: int,
|
||||
image_size: int,
|
||||
) -> tuple[int, int]:
|
||||
best_factor = float('-inf')
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
|
||||
for rw, rh in target_ratios:
|
||||
target_aspect_ratio = rw / rh
|
||||
size_factor = min((rw * rh * image_size * image_size) / area, 0.6)
|
||||
ratio_closeness = min(target_aspect_ratio / aspect_ratio,
|
||||
aspect_ratio / target_aspect_ratio)
|
||||
factor = size_factor * ratio_closeness
|
||||
|
||||
if factor > best_factor:
|
||||
best_factor = factor
|
||||
best_ratio = (rw, rh)
|
||||
|
||||
return best_ratio
|
||||
|
||||
|
||||
def calculate_nemotron_vl_targets(
|
||||
*,
|
||||
orig_width: int,
|
||||
orig_height: int,
|
||||
target_ratios: list[tuple[int, int]],
|
||||
image_size: int,
|
||||
use_thumbnail: bool,
|
||||
) -> tuple[int, int, int]:
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = find_closest_aspect_ratio(
|
||||
aspect_ratio,
|
||||
target_ratios,
|
||||
width=orig_width,
|
||||
height=orig_height,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
|
||||
# add thumbnail image if num_blocks != 1
|
||||
if use_thumbnail and blocks != 1:
|
||||
blocks += 1
|
||||
|
||||
return blocks, target_width, target_height
|
||||
|
||||
|
||||
def dynamic_preprocess_nemotron_vl(
|
||||
image: Image.Image,
|
||||
*,
|
||||
target_ratios: list[tuple[int, int]],
|
||||
image_size: int,
|
||||
use_thumbnail: bool,
|
||||
) -> list[Image.Image]:
|
||||
orig_width, orig_height = image.size
|
||||
|
||||
# calculate the number of blocks without thumbnail
|
||||
blocks, target_width, target_height = calculate_nemotron_vl_targets(
|
||||
orig_width=orig_width,
|
||||
orig_height=orig_height,
|
||||
target_ratios=target_ratios,
|
||||
image_size=image_size,
|
||||
use_thumbnail=False,
|
||||
)
|
||||
|
||||
# resize the image
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
for i in range(blocks):
|
||||
box = ((i % (target_width // image_size)) * image_size,
|
||||
(i // (target_width // image_size)) * image_size,
|
||||
((i % (target_width // image_size)) + 1) * image_size,
|
||||
((i // (target_width // image_size)) + 1) * image_size)
|
||||
# split the image
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
|
||||
assert len(processed_images) == blocks
|
||||
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = image.resize((image_size, image_size))
|
||||
processed_images.append(thumbnail_img)
|
||||
|
||||
return processed_images
|
||||
|
||||
|
||||
def get_nemotron_vl_target_ratios(
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
) -> list[tuple[int, int]]:
|
||||
target_ratios = {(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1) if min_num <= i * j <= max_num}
|
||||
return sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
|
||||
def image_to_pixel_values_nemotron_vl(
|
||||
image: Image.Image,
|
||||
*,
|
||||
input_size: int,
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
use_thumbnail: bool,
|
||||
) -> torch.Tensor:
|
||||
target_ratios = get_nemotron_vl_target_ratios(min_num, max_num)
|
||||
|
||||
transform = build_transform(input_size=input_size)
|
||||
|
||||
images = dynamic_preprocess_nemotron_vl(
|
||||
image,
|
||||
target_ratios=target_ratios,
|
||||
image_size=input_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
)
|
||||
|
||||
pixel_values = torch.stack([transform(image) for image in images])
|
||||
return pixel_values
|
||||
|
||||
|
||||
class NemotronVLProcessor(InternVLProcessor):
|
||||
|
||||
def __init__(
|
||||
@ -87,6 +229,50 @@ class NemotronVLProcessor(InternVLProcessor):
|
||||
def image_token_id(self) -> int:
|
||||
return self.tokenizer.get_vocab()[IMG_CONTEXT]
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
target_ratios = self.resolve_target_ratios(
|
||||
use_thumbnail=False, # Applied in calculate_targets
|
||||
)
|
||||
|
||||
num_patches, _, _ = calculate_nemotron_vl_targets(
|
||||
orig_width=image_width,
|
||||
orig_height=image_height,
|
||||
image_size=self.image_size,
|
||||
target_ratios=target_ratios,
|
||||
use_thumbnail=self.use_thumbnail,
|
||||
)
|
||||
|
||||
return num_patches * self.num_image_token
|
||||
|
||||
def _images_to_pixel_values_lst(
|
||||
self,
|
||||
images: list[Image.Image],
|
||||
min_dynamic_patch: Optional[int] = None,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
) -> list[torch.Tensor]:
|
||||
min_num, max_num = self.resolve_min_max_num(
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
use_thumbnail=False, # Applied in image_to_pixel_values
|
||||
)
|
||||
|
||||
return [
|
||||
image_to_pixel_values_nemotron_vl(
|
||||
image,
|
||||
input_size=self.image_size,
|
||||
min_num=min_num,
|
||||
max_num=max_num,
|
||||
use_thumbnail=self.use_thumbnail,
|
||||
) for image in images
|
||||
]
|
||||
|
||||
def _preprocess_image(
|
||||
self,
|
||||
text: list[str],
|
||||
|
Reference in New Issue
Block a user