[Bugfix] Fix shape checking for Fuyu (#21709)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-07-28 15:05:56 +08:00
committed by GitHub
parent 18cc33dd60
commit 139a97ec56

View File

@ -55,14 +55,15 @@ class FuyuImagePatchInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- fn: Num channels * patch_size_x * patch_size_y
- bnp: Batch size * number of images * number of patches
- fn: patch_size_x * patch_size_y * num_channels
"""
type: Literal["image_patches"] = "image_patches"
flat_data: Annotated[
torch.Tensor,
TensorShape("bn", "fn"),
TensorShape("bnp", "fn"),
]
patches_per_image: Annotated[list[int], TensorShape("bn")]
@ -309,8 +310,8 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_patches = kwargs.pop("image_patches", None)
if image_patches is not None:
image_patches_flat = flatten_bn(image_patches)
flat_data = flatten_bn(image_patches, concat=True).data.to(
self.vision_embed_tokens.weight.dtype)
flat_data = flatten_bn(image_patches_flat, concat=True)
return FuyuImagePatchInputs(
type="image_patches",
flat_data=flat_data,