Compare commits

...

1 Commits

Author SHA1 Message Date
d3eddd6ef1 initial
Signed-off-by: Roger Wang <ywang@roblox.com>
2025-04-01 16:06:59 -07:00
3 changed files with 395 additions and 52 deletions

View File

@ -67,6 +67,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TokenizeResponse,
TranscriptionRequest,
TranscriptionResponse,
TranslationRequest,
TranslationResponse,
UnloadLoRAAdapterRequest)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
@ -80,7 +82,7 @@ from vllm.entrypoints.openai.serving_score import ServingScores
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription)
OpenAIServingTranscription, OpenAIServingTranslation)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
with_cancellation)
@ -383,6 +385,10 @@ def transcription(request: Request) -> OpenAIServingTranscription:
return request.app.state.openai_serving_transcription
def translation(request: Request) -> OpenAIServingTranslation:
return request.app.state.openai_serving_translation
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@ -625,6 +631,31 @@ async def create_transcriptions(request: Annotated[TranscriptionRequest,
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/v1/audio/translations")
@with_cancellation
@load_aware_call
async def create_translations(request: Annotated[TranslationRequest,
Form()],
raw_request: Request):
handler = translation(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support Translations API")
audio_data = await request.file.read()
generator = await handler.create_translation(audio_data, request,
raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
elif isinstance(generator, TranslationResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
@router.post("/rerank", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call

View File

@ -1652,3 +1652,196 @@ class TranscriptionResponseVerbose(OpenAIBaseModel):
words: Optional[list[TranscriptionWord]] = None
"""Extracted words and their corresponding timestamps."""
class TranslationResponseStreamChoice(OpenAIBaseModel):
delta: DeltaMessage
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = None
class TranslationStreamResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}")
object: Literal["translation.chunk"] = "translation.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[TranslationResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
class TranslationRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/audio/createTranslation
file: UploadFile
"""
The audio file object (not file name) to translate, in one of these
formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
"""
model: Optional[str] = None
"""ID of the model to use.
"""
language: Optional[str] = None
"""The language of the input audio.
Supplying the input language in
[ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format
will improve accuracy and latency.
"""
prompt: str = Field(default="")
"""An optional text to guide the model's style or continue a previous audio
segment.
The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting)
should match the audio language.
"""
response_format: AudioResponseFormat = Field(default="json")
"""
The format of the output, in one of these options: `json`, `text`, `srt`,
`verbose_json`, or `vtt`.
"""
## TODO (varun) : Support if set to 0, certain thresholds are met !!
temperature: float = Field(default=0.0)
"""The sampling temperature, between 0 and 1.
Higher values like 0.8 will make the output more random, while lower values
like 0.2 will make it more focused / deterministic. If set to 0, the model
will use [log probability](https://en.wikipedia.org/wiki/Log_probability)
to automatically increase the temperature until certain thresholds are hit.
"""
timestamp_granularities: list[Literal["word", "segment"]] = Field(
alias="timestamp_granularities[]", default=[])
"""The timestamp granularities to populate for this translation.
`response_format` must be set `verbose_json` to use timestamp granularities.
Either or both of these options are supported: `word`, or `segment`. Note:
There is no additional latency for segment timestamps, but generating word
timestamps incurs additional latency.
"""
stream: Optional[bool] = False
"""Custom field not present in the original OpenAI definition. When set,
it will enable output to be streamed in a similar fashion as the Chat
Completion endpoint.
"""
# Flattened stream option to simplify form data.
stream_include_usage: Optional[bool] = False
stream_continuous_usage_stats: Optional[bool] = False
# Default sampling parameters for translation requests.
_DEFAULT_SAMPLING_PARAMS: dict = {
"temperature": 0,
}
def to_sampling_params(
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None) -> SamplingParams:
# TODO(#9845): remove max_tokens when field is removed from OpenAI API
max_tokens = default_max_tokens
if default_sampling_params is None:
default_sampling_params = {}
# Default parameters
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
return SamplingParams.from_optional(temperature=temperature,
max_tokens=max_tokens,
output_kind=RequestOutputKind.DELTA
if self.stream \
else RequestOutputKind.FINAL_ONLY)
@model_validator(mode="before")
@classmethod
def validate_stream_options(cls, data):
stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"]
stream = data.get("stream", False)
if any(bool(data.get(so, False)) for so in stream_opts) and not stream:
raise ValueError(
"Stream options can only be defined when `stream=True`.")
return data
# Translation response objects
class TranslationResponse(OpenAIBaseModel):
text: str
"""The translated text."""
class TranslationWord(OpenAIBaseModel):
end: float
"""End time of the word in seconds."""
start: float
"""Start time of the word in seconds."""
word: str
"""The text content of the word."""
class TranslationSegment(OpenAIBaseModel):
id: int
"""Unique identifier of the segment."""
avg_logprob: float
"""Average logprob of the segment.
If the value is lower than -1, consider the logprobs failed.
"""
compression_ratio: float
"""Compression ratio of the segment.
If the value is greater than 2.4, consider the compression failed.
"""
end: float
"""End time of the segment in seconds."""
no_speech_prob: float
"""Probability of no speech in the segment.
If the value is higher than 1.0 and the `avg_logprob` is below -1, consider
this segment silent.
"""
seek: int
"""Seek offset of the segment."""
start: float
"""Start time of the segment in seconds."""
temperature: float
"""Temperature parameter used for generating the segment."""
text: str
"""Text content of the segment."""
tokens: list[int]
"""Array of token IDs for the text content."""
class TranslationResponseVerbose(OpenAIBaseModel):
duration: str
"""The duration of the input audio."""
language: str
"""The language of the input audio."""
text: str
"""The translated text."""
segments: Optional[list[TranslationSegment]] = None
"""Segments of the translated text and their corresponding details."""
words: Optional[list[TranslationWord]] = None
"""Extracted words and their corresponding timestamps."""

View File

@ -4,7 +4,7 @@ import io
import time
from collections.abc import AsyncGenerator
from math import ceil
from typing import Final, Optional, Union, cast
from typing import Callable, Optional, Union, cast
from fastapi import Request
@ -14,7 +14,8 @@ from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
DeltaMessage, ErrorResponse, RequestResponseMetadata, TranscriptionRequest,
TranscriptionResponse, TranscriptionResponseStreamChoice,
TranscriptionStreamResponse, UsageInfo)
TranscriptionStreamResponse, TranslationRequest, TranslationResponse,
TranslationResponseStreamChoice, TranslationStreamResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import PromptType
@ -30,7 +31,7 @@ except ImportError:
logger = init_logger(__name__)
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages#supported-languages
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
# TODO these configs should live somewhere with the model so we can support
# additional ones
@ -144,16 +145,19 @@ ISO639_1_OTHER_LANGS = {
MAX_AUDIO_CLIP_FILESIZE_MB = 25
class OpenAIServingTranscription(OpenAIServing):
class OpenAISpeechToText(OpenAIServing):
"""Base class for speech-to-text operations like transcription and
translation."""
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
task_type: str = "transcribe", # or "translate"
):
super().__init__(engine_client=engine_client,
model_config=model_config,
@ -167,15 +171,16 @@ class OpenAIServingTranscription(OpenAIServing):
self.max_audio_clip_s = processor.feature_extractor.chunk_length
self.model_sr = processor.feature_extractor.sampling_rate
self.hop_length = processor.feature_extractor.hop_length
self.task_type = task_type
if self.default_sampling_params:
logger.info(
"Overwriting default completion sampling param with: %s",
self.default_sampling_params)
async def _preprocess_transcription(
async def _preprocess_speech_to_text(
self,
request: TranscriptionRequest,
request: Union[TranscriptionRequest, TranslationRequest],
audio_data: bytes,
) -> tuple[PromptType, float]:
# Validate request
@ -218,21 +223,22 @@ class OpenAIServingTranscription(OpenAIServing):
},
},
"decoder_prompt":
f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}"
(f"<|startoftranscript|>{lang_token}"
f"<|{self.task_type}|><|notimestamps|>{request.prompt}")
}
return cast(PromptType, prompt), duration
# TODO (varun) : Make verbose response work !
async def create_transcription(
self, audio_data: bytes, request: TranscriptionRequest,
raw_request: Request
) -> Union[TranscriptionResponse, AsyncGenerator[str, None],
ErrorResponse]:
"""Transcription API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/audio/createTranscription
for the API specification. This API mimics the OpenAI transcription API.
"""
async def _create_speech_to_text(
self,
audio_data: bytes,
request: Union[TranscriptionRequest, TranslationRequest],
raw_request: Request,
response_class: Union[TranscriptionResponse, TranslationResponse],
stream_generator_method: Callable,
) -> Union[Union[TranscriptionResponse, TranslationResponse],
AsyncGenerator[str, None], ErrorResponse]:
"""Base method for speech-to-text operations like transcription and
translation."""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
@ -247,7 +253,7 @@ class OpenAIServingTranscription(OpenAIServing):
return self.create_error_response(
"Currently only support response_format `text` or `json`")
request_id = f"trsc-{self._base_request_id(raw_request)}"
request_id = f"{self.task_type}-{self._base_request_id(raw_request)}"
request_metadata = RequestResponseMetadata(request_id=request_id)
if raw_request:
@ -261,13 +267,14 @@ class OpenAIServingTranscription(OpenAIServing):
if lora_request:
return self.create_error_response(
"Currently do not support LoRA for Transcription.")
"Currently do not support LoRA for "
f"{self.task_type.title()}.")
if prompt_adapter_request:
return self.create_error_response(
"Currently do not support PromptAdapter for Transcription."
)
f"Currently do not support PromptAdapter for "
f"{self.task_type.title()}.")
prompt, duration_s = await self._preprocess_transcription(
prompt, duration_s = await self._preprocess_speech_to_text(
request=request,
audio_data=audio_data,
)
@ -300,31 +307,36 @@ class OpenAIServingTranscription(OpenAIServing):
return self.create_error_response(str(e))
if request.stream:
return self.transcription_stream_generator(request,
result_generator,
request_id,
request_metadata,
duration_s)
return stream_generator_method(request, result_generator,
request_id, request_metadata,
duration_s)
# Non-streaming response.
try:
assert result_generator is not None
async for op in result_generator:
result = op
return TranscriptionResponse(text=result.outputs[0].text)
return response_class(text=result.outputs[0].text)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
async def transcription_stream_generator(
self, request: TranscriptionRequest,
result_generator: AsyncGenerator[RequestOutput, None],
request_id: str, request_metadata: RequestResponseMetadata,
audio_duration_s: float) -> AsyncGenerator[str, None]:
async def _speech_to_text_stream_generator(
self,
request: Union[TranscriptionRequest, TranslationRequest],
result_generator: AsyncGenerator[RequestOutput, None],
request_id: str,
request_metadata: RequestResponseMetadata,
audio_duration_s: float,
chunk_object_type: str,
response_stream_choice_class: Union[TranscriptionResponseStreamChoice,
TranslationResponseStreamChoice],
stream_response_class: Union[TranscriptionStreamResponse,
TranslationStreamResponse],
) -> AsyncGenerator[str, None]:
created_time = int(time.time())
model_name = request.model
chunk_object_type: Final = "transcription.chunk"
completion_tokens = 0
num_prompt_tokens = 0
@ -361,20 +373,20 @@ class OpenAIServingTranscription(OpenAIServing):
if output.finish_reason is None:
# Still generating, send delta update.
choice_data = TranscriptionResponseStreamChoice(
choice_data = response_stream_choice_class(
delta=delta_message)
else:
# Model is finished generating.
choice_data = TranscriptionResponseStreamChoice(
choice_data = response_stream_choice_class(
delta=delta_message,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
chunk = TranscriptionStreamResponse(id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
chunk = stream_response_class(id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
# handle usage stats if requested & if continuous
if include_continuous_usage:
@ -395,7 +407,7 @@ class OpenAIServingTranscription(OpenAIServing):
total_tokens=num_prompt_tokens +
completion_tokens)
final_usage_chunk = TranscriptionStreamResponse(
final_usage_chunk = stream_response_class(
id=request_id,
object=chunk_object_type,
created=created_time,
@ -414,8 +426,115 @@ class OpenAIServingTranscription(OpenAIServing):
except Exception as e:
# TODO: Use a vllm-specific Validation Error
logger.exception("Error in chat completion stream generator.")
logger.exception("Error in %s stream generator.", self.task_type)
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"
class OpenAIServingTranscription(OpenAISpeechToText):
"""Handles transcription requests."""
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
task_type="transcribe")
async def create_transcription(
self, audio_data: bytes, request: TranscriptionRequest,
raw_request: Request
) -> Union[TranscriptionResponse, AsyncGenerator[str, None],
ErrorResponse]:
"""Transcription API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/audio/createTranscription
for the API specification. This API mimics the OpenAI transcription API.
"""
return await self._create_speech_to_text(
audio_data=audio_data,
request=request,
raw_request=raw_request,
response_class=TranscriptionResponse,
stream_generator_method=self.transcription_stream_generator,
)
async def transcription_stream_generator(
self, request: TranscriptionRequest,
result_generator: AsyncGenerator[RequestOutput, None],
request_id: str, request_metadata: RequestResponseMetadata,
audio_duration_s: float) -> AsyncGenerator[str, None]:
return await self._speech_to_text_stream_generator(
request=request,
result_generator=result_generator,
request_id=request_id,
request_metadata=request_metadata,
audio_duration_s=audio_duration_s,
chunk_object_type="transcription.chunk",
response_stream_choice_class=TranscriptionResponseStreamChoice,
stream_response_class=TranscriptionStreamResponse,
)
class OpenAIServingTranslation(OpenAISpeechToText):
"""Handles translation requests."""
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids,
task_type="translate")
async def create_translation(
self, audio_data: bytes, request: TranslationRequest,
raw_request: Request
) -> Union[TranslationResponse, AsyncGenerator[str, None], ErrorResponse]:
"""Translation API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/audio/createTranslation
for the API specification. This API mimics the OpenAI translation API.
"""
return await self._create_speech_to_text(
audio_data=audio_data,
request=request,
raw_request=raw_request,
response_class=TranslationResponse,
stream_generator_method=self.translation_stream_generator,
)
async def translation_stream_generator(
self, request: TranslationRequest,
result_generator: AsyncGenerator[RequestOutput, None],
request_id: str, request_metadata: RequestResponseMetadata,
audio_duration_s: float) -> AsyncGenerator[str, None]:
return await self._speech_to_text_stream_generator(
request=request,
result_generator=result_generator,
request_id=request_id,
request_metadata=request_metadata,
audio_duration_s=audio_duration_s,
chunk_object_type="translation.chunk",
response_stream_choice_class=TranslationResponseStreamChoice,
stream_response_class=TranslationStreamResponse,
)