Migrate skyworkr1v inputs to TensorSchema (#23499)

Signed-off-by: Benji Beck <benjibeck@meta.com>
This commit is contained in:
Benji Beck
2025-08-24 21:43:21 -07:00
committed by GitHub
parent 99f8094400
commit a5203d04df

View File

@ -8,7 +8,7 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal, Optional, TypedDict, Union
from typing import Annotated, Literal, Optional, Union
import torch
import torch.nn as nn
@ -35,6 +35,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
@ -48,27 +49,42 @@ IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
class SkyworkR1VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values_flat: torch.Tensor
class SkyworkR1VImagePixelInputs(TensorSchema):
"""
Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
Dimensions:
- bnp: Batch size * number of images * (1 + num_patches)
- c: Number of channels (3)
- h: Height
- w: Width
- bn: Batch size * number of images
"""
type: Literal["pixel_values"] = "pixel_values"
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
pixel_values_flat: Annotated[
torch.Tensor,
TensorShape("bnp", 3, "h", "w"),
]
num_patches: Annotated[
torch.Tensor,
TensorShape("bn"),
]
class SkyworkR1VImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
class SkyworkR1VImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- ni: Number of images
- ifs: Image feature size
- hs: Hidden size (must match the hidden size of language model
backbone)
"""
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("ni", "ifs", "hs"),
]
SkyworkR1VImageInputs = Union[SkyworkR1VImagePixelInputs,
@ -731,26 +747,6 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
vit_embeds = self.mlp1(vit_embeds)
return vit_embeds
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape)
if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[SkyworkR1VImageInputs]:
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
@ -788,10 +784,12 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return SkyworkR1VImagePixelInputs(
type="pixel_values",
pixel_values_flat=self._validate_pixel_values(
pixel_values_flat),
pixel_values_flat=pixel_values_flat,
num_patches=image_num_patches,
)
resolve_bindings={
"h": self.config.vision_config.image_size,
"w": self.config.vision_config.image_size,
})
raise AssertionError("This line should be unreachable.")