[Frontend] Add /v1/audio/transcriptions OpenAI API endpoint (#12909)

This commit is contained in:
Nicolò Lucchesi
2025-02-13 16:23:45 +01:00
committed by GitHub
parent 37dfa60037
commit d84cef76eb
20 changed files with 910 additions and 19 deletions

View File

@ -117,7 +117,7 @@ steps:
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/correctness/
- pytest -v -s entrypoints/test_chat_utils.py
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
@ -205,7 +205,7 @@ steps:
- VLLM_USE_V1=1 pytest -v -s v1/e2e
# Integration test for streaming correctness (requires special branch).
- pip install -U git+https://github.com/robertgshaw2-neuralmagic/lm-evaluation-harness.git@streaming-api
- pytest -v -s entrypoints/openai/test_accuracy.py::test_lm_eval_accuracy_v1_engine
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
- label: Examples Test # 25min
working_dir: "/vllm-workspace/examples"
@ -339,6 +339,14 @@ steps:
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- bash ./run-tests.sh -c configs/models-small.txt -t 1
- label: OpenAI API correctness
source_file_dependencies:
- csrc/
- vllm/entrypoints/openai/
- vllm/model_executor/models/whisper.py
commands: # LMEval+Transcription WER check
- pytest -s entrypoints/openai/correctness/
- label: Encoder Decoder tests # 5min
source_file_dependencies:
- vllm/

View File

@ -41,6 +41,8 @@ We currently support the following OpenAI APIs:
- *Note: `parallel_tool_calls` and `user` parameters are ignored.*
- [Embeddings API](#embeddings-api) (`/v1/embeddings`)
- Only applicable to [embedding models](../models/pooling_models.md) (`--task embed`).
- [Transcriptions API](#transcriptions-api) (`/v1/audio/transcriptions`)
- Only applicable to Automatic Speech Recognition (ASR) models (OpenAI Whisper) (`--task generate`).
In addition, we have the following custom APIs:
@ -296,6 +298,17 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s
:end-before: end-chat-embedding-extra-params
:::
(transcriptions-api)=
### Transcriptions API
Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription);
you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it.
<!-- TODO: api enforced limits + uploading audios -->
Code example: <gh-file:examples/online_serving/openai_transcription_client.py>
(tokenizer-api)=
### Tokenizer API

View File

@ -0,0 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
from openai import OpenAI
from vllm.assets.audio import AudioAsset
mary_had_lamb = AudioAsset('mary_had_lamb').get_local_path()
winning_call = AudioAsset('winning_call').get_local_path()
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
with open(str(mary_had_lamb), "rb") as f:
transcription = client.audio.transcriptions.create(
file=f,
model="openai/whisper-large-v3",
language="en",
response_format="text",
temperature=0.0)
print("transcription result:", transcription)

View File

@ -8,12 +8,11 @@ py-cpuinfo
transformers >= 4.48.2 # Required for Bamba model and Transformers backend.
tokenizers >= 0.19.1 # Required for Llama 3.
protobuf # Required by LlamaTokenizer.
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'
fastapi >= 0.107.0, != 0.113.*, != 0.114.0; python_version >= '3.9'
fastapi[standard] >= 0.107.0, < 0.113.0; python_version < '3.9'
fastapi[standard] >= 0.107.0, != 0.113.*, != 0.114.0; python_version >= '3.9'
aiohttp
openai >= 1.52.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support)
uvicorn[standard]
pydantic >= 2.9 # Required for fastapi >= 0.113.0
pydantic >= 2.9
prometheus_client >= 0.18.0
pillow # Required for image processing
prometheus-fastapi-instrumentator >= 7.0.0

View File

@ -19,6 +19,7 @@ pqdm
ray[adag]==2.40.0
sentence-transformers # required for embedding tests
soundfile # required for audio tests
jiwer # required for audio tests
timm # required for internvl test
torch==2.5.1
torchaudio==2.5.1

View File

@ -66,6 +66,7 @@ charset-normalizer==3.4.0
click==8.1.7
# via
# black
# jiwer
# nltk
# ray
colorama==0.4.6
@ -187,6 +188,8 @@ jinja2==3.1.4
# via
# datamodel-code-generator
# torch
jiwer==3.0.5
# via -r requirements-test.in
jmespath==1.0.1
# via
# boto3
@ -470,6 +473,8 @@ pyyaml==6.0.2
# timm
# transformers
# vocos
rapidfuzz==3.12.1
# via jiwer
ray[adag]==2.40.0
# via -r requirements-test.in
redis==5.2.0

View File

@ -13,7 +13,7 @@ import pytest
from vllm.platforms import current_platform
from ...utils import RemoteOpenAIServer
from ....utils import RemoteOpenAIServer
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
NUM_CONCURRENT = 500

View File

@ -0,0 +1,166 @@
# SPDX-License-Identifier: Apache-2.0
"""
Evaluate Transcription API correctness by computing Word Error Rate (WER)
on a given ASR dataset. When provided, it will also compare the WER against
a baseline.
This simulates real work usage of the API and makes sure that the frontend and
AsyncLLMEngine are working correctly.
"""
import asyncio
import io
import time
from statistics import mean, median
from typing import List
import librosa
import pytest
import soundfile
import torch
from datasets import load_dataset
from evaluate import load
from transformers import AutoTokenizer
from ....utils import RemoteOpenAIServer
def to_bytes(y, sr):
buffer = io.BytesIO()
soundfile.write(buffer, y, sr, format="WAV")
buffer.seek(0)
return buffer
async def transcribe_audio(client, tokenizer, y, sr):
# Send loaded audio directly instead of loading from disk,
# dont account for that time though
with to_bytes(y, sr) as f:
start_time = time.perf_counter()
transcription = await client.audio.transcriptions.create(
file=f,
model=tokenizer.name_or_path,
language="en",
temperature=0.0,
)
end_time = time.perf_counter()
# NOTE there's no streaming in transcriptions, can't measure ttft
latency = end_time - start_time
num_output_tokens = len(
tokenizer(transcription.text, add_special_tokens=False).input_ids)
return latency, num_output_tokens, transcription.text
async def bound_transcribe(model_name, sem, client, audio, reference):
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Use semaphore to limit concurrent requests.
async with sem:
result = await transcribe_audio(client, tokenizer, *audio)
# Normalize *english* output/reference for evaluation.
out = tokenizer.normalize(result[2])
ref = tokenizer.normalize(reference)
return result[:2] + (out, ref)
async def process_dataset(model, client, data, concurrent_request):
sem = asyncio.Semaphore(concurrent_request)
# Warmup call as the first `librosa.load` server-side is quite slow.
audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"]
_ = await bound_transcribe(model, sem, client, (audio, sr), "")
tasks: List[asyncio.Task] = []
for sample in data:
audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
task = asyncio.create_task(
bound_transcribe(model, sem, client, (audio, sr), sample["text"]))
tasks.append(task)
return await asyncio.gather(*tasks)
def print_performance_metrics(results, total_time):
latencies = [res[0] for res in results]
total_tokens = sum([res[1] for res in results])
total = len(results)
print(f"Total Requests: {total}")
print(f"Successful Requests: {len(latencies)}")
print(f"Average Latency: {mean(latencies):.4f} seconds")
print(f"Median Latency: {median(latencies):.4f} seconds")
perc = sorted(latencies)[int(len(latencies) * 0.95) - 1]
print(f"95th Percentile Latency: {perc:.4f} seconds")
# Throughput
req_throughput = len(latencies) / total_time
print(f"Estimated req_Throughput: {req_throughput:.2f} requests/s")
throughput = total_tokens / total_time
print(f"Estimated Throughput: {throughput:.2f} tok/s")
def add_duration(sample):
y, sr = sample['audio']["array"], sample['audio']["sampling_rate"]
sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000
return sample
def load_hf_dataset(dataset_repo: str, split='validation', **hf_kwargs):
## Load and filter the dataset
dataset = load_dataset(dataset_repo, split=split, **hf_kwargs)
if 'duration_ms' not in dataset[0]:
# compute duration to filter
dataset = dataset.map(add_duration)
# Whisper max supported duration
dataset = dataset.filter(lambda example: example['duration_ms'] < 30000)
return dataset
def run_evaluation(model: str,
client,
dataset,
max_concurrent_reqs: int,
n_examples: int = -1,
print_metrics: bool = True):
if n_examples > 0:
dataset = dataset.select(range(n_examples))
start = time.perf_counter()
results = asyncio.run(
process_dataset(model, client, dataset, max_concurrent_reqs))
end = time.perf_counter()
total_time = end - start
print(f"Total Test Time: {total_time:.4f} seconds")
if print_metrics:
print_performance_metrics(results, total_time)
# Compute WER
predictions = [res[2] for res in results]
references = [res[3] for res in results]
wer = load("wer")
wer_score = 100 * wer.compute(references=references,
predictions=predictions)
print("WER:", wer_score)
return wer_score
# alternatives "openai/whisper-large-v2", "openai/whisper-large-v3-turbo"..
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"])
# Original dataset is 20GB+ in size, hence we use a pre-filtered slice.
@pytest.mark.parametrize(
"dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"])
# NOTE: Expected WER measured with equivalent hf.transformers args:
# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered.
@pytest.mark.parametrize("expected_wer", [12.744980])
def test_wer_correctness(model_name,
dataset_repo,
expected_wer,
n_examples=-1,
max_concurrent_request=None):
with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server:
dataset = load_hf_dataset(dataset_repo)
if not max_concurrent_request:
# No max concurrency
max_concurrent_request = n_examples if n_examples > 0\
else len(dataset)
client = remote_server.get_async_client()
wer = run_evaluation(model_name, client, dataset,
max_concurrent_request, n_examples)
if expected_wer:
torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2)

