mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Migrate LlavaOnevisionMultiInputs to TensorSchema (#21844)
Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
@ -3,7 +3,7 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Final, Literal, Optional, Protocol, TypedDict, Union
|
||||
from typing import Annotated, Final, Literal, Optional, Protocol, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -11,7 +11,6 @@ from transformers import (BatchFeature, LlavaOnevisionConfig,
|
||||
LlavaOnevisionProcessor)
|
||||
from transformers.models.llava_onevision.modeling_llava_onevision import (
|
||||
get_anyres_image_grid_shape, unpad_image)
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
@ -23,6 +22,7 @@ from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
|
||||
VideoEmbeddingItems, VideoProcessorItems)
|
||||
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
@ -38,44 +38,62 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
_MAX_FRAMES_PER_VIDEO = 16
|
||||
|
||||
|
||||
class LlavaOnevisionVideoPixelInputs(TypedDict):
|
||||
type: Literal["pixel_values_videos"]
|
||||
pixel_values_videos: Union[torch.Tensor, list[torch.Tensor]]
|
||||
class LlavaOnevisionVideoPixelInputs(TensorSchema):
|
||||
"""
|
||||
Shape: `(batch_size * num_videos, num_frames, num_channels, height, width)`
|
||||
Dimensions:
|
||||
- bn: Batch size * number of videos
|
||||
- f: Number of frames
|
||||
- c: Number of channels (3)
|
||||
- h: Height
|
||||
- w: Width
|
||||
|
||||
Note that `num_videos` may be different for each batch, and 'num_frames'
|
||||
may be different for each video, in which case the data is passed as a
|
||||
list instead of a batched tensor.
|
||||
Note that `num_videos` may be different for each batch, and 'num_frames'
|
||||
may be different for each video, in which case the data is passed as a
|
||||
list instead of a batched tensor.
|
||||
"""
|
||||
type: Literal["pixel_values_videos"] = "pixel_values_videos"
|
||||
|
||||
pixel_values_videos: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "f", 3, "h", "w", dynamic_dims={"f"}),
|
||||
]
|
||||
|
||||
|
||||
class LlavaOnevisionImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
|
||||
class LlavaOnevisionImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- np: Number of patches (1 + num_patches)
|
||||
- c: Number of channels (3)
|
||||
- h: Height
|
||||
- w: Width
|
||||
|
||||
Note that `num_patches` may be different per batch and image,
|
||||
in which case the data is passed as a list instead of a batched tensor.
|
||||
Note that `num_patches` may be different per batch and image,
|
||||
in which case the data is passed as a list instead of a batched tensor.
|
||||
"""
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
|
||||
image_sizes: NotRequired[torch.Tensor]
|
||||
pixel_values: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "np", 3, "h", "w"),
|
||||
]
|
||||
|
||||
image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]
|
||||
|
||||
|
||||
class LlavaOnevisionImageEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Shape: `(batch_size * num_images, 2)`
|
||||
|
||||
This should be in `(height, width)` format.
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- ifs: Image feature size
|
||||
- hs: Hidden size (must match language model backbone)
|
||||
"""
|
||||
type: Literal["image_embeds"] = "image_embeds"
|
||||
|
||||
|
||||
class LlavaOnevisionImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
data: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("bn", "ifs", "hs"),
|
||||
]
|
||||
|
||||
|
||||
LlavaOnevisionImageInputs = Union[LlavaOnevisionImagePixelInputs,
|
||||
@ -482,44 +500,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.model.make_empty_intermediate_tensors)
|
||||
|
||||
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
|
||||
expected_dims = (2, )
|
||||
|
||||
def _validate_shape(d: torch.Tensor):
|
||||
actual_dims = tuple(d.shape)
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = str(expected_dims)
|
||||
raise ValueError(
|
||||
f"The expected shape of image sizes per image per batch "
|
||||
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
||||
|
||||
for d in data:
|
||||
_validate_shape(d)
|
||||
|
||||
return data
|
||||
|
||||
def _validate_image_pixel_values(
|
||||
self, data: Union[torch.Tensor, list[torch.Tensor]]
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
|
||||
h = w = self.config.vision_config.image_size
|
||||
expected_dims = (3, h, w)
|
||||
|
||||
def _validate_shape(d: torch.Tensor):
|
||||
actual_dims = tuple(d.shape[1:])
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = ("num_patches", *map(str, expected_dims))
|
||||
raise ValueError(
|
||||
"The expected shape of pixel values per image per batch "
|
||||
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
||||
|
||||
for d in data:
|
||||
_validate_shape(d)
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[LlavaOnevisionImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
@ -540,11 +520,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return LlavaOnevisionImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=self._validate_image_pixel_values(
|
||||
flatten_bn(pixel_values)),
|
||||
image_sizes=self._validate_image_sizes(
|
||||
flatten_bn(image_sizes, concat=True)),
|
||||
)
|
||||
pixel_values=flatten_bn(pixel_values),
|
||||
image_sizes=flatten_bn(image_sizes, concat=True),
|
||||
resolve_bindings={
|
||||
"h": self.config.vision_config.image_size,
|
||||
"w": self.config.vision_config.image_size
|
||||
})
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
@ -558,27 +539,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _validate_video_pixel_values(
|
||||
self, data: Union[torch.Tensor, list[torch.Tensor]]
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
|
||||
h = w = self.config.vision_config.image_size
|
||||
expected_dims = (3, h, w)
|
||||
|
||||
def _validate_shape(d: torch.Tensor):
|
||||
actual_dims = tuple(d.shape[2:])
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = ("num_frames", *map(str, expected_dims))
|
||||
raise ValueError(
|
||||
"The expected shape of pixel values in each video frame "
|
||||
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
||||
|
||||
for d in data:
|
||||
_validate_shape(d)
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_video_input(
|
||||
self,
|
||||
**kwargs: object) -> Optional[LlavaOnevisionVideoPixelInputs]:
|
||||
@ -600,7 +560,10 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return LlavaOnevisionVideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
pixel_values_videos=flatten_bn(pixel_values_videos),
|
||||
)
|
||||
resolve_bindings={
|
||||
"h": self.config.vision_config.image_size,
|
||||
"w": self.config.vision_config.image_size
|
||||
})
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
mm_input_by_modality = {}
|
||||
|
Reference in New Issue
Block a user