[Docs] Add transcription support to model (#24664)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-09-11 16:39:01 +02:00
committed by GitHub
parent 817beef7f3
commit 404c85ca72
3 changed files with 275 additions and 0 deletions

View File

@ -44,6 +44,7 @@ nav:
- contributing/model/registration.md
- contributing/model/tests.md
- contributing/model/multimodal.md
- contributing/model/transcription.md
- CI: contributing/ci
- Design Documents: design
- API Reference:

View File

@ -15,6 +15,7 @@ Read through these pages for a step-by-step guide:
- [Registering a Model](registration.md)
- [Unit Testing](tests.md)
- [Multi-Modal Support](multimodal.md)
- [Speech-to-Text Support](transcription.md)
!!! tip
If you are encountering issues while integrating your model into vLLM, feel free to open a [GitHub issue](https://github.com/vllm-project/vllm/issues)

View File

@ -0,0 +1,273 @@
# Speech-to-Text (Transcription/Translation) Support
This document walks you through the steps to add support for speech-to-text (ASR) models to vLLMs transcription and translation APIs by implementing [SupportsTranscription][vllm.model_executor.models.interfaces.SupportsTranscription].
Please refer to the [supported models](../../models/supported_models.md#transcription) for further guidance.
## 1. Update the base vLLM model
It is assumed you have already implemented your model in vLLM according to the basic model guide. Extend your model with the [SupportsTranscription][vllm.model_executor.models.interfaces.SupportsTranscription] interface and implement the following class attributes and methods.
- Declare supported languages and capabilities:
??? code
```python
from typing import ClassVar, Mapping, Optional, Literal
import numpy as np
import torch
from torch import nn
from vllm.config import ModelConfig, SpeechToTextConfig
from vllm.inputs.data import PromptType
from vllm.model_executor.models.interfaces import SupportsTranscription
class YourASRModel(nn.Module, SupportsTranscription):
# Map of ISO 639-1 language codes to language names
supported_languages: ClassVar[Mapping[str, str]] = {
"en": "English",
"it": "Italian",
# ... add more as needed
}
# If your model only supports audio-conditioned generation
# (no text-only generation), enable this flag.
supports_transcription_only: ClassVar[bool] = True
```
- The `supported_languages` mapping is validated at init time.
- Set `supports_transcription_only=True` if the model should not serve text generation (eg Whisper).
- Provide an ASR configuration via [get_speech_to_text_config][vllm.model_executor.models.interfaces.SupportsTranscription.get_speech_to_text_config].
This is for controlling general behavior of the API when serving your model:
??? code
```python
class YourASRModel(nn.Module, SupportsTranscription):
...
@classmethod
def get_speech_to_text_config(
cls,
model_config: ModelConfig,
task_type: Literal["transcribe", "translate"],
) -> SpeechToTextConfig:
return SpeechToTextConfig(
sample_rate=16_000,
max_audio_clip_s=30,
# Set to None to disable server-side chunking if your
# model/processor handles it already
min_energy_split_window_size=None,
)
```
See the “Audio preprocessing and chunking” section for what each field controls.
- Implement the prompt construction via [get_generation_prompt][vllm.model_executor.models.interfaces.SupportsTranscription.get_generation_prompt]. The server passes you the resampled waveform and task parameters; you return a valid [PromptType][vllm.inputs.data.PromptType]. There are two common patterns:
#### A. Multimodal LLM with audio embeddings (e.g., Voxtral, Gemma3n)
Return a dict containing `multi_modal_data` with the audio, and either a `prompt` string or `prompt_token_ids`:
??? code
```python
class YourASRModel(nn.Module, SupportsTranscription):
...
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: Optional[str],
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: Optional[str],
) -> PromptType:
# Example with a free-form instruction prompt
task_word = "Transcribe" if task_type == "transcribe" else "Translate"
prompt = (
"<start_of_turn>user\n"
f"{task_word} this audio: <audio_soft_token>"
"<end_of_turn>\n<start_of_turn>model\n"
)
return {
"multi_modal_data": {"audio": (audio, stt_config.sample_rate)},
"prompt": prompt,
}
```
For further clarification on multi modal inputs, please refer to [Multi-Modal Inputs](../../features/multimodal_inputs.md).
#### B. Encoderdecoder audio-only (e.g., Whisper)
Return a dict with separate `encoder_prompt` and `decoder_prompt` entries:
??? code
```python
class YourASRModel(nn.Module, SupportsTranscription):
...
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: Optional[str],
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: Optional[str],
) -> PromptType:
if language is None:
raise ValueError("Language must be specified")
prompt = {
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {
"audio": (audio, stt_config.sample_rate),
},
},
"decoder_prompt": (
(f"<|prev|>{request_prompt}" if request_prompt else "")
+ f"<|startoftranscript|><|{language}|>"
+ f"<|{task_type}|><|notimestamps|>"
),
}
return cast(PromptType, prompt)
```
- (Optional) Language validation via [validate_language][vllm.model_executor.models.interfaces.SupportsTranscription.validate_language]
If your model requires a language and you want a default, override this method (see Whisper):
```python
@classmethod
def validate_language(cls, language: Optional[str]) -> Optional[str]:
if language is None:
logger.warning(
"Defaulting to language='en'. If you wish to transcribe audio in a different language, pass the `language` field.")
language = "en"
return super().validate_language(language)
```
- (Optional) Token accounting for streaming via [get_num_audio_tokens][vllm.model_executor.models.interfaces.SupportsTranscription.get_num_audio_tokens]
Provide a fast duration→token estimate to improve streaming usage statistics:
??? code
```python
class YourASRModel(nn.Module, SupportsTranscription):
...
@classmethod
def get_num_audio_tokens(
cls,
audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
) -> Optional[int]:
# Return None if unknown; otherwise return an estimate.
return int(audio_duration_s * stt_config.sample_rate // 320) # example
```
## 2. Audio preprocessing and chunking
The API server takes care of basic audio I/O and optional chunking before building prompts:
- Resampling: Input audio is resampled to `SpeechToTextConfig.sample_rate` using `librosa`.
- Chunking: If `SpeechToTextConfig.allow_audio_chunking` is True and the duration exceeds `max_audio_clip_s`, the server splits the audio into overlapping chunks and generates a prompt per chunk. Overlap is controlled by `overlap_chunk_second`.
- Energy-aware splitting: When `min_energy_split_window_size` is set, the server finds low-energy regions to minimize cutting within words.
Relevant server logic:
??? code
```python
# vllm/entrypoints/openai/speech_to_text.py
async def _preprocess_speech_to_text(...):
language = self.model_cls.validate_language(request.language)
...
y, sr = librosa.load(bytes_, sr=self.asr_config.sample_rate)
duration = librosa.get_duration(y=y, sr=sr)
do_split_audio = (self.asr_config.allow_audio_chunking
and duration > self.asr_config.max_audio_clip_s)
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
prompts = []
for chunk in chunks:
prompt = self.model_cls.get_generation_prompt(
audio=chunk,
stt_config=self.asr_config,
model_config=self.model_config,
language=language,
task_type=self.task_type,
request_prompt=request.prompt,
to_language=to_language,
)
prompts.append(prompt)
return prompts, duration
```
## 3. Exposing tasks automatically
- vLLM automatically advertises transcription support if your model implements the interface:
```python
if supports_transcription(model):
if model.supports_transcription_only:
return ["transcription"]
supported_tasks.append("transcription")
```
- When enabled, the server initializes the transcription and translation handlers:
```python
state.openai_serving_transcription = OpenAIServingTranscription(...) if "transcription" in supported_tasks else None
state.openai_serving_translation = OpenAIServingTranslation(...) if "transcription" in supported_tasks else None
```
No extra registration is required beyond having your model class available via the model registry and implementing `SupportsTranscription`.
## 4. Examples in-tree
- Whisper encoderdecoder (audio-only): <gh-file:vllm/model_executor/models/whisper.py>
- Voxtral decoder-only (audio embeddings + LLM): <gh-file:vllm/model_executor/models/voxtral.py>
- Gemma3n decoder-only with fixed instruction prompt: <gh-file:vllm/model_executor/models/gemma3n_mm.py>
## 5. Test with the API
Once your model implements `SupportsTranscription`, you can test the endpoints (API mimics OpenAI):
- Transcription (ASR):
```bash
curl -s -X POST \
-H "Authorization: Bearer $VLLM_API_KEY" \
-H "Content-Type: multipart/form-data" \
-F "file=@/path/to/audio.wav" \
-F "model=$MODEL_ID" \
http://localhost:8000/v1/audio/transcriptions
```
- Translation (source → English unless otherwise supported):
```bash
curl -s -X POST \
-H "Authorization: Bearer $VLLM_API_KEY" \
-H "Content-Type: multipart/form-data" \
-F "file=@/path/to/audio.wav" \
-F "model=$MODEL_ID" \
http://localhost:8000/v1/audio/translations
```
Or check out more examples in <gh-file:examples/online_serving>.
!!! note
- If your model handles chunking internally (e.g., via its processor or encoder), set `min_energy_split_window_size=None` in the returned `SpeechToTextConfig` to disable server-side chunking.
- Implementing `get_num_audio_tokens` improves accuracy of streaming usage metrics (`prompt_tokens`) without an extra forward pass.
- For multilingual behavior, keep `supported_languages` aligned with actual model capabilities.