mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[VLM] Merged multi-modal processor and V1 support for Qwen-VL (#12504)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -745,7 +745,7 @@ See [this page](#generative-models) for more information on how to use generativ
|
|||||||
- `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc.
|
- `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc.
|
||||||
- ✅︎
|
- ✅︎
|
||||||
- ✅︎
|
- ✅︎
|
||||||
-
|
- ✅︎
|
||||||
* - `Qwen2AudioForConditionalGeneration`
|
* - `Qwen2AudioForConditionalGeneration`
|
||||||
- Qwen2-Audio
|
- Qwen2-Audio
|
||||||
- T + A<sup>+</sup>
|
- T + A<sup>+</sup>
|
||||||
|
|||||||
@ -16,7 +16,6 @@ from ...registry import HF_EXAMPLE_MODELS
|
|||||||
|
|
||||||
def _test_processing_correctness(
|
def _test_processing_correctness(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
modalities: dict[str, bool],
|
|
||||||
hit_rate: float,
|
hit_rate: float,
|
||||||
num_batches: int,
|
num_batches: int,
|
||||||
simplify_rate: float,
|
simplify_rate: float,
|
||||||
@ -25,11 +24,6 @@ def _test_processing_correctness(
|
|||||||
model_info.check_available_online(on_fail="skip")
|
model_info.check_available_online(on_fail="skip")
|
||||||
model_info.check_transformers_version(on_fail="skip")
|
model_info.check_transformers_version(on_fail="skip")
|
||||||
|
|
||||||
limit_mm_per_prompt = {
|
|
||||||
modality: 3 if supports_multi else 1
|
|
||||||
for modality, supports_multi in modalities.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
model_id,
|
model_id,
|
||||||
task="auto",
|
task="auto",
|
||||||
@ -40,18 +34,29 @@ def _test_processing_correctness(
|
|||||||
dtype="float16",
|
dtype="float16",
|
||||||
revision=None,
|
revision=None,
|
||||||
hf_overrides=model_info.hf_overrides,
|
hf_overrides=model_info.hf_overrides,
|
||||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||||
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||||
ctx = InputProcessingContext(
|
ctx = InputProcessingContext(
|
||||||
model_config,
|
model_config,
|
||||||
tokenizer=cached_get_tokenizer(model_config.tokenizer),
|
tokenizer=cached_get_tokenizer(
|
||||||
|
model_config.tokenizer,
|
||||||
|
trust_remote_code=model_info.trust_remote_code,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
# Ensure that it can fit all of the data
|
# Ensure that it can fit all of the data
|
||||||
cache = ProcessingCache(capacity=1 << 30)
|
cache = ProcessingCache(capacity=1 << 30)
|
||||||
|
|
||||||
|
processing_info = factories.info(ctx)
|
||||||
|
supported_mm_limits = processing_info.get_supported_mm_limits()
|
||||||
|
limit_mm_per_prompt = {
|
||||||
|
modality: 3 if limit is None else limit
|
||||||
|
for modality, limit in supported_mm_limits.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt
|
||||||
|
|
||||||
baseline_processor = factories.build_processor(ctx, cache=None)
|
baseline_processor = factories.build_processor(ctx, cache=None)
|
||||||
cached_processor = factories.build_processor(ctx, cache=cache)
|
cached_processor = factories.build_processor(ctx, cache=cache)
|
||||||
dummy_inputs = baseline_processor.dummy_inputs
|
dummy_inputs = baseline_processor.dummy_inputs
|
||||||
@ -82,8 +87,8 @@ def _test_processing_correctness(
|
|||||||
mm_data = {
|
mm_data = {
|
||||||
k:
|
k:
|
||||||
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
|
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
|
||||||
for _ in range(rng.randint(limit_mm_per_prompt[k]))]
|
for _ in range(rng.randint(limit))]
|
||||||
for k in modalities
|
for k, limit in limit_mm_per_prompt.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
|
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
|
||||||
@ -135,21 +140,22 @@ def _test_processing_correctness(
|
|||||||
|
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
# True if the model supports multiple data items of the modality per request
|
# True if the model supports multiple data items of the modality per request
|
||||||
@pytest.mark.parametrize(("model_id", "modalities"), [
|
@pytest.mark.parametrize("model_id", [
|
||||||
("rhymes-ai/Aria", {"image": True}),
|
"rhymes-ai/Aria",
|
||||||
("Salesforce/blip2-opt-2.7b", {"image": False}),
|
"Salesforce/blip2-opt-2.7b",
|
||||||
("facebook/chameleon-7b", {"image": False}),
|
"facebook/chameleon-7b",
|
||||||
("deepseek-ai/deepseek-vl2-tiny", {"image": True}),
|
"deepseek-ai/deepseek-vl2-tiny",
|
||||||
("adept/fuyu-8b", {"image": False}),
|
"adept/fuyu-8b",
|
||||||
("llava-hf/llava-1.5-7b-hf", {"image": True}),
|
"llava-hf/llava-1.5-7b-hf",
|
||||||
("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),
|
"llava-hf/llava-v1.6-mistral-7b-hf",
|
||||||
("llava-hf/LLaVA-NeXT-Video-7B-hf", {"video": False}),
|
"llava-hf/LLaVA-NeXT-Video-7B-hf",
|
||||||
("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", {"image": True, "video": True}), # noqa: E501
|
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
||||||
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
|
"TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||||
("mistral-community/pixtral-12b", {"image": True}),
|
"mistral-community/pixtral-12b",
|
||||||
("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),
|
"Qwen/Qwen-VL-Chat",
|
||||||
("Qwen/Qwen2-Audio-7B-Instruct", {"audio": True}),
|
"Qwen/Qwen2-VL-2B-Instruct",
|
||||||
("fixie-ai/ultravox-v0_3", {"audio": True}),
|
"Qwen/Qwen2-Audio-7B-Instruct",
|
||||||
|
"fixie-ai/ultravox-v0_3",
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||||
@pytest.mark.parametrize("num_batches", [32])
|
@pytest.mark.parametrize("num_batches", [32])
|
||||||
@ -157,14 +163,12 @@ def _test_processing_correctness(
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
def test_processing_correctness(
|
def test_processing_correctness(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
modalities: dict[str, bool],
|
|
||||||
hit_rate: float,
|
hit_rate: float,
|
||||||
num_batches: int,
|
num_batches: int,
|
||||||
simplify_rate: float,
|
simplify_rate: float,
|
||||||
):
|
):
|
||||||
_test_processing_correctness(
|
_test_processing_correctness(
|
||||||
model_id,
|
model_id,
|
||||||
modalities,
|
|
||||||
hit_rate=hit_rate,
|
hit_rate=hit_rate,
|
||||||
num_batches=num_batches,
|
num_batches=num_batches,
|
||||||
simplify_rate=simplify_rate,
|
simplify_rate=simplify_rate,
|
||||||
@ -172,16 +176,13 @@ def test_processing_correctness(
|
|||||||
|
|
||||||
|
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@pytest.mark.parametrize(("model_id", "modalities"), [
|
@pytest.mark.parametrize("model_id", ["microsoft/Phi-3-vision-128k-instruct"])
|
||||||
("microsoft/Phi-3-vision-128k-instruct", {"image": True}),
|
|
||||||
])
|
|
||||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||||
@pytest.mark.parametrize("num_batches", [32])
|
@pytest.mark.parametrize("num_batches", [32])
|
||||||
@pytest.mark.parametrize("simplify_rate", [1.0])
|
@pytest.mark.parametrize("simplify_rate", [1.0])
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
def test_processing_correctness_phi3v(
|
def test_processing_correctness_phi3v(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
modalities: dict[str, bool],
|
|
||||||
hit_rate: float,
|
hit_rate: float,
|
||||||
num_batches: int,
|
num_batches: int,
|
||||||
simplify_rate: float,
|
simplify_rate: float,
|
||||||
@ -195,7 +196,6 @@ def test_processing_correctness_phi3v(
|
|||||||
|
|
||||||
_test_processing_correctness(
|
_test_processing_correctness(
|
||||||
model_id,
|
model_id,
|
||||||
modalities,
|
|
||||||
hit_rate=hit_rate,
|
hit_rate=hit_rate,
|
||||||
num_batches=num_batches,
|
num_batches=num_batches,
|
||||||
simplify_rate=simplify_rate,
|
simplify_rate=simplify_rate,
|
||||||
|
|||||||
@ -1,144 +0,0 @@
|
|||||||
"""Tests for Qwen's multimodal preprocessing kwargs."""
|
|
||||||
from typing import Dict, List, Union
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from PIL.Image import Image
|
|
||||||
|
|
||||||
from vllm.inputs import InputContext, token_inputs
|
|
||||||
from vllm.multimodal import MultiModalKwargs
|
|
||||||
from vllm.multimodal.utils import cached_get_tokenizer
|
|
||||||
|
|
||||||
from ....conftest import IMAGE_ASSETS
|
|
||||||
from ...utils import build_model_context
|
|
||||||
|
|
||||||
### Multimodal preprocessing tests
|
|
||||||
SAMPLE_IMAGE = IMAGE_ASSETS[0].pil_image
|
|
||||||
# These values are specific to Qwen-VL/Chat; we can get these from the model
|
|
||||||
# config also, but they are hardcoded here to keep the parameterize/fixtures
|
|
||||||
# easy to read.
|
|
||||||
IMG_START_ID = 151857
|
|
||||||
IMG_END_ID = 151858
|
|
||||||
IMG_PAD_ID = 151859
|
|
||||||
TOKS_PER_IMG = 256
|
|
||||||
VIS_ENC_DIM = 4096
|
|
||||||
IMG_SIZE = 448
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def input_mapper_for_qwen():
|
|
||||||
# Lazy import to avoid initializing CUDA during test collection
|
|
||||||
from vllm.model_executor.models.qwen import input_mapper_for_qwen
|
|
||||||
return input_mapper_for_qwen
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def input_processor_for_qwen():
|
|
||||||
# Lazy import to avoid initializing CUDA during test collection
|
|
||||||
from vllm.model_executor.models.qwen import input_processor_for_qwen
|
|
||||||
return input_processor_for_qwen
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def qwen_vl_context() -> InputContext:
|
|
||||||
"""Get an InputContext for Qwen-VL."""
|
|
||||||
return build_model_context(model_name="Qwen/Qwen-VL",
|
|
||||||
trust_remote_code=True)
|
|
||||||
|
|
||||||
|
|
||||||
# Happy path tests for single/multi-image scenarios for the multimodal
|
|
||||||
# input processor and mapper, respectively
|
|
||||||
@pytest.mark.parametrize("num_images", [1, 2])
|
|
||||||
def test_input_processor_valid_mm_data(input_processor_for_qwen,
|
|
||||||
qwen_vl_context: InputContext,
|
|
||||||
num_images: int):
|
|
||||||
"""Happy cases for image inputs to Qwen's multimodal input processor."""
|
|
||||||
prompt = "".join(
|
|
||||||
[f"Picture {num}: <img></img>\n" for num in range(1, num_images + 1)])
|
|
||||||
inputs = token_inputs(
|
|
||||||
prompt=prompt,
|
|
||||||
# When processing multimodal data for a multimodal model, the qwen
|
|
||||||
# input processor will overwrite the provided prompt_token_ids with
|
|
||||||
# the image prompts
|
|
||||||
prompt_token_ids=[],
|
|
||||||
multi_modal_data={"image": torch.rand(num_images, TOKS_PER_IMG, 4096)},
|
|
||||||
)
|
|
||||||
proc_inputs = input_processor_for_qwen(qwen_vl_context, inputs)
|
|
||||||
assert isinstance(proc_inputs, dict)
|
|
||||||
|
|
||||||
# Each image should have one start / stop and a fixed context of 256
|
|
||||||
proc_tokens = proc_inputs["prompt_token_ids"]
|
|
||||||
assert proc_tokens.count(IMG_START_ID) == num_images
|
|
||||||
assert proc_tokens.count(IMG_END_ID) == num_images
|
|
||||||
assert proc_tokens.count(IMG_PAD_ID) == num_images * TOKS_PER_IMG
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"img_data,expected_shape",
|
|
||||||
[
|
|
||||||
# single / multi-image
|
|
||||||
(SAMPLE_IMAGE, (1, 3, IMG_SIZE, IMG_SIZE)),
|
|
||||||
(2 * [SAMPLE_IMAGE], (2, 3, IMG_SIZE, IMG_SIZE)),
|
|
||||||
# single / multi-image embeddings
|
|
||||||
(torch.rand(
|
|
||||||
(TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)),
|
|
||||||
(torch.rand(
|
|
||||||
(1, TOKS_PER_IMG, VIS_ENC_DIM)), (1, TOKS_PER_IMG, VIS_ENC_DIM)),
|
|
||||||
(torch.rand(
|
|
||||||
(2, TOKS_PER_IMG, VIS_ENC_DIM)), (2, TOKS_PER_IMG, VIS_ENC_DIM)),
|
|
||||||
])
|
|
||||||
def test_input_mapper_valid_mm_data(input_mapper_for_qwen,
|
|
||||||
qwen_vl_context: InputContext,
|
|
||||||
img_data: Union[torch.Tensor, List[Image],
|
|
||||||
Image],
|
|
||||||
expected_shape: List[int]):
|
|
||||||
"""Happy cases for image inputs to Qwen's multimodal input mapper."""
|
|
||||||
mapped_img_data = input_mapper_for_qwen(qwen_vl_context, img_data)
|
|
||||||
# Ensure that we get the appropriately shaped pixel_values
|
|
||||||
# for images and image embeddings, respectively.
|
|
||||||
assert isinstance(mapped_img_data, MultiModalKwargs)
|
|
||||||
assert "pixel_values" in mapped_img_data
|
|
||||||
assert mapped_img_data["pixel_values"].shape == expected_shape
|
|
||||||
|
|
||||||
|
|
||||||
# Sad path tests for the multimodal input processor and mapper, respectively
|
|
||||||
@pytest.mark.parametrize("mm_data", [
|
|
||||||
{
|
|
||||||
"image": torch.rand(5)
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"image": torch.rand((5, 5, 5, 5, 5))
|
|
||||||
},
|
|
||||||
])
|
|
||||||
def test_input_processor_invalid_mm_data(input_processor_for_qwen,
|
|
||||||
qwen_vl_context: InputContext,
|
|
||||||
mm_data: Dict[str, torch.Tensor]):
|
|
||||||
"""Test sad cases validated in Qwen's multimodal input processor."""
|
|
||||||
tokenizer = cached_get_tokenizer(qwen_vl_context.model_config.tokenizer,
|
|
||||||
trust_remote_code=True)
|
|
||||||
prompt = "Picture 1: <img></img>\n"
|
|
||||||
prompt_token_ids = tokenizer.encode(prompt)
|
|
||||||
inputs = token_inputs(prompt=prompt,
|
|
||||||
prompt_token_ids=prompt_token_ids,
|
|
||||||
multi_modal_data=mm_data)
|
|
||||||
# Should fail since we have too many or too few dimensions for embeddings
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
input_processor_for_qwen(qwen_vl_context, inputs)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"img_data",
|
|
||||||
[
|
|
||||||
# Wrong context length
|
|
||||||
torch.rand((1, TOKS_PER_IMG + 10, VIS_ENC_DIM)),
|
|
||||||
# Wrong visual encoder output size
|
|
||||||
torch.rand((1, TOKS_PER_IMG, VIS_ENC_DIM + 10)),
|
|
||||||
])
|
|
||||||
def test_input_mapper_invalid_mm_data(
|
|
||||||
input_mapper_for_qwen,
|
|
||||||
qwen_vl_context: InputContext,
|
|
||||||
img_data: Union[torch.Tensor, List[Image], Image],
|
|
||||||
):
|
|
||||||
"""Sad cases validated in Qwen VL's multimodal input mapper."""
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
input_mapper_for_qwen(qwen_vl_context, img_data)
|
|
||||||
@ -4,26 +4,28 @@
|
|||||||
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
|
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE
|
||||||
"""Inference-only QWen model compatible with HuggingFace weights."""
|
"""Inference-only QWen model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
|
import copy
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
from functools import partial
|
import unicodedata
|
||||||
from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
|
from functools import lru_cache, partial
|
||||||
Optional, Set, Tuple, TypedDict, Union)
|
from typing import (AbstractSet, Any, Callable, Collection, Dict, Iterable,
|
||||||
|
List, Literal, Mapping, Optional, Set, Tuple, TypedDict,
|
||||||
|
Union)
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from torchvision.transforms import InterpolationMode
|
from torchvision.transforms import InterpolationMode
|
||||||
from transformers import PretrainedConfig
|
from transformers import (BatchFeature, PretrainedConfig, PreTrainedTokenizer,
|
||||||
|
TensorType)
|
||||||
|
from transformers.image_utils import ImageInput
|
||||||
|
from transformers.tokenization_utils_base import TextInput
|
||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
|
||||||
InputContext, token_inputs)
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -42,15 +44,20 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.utils import cached_get_tokenizer
|
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||||
from vllm.sequence import IntermediateTensors, SequenceData
|
NestedTensors)
|
||||||
from vllm.utils import is_list_of
|
from vllm.multimodal.parse import MultiModalDataItems
|
||||||
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
|
BaseProcessingInfo, PromptReplacement,
|
||||||
|
PromptReplacementDetails)
|
||||||
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||||
from .utils import (flatten_bn, is_pp_missing_parameter,
|
from .utils import (flatten_bn, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix, merge_multimodal_embeddings)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -353,8 +360,10 @@ class VisionTransformer(nn.Module):
|
|||||||
self.ln_post = norm_layer(output_dim)
|
self.ln_post = norm_layer(output_dim)
|
||||||
self.proj = nn.Parameter(
|
self.proj = nn.Parameter(
|
||||||
(output_dim**-0.5) * torch.randn(output_dim, output_dim))
|
(output_dim**-0.5) * torch.randn(output_dim, output_dim))
|
||||||
|
|
||||||
self.image_start_id = image_start_id
|
self.image_start_id = image_start_id
|
||||||
self.image_end_id = image_start_id + 1
|
self.image_end_id = image_start_id + 1
|
||||||
|
self.image_pad_id = image_start_id + 2
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x = x.to(
|
x = x.to(
|
||||||
@ -383,21 +392,6 @@ class VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def get_image_positions(self,
|
|
||||||
input_ids: torch.Tensor) -> Optional[torch.Tensor]:
|
|
||||||
"""Given the input IDs, extracts start/stop points corresponding to
|
|
||||||
images.
|
|
||||||
|
|
||||||
args:
|
|
||||||
Returns:
|
|
||||||
Optional torch tensor corresponding to start/stop pairs of images.
|
|
||||||
"""
|
|
||||||
if torch.any(input_ids == self.image_start_id):
|
|
||||||
bos_pos = torch.where(input_ids == self.image_start_id)
|
|
||||||
eos_pos = torch.where(input_ids == self.image_end_id)
|
|
||||||
return torch.stack((bos_pos[0], eos_pos[0]), dim=1)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class QWenMLP(nn.Module):
|
class QWenMLP(nn.Module):
|
||||||
"""MLP for the language component of the Qwen model, which contains a
|
"""MLP for the language component of the Qwen model, which contains a
|
||||||
@ -579,9 +573,12 @@ class QWenModel(nn.Module):
|
|||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
make_empty_intermediate_tensors_factory(
|
make_empty_intermediate_tensors_factory(
|
||||||
["hidden_states", "residual"], config.hidden_size))
|
["hidden_states", "residual"], config.hidden_size))
|
||||||
self.visual = VisionTransformer(**config.visual,
|
|
||||||
quant_config=quant_config) if hasattr(
|
if (vision_config := getattr(config, "visual", None)):
|
||||||
config, "visual") else None
|
self.visual = VisionTransformer(**vision_config,
|
||||||
|
quant_config=quant_config)
|
||||||
|
else:
|
||||||
|
self.visual = None
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
return self.wte(input_ids)
|
return self.wte(input_ids)
|
||||||
@ -593,38 +590,13 @@ class QWenModel(nn.Module):
|
|||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
pixel_values: Optional[QwenImageInputs],
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
img_pos = None
|
|
||||||
# If pixel / visual embeddings are provided, this is a visual model
|
|
||||||
if pixel_values is not None and self.visual is not None:
|
|
||||||
if pixel_values["type"] != "image_embeds":
|
|
||||||
image_embeds = self.visual(pixel_values["data"])
|
|
||||||
else:
|
|
||||||
image_embeds = pixel_values["data"]
|
|
||||||
|
|
||||||
# features should be of shape (# images, 256, hidden_dim)
|
|
||||||
img_pos = self.visual.get_image_positions(input_ids)
|
|
||||||
if isinstance(
|
|
||||||
img_pos,
|
|
||||||
np.ndarray) and img_pos.shape[0] != image_embeds.shape[0]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Number of placeholders: {img_pos.shape[0]} "
|
|
||||||
f"does not match number of images {image_embeds.shape[0]}."
|
|
||||||
)
|
|
||||||
|
|
||||||
if get_pp_group().is_first_rank:
|
if get_pp_group().is_first_rank:
|
||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
else:
|
else:
|
||||||
hidden_states = self.get_input_embeddings(input_ids)
|
hidden_states = self.get_input_embeddings(input_ids)
|
||||||
hidden_states = self.wte(input_ids)
|
|
||||||
# Merge the image embeddings into the hidden states if actually have
|
|
||||||
# visual features and the corresponding image tokens
|
|
||||||
if img_pos is not None:
|
|
||||||
for idx, (img_bos, img_eos) in enumerate(img_pos):
|
|
||||||
hidden_states[img_bos + 1:img_eos] = image_embeds[idx]
|
|
||||||
residual = None
|
residual = None
|
||||||
else:
|
else:
|
||||||
assert intermediate_tensors is not None
|
assert intermediate_tensors is not None
|
||||||
@ -648,159 +620,9 @@ class QWenModel(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def get_image_text(image_num: int, padding: bool) -> str:
|
|
||||||
"""Retrieves a placeholder text that when tokenized, will be expanded with
|
|
||||||
image pads.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_num: The number of the image that we want a text prompt for.
|
|
||||||
Images should be indexed starting at 1.
|
|
||||||
padding: Whether or not padding should be manually added.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Text placeholder prompt for the image being considered.
|
|
||||||
"""
|
|
||||||
image_start = f"Picture {image_num}: {IMG_START}"
|
|
||||||
image_end = f"{IMG_END}\n"
|
|
||||||
if not padding:
|
|
||||||
return f"{image_start}{image_end}"
|
|
||||||
return f"{image_start}{MAX_QWEN_IMG_TOKENS * IMG_PAD}{image_end}"
|
|
||||||
|
|
||||||
|
|
||||||
def input_processor_for_qwen(ctx: InputContext,
|
|
||||||
inputs: DecoderOnlyInputs) -> DecoderOnlyInputs:
|
|
||||||
"""Processes the inputs, which may or may not be multimodal.
|
|
||||||
Multimodal inputs will only be processed if the model has a "visual"
|
|
||||||
component in its model config, otherwise they'll be ignored.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ctx: Context of the loaded model.
|
|
||||||
inputs: LLM inputs which may have a multi_modal_data attribute.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
If the model is language only or not multimodal inputs were provided,
|
|
||||||
returns inputs unmodified. Otherwise, processes the multimodal
|
|
||||||
images / image embeddings and adds the fixed-length image placeholders.
|
|
||||||
"""
|
|
||||||
multi_modal_data = inputs.get("multi_modal_data")
|
|
||||||
|
|
||||||
# Only process images if we have multimodal data and a visual config
|
|
||||||
hf_config = ctx.get_hf_config()
|
|
||||||
if (multi_modal_data is None or "image" not in multi_modal_data
|
|
||||||
or not hasattr(hf_config, "visual")):
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
prompt = inputs.get("prompt")
|
|
||||||
prompt_token_ids = inputs["prompt_token_ids"]
|
|
||||||
model_config = ctx.model_config
|
|
||||||
tokenizer = cached_get_tokenizer(
|
|
||||||
model_config.tokenizer,
|
|
||||||
trust_remote_code=model_config.trust_remote_code)
|
|
||||||
image_data = multi_modal_data["image"]
|
|
||||||
if isinstance(image_data, torch.Tensor):
|
|
||||||
num_dims = len(image_data.shape)
|
|
||||||
if num_dims < 2 or num_dims > 3:
|
|
||||||
raise ValueError(
|
|
||||||
f"Expected img embeds to be have 3 dimensions, got {num_dims}")
|
|
||||||
num_images = 1 if num_dims == 2 else image_data.shape[0]
|
|
||||||
elif isinstance(image_data, Image.Image):
|
|
||||||
num_images = 1
|
|
||||||
elif is_list_of(image_data, Image.Image):
|
|
||||||
num_images = len(image_data)
|
|
||||||
else:
|
|
||||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
|
||||||
|
|
||||||
if prompt is None:
|
|
||||||
prompt = tokenizer.decode(prompt_token_ids)
|
|
||||||
|
|
||||||
# Drops anything between <img>/</img> tags; encoding with the tokenizer
|
|
||||||
# will automatically add the image pads for the context.
|
|
||||||
new_prompt, num_matched_images = re.subn(
|
|
||||||
r"(Picture \d*: <img>).*?(<\/img>\n)",
|
|
||||||
r"\1\2",
|
|
||||||
prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
if num_matched_images != num_images:
|
|
||||||
logger.warning(
|
|
||||||
"Number of matched image placeholders %s doesn't match the number "
|
|
||||||
"of expected images %s; check your placeholder formatting.",
|
|
||||||
num_matched_images, num_images)
|
|
||||||
|
|
||||||
new_prompt_token_ids = tokenizer.encode(new_prompt)
|
|
||||||
|
|
||||||
return token_inputs(prompt=new_prompt,
|
|
||||||
prompt_token_ids=new_prompt_token_ids,
|
|
||||||
multi_modal_data=multi_modal_data)
|
|
||||||
|
|
||||||
|
|
||||||
def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalKwargs:
|
|
||||||
"""Maps the input data to its MultiModalKwargs (if any).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ctx: Context of the loaded model.
|
|
||||||
data: data potentially containing image/image embeddings to be mapped
|
|
||||||
to pixel_values in .forward() for a visual QWenLMHeadModel model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
MultiModalKwargs containing the stacked normalized images tensor or
|
|
||||||
image embeddings.
|
|
||||||
"""
|
|
||||||
# Early exit if we have provided an image to a language only Qwen model
|
|
||||||
hf_config = ctx.get_hf_config()
|
|
||||||
if not hasattr(hf_config, "visual"):
|
|
||||||
logger.warning(
|
|
||||||
"Images were provided but this model has no visual config; "
|
|
||||||
"multimodal inputs will not be forwarded to the model.")
|
|
||||||
return MultiModalKwargs()
|
|
||||||
|
|
||||||
model_config = ctx.model_config
|
|
||||||
tokenizer = cached_get_tokenizer(
|
|
||||||
model_config.tokenizer,
|
|
||||||
trust_remote_code=model_config.trust_remote_code)
|
|
||||||
|
|
||||||
image_pair_tok = tokenizer.encode(IMG_START + IMG_END,
|
|
||||||
add_special_tokens=False,
|
|
||||||
return_tensors="pt").squeeze()
|
|
||||||
image_start_id = image_pair_tok[0]
|
|
||||||
image_end_id = image_pair_tok[-1]
|
|
||||||
if (image_start_id + 1) != image_end_id:
|
|
||||||
raise ValueError(
|
|
||||||
f"Found image end ID {image_end_id}, but expected {IMG_START} + 1")
|
|
||||||
if len(image_pair_tok) != (MAX_QWEN_IMG_TOKENS + 2):
|
|
||||||
raise ValueError(
|
|
||||||
f"Expected image context length of {MAX_QWEN_IMG_TOKENS}, "
|
|
||||||
f"but got {image_pair_tok - 2}")
|
|
||||||
|
|
||||||
hf_config = ctx.get_hf_config()
|
|
||||||
image_size = hf_config.visual["image_size"]
|
|
||||||
img_emb_size = hf_config.visual["output_dim"]
|
|
||||||
|
|
||||||
if isinstance(data, torch.Tensor):
|
|
||||||
# It's expected that our values have already been processed
|
|
||||||
# by the visual transformer; shape is expected to be:
|
|
||||||
# (# images, 256, hidden_size)
|
|
||||||
if len(data.shape) == 2:
|
|
||||||
# Assume only one image embed was provided; unsqueeze the extra dim
|
|
||||||
data = data.unsqueeze(0)
|
|
||||||
if len(data.shape) != 3 or data.shape[
|
|
||||||
1] != MAX_QWEN_IMG_TOKENS or data.shape[2] != img_emb_size:
|
|
||||||
raise ValueError(
|
|
||||||
"Expected image embeds to be a tensor of shape"
|
|
||||||
f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but "
|
|
||||||
f"received shape [{data.shape}]")
|
|
||||||
pixel_values = data
|
|
||||||
else:
|
|
||||||
transform = build_normalization_transform(image_size)
|
|
||||||
if not isinstance(data, (list, tuple)):
|
|
||||||
data = [data]
|
|
||||||
transformed_images = [transform(datum) for datum in data]
|
|
||||||
pixel_values = torch.stack(transformed_images, dim=0)
|
|
||||||
return MultiModalKwargs({"pixel_values": pixel_values})
|
|
||||||
|
|
||||||
|
|
||||||
def build_normalization_transform(image_size: int) -> transforms.Compose:
|
def build_normalization_transform(image_size: int) -> transforms.Compose:
|
||||||
"""Builds a normalization transform which can be applied to one or
|
"""
|
||||||
|
Build a normalization transform which can be applied to one or
|
||||||
more input images from which we want to extract visual features.
|
more input images from which we want to extract visual features.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -817,62 +639,251 @@ def build_normalization_transform(image_size: int) -> transforms.Compose:
|
|||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
def dummy_data_for_qwen(
|
@lru_cache(maxsize=1)
|
||||||
ctx: InputContext,
|
def _get_tokenizer_without_image_pad(
|
||||||
seq_len: int,
|
tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer:
|
||||||
mm_counts: Mapping[str, int],
|
|
||||||
) -> DummyData:
|
|
||||||
"""Build dummy data for warming up Qwen models; this will only contain text
|
|
||||||
matching the defaults for VLLM unless the model has a visual config.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
ctx: Context of the loaded model.
|
|
||||||
seq_len: Number of tokens in the text sequence.
|
|
||||||
mm_counts: multimodal data counts.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple containing sequential and multimodal data.
|
|
||||||
"""
|
"""
|
||||||
hf_config = ctx.get_hf_config()
|
The logic of adding image pad tokens should only be applied in
|
||||||
|
:class:`QWenVLProcessor`, so they are patched out here.
|
||||||
|
|
||||||
# The presence of a visual config indicates this is a multimodal model.
|
The definition of the wrapped tokenizer can be found here:
|
||||||
# If we don't have it, the model is considered an LLM for warmup purposes.
|
https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py
|
||||||
if not hasattr(hf_config, "visual"):
|
"""
|
||||||
seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
|
new_tokenizer = copy.deepcopy(tokenizer)
|
||||||
mm_data = None
|
|
||||||
return DummyData(seq_data, mm_data)
|
|
||||||
|
|
||||||
# We have a visual component - use images to warm up
|
class TokenizerWithoutImagePad(tokenizer.__class__): # type: ignore
|
||||||
num_images = mm_counts["image"]
|
|
||||||
model_config = ctx.model_config
|
|
||||||
tokenizer = cached_get_tokenizer(
|
|
||||||
model_config.tokenizer,
|
|
||||||
trust_remote_code=model_config.trust_remote_code)
|
|
||||||
|
|
||||||
# Build the image prompts with no imgpads; the tokenizer will add img pads
|
def tokenize(
|
||||||
image_prompt = ''.join(
|
self,
|
||||||
[get_image_text(idx, False) for idx in range(1, num_images + 1)])
|
text: str,
|
||||||
toks = tokenizer.encode(image_prompt, add_special_tokens=False)
|
allowed_special: Union[AbstractSet[str], str] = "all",
|
||||||
|
disallowed_special: Union[Collection[str], str] = (),
|
||||||
|
**kwargs,
|
||||||
|
) -> list[Union[bytes, str]]:
|
||||||
|
text = unicodedata.normalize("NFC", text)
|
||||||
|
|
||||||
# Make sure we actually get the fixed context size per tok padding
|
return [
|
||||||
num_pads = toks.count(tokenizer.encode(IMG_PAD)[0])
|
self.decoder[t] for t in self.tokenizer.encode(
|
||||||
if num_pads != (num_images * MAX_QWEN_IMG_TOKENS):
|
text,
|
||||||
raise ValueError(
|
allowed_special=allowed_special,
|
||||||
f"Tokenized dummy data should encode {MAX_QWEN_IMG_TOKENS} pads"
|
disallowed_special=disallowed_special,
|
||||||
f" per image, but got {num_pads} pads for {num_images} image(s)"
|
)
|
||||||
" in total. Are you using a qwen tokenizer?")
|
]
|
||||||
|
|
||||||
# Ensure the number of tokens is at minimum the sequence length provided
|
def _decode(
|
||||||
if len(toks) < seq_len:
|
self,
|
||||||
toks += [0] * (seq_len - len(toks))
|
token_ids: Union[int, List[int]],
|
||||||
|
skip_special_tokens: bool = False,
|
||||||
|
errors: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
|
if isinstance(token_ids, int):
|
||||||
|
token_ids = [token_ids]
|
||||||
|
|
||||||
seq_data = SequenceData.from_seqs(toks)
|
return self.tokenizer.decode(
|
||||||
|
token_ids,
|
||||||
|
errors=errors or self.errors,
|
||||||
|
)
|
||||||
|
|
||||||
# Build the input images; width/height doesn't actually matter here since
|
TokenizerWithoutImagePad.__name__ = \
|
||||||
# the data will get resized and the # of tokens per image is constant
|
f"{tokenizer.__class__.__name__}WithoutImagePad"
|
||||||
image = Image.new("RGB", (224, 224), color=0)
|
|
||||||
mm_data = {"image": image if num_images == 1 else [image] * num_images}
|
new_tokenizer.__class__ = TokenizerWithoutImagePad
|
||||||
return DummyData(seq_data, mm_data)
|
return new_tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class QWenVLProcessor:
|
||||||
|
"""
|
||||||
|
This model doesn't define its own HF processor,
|
||||||
|
so we implement our own one here.
|
||||||
|
|
||||||
|
We call the wrapped tokenizer to automatically insert image pad tokens:
|
||||||
|
https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py#L245
|
||||||
|
|
||||||
|
The image processor is defined here:
|
||||||
|
https://huggingface.co/Qwen/Qwen-VL/blob/main/visual.py#L354
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
if hasattr(self.config, "visual"):
|
||||||
|
self.image_transform = build_normalization_transform(
|
||||||
|
config.visual["image_size"])
|
||||||
|
else:
|
||||||
|
self.image_transform = None
|
||||||
|
|
||||||
|
special_tokens: dict[str,
|
||||||
|
int] = tokenizer.special_tokens # type: ignore
|
||||||
|
self.img_start_id = special_tokens[IMG_START]
|
||||||
|
self.img_end_id = special_tokens[IMG_END]
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
text: Optional[Union[TextInput, list[TextInput]]] = None,
|
||||||
|
images: Optional[Union[ImageInput, list[ImageInput]]] = None,
|
||||||
|
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||||
|
) -> BatchFeature:
|
||||||
|
if text is None:
|
||||||
|
text = []
|
||||||
|
if not isinstance(text, list):
|
||||||
|
text = [text]
|
||||||
|
if images is None:
|
||||||
|
images = []
|
||||||
|
if not isinstance(images, list):
|
||||||
|
images = [images]
|
||||||
|
|
||||||
|
text_inputs = self.tokenizer(text)
|
||||||
|
|
||||||
|
if len(images) == 0:
|
||||||
|
image_inputs = {}
|
||||||
|
else:
|
||||||
|
if self.image_transform is None:
|
||||||
|
raise ValueError("This model does not support image inputs")
|
||||||
|
|
||||||
|
pixel_values = [self.image_transform(image) for image in images]
|
||||||
|
image_inputs = {"pixel_values": torch.stack(pixel_values)}
|
||||||
|
|
||||||
|
return BatchFeature(
|
||||||
|
{
|
||||||
|
**text_inputs,
|
||||||
|
**image_inputs,
|
||||||
|
},
|
||||||
|
tensor_type=return_tensors,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QWenVLProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
|
def get_tokenizer(self) -> PreTrainedTokenizer:
|
||||||
|
tokenizer = self.ctx.tokenizer
|
||||||
|
assert isinstance(tokenizer, PreTrainedTokenizer)
|
||||||
|
|
||||||
|
return _get_tokenizer_without_image_pad(tokenizer)
|
||||||
|
|
||||||
|
def get_hf_processor(self) -> QWenVLProcessor:
|
||||||
|
tokenizer = self.ctx.tokenizer
|
||||||
|
assert isinstance(tokenizer, PreTrainedTokenizer)
|
||||||
|
|
||||||
|
return QWenVLProcessor(self.get_hf_config(), tokenizer)
|
||||||
|
|
||||||
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
|
return {"image": None}
|
||||||
|
|
||||||
|
def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
|
||||||
|
return {"image": self.get_num_image_tokens()}
|
||||||
|
|
||||||
|
def get_num_image_tokens(self) -> int:
|
||||||
|
return MAX_QWEN_IMG_TOKENS
|
||||||
|
|
||||||
|
|
||||||
|
class QWenVLDummyInputsBuilder(BaseDummyInputsBuilder[QWenVLProcessingInfo]):
|
||||||
|
|
||||||
|
def get_dummy_processor_inputs(
|
||||||
|
self,
|
||||||
|
seq_len: int,
|
||||||
|
mm_counts: Mapping[str, int],
|
||||||
|
) -> ProcessorInputs:
|
||||||
|
hf_config = self.info.get_hf_config()
|
||||||
|
if not hasattr(hf_config, "visual"):
|
||||||
|
return ProcessorInputs(prompt_text="", mm_data={})
|
||||||
|
|
||||||
|
vision_config = hf_config.visual
|
||||||
|
|
||||||
|
max_image_size = vision_config["image_size"]
|
||||||
|
num_images = mm_counts.get("image", 0)
|
||||||
|
|
||||||
|
mm_data = {
|
||||||
|
"image":
|
||||||
|
self._get_dummy_images(width=max_image_size,
|
||||||
|
height=max_image_size,
|
||||||
|
num_images=num_images)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ProcessorInputs(
|
||||||
|
prompt_text="".join(f"Picture {i}: {IMG_START}{IMG_END}\n"
|
||||||
|
for i in range(1, num_images + 1)),
|
||||||
|
mm_data=mm_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QWenVLMultiModalProcessor(BaseMultiModalProcessor[QWenVLProcessingInfo]):
|
||||||
|
|
||||||
|
def _call_hf_processor(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
mm_data: Mapping[str, object],
|
||||||
|
mm_kwargs: Mapping[str, object],
|
||||||
|
) -> BatchFeature:
|
||||||
|
# Drops anything between <img>/</img> tags; encoding with the tokenizer
|
||||||
|
# will automatically add the image pads for the context.
|
||||||
|
prompt, num_matched_images = re.subn(
|
||||||
|
r"(Picture \d*: <img>).*?(<\/img>\n)",
|
||||||
|
r"\1\2",
|
||||||
|
prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_data = mm_data.get("images")
|
||||||
|
if image_data is not None:
|
||||||
|
assert isinstance(image_data, list)
|
||||||
|
|
||||||
|
num_images = len(image_data)
|
||||||
|
if num_matched_images != num_images:
|
||||||
|
logger.warning(
|
||||||
|
"Number of matched image placeholders %s doesn't match "
|
||||||
|
"the number of expected images %s; check your placeholder "
|
||||||
|
"formatting.", num_matched_images, num_images)
|
||||||
|
|
||||||
|
return super()._call_hf_processor(
|
||||||
|
prompt=prompt,
|
||||||
|
mm_data=mm_data,
|
||||||
|
mm_kwargs=mm_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
return dict(
|
||||||
|
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||||
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_prompt_replacements(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
out_mm_kwargs: MultiModalKwargs,
|
||||||
|
) -> list[PromptReplacement]:
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
special_tokens: dict[str,
|
||||||
|
int] = tokenizer.special_tokens # type: ignore
|
||||||
|
|
||||||
|
img_start_id = special_tokens[IMG_START]
|
||||||
|
img_end_id = special_tokens[IMG_END]
|
||||||
|
img_pad_id = special_tokens[IMG_PAD]
|
||||||
|
|
||||||
|
num_image_tokens = self.info.get_num_image_tokens()
|
||||||
|
image_tokens = [img_pad_id] * num_image_tokens
|
||||||
|
|
||||||
|
return [
|
||||||
|
PromptReplacement(
|
||||||
|
modality="image",
|
||||||
|
target=[img_start_id, img_end_id],
|
||||||
|
replacement=PromptReplacementDetails(
|
||||||
|
full=[img_start_id] + image_tokens + [img_end_id],
|
||||||
|
features=image_tokens,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
|
class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
|
||||||
@ -898,38 +909,77 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
|
|||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.transformer.make_empty_intermediate_tensors)
|
self.transformer.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
def _get_image_input_type(
|
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||||
self,
|
h = w = self.config.visual["image_size"]
|
||||||
pixel_values: Optional[torch.Tensor]) -> Optional[QwenImageInputs]:
|
expected_dims = (3, h, w)
|
||||||
"""Determines if the provided pixel_values are normalized pixel values
|
actual_dims = tuple(data.shape[1:])
|
||||||
or image embeddings.
|
|
||||||
|
|
||||||
Args:
|
if actual_dims != expected_dims:
|
||||||
pixel_values: Optional data to processed into visual embeddings.
|
expected_expr = ("batch_size", *map(str, expected_dims))
|
||||||
|
raise ValueError(
|
||||||
|
f"The expected shape of pixel values is {expected_expr}. "
|
||||||
|
f"You supplied {tuple(data.shape)}.")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _parse_and_validate_image_input(
|
||||||
|
self, **kwargs: object) -> Optional[QwenImageInputs]:
|
||||||
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
|
|
||||||
|
if pixel_values is not None:
|
||||||
|
if not isinstance(pixel_values, torch.Tensor):
|
||||||
|
raise ValueError("Incorrect type of pixel values. "
|
||||||
|
f"Got type: {type(pixel_values)}")
|
||||||
|
|
||||||
|
return QwenImagePixelInputs(
|
||||||
|
type="pixel_values",
|
||||||
|
data=self._validate_pixel_values(
|
||||||
|
flatten_bn(pixel_values, concat=True)),
|
||||||
|
)
|
||||||
|
|
||||||
|
if image_embeds is not None:
|
||||||
|
if not isinstance(image_embeds, torch.Tensor):
|
||||||
|
raise ValueError("Incorrect type of image embeddings. "
|
||||||
|
f"Got type: {type(image_embeds)}")
|
||||||
|
|
||||||
|
return QwenImageEmbeddingInputs(
|
||||||
|
type="image_embeds",
|
||||||
|
data=flatten_bn(image_embeds),
|
||||||
|
)
|
||||||
|
|
||||||
Returns:
|
|
||||||
None of the QwenImageInputs type used to determine whether or not
|
|
||||||
the visual transformer needs to process the pixel_values.
|
|
||||||
"""
|
|
||||||
if pixel_values is not None and self.transformer.visual is not None:
|
|
||||||
pixel_values = flatten_bn(pixel_values)
|
|
||||||
if len(pixel_values.shape) == 3 and pixel_values.shape[
|
|
||||||
1] == MAX_QWEN_IMG_TOKENS and pixel_values.shape[
|
|
||||||
2] == self.config.visual["output_dim"]:
|
|
||||||
return QwenImageEmbeddingInputs(
|
|
||||||
type="image_embeds",
|
|
||||||
data=pixel_values,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# If we have the wrong shape, assume we still need to process
|
|
||||||
return QwenImagePixelInputs(
|
|
||||||
type="pixel_values",
|
|
||||||
data=pixel_values,
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
def _process_image_input(self,
|
||||||
return self.transformer.get_input_embeddings(input_ids)
|
image_input: QwenImageInputs) -> torch.Tensor:
|
||||||
|
if image_input["type"] == "image_embeds":
|
||||||
|
return image_input["data"]
|
||||||
|
|
||||||
|
assert self.transformer.visual is not None
|
||||||
|
return self.transformer.visual(image_input["data"])
|
||||||
|
|
||||||
|
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||||
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
|
if image_input is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
|
return vision_embeddings
|
||||||
|
|
||||||
|
def get_input_embeddings(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
inputs_embeds = self.transformer.get_input_embeddings(input_ids)
|
||||||
|
|
||||||
|
if multimodal_embeddings is not None:
|
||||||
|
assert self.transformer.visual is not None
|
||||||
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
|
input_ids, inputs_embeds, multimodal_embeddings,
|
||||||
|
self.transformer.visual.image_pad_id)
|
||||||
|
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -938,18 +988,23 @@ class QWenBaseModel(nn.Module, SupportsPP, SupportsLoRA):
|
|||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
pixel_values: Optional[torch.Tensor] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs: object,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
if intermediate_tensors is not None:
|
if intermediate_tensors is not None:
|
||||||
|
inputs_embeds = None
|
||||||
|
|
||||||
|
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||||
|
# condition is for v0 compatibility.
|
||||||
|
elif inputs_embeds is None:
|
||||||
|
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||||
|
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||||
|
vision_embeddings)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
pixel_values = None
|
|
||||||
else:
|
|
||||||
pixel_values = self._get_image_input_type(pixel_values)
|
|
||||||
|
|
||||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||||
attn_metadata, intermediate_tensors,
|
attn_metadata, intermediate_tensors,
|
||||||
pixel_values, inputs_embeds)
|
inputs_embeds)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
@ -1063,10 +1118,9 @@ class QWenVL(QWenBaseModel, SupportsMultiModal):
|
|||||||
tower_model="transformer.visual.transformer")
|
tower_model="transformer.visual.transformer")
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen)
|
@MULTIMODAL_REGISTRY.register_processor(QWenVLMultiModalProcessor,
|
||||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS)
|
info=QWenVLProcessingInfo,
|
||||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen)
|
dummy_inputs=QWenVLDummyInputsBuilder)
|
||||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen)
|
|
||||||
class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA):
|
class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA):
|
||||||
"""
|
"""
|
||||||
QWenLMHeadModel is not only applicable to LLM but also to VL, which is not
|
QWenLMHeadModel is not only applicable to LLM but also to VL, which is not
|
||||||
@ -1084,7 +1138,7 @@ class QWenLMHeadModel(QWenBaseModel, SupportsMultiModal, SupportsLoRA):
|
|||||||
cls,
|
cls,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> QWenBaseModel:
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
# Initialize VL
|
# Initialize VL
|
||||||
if hasattr(config, "visual"):
|
if hasattr(config, "visual"):
|
||||||
|
|||||||
Reference in New Issue
Block a user