[VLM] Merged multi-modal processor and V1 support for Qwen-VL (#12504)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-29 00:25:05 +08:00
committed by GitHub
parent 2079e43bee
commit 8f58a51358
4 changed files with 381 additions and 471 deletions

View File

@ -745,7 +745,7 @@ See [this page](#generative-models) for more information on how to use generativ
- `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc.
- ✅︎
- ✅︎
-
- ✅︎
* - `Qwen2AudioForConditionalGeneration`
- Qwen2-Audio
- T + A<sup>+</sup>

View File

@ -16,7 +16,6 @@ from ...registry import HF_EXAMPLE_MODELS
def _test_processing_correctness(
model_id: str,
modalities: dict[str, bool],
hit_rate: float,
num_batches: int,
simplify_rate: float,
@ -25,11 +24,6 @@ def _test_processing_correctness(
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
limit_mm_per_prompt = {
modality: 3 if supports_multi else 1
for modality, supports_multi in modalities.items()
}
model_config = ModelConfig(
model_id,
task="auto",
@ -40,18 +34,29 @@ def _test_processing_correctness(
dtype="float16",
revision=None,
hf_overrides=model_info.hf_overrides,
limit_mm_per_prompt=limit_mm_per_prompt,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
ctx = InputProcessingContext(
model_config,
tokenizer=cached_get_tokenizer(model_config.tokenizer),
tokenizer=cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_info.trust_remote_code,
),
)
# Ensure that it can fit all of the data
cache = ProcessingCache(capacity=1 << 30)
processing_info = factories.info(ctx)
supported_mm_limits = processing_info.get_supported_mm_limits()
limit_mm_per_prompt = {
modality: 3 if limit is None else limit
for modality, limit in supported_mm_limits.items()
}
model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt
baseline_processor = factories.build_processor(ctx, cache=None)
cached_processor = factories.build_processor(ctx, cache=cache)
dummy_inputs = baseline_processor.dummy_inputs
@ -82,8 +87,8 @@ def _test_processing_correctness(
mm_data = {
k:
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
for _ in range(rng.randint(limit_mm_per_prompt[k]))]
for k in modalities
for _ in range(rng.randint(limit))]
for k, limit in limit_mm_per_prompt.items()
}
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
@ -135,21 +140,22 @@ def _test_processing_correctness(
# yapf: disable
# True if the model supports multiple data items of the modality per request
@pytest.mark.parametrize(("model_id", "modalities"), [
("rhymes-ai/Aria", {"image": True}),
("Salesforce/blip2-opt-2.7b", {"image": False}),
("facebook/chameleon-7b", {"image": False}),
("deepseek-ai/deepseek-vl2-tiny", {"image": True}),
("adept/fuyu-8b", {"image": False}),
("llava-hf/llava-1.5-7b-hf", {"image": True}),
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),
("llava-hf/LLaVA-NeXT-Video-7B-hf", {"video": False}),
("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", {"image": True, "video": True}), # noqa: E501
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
("mistral-community/pixtral-12b", {"image": True}),
("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),
("Qwen/Qwen2-Audio-7B-Instruct", {"audio": True}),
("fixie-ai/ultravox-v0_3", {"audio": True}),
@pytest.mark.parametrize("model_id", [
"rhymes-ai/Aria",
"Salesforce/blip2-opt-2.7b",
"facebook/chameleon-7b",
"deepseek-ai/deepseek-vl2-tiny",
"adept/fuyu-8b",
"llava-hf/llava-1.5-7b-hf",
"llava-hf/llava-v1.6-mistral-7b-hf",
"llava-hf/LLaVA-NeXT-Video-7B-hf",
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
"TIGER-Lab/Mantis-8B-siglip-llama3",
"mistral-community/pixtral-12b",
"Qwen/Qwen-VL-Chat",
"Qwen/Qwen2-VL-2B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",
"fixie-ai/ultravox-v0_3",
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@ -157,14 +163,12 @@ def _test_processing_correctness(
# yapf: enable
def test_processing_correctness(
model_id: str,
modalities: dict[str, bool],
hit_rate: float,
num_batches: int,
simplify_rate: float,
):
_test_processing_correctness(
model_id,
modalities,
hit_rate=hit_rate,
num_batches=num_batches,
simplify_rate=simplify_rate,
@ -172,16 +176,13 @@ def test_processing_correctness(
# yapf: disable
@pytest.mark.parametrize(("model_id", "modalities"), [
("microsoft/Phi-3-vision-128k-instruct", {"image": True}),
])
@pytest.mark.parametrize("model_id", ["microsoft/Phi-3-vision-128k-instruct"])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_correctness_phi3v(
model_id: str,
modalities: dict[str, bool],
hit_rate: float,
num_batches: int,
simplify_rate: float,
@ -195,7 +196,6 @@ def test_processing_correctness_phi3v(
_test_processing_correctness(
model_id,
modalities,
hit_rate=hit_rate,
num_batches=num_batches,
simplify_rate=simplify_rate,

View File

@ -1,144 +0,0 @@
"""Tests for Qwen's multimodal preprocessing kwargs."""
from typing import Dict, List, Union
import pytest
import torch
from PIL.Image import Image
from vllm.inputs import InputContext, token_inputs
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.utils import cached_get_tokenizer
from ....conftest import IMAGE_ASSETS
from ...utils import build_model_context
### Multimodal preprocessing tests
SAMPLE_IMAGE = IMAGE_ASSETS[0].pil_image
# These values are specific to Qwen-VL/Chat; we can get these from the model
# config also, but they are hardcoded here to keep the parameterize/fixtures
# easy to read.
IMG_START_ID = 151857
IMG_END_ID = 151858
IMG_PAD_ID = 151859
TOKS_PER_IMG = 256
VIS_ENC_DIM = 4096
IMG_SIZE = 448
@pytest.fixture()
def input_mapper_for_qwen():
# Lazy import to avoid initializing CUDA during test collection
from vllm.model_executor.models.qwen import input_mapper_for_qwen
return input_mapper_for_qwen
@pytest.fixture()
def input_processor_for_qwen():
# Lazy import to avoid initializing CUDA during test collection
from vllm.model_executor.models.qwen import input_processor_for_qwen
return input_processor_for_qwen
@pytest.fixture()
def qwen_vl_context() -> InputContext:
"""Get an InputContext for Qwen-VL."""
return build_model_context(model_name="Qwen/Qwen-VL",
trust_remote_code=True)
# Happy path tests for single/multi-image scenarios for the multimodal
# input processor and mapper, respectively
@pytest.mark.parametrize("num_images", [1, 2])
def test_input_processor_valid_mm_data(input_processor_for_qwen,
qwen_vl_context: InputContext,
num_images: int):
"""Happy cases for image inputs to Qwen's multimodal input processor."""
prompt = "".join(
[f"Picture {num}: <img></img>\n" for num in range(1, num_images + 1)])
inputs = token_inputs(
prompt=prompt,
# When processing multimodal data for a multimodal model, the qwen
# input processor will overwrite the provided prompt_token_ids with
# the image prompts
prompt_token_ids=[],
multi_modal_data={"image": torch.rand(num_images, TOKS_PER_IMG, 4096)},
)
proc_inputs = input_processor_for_qwen(qwen_vl_context, inputs)
assert isinstance(proc_inputs, dict)
# Each image should have one start / stop and a fixed context of 256
proc_tokens = proc_inputs["prompt_token_ids"]
assert proc_tokens.count(IMG_START_ID) == num_images
assert proc_tokens.count(IMG_END_ID) == num_images
assert proc_tokens.count(IMG_PAD_ID) == num_images * TOKS_PER_IMG
@pytest.mark.parametrize(
"img_data,expected_shape",
[
# single / multi-image
(SAMPLE_IMAGE, (1, 3, IMG_SIZE, IMG_SIZE)),
(2 * [SAMPLE_IMAGE], (2, 3, IMG_SIZE, IMG_SIZE)),
# single / multi-image embeddings
(torch.rand(
(TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)),
(torch.rand(
(1, TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)),
(torch.rand(
(2, TOKS_PER_IMG, VIS_ENC_DIM)), (2, TOKS_PER_IMG, VIS_ENC_DIM)),
])
def test_input_mapper_valid_mm_data(input_mapper_for_qwen,
qwen_vl_context: InputContext,
img_data: Union[torch.Tensor, List[Image],
Image],
expected_shape: List[int]):
"""Happy cases for image inputs to Qwen's multimodal input mapper."""
mapped_img_data = input_mapper_for_qwen(qwen_vl_context, img_data)
# Ensure that we get the appropriately shaped pixel_values
# for images and image embeddings, respectively.
assert isinstance(mapped_img_data, MultiModalKwargs)
assert "pixel_values" in mapped_img_data
assert mapped_img_data["pixel_values"].shape == expected_shape
# Sad path tests for the multimodal input processor and mapper, respectively
@pytest.mark.parametrize("mm_data", [
{
"image": torch.rand(5)
},
{
"image": torch.rand((5, 5, 5, 5, 5))
},
])
def test_input_processor_invalid_mm_data(input_processor_for_qwen,
qwen_vl_context: InputContext,
mm_data: Dict[str, torch.Tensor]):
"""Test sad cases validated in Qwen's multimodal input processor."""
tokenizer = cached_get_tokenizer(qwen_vl_context.model_config.tokenizer,
trust_remote_code=True)
prompt = "Picture 1: <img></img>\n"
prompt_token_ids = tokenizer.encode(prompt)
inputs = token_inputs(prompt=prompt,
prompt_token_ids=prompt_token_ids,
multi_modal_data=mm_data)
# Should fail since we have too many or too few dimensions for embeddings
with pytest.raises(ValueError):
input_processor_for_qwen(qwen_vl_context, inputs)
@pytest.mark.parametrize(
"img_data",
[
# Wrong context length
torch.rand((1, TOKS_PER_IMG + 10, VIS_ENC_DIM)),
# Wrong visual encoder output size
torch.rand((1, TOKS_PER_IMG, VIS_ENC_DIM + 10)),
])
def test_input_mapper_invalid_mm_data(
input_mapper_for_qwen,
qwen_vl_context: InputContext,
img_data: Union[torch.Tensor, List[Image], Image],
):
"""Sad cases validated in Qwen VL's multimodal input mapper."""
with pytest.raises(ValueError):
input_mapper_for_qwen(qwen_vl_context, img_data)

View File

@ -4,26 +4,28 @@
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
"""Inference-only QWen model compatible with HuggingFace weights."""
import copy
import math
import re
from functools import partial
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
Optional, Set, Tuple, TypedDict, Union)
import unicodedata
from functools import lru_cache, partial
from typing import (AbstractSet, Any, Callable, Collection, Dict, Iterable,
List, Literal, Mapping, Optional, Set, Tuple, TypedDict,
Union)
import numpy as np
import torch
from PIL import Image
from torch import nn
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from transformers import PretrainedConfig
from transformers import (BatchFeature, PretrainedConfig, PreTrainedTokenizer,
TensorType)
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
@ -42,15 +44,20 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.utils import is_list_of
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
maybe_prefix, merge_multimodal_embeddings)
logger = init_logger(__name__)
@ -353,8 +360,10 @@ class VisionTransformer(nn.Module):
self.ln_post = norm_layer(output_dim)
self.proj = nn.Parameter(
(output_dim**-0.5) * torch.randn(output_dim, output_dim))
self.image_start_id = image_start_id
self.image_end_id = image_start_id + 1
self.image_pad_id = image_start_id + 2
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.to(
@ -383,21 +392,6 @@ class VisionTransformer(nn.Module):
return x
def get_image_positions(self,
input_ids: torch.Tensor) -> Optional[torch.Tensor]:
"""Given the input IDs, extracts start/stop points corresponding to
images.
args:
Returns:
Optional torch tensor corresponding to start/stop pairs of images.
"""
if torch.any(input_ids == self.image_start_id):
bos_pos = torch.where(input_ids == self.image_start_id)
eos_pos = torch.where(input_ids == self.image_end_id)
return torch.stack((bos_pos[0], eos_pos[0]), dim=1)
return None
class QWenMLP(nn.Module):
"""MLP for the language component of the Qwen model, which contains a
@ -579,9 +573,12 @@ class QWenModel(nn.Module):
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
self.visual = VisionTransformer(**config.visual,
quant_config=quant_config) if hasattr(
config, "visual") else None
if (vision_config := getattr(config, "visual", None)):
self.visual = VisionTransformer(**vision_config,
quant_config=quant_config)
else:
self.visual = None
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids)
@ -593,38 +590,13 @@ class QWenModel(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
pixel_values: Optional[QwenImageInputs],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
img_pos = None
# If pixel / visual embeddings are provided, this is a visual model
if pixel_values is not None and self.visual is not None:
if pixel_values["type"] != "image_embeds":
image_embeds = self.visual(pixel_values["data"])
else:
image_embeds = pixel_values["data"]
# features should be of shape (# images, 256, hidden_dim)
img_pos = self.visual.get_image_positions(input_ids)
if isinstance(
img_pos,
np.ndarray) and img_pos.shape[0] != image_embeds.shape[0]:
raise ValueError(
f"Number of placeholders: {img_pos.shape[0]} "
f"does not match number of images {image_embeds.shape[0]}."
)
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states = self.wte(input_ids)
# Merge the image embeddings into the hidden states if actually have
# visual features and the corresponding image tokens
if img_pos is not None:
for idx, (img_bos, img_eos) in enumerate(img_pos):
hidden_states[img_bos + 1:img_eos] = image_embeds[idx]
residual = None
else:
assert intermediate_tensors is not None
@ -648,159 +620,9 @@ class QWenModel(nn.Module):
return hidden_states
def get_image_text(image_num: int, padding: bool) -> str:
"""Retrieves a placeholder text that when tokenized, will be expanded with
image pads.
Args:
image_num: The number of the image that we want a text prompt for.
Images should be indexed starting at 1.
padding: Whether or not padding should be manually added.
Returns:
Text placeholder prompt for the image being considered.
"""
image_start = f"Picture {image_num}: {IMG_START}"
image_end = f"{IMG_END}\n"
if not padding:
return f"{image_start}{image_end}"
return f"{image_start}{MAX_QWEN_IMG_TOKENS * IMG_PAD}{image_end}"
def input_processor_for_qwen(ctx: InputContext,
inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
"""Processes the inputs, which may or may not be multimodal.
Multimodal inputs will only be processed if the model has a "visual"
component in its model config, otherwise they'll be ignored.
Args:
ctx: Context of the loaded model.
inputs: LLM inputs which may have a multi_modal_data attribute.
Returns:
If the model is language only or not multimodal inputs were provided,
returns inputs unmodified. Otherwise, processes the multimodal
images / image embeddings and adds the fixed-length image placeholders.
"""
multi_modal_data = inputs.get("multi_modal_data")
# Only process images if we have multimodal data and a visual config
hf_config = ctx.get_hf_config()
if (multi_modal_data is None or "image" not in multi_modal_data
or not hasattr(hf_config, "visual")):
return inputs
prompt = inputs.get("prompt")
prompt_token_ids = inputs["prompt_token_ids"]
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
image_data = multi_modal_data["image"]
if isinstance(image_data, torch.Tensor):
num_dims = len(image_data.shape)
if num_dims < 2 or num_dims > 3:
raise ValueError(
f"Expected img embeds to be have 3 dimensions, got {num_dims}")
num_images = 1 if num_dims == 2 else image_data.shape[0]
elif isinstance(image_data, Image.Image):
num_images = 1
elif is_list_of(image_data, Image.Image):
num_images = len(image_data)
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
if prompt is None:
prompt = tokenizer.decode(prompt_token_ids)
# Drops anything between <img>/</img> tags; encoding with the tokenizer
# will automatically add the image pads for the context.
new_prompt, num_matched_images = re.subn(
r"(Picture \d*: <img>).*?(<\/img>\n)",
r"\1\2",
prompt,
)
if num_matched_images != num_images:
logger.warning(
"Number of matched image placeholders %s doesn't match the number "
"of expected images %s; check your placeholder formatting.",
num_matched_images, num_images)
new_prompt_token_ids = tokenizer.encode(new_prompt)
return token_inputs(prompt=new_prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data)
def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalKwargs:
"""Maps the input data to its MultiModalKwargs (if any).
Args:
ctx: Context of the loaded model.
data: data potentially containing image/image embeddings to be mapped
to pixel_values in .forward() for a visual QWenLMHeadModel model.
Returns:
MultiModalKwargs containing the stacked normalized images tensor or
image embeddings.
"""
# Early exit if we have provided an image to a language only Qwen model
hf_config = ctx.get_hf_config()
if not hasattr(hf_config, "visual"):
logger.warning(
"Images were provided but this model has no visual config; "
"multimodal inputs will not be forwarded to the model.")
return MultiModalKwargs()
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
image_pair_tok = tokenizer.encode(IMG_START + IMG_END,
add_special_tokens=False,
return_tensors="pt").squeeze()
image_start_id = image_pair_tok[0]
image_end_id = image_pair_tok[-1]
if (image_start_id + 1) != image_end_id:
raise ValueError(
f"Found image end ID {image_end_id}, but expected {IMG_START} + 1")
if len(image_pair_tok) != (MAX_QWEN_IMG_TOKENS + 2):
raise ValueError(
f"Expected image context length of {MAX_QWEN_IMG_TOKENS}, "
f"but got {image_pair_tok - 2}")
hf_config = ctx.get_hf_config()
image_size = hf_config.visual["image_size"]
img_emb_size = hf_config.visual["output_dim"]
if isinstance(data, torch.Tensor):
# It's expected that our values have already been processed
# by the visual transformer; shape is expected to be:
# (# images, 256, hidden_size)
if len(data.shape) == 2:
# Assume only one image embed was provided; unsqueeze the extra dim
data = data.unsqueeze(0)
if len(data.shape) != 3 or data.shape[
1] != MAX_QWEN_IMG_TOKENS or data.shape[2] != img_emb_size:
raise ValueError(
"Expected image embeds to be a tensor of shape"
f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but "
f"received shape [{data.shape}]")
pixel_values = data
else:
transform = build_normalization_transform(image_size)
if not isinstance(data, (list, tuple)):
data = [data]
transformed_images = [transform(datum) for datum in data]
pixel_values = torch.stack(transformed_images, dim=0)
return MultiModalKwargs({"pixel_values": pixel_values})
def build_normalization_transform(image_size: int) -> transforms.Compose:
"""Builds a normalization transform which can be applied to one or
"""
Build a normalization transform which can be applied to one or
more input images from which we want to extract visual features.
Args:
@ -817,62 +639,251 @@ def build_normalization_transform(image_size: int) -> transforms.Compose:
])
def dummy_data_for_qwen(
ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
) -> DummyData:
"""Build dummy data for warming up Qwen models; this will only contain text
matching the defaults for VLLM unless the model has a visual config.
Args:
ctx: Context of the loaded model.
seq_len: Number of tokens in the text sequence.
mm_counts: multimodal data counts.
Returns:
Tuple containing sequential and multimodal data.
@lru_cache(maxsize=1)
def _get_tokenizer_without_image_pad(
tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer:
"""
hf_config = ctx.get_hf_config()
The logic of adding image pad tokens should only be applied in
:class:`QWenVLProcessor`, so they are patched out here.
# The presence of a visual config indicates this is a multimodal model.
# If we don't have it, the model is considered an LLM for warmup purposes.
if not hasattr(hf_config, "visual"):
seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
mm_data = None
return DummyData(seq_data, mm_data)
The definition of the wrapped tokenizer can be found here:
https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py
"""
new_tokenizer = copy.deepcopy(tokenizer)
# We have a visual component - use images to warm up
num_images = mm_counts["image"]
model_config = ctx.model_config
tokenizer = cached_get_tokenizer(
model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code)
class TokenizerWithoutImagePad(tokenizer.__class__): # type: ignore
# Build the image prompts with no imgpads; the tokenizer will add img pads
image_prompt = ''.join(
[get_image_text(idx, False) for idx in range(1, num_images + 1)])
toks = tokenizer.encode(image_prompt, add_special_tokens=False)
def tokenize(
self,
text: str,
allowed_special: Union[AbstractSet[str], str] = "all",
disallowed_special: Union[Collection[str], str] = (),
**kwargs,
) -> list[Union[bytes, str]]:
text = unicodedata.normalize("NFC", text)
# Make sure we actually get the fixed context size per tok padding
num_pads = toks.count(tokenizer.encode(IMG_PAD)[0])
if num_pads != (num_images * MAX_QWEN_IMG_TOKENS):
raise ValueError(
f"Tokenized dummy data should encode {MAX_QWEN_IMG_TOKENS} pads"
f" per image, but got {num_pads} pads for {num_images} image(s)"
" in total. Are you using a qwen tokenizer?")
return [
self.decoder[t] for t in self.tokenizer.encode(
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
]
# Ensure the number of tokens is at minimum the sequence length provided
if len(toks) < seq_len:
toks += [0] * (seq_len - len(toks))
def _decode(
self,
token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
errors: Optional[str] = None,
**kwargs,
) -> str:
if isinstance(token_ids, int):
token_ids = [token_ids]
seq_data = SequenceData.from_seqs(toks)
return self.tokenizer.decode(
token_ids,
errors=errors or self.errors,
)
# Build the input images; width/height doesn't actually matter here since
# the data will get resized and the # of tokens per image is constant
image = Image.new("RGB", (224, 224), color=0)
mm_data = {"image": image if num_images == 1 else [image] * num_images}
return DummyData(seq_data, mm_data)
TokenizerWithoutImagePad.__name__ = \
f"{tokenizer.__class__.__name__}WithoutImagePad"
new_tokenizer.__class__ = TokenizerWithoutImagePad
return new_tokenizer
class QWenVLProcessor:
"""
This model doesn't define its own HF processor,
so we implement our own one here.
We call the wrapped tokenizer to automatically insert image pad tokens:
https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py#L245
The image processor is defined here:
https://huggingface.co/Qwen/Qwen-VL/blob/main/visual.py#L354
"""
def __init__(
self,
config: PretrainedConfig,
tokenizer: PreTrainedTokenizer,
) -> None:
super().__init__()
self.config = config
self.tokenizer = tokenizer
if hasattr(self.config, "visual"):
self.image_transform = build_normalization_transform(
config.visual["image_size"])
else:
self.image_transform = None
special_tokens: dict[str,
int] = tokenizer.special_tokens # type: ignore
self.img_start_id = special_tokens[IMG_START]
self.img_end_id = special_tokens[IMG_END]
def __call__(
self,
text: Optional[Union[TextInput, list[TextInput]]] = None,
images: Optional[Union[ImageInput, list[ImageInput]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
) -> BatchFeature:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
text_inputs = self.tokenizer(text)
if len(images) == 0:
image_inputs = {}
else:
if self.image_transform is None:
raise ValueError("This model does not support image inputs")
pixel_values = [self.image_transform(image) for image in images]
image_inputs = {"pixel_values": torch.stack(pixel_values)}
return BatchFeature(
{
**text_inputs,
**image_inputs,
},
tensor_type=return_tensors,
)
class QWenVLProcessingInfo(BaseProcessingInfo):
def get_tokenizer(self) -> PreTrainedTokenizer:
tokenizer = self.ctx.tokenizer
assert isinstance(tokenizer, PreTrainedTokenizer)
return _get_tokenizer_without_image_pad(tokenizer)
def get_hf_processor(self) -> QWenVLProcessor:
tokenizer = self.ctx.tokenizer
assert isinstance(tokenizer, PreTrainedTokenizer)
return QWenVLProcessor(self.get_hf_config(), tokenizer)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
return {"image": self.get_num_image_tokens()}
def get_num_image_tokens(self) -> int:
return MAX_QWEN_IMG_TOKENS
class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
hf_config = self.info.get_hf_config()
if not hasattr(hf_config, "visual"):
return ProcessorInputs(prompt_text="", mm_data={})
vision_config = hf_config.visual
max_image_size = vision_config["image_size"]
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=max_image_size,
height=max_image_size,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="".join(f"Picture {i}: {IMG_START}{IMG_END}\n"
for i in range(1, num_images + 1)),
mm_data=mm_data,
)
class QWenVLMultiModalProcessor(BaseMultiModalProcessor[QWenVLProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
# Drops anything between <img>/</img> tags; encoding with the tokenizer
# will automatically add the image pads for the context.
prompt, num_matched_images = re.subn(
r"(Picture \d*: <img>).*?(<\/img>\n)",
r"\1\2",
prompt,
)
image_data = mm_data.get("images")
if image_data is not None:
assert isinstance(image_data, list)
num_images = len(image_data)
if num_matched_images != num_images:
logger.warning(
"Number of matched image placeholders %s doesn't match "
"the number of expected images %s; check your placeholder "
"formatting.", num_matched_images, num_images)
return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
tokenizer = self.info.get_tokenizer()
special_tokens: dict[str,
int] = tokenizer.special_tokens # type: ignore
img_start_id = special_tokens[IMG_START]
img_end_id = special_tokens[IMG_END]
img_pad_id = special_tokens[IMG_PAD]
num_image_tokens = self.info.get_num_image_tokens()
image_tokens = [img_pad_id] * num_image_tokens
return [
PromptReplacement(
modality="image",
target=[img_start_id, img_end_id],
replacement=PromptReplacementDetails(
full=[img_start_id] + image_tokens + [img_end_id],
features=image_tokens,
),
)
]
class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
@ -898,38 +909,77 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def _get_image_input_type(
self,
pixel_values: Optional[torch.Tensor]) -> Optional[QwenImageInputs]:
"""Determines if the provided pixel_values are normalized pixel values
or image embeddings.
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.visual["image_size"]
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])
Args:
pixel_values: Optional data to processed into visual embeddings.
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[QwenImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return QwenImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return QwenImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
)
Returns:
None of the QwenImageInputs type used to determine whether or not
the visual transformer needs to process the pixel_values.
"""
if pixel_values is not None and self.transformer.visual is not None:
pixel_values = flatten_bn(pixel_values)
if len(pixel_values.shape) == 3 and pixel_values.shape[
1] == MAX_QWEN_IMG_TOKENS and pixel_values.shape[
2] == self.config.visual["output_dim"]:
return QwenImageEmbeddingInputs(
type="image_embeds",
data=pixel_values,
)
else:
# If we have the wrong shape, assume we still need to process
return QwenImagePixelInputs(
type="pixel_values",
data=pixel_values,
)
return None
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def _process_image_input(self,
image_input: QwenImageInputs) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.transformer.visual is not None
return self.transformer.visual(image_input["data"])
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.transformer.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
assert self.transformer.visual is not None
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.transformer.visual.image_pad_id)
return inputs_embeds
def forward(
self,
@ -938,18 +988,23 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
pixel_values: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
pixel_values = None
else:
pixel_values = self._get_image_input_type(pixel_values)
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
pixel_values, inputs_embeds)
inputs_embeds)
return hidden_states
def compute_logits(
@ -1063,10 +1118,9 @@ class QWenVL(QWenBaseModel, SupportsMultiModal):
tower_model="transformer.visual.transformer")
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
@MULTIMODAL_REGISTRY.register_processor(QWenVLMultiModalProcessor,
info=QWenVLProcessingInfo,
dummy_inputs=QWenVLDummyInputsBuilder)
class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA):
"""
QWenLMHeadModel is not only applicable to LLM but also to VL, which is not
@ -1084,7 +1138,7 @@ class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA):
cls,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
) -> QWenBaseModel:
config = vllm_config.model_config.hf_config
# Initialize VL
if hasattr(config, "visual"):