View File

@ -0,0 +1,122 @@
# SPDX-License-Identifier: Apache-2.0
# imports for guided decoding tests
import io
import json
import librosa
import numpy as np
import openai
import pytest
import soundfile as sf
from vllm.assets.audio import AudioAsset
from ...utils import RemoteOpenAIServer
@pytest.fixture
def mary_had_lamb():
path = AudioAsset('mary_had_lamb').get_local_path()
with open(str(path), "rb") as f:
yield f
@pytest.fixture
def winning_call():
path = AudioAsset('winning_call').get_local_path()
with open(str(path), "rb") as f:
yield f
@pytest.mark.asyncio
async def test_basic_audio(mary_had_lamb):
model_name = "openai/whisper-large-v3-turbo"
server_args = ["--enforce-eager"]
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
prompt = "THE FIRST WORDS I SPOKE"
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
transcription = await client.audio.transcriptions.create(
model=model_name,
file=mary_had_lamb,
language="en",
response_format="text",
temperature=0.0)
out = json.loads(transcription)['text']
assert "Mary had a little lamb," in out
# This should "force" whisper to continue prompt in all caps
transcription_wprompt = await client.audio.transcriptions.create(
model=model_name,
file=mary_had_lamb,
language="en",
response_format="text",
prompt=prompt,
temperature=0.0)
out_capital = json.loads(transcription_wprompt)['text']
assert prompt not in out_capital
@pytest.mark.asyncio
async def test_bad_requests(mary_had_lamb):
model_name = "openai/whisper-small"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
# invalid language
with pytest.raises(openai.BadRequestError):
await client.audio.transcriptions.create(model=model_name,
file=mary_had_lamb,
language="hh",
temperature=0.0)
# Expect audio too long: repeat the timeseries
mary_had_lamb.seek(0)
audio, sr = librosa.load(mary_had_lamb)
repeated_audio = np.tile(audio, 10)
# Repeated audio to buffer
buffer = io.BytesIO()
sf.write(buffer, repeated_audio, sr, format='WAV')
buffer.seek(0)
with pytest.raises(openai.BadRequestError):
await client.audio.transcriptions.create(model=model_name,
file=buffer,
language="en",
temperature=0.0)
@pytest.mark.asyncio
async def test_non_asr_model(winning_call):
# text to text model
model_name = "JackFram/llama-68m"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
res = await client.audio.transcriptions.create(model=model_name,
file=winning_call,
language="en",
temperature=0.0)
assert res.code == 400 and not res.text
assert res.message == "The model does not support Transcriptions API"
@pytest.mark.asyncio
async def test_completion_endpoints():
# text to text model
model_name = "openai/whisper-small"
server_args = ["--enforce-eager"]
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
res = await client.chat.completions.create(
model=model_name,
messages=[{
"role": "system",
"content": "You are a helpful assistant."
}])
assert res.code == 400
assert res.message == "The model does not support Chat Completions API"
res = await client.completions.create(model=model_name, prompt="Hello")
assert res.code == 400
assert res.message == "The model does not support Completions API"

