Migrate Qwen inputs to TensorSchema (#23473)
Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
@ -11,7 +11,7 @@ import math
|
||||
import unicodedata
|
||||
from collections.abc import Collection, Mapping, Sequence, Set
|
||||
from functools import lru_cache, partial
|
||||
from typing import Callable, Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Callable, Literal, Optional, Union
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
@ -47,26 +48,34 @@ from .qwen import QWenBaseModel, QWenModel
|
||||
from .utils import flatten_bn, merge_multimodal_embeddings
|
||||
|
||||
|
||||
class QwenImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
class QwenImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Shape: `(batch_size * num_images, 3, image_size, image_size)`
|
||||
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- c: Number of channels (3)
|
||||
- h: Height
|
||||
- w: Width
|
||||
|
||||
Note that image_size is the value in the vision config to which we resize
|
||||
the image to in the normalization transform. Currently multi-image support
|
||||
can only be leveraged by passing image embeddings directly.
|
||||
"""
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
|
||||
|
||||
|
||||
class QwenImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images, 256, hidden_size)`
|
||||
|
||||
class QwenImageEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- ifs: Image feature size (256)
|
||||
- hs: Hidden size
|
||||
|
||||
`hidden_size` must match the hidden size of the language model backbone
|
||||
and is stored in the visual config of the model if we have one.
|
||||
"""
|
||||
type: Literal["image_embeds"] = "image_embeds"
|
||||
data: Annotated[torch.Tensor, TensorShape("bn", 256, "hs")]
|
||||
|
||||
|
||||
QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs]
|
||||
@ -697,19 +706,6 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
|
||||
|
||||
self.transformer: QwenVLModel
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
h = w = self.config.visual["image_size"]
|
||||
expected_dims = (3, h, w)
|
||||
actual_dims = tuple(data.shape[1:])
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = ("batch_size", *map(str, expected_dims))
|
||||
raise ValueError(
|
||||
f"The expected shape of pixel values is {expected_expr}. "
|
||||
f"You supplied {tuple(data.shape)}.")
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[QwenImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
@ -720,10 +716,13 @@ class QwenVLForConditionalGeneration(QWenBaseModel, SupportsPP, SupportsLoRA,
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
expected_h = expected_w = self.config.visual["image_size"]
|
||||
resolve_bindings = {"h": expected_h, "w": expected_w}
|
||||
|
||||
return QwenImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(
|
||||
flatten_bn(pixel_values, concat=True)),
|
||||
data=flatten_bn(pixel_values, concat=True),
|
||||
resolve_bindings=resolve_bindings,
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
|
Reference in New Issue
Block a user