Migrate GraniteSpeechAudioInputs to TensorSchema (#21682)

Signed-off-by: Benji Beck <benjibeck@meta.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Benji Beck
2025-07-27 22:37:20 -07:00
committed by GitHub
parent 304dcdf575
commit 75856bc2cb

View File

@ -25,7 +25,7 @@
"""Inference-only IBM Granite speech model."""
import math
from collections.abc import Iterable, Mapping
from typing import Optional, TypedDict, Union
from typing import Annotated, Optional, Union
import torch
import torch.nn.functional as F
@ -48,6 +48,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .blip2 import Blip2QFormerModel
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
@ -57,16 +58,24 @@ from .utils import (AutoWeightsLoader, embed_multimodal,
### Audio Input
class GraniteSpeechAudioInputs(TypedDict):
class GraniteSpeechAudioInputs(TensorSchema):
"""
Audio input features for Granite Speech model.
Dimensions:
- b: Batch size
- nf: Number of audio features (variable length)
- 160: Fixed feature dimension for Mel spectrogram features
"""
input_features: torch.Tensor
"""Shape: `(bsz, num_features, 160)`"""
input_features: Annotated[torch.Tensor, TensorShape("b", "nf", 160)]
"""Audio input features."""
input_features_mask: torch.Tensor
"""Shape: `(bsz, num_features)`"""
input_features_mask: Annotated[torch.Tensor, TensorShape("b", "nf")]
"""Mask for variable length audio features."""
audio_embed_sizes: list[int]
"""List of length `bsz`"""
audio_embed_sizes: Annotated[list[int], TensorShape("b")]
"""List of audio embedding sizes for each item in batch."""
class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo):
@ -581,6 +590,7 @@ class GraniteSpeechForConditionalGeneration(
input_features = kwargs.pop("input_features", None)
input_features_mask = kwargs.pop("input_features_mask", None)
audio_embed_sizes = kwargs.pop("audio_embed_sizes", None)
if input_features is None:
return None