View File

@ -17,6 +17,7 @@ from vllm.platforms import current_platform
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "score"),
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"),
("openai/whisper-small", "transcription", "transcription"),
],
)
def test_auto_task(model_id, expected_runner_type, expected_task):

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from pathlib import Path
from typing import Literal
from urllib.parse import urljoin
@ -28,6 +29,10 @@ class AudioAsset:
s3_prefix=ASSET_DIR)
return librosa.load(audio_path, sr=None)
def get_local_path(self) -> Path:
return get_vllm_public_assets(filename=f"{self.name}.ogg",
s3_prefix=ASSET_DIR)
@property
def url(self) -> str:
return urljoin(VLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg")

View File

@ -54,17 +54,18 @@ _POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
"score", "reward"]
"score", "reward", "transcription"]
_ResolvedTask = Literal["generate", "embed", "classify", "score", "reward",
"draft"]
"draft", "transcription"]
RunnerType = Literal["generate", "pooling", "draft"]
RunnerType = Literal["generate", "pooling", "draft", "transcription"]
_RUNNER_TASKS: Dict[RunnerType, List[_ResolvedTask]] = {
"generate": ["generate"],
"pooling": ["embed", "classify", "score", "reward"],
"draft": ["draft"],
"transcription": ["transcription"],
}
_TASK_RUNNER: Dict[_ResolvedTask, RunnerType] = {
@ -484,6 +485,8 @@ class ModelConfig:
return "embed"
if ModelRegistry.is_cross_encoder_model(architectures):
return "score"
if ModelRegistry.is_transcription_model(architectures):
return "transcription"
suffix_to_preferred_task: List[Tuple[str, _ResolvedTask]] = [
# Other models follow this pattern
@ -516,6 +519,8 @@ class ModelConfig:
runner_support: Dict[RunnerType, bool] = {
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"transcription":
ModelRegistry.is_transcription_model(architectures),
"generate": ModelRegistry.is_text_generation_model(architectures),
"pooling": ModelRegistry.is_pooling_model(architectures),
}

View File

@ -16,10 +16,10 @@ from argparse import Namespace
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union
from typing import Annotated, AsyncIterator, Dict, Optional, Set, Tuple, Union
import uvloop
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
@ -61,6 +61,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ScoreRequest, ScoreResponse,
TokenizeRequest,
TokenizeResponse,
TranscriptionRequest,
TranscriptionResponse,
UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
# yapf: enable
@ -75,6 +77,8 @@ from vllm.entrypoints.openai.serving_rerank import JinaAIServingRerank
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import with_cancellation
from vllm.logger import init_logger
@ -327,6 +331,10 @@ def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
def transcription(request: Request) -> OpenAIServingTranscription:
return request.app.state.openai_serving_transcription
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@ -520,6 +528,31 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
return await create_score(request, raw_request)
@router.post("/v1/audio/transcriptions")
@with_cancellation
async def create_transcriptions(request: Annotated[TranscriptionRequest,
Form()],
raw_request: Request):
handler = transcription(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Transcriptions API")
audio_data = await request.file.read()
generator = await handler.create_transcription(audio_data, request,
raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, TranscriptionResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
@with_cancellation
async def do_rerank(request: RerankRequest, raw_request: Request):
@ -832,6 +865,12 @@ async def init_app_state(
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
)
state.openai_serving_transcription = OpenAIServingTranscription(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger,
) if model_config.runner_type == "transcription" else None
state.task = model_config.task

View File

@ -8,9 +8,10 @@ from argparse import Namespace
from typing import Any, ClassVar, Dict, List, Literal, Optional, Set, Union
import torch
from fastapi import UploadFile
from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter,
ValidationInfo, field_validator, model_validator)
from typing_extensions import Annotated
from typing_extensions import Annotated, TypeAlias
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.logger import init_logger
@ -1426,3 +1427,163 @@ class LoadLoraAdapterRequest(BaseModel):
class UnloadLoraAdapterRequest(BaseModel):
lora_name: str
lora_int_id: Optional[int] = Field(default=None)
## Protocols for Audio
AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json",
"vtt"]
class TranscriptionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
#https://platform.openai.com/docs/api-reference/audio/createTranscription
file: UploadFile
"""
The audio file object (not file name) to transcribe, in one of these
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
"""
model: str
"""ID of the model to use.
"""
language: Optional[str] = None
"""The language of the input audio.
Supplying the input language in
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
will improve accuracy and latency.
"""
prompt: str = Field(default="")
"""An optional text to guide the model's style or continue a previous audio
segment.
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
should match the audio language.
"""
response_format: AudioResponseFormat = Field(default="json")
"""
The format of the output, in one of these options: `json`, `text`, `srt`,
`verbose_json`, or `vtt`.
"""
## TODO (varun) : Support if set to 0, certain thresholds are met !!
temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1.
Higher values like 0.8 will make the output more random, while lower values
like 0.2 will make it more focused / deterministic. If set to 0, the model
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
to automatically increase the temperature until certain thresholds are hit.
"""
timestamp_granularities: List[Literal["word", "segment"]] = Field(
alias="timestamp_granularities[]", default=[])
"""The timestamp granularities to populate for this transcription.
`response_format` must be set `verbose_json` to use timestamp granularities.
Either or both of these options are supported: `word`, or `segment`. Note:
There is no additional latency for segment timestamps, but generating word
timestamps incurs additional latency.
"""
# Default sampling parameters for transcription requests.
_DEFAULT_SAMPLING_PARAMS: dict = {
"temperature": 0,
}
def to_sampling_params(
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Default parameters
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
return SamplingParams.from_optional(temperature=temperature,
max_tokens=max_tokens)
# Transcription response objects
class TranscriptionResponse(OpenAIBaseModel):
text: str
"""The transcribed text."""
class TranscriptionWord(OpenAIBaseModel):
end: float
"""End time of the word in seconds."""
start: float
"""Start time of the word in seconds."""
word: str
"""The text content of the word."""
class TranscriptionSegment(OpenAIBaseModel):
id: int
"""Unique identifier of the segment."""
avg_logprob: float
"""Average logprob of the segment.
If the value is lower than -1, consider the logprobs failed.
"""
compression_ratio: float
"""Compression ratio of the segment.
If the value is greater than 2.4, consider the compression failed.
"""
end: float
"""End time of the segment in seconds."""
no_speech_prob: float
"""Probability of no speech in the segment.
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
this segment silent.
"""
seek: int
"""Seek offset of the segment."""
start: float
"""Start time of the segment in seconds."""
temperature: float
"""Temperature parameter used for generating the segment."""
text: str
"""Text content of the segment."""
tokens: List[int]
"""Array of token IDs for the text content."""
class TranscriptionResponseVerbose(OpenAIBaseModel):
duration: str
"""The duration of the input audio."""
language: str
"""The language of the input audio."""
text: str
"""The transcribed text."""
segments: Optional[List[TranscriptionSegment]] = None
"""Segments of the transcribed text and their corresponding details."""
words: Optional[List[TranscriptionWord]] = None
"""Extracted words and their corresponding timestamps."""

View File

@ -31,7 +31,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ErrorResponse, RerankRequest,
ScoreRequest,
TokenizeChatRequest,
TokenizeCompletionRequest)
TokenizeCompletionRequest,
TranscriptionRequest)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser
# yapf: enable
@ -57,7 +58,8 @@ CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
TokenizeChatRequest]
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest]
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest,
TranscriptionRequest]
class TextTokensPrompt(TypedDict):

View File

@ -0,0 +1,305 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import io
from typing import AsyncGenerator, Optional, Union, cast
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ErrorResponse,
RequestResponseMetadata,
TranscriptionRequest,
TranscriptionResponse,
TranscriptionResponseVerbose)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.utils import PlaceholderModule
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
logger = init_logger(__name__)
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages#supported-languages
# TODO these configs should live somewhere with the model so we can support
# additional ones
ISO639_1_SUPPORTED_LANGS = {
"af": "Afrikaans",
"ar": "Arabic",
"hy": "Armenian",
"az": "Azerbaijani",
"be": "Belarusian",
"bs": "Bosnian",
"bg": "Bulgarian",
"ca": "Catalan",
"zh": "Chinese",
"hr": "Croatian",
"cs": "Czech",
"da": "Danish",
"nl": "Dutch",
"en": "English",
"et": "Estonian",
"fi": "Finnish",
"fr": "French",
"gl": "Galician",
"de": "German",
"el": "Greek",
"he": "Hebrew",
"hi": "Hindi",
"hu": "Hungarian",
"is": "Icelandic",
"id": "Indonesian",
"it": "Italian",
"ja": "Japanese",
"kn": "Kannada",
"kk": "Kazakh",
"ko": "Korean",
"lv": "Latvian",
"lt": "Lithuanian",
"mk": "Macedonian",
"ms": "Malay",
"mr": "Marathi",
"mi": "Maori",
"ne": "Nepali",
"no": "Norwegian",
"fa": "Persian",
"pl": "Polish",
"pt": "Portuguese",
"ro": "Romanian",
"ru": "Russian",
"sr": "Serbian",
"sk": "Slovak",
"sl": "Slovenian",
"es": "Spanish",
"sw": "Swahili",
"sv": "Swedish",
"tl": "Tagalog",
"ta": "Tamil",
"th": "Thai",
"tr": "Turkish",
"uk": "Ukrainian",
"ur": "Urdu",
"vi": "Vietnamese",
"cy": "Welsh"
}
ISO639_1_OTHER_LANGS = {
"lo": "Lao",
"jw": "Javanese",
"tk": "Turkmen",
"yi": "Yiddish",
"so": "Somali",
"bn": "Bengali",
"nn": "Norwegian Nynorsk",
"si": "Sinhala",
"yo": "Yoruba",
"sa": "Sanskrit",
"mi": "Māori",
"fo": "Faroese", # codespell:ignore
"mt": "Maltese",
"tg": "Tajik",
"mg": "Malagasy",
"haw": "Hawaiian",
"km": "Khmer",
"br": "Breton",
"ps": "Pashto",
"ln": "Lingala",
"la": "Latin",
"ml": "Malayalam",
"sq": "Albanian",
"su": "Sundanese",
"eu": "Basque",
"ka": "Georgian",
"uz": "Uzbek",
"sn": "Shona",
"ht": "Haitian",
"as": "Assamese",
"mn": "Mongolian",
"te": "Telugu",
"pa": "Panjabi",
"tt": "Tatar",
"gu": "Gujarati",
"oc": "Occitan",
"ha": "Hausa",
"ba": "Bashkir",
"my": "Burmese",
"sd": "Sindhi",
"am": "Amharic",
"lb": "Luxembourgish",
"bo": "Tibetan"
}
# As per https://platform.openai.com/docs/guides/speech-to-text#overview.
# TODO configurable
MAX_AUDIO_CLIP_FILESIZE_MB = 25
# TODO get from processor.feature_extractor.chunk_length
MAX_AUDIO_CLIP_DURATION_S = 30
class OpenAIServingTranscription(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
diff_sampling_param = self.model_config.get_diff_sampling_param()
if diff_sampling_param:
logger.info(
"Overwriting default completion sampling param with: %s",
diff_sampling_param)
async def _preprocess_transcription(
self,
request: TranscriptionRequest,
audio_data: bytes,
) -> PromptType:
# Validate request
# TODO language should be optional and can be guessed.
# For now we default to en. See
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
lang_token = f"<|{request.language}|>" if request.language else "<|en|>"
if request.language:
if request.language in ISO639_1_SUPPORTED_LANGS:
pass
elif request.language in ISO639_1_OTHER_LANGS:
logger.warning(
"The selected language %s has limited accuracy with"
" reported WER>=0.5. Results may be less accurate "
"for this choice.", request.language)
else:
raise ValueError(
f"Unsupported language: {request.language}."
"Language should be one of:" +
f" {list(ISO639_1_SUPPORTED_LANGS.values())}" +
f"or {list(ISO639_1_OTHER_LANGS.values())}")
if len(audio_data) / 1024**2 > MAX_AUDIO_CLIP_FILESIZE_MB:
raise ValueError("Maximum file size exceeded.")
with io.BytesIO(audio_data) as bytes_:
y, sr = librosa.load(bytes_)
if librosa.get_duration(y=y, sr=sr) > MAX_AUDIO_CLIP_DURATION_S:
raise ValueError(
f"Maximum clip duration ({MAX_AUDIO_CLIP_DURATION_S}s) "
"exceeded.")
prompt = {
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {
"audio": (y, sr),
},
},
"decoder_prompt":
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
}
return cast(PromptType, prompt)
# TODO (varun) : Make verbose response work !
async def create_transcription(
self, audio_data: bytes, request: TranscriptionRequest,
raw_request: Request
) -> Union[TranscriptionResponse, TranscriptionResponseVerbose,
ErrorResponse]:
"""Transcription API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/audio/createTranscription
for the API specification. This API mimics the OpenAI transcription API.
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
# If the engine is dead, raise the engine's DEAD_ERROR.
# This is required for the streaming case, where we return a
# success status before we actually start generating text :).
if self.engine_client.errored:
raise self.engine_client.dead_error
if request.response_format not in ['text', 'json']:
return self.create_error_response(
"Currently only support response_format `text` or `json`")
# TODO cmpl->transcription?
request_id = f"cmpl-{self._base_request_id(raw_request)}"
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
raw_request.state.request_metadata = request_metadata
try:
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
if lora_request:
return self.create_error_response(
"Currently do not support LoRA for Transcription.")
if prompt_adapter_request:
return self.create_error_response(
"Currently do not support PromptAdapter for Transcription."
)
prompt = await self._preprocess_transcription(
request=request,
audio_data=audio_data,
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
result_generator: Optional[AsyncGenerator[RequestOutput, None]] = None
try:
# TODO(rob): subtract len of tokenized prompt.
default_max_tokens = self.model_config.max_model_len
default_params = self.model_config.get_diff_sampling_param()
sampling_params = request.to_sampling_params(
default_max_tokens, default_params)
self._log_inputs(
request_id,
prompt['decoder_prompt'], # type: ignore
params=sampling_params,
lora_request=None,
prompt_adapter_request=None)
result_generator = self.engine_client.generate(
prompt,
sampling_params,
request_id,
)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
# TODO(rob): figure out a way to pipe streaming in.
# Non-streaming response.
try:
async for op in result_generator:
result = op
return TranscriptionResponse(text=result.outputs[0].text)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))

View File

@ -441,3 +441,30 @@ def supports_cross_encoding(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
return is_pooling_model(model) and _supports_cross_encoding(model)
@runtime_checkable
class SupportsTranscription(Protocol):
"""The interface required for all models that support transcription."""
supports_transcription: ClassVar[Literal[True]] = True
@overload
def supports_transcription(
model: Type[object]) -> TypeIs[Type[SupportsTranscription]]:
...
@overload
def supports_transcription(model: object) -> TypeIs[SupportsTranscription]:
...
def supports_transcription(
model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsTranscription]], TypeIs[SupportsTranscription]]:
if isinstance(model, type):
return isinstance(model, SupportsTranscription)
return isinstance(model, SupportsTranscription)

