Migrate AyaVisionImagePixelInputs to TensorSchema for shape validation (#21622)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck
2025-07-26 06:08:18 -07:00
committed by GitHub
parent 9d197280fa
commit de10ff0b7c

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/aya_vision
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union, cast
from typing import Annotated, Literal, Optional, Union, cast
import torch
from torch import nn
@ -29,6 +29,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
@ -37,18 +38,28 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
merge_multimodal_embeddings)
class AyaVisionImagePixelInputs(TypedDict):
class AyaVisionImagePixelInputs(TensorSchema):
"""
Dimensions:
- np: The total number of patches over each image over each prompt in
the batch
- c: Number of channels
- h: Height of each image patch
- w: Width of each image patch
- bn: Batch size * number of images
"""
type: Literal["pixel_values"]
pixel_values: torch.Tensor
"""
Shape: `(num_patches_total, num_channels, height, width)`
`num_patches_total` is the total number of patches over each image over each
prompt in the batch.
"""
pixel_values: Annotated[
torch.Tensor,
TensorShape("np", 3, "h", "w"),
]
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
num_patches: Annotated[
torch.Tensor,
TensorShape("bn"),
]
class AyaVisionMultiModalProjector(nn.Module):
@ -383,21 +394,6 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
e.flatten(0, 2) for e in image_embeds.split(num_patches.tolist())
]
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
if d.shape != expected_dims:
raise ValueError(
"The expected shape of pixel values per image per batch "
f"is {expected_dims}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[AyaVisionImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@ -405,22 +401,17 @@ class AyaVisionForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Aya Vision does not support image_embeds."
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if num_patches is not None and not isinstance(num_patches,
(torch.Tensor, list)):
raise ValueError("Incorrect type of num_patches. "
f"Got type: {type(num_patches)}")
pixel_values = flatten_bn(pixel_values, concat=True)
num_patches = flatten_bn(num_patches, concat=True)
if pixel_values is None:
return None
return AyaVisionImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values),
num_patches=num_patches,
)
pixel_values=flatten_bn(pixel_values, concat=True),
num_patches=flatten_bn(num_patches, concat=True),
resolve_bindings={
"h": self.config.vision_config.image_size,
"w": self.config.vision_config.image_size,
})
def get_language_model(self) -> torch.nn.Module:
return self.language_model