Migrate tarsier inputs to TensorSchema (#23500)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck
2025-08-24 21:42:36 -07:00
committed by GitHub
parent 170e8ea9ea
commit 99f8094400

View File

@ -3,7 +3,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
from typing import (Annotated, Final, Literal, Optional, Protocol, TypeVar,
Union, cast)
import torch
@ -34,6 +34,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@ -43,14 +44,28 @@ from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
from .vision import VisionEncoderInfo, get_vision_encoder_info
class TarsierImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
class TarsierImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height
- w: Width
"""
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
class TarsierImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
class TarsierImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- ifs: Image feature size
- hs: Hidden size (must match the hidden size of language model
backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[torch.Tensor, TensorShape("bn", "ifs", "hs")]
TarsierImageInputs = Union[TarsierImagePixelInputs,
@ -432,18 +447,6 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w) # Assuming 3 channels
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[TarsierImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@ -459,8 +462,7 @@ class TarsierForConditionalGeneration(nn.Module, SupportsMultiModal,
return TarsierImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
pixel_values=flatten_bn(pixel_values, concat=True),
)
if image_embeds is not None: