mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Migrate GraniteSpeechAudioInputs to TensorSchema (#21682)
Signed-off-by: Benji Beck <benjibeck@meta.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@ -25,7 +25,7 @@
|
||||
"""Inference-only IBM Granite speech model."""
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping
|
||||
from typing import Optional, TypedDict, Union
|
||||
from typing import Annotated, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -48,6 +48,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .blip2 import Blip2QFormerModel
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
@ -57,16 +58,24 @@ from .utils import (AutoWeightsLoader, embed_multimodal,
|
||||
|
||||
|
||||
### Audio Input
|
||||
class GraniteSpeechAudioInputs(TypedDict):
|
||||
class GraniteSpeechAudioInputs(TensorSchema):
|
||||
"""
|
||||
Audio input features for Granite Speech model.
|
||||
|
||||
Dimensions:
|
||||
- b: Batch size
|
||||
- nf: Number of audio features (variable length)
|
||||
- 160: Fixed feature dimension for Mel spectrogram features
|
||||
"""
|
||||
|
||||
input_features: torch.Tensor
|
||||
"""Shape: `(bsz, num_features, 160)`"""
|
||||
input_features: Annotated[torch.Tensor, TensorShape("b", "nf", 160)]
|
||||
"""Audio input features."""
|
||||
|
||||
input_features_mask: torch.Tensor
|
||||
"""Shape: `(bsz, num_features)`"""
|
||||
input_features_mask: Annotated[torch.Tensor, TensorShape("b", "nf")]
|
||||
"""Mask for variable length audio features."""
|
||||
|
||||
audio_embed_sizes: list[int]
|
||||
"""List of length `bsz`"""
|
||||
audio_embed_sizes: Annotated[list[int], TensorShape("b")]
|
||||
"""List of audio embedding sizes for each item in batch."""
|
||||
|
||||
|
||||
class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo):
|
||||
@ -581,6 +590,7 @@ class GraniteSpeechForConditionalGeneration(
|
||||
input_features = kwargs.pop("input_features", None)
|
||||
input_features_mask = kwargs.pop("input_features_mask", None)
|
||||
audio_embed_sizes = kwargs.pop("audio_embed_sizes", None)
|
||||
|
||||
if input_features is None:
|
||||
return None
|
||||
|
||||
|
Reference in New Issue
Block a user