View File

@ -22,7 +22,7 @@ from vllm.logger import init_logger
from .interfaces import (has_inner_state, is_attention_free, is_hybrid,
supports_cross_encoding, supports_multimodal,
supports_pp)
supports_pp, supports_transcription)
from .interfaces_base import is_text_generation_model
logger = init_logger(__name__)
@ -224,6 +224,7 @@ class _ModelInfo:
has_inner_state: bool
is_attention_free: bool
is_hybrid: bool
supports_transcription: bool
@staticmethod
def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo":
@ -237,7 +238,7 @@ class _ModelInfo:
has_inner_state=has_inner_state(model),
is_attention_free=is_attention_free(model),
is_hybrid=is_hybrid(model),
)
supports_transcription=supports_transcription(model))
class _BaseRegisteredModel(ABC):
@ -485,6 +486,13 @@ class _ModelRegistry:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.is_hybrid
def is_transcription_model(
self,
architectures: Union[str, List[str]],
) -> bool:
model_cls, _ = self.inspect_model_cls(architectures)
return model_cls.supports_transcription
ModelRegistry = _ModelRegistry({
model_arch:

View File

@ -31,7 +31,7 @@ from vllm.multimodal.audio import resample_audio
from vllm.sequence import SequenceData
from vllm.transformers_utils.processor import cached_get_processor
from .interfaces import SupportsMultiModal
from .interfaces import SupportsMultiModal, SupportsTranscription
from .utils import AutoWeightsLoader, WeightsMapper, make_layers
logger = init_logger(__name__)
@ -637,7 +637,8 @@ def input_mapper_for_whisper(
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_max_whisper_audio_tokens)
class WhisperForConditionalGeneration(nn.Module, SupportsMultiModal):
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
SupportsMultiModal):
packed_modules_mapping = {
"self_attn.qkv_proj": [
"self_attn.q_proj",