mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Fix shape checking for Fuyu (#21709)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@ -55,14 +55,15 @@ class FuyuImagePatchInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- fn: Num channels * patch_size_x * patch_size_y
|
||||
- bnp: Batch size * number of images * number of patches
|
||||
- fn: patch_size_x * patch_size_y * num_channels
|
||||
"""
|
||||
|
||||
type: Literal["image_patches"] = "image_patches"
|
||||
|
||||
flat_data: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("bn", "fn"),
|
||||
TensorShape("bnp", "fn"),
|
||||
]
|
||||
|
||||
patches_per_image: Annotated[list[int], TensorShape("bn")]
|
||||
@ -309,8 +310,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
image_patches = kwargs.pop("image_patches", None)
|
||||
if image_patches is not None:
|
||||
image_patches_flat = flatten_bn(image_patches)
|
||||
flat_data = flatten_bn(image_patches, concat=True).data.to(
|
||||
self.vision_embed_tokens.weight.dtype)
|
||||
flat_data = flatten_bn(image_patches_flat, concat=True)
|
||||
|
||||
return FuyuImagePatchInputs(
|
||||
type="image_patches",
|
||||
flat_data=flat_data,
|
||||
|
Reference in New Issue
Block a user