Migrate LlavaOnevisionMultiInputs to TensorSchema (#21844)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck
2025-08-19 10:02:02 -07:00
committed by GitHub
parent 24f4d1a224
commit a70d0bd0a3

View File

@ -3,7 +3,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Final, Literal, Optional, Protocol, TypedDict, Union
from typing import Annotated, Final, Literal, Optional, Protocol, Union
import torch
import torch.nn as nn
@ -11,7 +11,6 @@ from transformers import (BatchFeature, LlavaOnevisionConfig,
LlavaOnevisionProcessor)
from transformers.models.llava_onevision.modeling_llava_onevision import (
get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired
from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn
@ -23,6 +22,7 @@ from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@ -38,44 +38,62 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
_MAX_FRAMES_PER_VIDEO = 16
class LlavaOnevisionVideoPixelInputs(TypedDict):
type: Literal["pixel_values_videos"]
pixel_values_videos: Union[torch.Tensor, list[torch.Tensor]]
class LlavaOnevisionVideoPixelInputs(TensorSchema):
"""
Shape: `(batch_size * num_videos, num_frames, num_channels, height, width)`
Dimensions:
- bn: Batch size * number of videos
- f: Number of frames
- c: Number of channels (3)
- h: Height
- w: Width
Note that `num_videos` may be different for each batch, and 'num_frames'
may be different for each video, in which case the data is passed as a
list instead of a batched tensor.
Note that `num_videos` may be different for each batch, and 'num_frames'
may be different for each video, in which case the data is passed as a
list instead of a batched tensor.
"""
type: Literal["pixel_values_videos"] = "pixel_values_videos"
pixel_values_videos: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "f", 3, "h", "w", dynamic_dims={"f"}),
]
class LlavaOnevisionImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
class LlavaOnevisionImagePixelInputs(TensorSchema):
"""
Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
Dimensions:
- bn: Batch size * number of images
- np: Number of patches (1 + num_patches)
- c: Number of channels (3)
- h: Height
- w: Width
Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""
type: Literal["pixel_values"] = "pixel_values"
image_sizes: NotRequired[torch.Tensor]
pixel_values: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "np", 3, "h", "w"),
]
image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]
class LlavaOnevisionImageEmbeddingInputs(TensorSchema):
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format.
Dimensions:
- bn: Batch size * number of images
- ifs: Image feature size
- hs: Hidden size (must match language model backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
class LlavaOnevisionImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
data: Annotated[
torch.Tensor,
TensorShape("bn", "ifs", "hs"),
]
LlavaOnevisionImageInputs = Union[LlavaOnevisionImagePixelInputs,
@ -482,44 +500,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = (
self.language_model.model.make_empty_intermediate_tensors)
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
expected_dims = (2, )
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
f"The expected shape of image sizes per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _validate_image_pixel_values(
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaOnevisionImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@ -540,11 +520,12 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return LlavaOnevisionImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_image_pixel_values(
flatten_bn(pixel_values)),
image_sizes=self._validate_image_sizes(
flatten_bn(image_sizes, concat=True)),
)
pixel_values=flatten_bn(pixel_values),
image_sizes=flatten_bn(image_sizes, concat=True),
resolve_bindings={
"h": self.config.vision_config.image_size,
"w": self.config.vision_config.image_size
})
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
@ -558,27 +539,6 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
raise AssertionError("This line should be unreachable.")
def _validate_video_pixel_values(
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[2:])
if actual_dims != expected_dims:
expected_expr = ("num_frames", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values in each video frame "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_video_input(
self,
**kwargs: object) -> Optional[LlavaOnevisionVideoPixelInputs]:
@ -600,7 +560,10 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return LlavaOnevisionVideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=flatten_bn(pixel_values_videos),
)
resolve_bindings={
"h": self.config.vision_config.image_size,
"w": self.config.vision_config.image_size
})
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
mm_input_by_modality = {}