Files
vllm/vllm/entrypoints/openai/serving_engine.py
Gabriel Marinho 5b8077b8ac Fix wrong truncate_prompt_tokens type hint (#22761)
Signed-off-by: Gabriel Marinho <gmarinho@ibm.com>
Signed-off-by: Gabriel Marinho <104592062+gmarinho2@users.noreply.github.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
2025-08-30 20:39:38 +00:00

1141 lines
44 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import io
import json
import sys
import time
import traceback
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from http import HTTPStatus
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
TypeVar, Union, cast, overload)
import pybase64
import torch
from fastapi import Request
from pydantic import BaseModel, ConfigDict, Field
from starlette.datastructures import Headers
from typing_extensions import TypeIs
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
ConversationMessage,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages_futures,
resolve_chat_template_content_format)
from vllm.entrypoints.context import ConversationContext
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
ClassificationRequest,
ClassificationResponse,
CompletionRequest,
CompletionResponse,
DetokenizeRequest,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse, ErrorInfo,
ErrorResponse, PoolingResponse,
RerankRequest, ResponsesRequest,
ScoreRequest, ScoreResponse,
TokenizeChatRequest,
TokenizeCompletionRequest,
TokenizeResponse,
TranscriptionRequest,
TranscriptionResponse,
TranslationRequest)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser
# yapf: enable
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import ( # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
MultiModalDataDict)
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob, PromptLogprobs
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import (AsyncMicrobatchTokenizer, is_list_of,
merge_async_iterators, random_uuid)
logger = init_logger(__name__)
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
EmbeddingCompletionRequest, RerankRequest,
ClassificationRequest, ScoreRequest,
TokenizeCompletionRequest]
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
TokenizeChatRequest]
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest,
ResponsesRequest]
AnyResponse = Union[
CompletionResponse,
ChatCompletionResponse,
EmbeddingResponse,
TranscriptionResponse,
TokenizeResponse,
PoolingResponse,
ClassificationResponse,
ScoreResponse,
]
class TextTokensPrompt(TypedDict):
prompt: str
prompt_token_ids: list[int]
class EmbedsPrompt(TypedDict):
prompt_embeds: torch.Tensor
RequestPrompt = Union[list[int], str, TextTokensPrompt, EmbedsPrompt]
def is_text_tokens_prompt(prompt: RequestPrompt) -> TypeIs[TextTokensPrompt]:
return (isinstance(prompt, dict) and "prompt_token_ids" in prompt
and "prompt_embeds" not in prompt)
def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:
return (isinstance(prompt, dict) and "prompt_token_ids" not in prompt
and "prompt_embeds" in prompt)
RequestT = TypeVar("RequestT", bound=AnyRequest)
class RequestProcessingMixin(BaseModel):
"""
Mixin for request processing,
handling prompt preparation and engine input.
"""
request_prompts: Optional[Sequence[RequestPrompt]] = []
engine_prompts: Optional[Union[list[EngineTokensPrompt],
list[EngineEmbedsPrompt]]] = []
model_config = ConfigDict(arbitrary_types_allowed=True)
class ResponseGenerationMixin(BaseModel):
"""
Mixin for response generation,
managing result generators and final batch results.
"""
result_generator: Optional[AsyncGenerator[tuple[int, Union[
RequestOutput, PoolingRequestOutput]], None]] = None
final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field(
default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel,
Generic[RequestT]):
# Shared across all requests
request: RequestT
raw_request: Optional[Request] = None
model_name: str
request_id: str
created_time: int = Field(default_factory=lambda: int(time.time()))
lora_request: Optional[LoRARequest] = None
# Shared across most requests
tokenizer: Optional[AnyTokenizer] = None
# `protected_namespaces` resolves Pydantic v2's warning
# on conflict with protected namespace "model_"
model_config = ConfigDict(
protected_namespaces=(),
arbitrary_types_allowed=True,
)
ClassificationServeContext = ServeContext[ClassificationRequest]
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
chat_template: Optional[str] = None
chat_template_content_format: ChatTemplateContentFormatOption
# Used to resolve the Pydantic error related to
# forward reference of MultiModalDataDict in TokensPrompt
RequestProcessingMixin.model_rebuild()
ServeContext.model_rebuild()
ClassificationServeContext.model_rebuild()
EmbeddingServeContext.model_rebuild()
class OpenAIServing:
request_id_prefix: ClassVar[str] = """
A short string prepended to every requests ID (e.g. "embd", "classify")
so you can easily tell “this ID came from Embedding vs Classification.”
"""
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
enable_force_include_usage: bool = False,
log_error_stack: bool = False,
):
super().__init__()
self.engine_client = engine_client
self.model_config = model_config
self.max_model_len = model_config.max_model_len
self.models = models
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
self.enable_force_include_usage = enable_force_include_usage
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
self._async_tokenizer_pool: dict[AnyTokenizer,
AsyncMicrobatchTokenizer] = {}
self.log_error_stack = log_error_stack
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
"""
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
given tokenizer.
"""
async_tokenizer = self._async_tokenizer_pool.get(tokenizer)
if async_tokenizer is None:
async_tokenizer = AsyncMicrobatchTokenizer(tokenizer)
self._async_tokenizer_pool[tokenizer] = async_tokenizer
return async_tokenizer
async def _preprocess(
self,
ctx: ServeContext,
) -> Optional[ErrorResponse]:
"""
Default preprocessing hook. Subclasses may override
to prepare `ctx` (classification, embedding, etc.).
"""
return None
def _build_response(
self,
ctx: ServeContext,
) -> Union[AnyResponse, ErrorResponse]:
"""
Default response builder. Subclass may override this method
to return the appropriate response object.
"""
return self.create_error_response("unimplemented endpoint")
async def handle(
self,
ctx: ServeContext,
) -> Union[AnyResponse, ErrorResponse]:
generation: AsyncGenerator[Union[AnyResponse, ErrorResponse], None]
generation = self._pipeline(ctx)
async for response in generation:
return response
return self.create_error_response("No response yielded from pipeline")
async def _pipeline(
self,
ctx: ServeContext,
) -> AsyncGenerator[Union[AnyResponse, ErrorResponse], None]:
"""Execute the request processing pipeline yielding responses."""
if error := await self._check_model(ctx.request):
yield error
if error := self._validate_request(ctx):
yield error
preprocess_ret = await self._preprocess(ctx)
if isinstance(preprocess_ret, ErrorResponse):
yield preprocess_ret
generators_ret = await self._prepare_generators(ctx)
if isinstance(generators_ret, ErrorResponse):
yield generators_ret
collect_ret = await self._collect_batch(ctx)
if isinstance(collect_ret, ErrorResponse):
yield collect_ret
yield self._build_response(ctx)
def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]:
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
None)
if truncate_prompt_tokens is not None and \
truncate_prompt_tokens > self.max_model_len:
return self.create_error_response(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size.")
return None
def _create_pooling_params(
self,
ctx: ServeContext,
) -> Union[PoolingParams, ErrorResponse]:
if not hasattr(ctx.request, "to_pooling_params"):
return self.create_error_response(
"Request type does not support pooling parameters")
return ctx.request.to_pooling_params()
async def _prepare_generators(
self,
ctx: ServeContext,
) -> Optional[ErrorResponse]:
"""Schedule the request and get the result generator."""
generators: list[AsyncGenerator[Union[RequestOutput,
PoolingRequestOutput],
None]] = []
try:
trace_headers = (None if ctx.raw_request is None else await
self._get_trace_headers(ctx.raw_request.headers))
pooling_params = self._create_pooling_params(ctx)
if isinstance(pooling_params, ErrorResponse):
return pooling_params
if ctx.engine_prompts is None:
return self.create_error_response(
"Engine prompts not available")
for i, engine_prompt in enumerate(ctx.engine_prompts):
request_id_item = f"{ctx.request_id}-{i}"
if ctx.request_prompts is None:
return self.create_error_response(
"Request prompts not available")
self._log_inputs(request_id_item,
ctx.request_prompts[i],
params=pooling_params,
lora_request=ctx.lora_request)
# Mypy has an existing bug related to inferring the variance of
# TypedDicts with `builtins.enumerate`:
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
engine_prompt = cast(
Union[EngineTokensPrompt, EngineEmbedsPrompt],
engine_prompt)
generator = self.engine_client.encode(
engine_prompt,
pooling_params,
request_id_item,
lora_request=ctx.lora_request,
trace_headers=trace_headers,
priority=getattr(ctx.request, "priority", 0),
)
generators.append(generator)
ctx.result_generator = merge_async_iterators(*generators)
return None
except Exception as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
async def _collect_batch(
self,
ctx: ServeContext,
) -> Optional[ErrorResponse]:
"""Collect batch results from the result generator."""
try:
if ctx.engine_prompts is None:
return self.create_error_response(
"Engine prompts not available")
num_prompts = len(ctx.engine_prompts)
final_res_batch: list[Optional[Union[RequestOutput,
PoolingRequestOutput]]]
final_res_batch = [None] * num_prompts
if ctx.result_generator is None:
return self.create_error_response(
"Result generator not available")
async for i, res in ctx.result_generator:
final_res_batch[i] = res
if None in final_res_batch:
return self.create_error_response(
"Failed to generate results for all prompts")
ctx.final_res_batch = [
res for res in final_res_batch if res is not None
]
return None
except Exception as e:
return self.create_error_response(str(e))
def create_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
if self.log_error_stack:
exc_type, _, _ = sys.exc_info()
if exc_type is not None:
traceback.print_exc()
else:
traceback.print_stack()
return ErrorResponse(error=ErrorInfo(
message=message, type=err_type, code=status_code.value))
def create_streaming_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
json_str = json.dumps(
self.create_error_response(message=message,
err_type=err_type,
status_code=status_code).model_dump())
return json_str
async def _check_model(
self,
request: AnyRequest,
) -> Optional[ErrorResponse]:
error_response = None
if self._is_model_supported(request.model):
return None
if request.model in self.models.lora_requests:
return None
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and (
load_result := await self.models.resolve_lora(request.model)):
if isinstance(load_result, LoRARequest):
return None
if isinstance(load_result, ErrorResponse) and \
load_result.error.code == HTTPStatus.BAD_REQUEST.value:
error_response = load_result
return error_response or self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
def _get_active_default_mm_loras(
self, request: AnyRequest) -> Optional[LoRARequest]:
"""Determine if there are any active default multimodal loras."""
# TODO: Currently this is only enabled for chat completions
# to be better aligned with only being enabled for .generate
# when run offline. It would be nice to support additional
# tasks types in the future.
message_types = self._get_message_types(request)
default_mm_loras = set()
for lora in self.models.lora_requests.values():
# Best effort match for default multimodal lora adapters;
# There is probably a better way to do this, but currently
# this matches against the set of 'types' in any content lists
# up until '_', e.g., to match audio_url -> audio
if lora.lora_name in message_types:
default_mm_loras.add(lora)
# Currently only support default modality specific loras if
# we have exactly one lora matched on the request.
if len(default_mm_loras) == 1:
return default_mm_loras.pop()
return None
def _maybe_get_adapters(
self,
request: AnyRequest,
supports_default_mm_loras: bool = False,
) -> Optional[LoRARequest]:
if request.model in self.models.lora_requests:
return self.models.lora_requests[request.model]
# Currently only support default modality specific loras
# if we have exactly one lora matched on the request.
if supports_default_mm_loras:
default_mm_lora = self._get_active_default_mm_loras(request)
if default_mm_lora is not None:
return default_mm_lora
if self._is_model_supported(request.model):
return None
# if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.")
def _get_message_types(self, request: AnyRequest) -> set[str]:
"""Retrieve the set of types from message content dicts up
until `_`; we use this to match potential multimodal data
with default per modality loras.
"""
message_types: set[str] = set()
if not hasattr(request, "messages"):
return message_types
for message in request.messages:
if (isinstance(message, dict) and "content" in message
and isinstance(message["content"], list)):
for content_dict in message["content"]:
if "type" in content_dict:
message_types.add(content_dict["type"].split("_")[0])
return message_types
async def _normalize_prompt_text_to_input(
self,
request: AnyRequest,
prompt: str,
tokenizer: AnyTokenizer,
add_special_tokens: bool,
) -> TextTokensPrompt:
async_tokenizer = self._get_async_tokenizer(tokenizer)
if (self.model_config.encoder_config is not None
and self.model_config.encoder_config.get(
"do_lower_case", False)):
prompt = prompt.lower()
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None)
if truncate_prompt_tokens is None:
encoded = await async_tokenizer(
prompt, add_special_tokens=add_special_tokens)
elif truncate_prompt_tokens < 0:
# Negative means we cap at the model's max length
encoded = await async_tokenizer(
prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=self.max_model_len)
else:
encoded = await async_tokenizer(
prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=truncate_prompt_tokens)
input_ids = encoded.input_ids
input_text = prompt
return self._validate_input(request, input_ids, input_text)
async def _normalize_prompt_tokens_to_input(
self,
request: AnyRequest,
prompt_ids: list[int],
tokenizer: Optional[AnyTokenizer],
) -> TextTokensPrompt:
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None)
if truncate_prompt_tokens is None:
input_ids = prompt_ids
elif truncate_prompt_tokens < 0:
input_ids = prompt_ids[-self.max_model_len:]
else:
input_ids = prompt_ids[-truncate_prompt_tokens:]
if tokenizer is None:
input_text = ""
else:
async_tokenizer = self._get_async_tokenizer(tokenizer)
input_text = await async_tokenizer.decode(input_ids)
return self._validate_input(request, input_ids, input_text)
def _validate_input(
self,
request: AnyRequest,
input_ids: list[int],
input_text: str,
) -> TextTokensPrompt:
token_num = len(input_ids)
# Note: EmbeddingRequest, ClassificationRequest,
# and ScoreRequest doesn't have max_tokens
if isinstance(request,
(EmbeddingChatRequest, EmbeddingCompletionRequest,
ScoreRequest, RerankRequest, ClassificationRequest)):
# Note: input length can be up to the entire model context length
# since these requests don't generate tokens.
if token_num > self.max_model_len:
operations: dict[type[AnyRequest], str] = {
ScoreRequest: "score",
ClassificationRequest: "classification"
}
operation = operations.get(type(request),
"embedding generation")
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for {operation}. "
f"Please reduce the length of the input.")
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
DetokenizeRequest)):
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
# chat completion endpoint supports max_completion_tokens
if isinstance(request, ChatCompletionRequest):
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
max_tokens = request.max_completion_tokens or request.max_tokens
else:
max_tokens = getattr(request, "max_tokens", None)
# Note: input length can be up to model context length - 1 for
# completion-like requests.
if token_num >= self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, your request has "
f"{token_num} input tokens. Please reduce the length of "
"the input messages.")
if max_tokens is not None and \
token_num + max_tokens > self.max_model_len:
raise ValueError(
"'max_tokens' or 'max_completion_tokens' is too large: "
f"{max_tokens}. This model's maximum context length is "
f"{self.max_model_len} tokens and your request has "
f"{token_num} input tokens ({max_tokens} > {self.max_model_len}"
f" - {token_num}).")
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
async def _tokenize_prompt_input_async(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_input: Union[str, list[int]],
add_special_tokens: bool = True,
) -> TextTokensPrompt:
"""
A simpler implementation of
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
that assumes single input.
"""
async for result in self._tokenize_prompt_inputs_async(
request,
tokenizer,
[prompt_input],
add_special_tokens=add_special_tokens,
):
return result
raise ValueError("No results yielded from tokenization")
async def _tokenize_prompt_inputs_async(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_inputs: Iterable[Union[str, list[int]]],
add_special_tokens: bool = True,
) -> AsyncGenerator[TextTokensPrompt, None]:
"""
A simpler implementation of
[`_tokenize_prompt_input_or_inputs`][vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs]
that assumes multiple inputs.
"""
for prompt in prompt_inputs:
if isinstance(prompt, str):
yield await self._normalize_prompt_text_to_input(
request,
prompt=prompt,
tokenizer=tokenizer,
add_special_tokens=add_special_tokens,
)
else:
yield await self._normalize_prompt_tokens_to_input(
request,
prompt_ids=prompt,
tokenizer=tokenizer,
)
async def _tokenize_prompt_input_or_inputs_async(
self,
request: AnyRequest,
tokenizer: Optional[AnyTokenizer],
input_or_inputs: Optional[Union[str, list[str], list[int],
list[list[int]]]],
add_special_tokens: bool = True,
) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]:
"""
Tokenize/detokenize depending on the input format.
According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
, each input can be a string or array of tokens. Note that each request
can pass one or more inputs.
"""
inputs_embeds = list[EmbedsPrompt]()
inputs_text = list[TextTokensPrompt]()
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None)
if (truncate_prompt_tokens or 0) < 0:
truncate_prompt_tokens = self.max_model_len
if (isinstance(request, CompletionRequest)
and request.prompt_embeds is not None):
inputs_embeds.extend(
self._load_prompt_embeds(request.prompt_embeds,
truncate_prompt_tokens))
# Empty prompts are okay as long as there are prompt embeddings
if input_or_inputs is None or (inputs_embeds
and input_or_inputs == ""):
return [], inputs_embeds
# Although our type checking is based on mypy,
# VSCode Pyright extension should still work properly
# "is False" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
# Parse and batch the input prompts
batch_inputs = parse_and_batch_prompt(input_or_inputs)
# Process each input in the batch concurrently
tasks = []
for prompt_input in batch_inputs:
if prompt_input["is_tokens"] is False:
assert tokenizer is not None, \
"Tokenizer is required for text prompts"
task = self._normalize_prompt_text_to_input(
request,
prompt_input["content"],
tokenizer=tokenizer,
add_special_tokens=add_special_tokens)
else:
task = self._normalize_prompt_tokens_to_input(
request, prompt_input["content"], tokenizer=tokenizer)
tasks.append(task)
# Wait for all tokenization tasks to complete
results = await asyncio.gather(*tasks)
inputs_text.extend(results)
return inputs_text, inputs_embeds
@overload
async def _preprocess_completion(
self,
request: Union[DetokenizeRequest, EmbeddingCompletionRequest,
RerankRequest, ClassificationRequest, ScoreRequest,
TokenizeCompletionRequest],
tokenizer: Optional[AnyTokenizer],
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
add_special_tokens: bool = ...,
) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]:
...
@overload
async def _preprocess_completion(
self,
request: CompletionRequest,
tokenizer: Optional[AnyTokenizer],
input_or_inputs: Optional[Union[str, list[str], list[int],
list[list[int]]]],
add_special_tokens: bool = ...,
) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[
EngineTokensPrompt, EngineEmbedsPrompt]]]:
...
async def _preprocess_completion(
self,
request: CompletionLikeRequest,
tokenizer: Optional[AnyTokenizer],
input_or_inputs: Optional[Union[str, list[str], list[int],
list[list[int]]]],
add_special_tokens: bool = True,
) -> tuple[Union[list[TextTokensPrompt], list[Union[
TextTokensPrompt, EmbedsPrompt]]], Union[
list[EngineTokensPrompt], list[Union[EngineTokensPrompt,
EngineEmbedsPrompt]]]]:
if not isinstance(request,
CompletionRequest) and input_or_inputs is None:
raise ValueError(
"Prompt embeds with non-completion requests is not"
" currently supported.")
(request_prompts_text, request_prompts_embeds
) = await self._tokenize_prompt_input_or_inputs_async(
request,
tokenizer,
input_or_inputs,
add_special_tokens=add_special_tokens,
)
engine_prompts_text = [
EngineTokensPrompt(
prompt_token_ids=request_prompt_text["prompt_token_ids"])
for request_prompt_text in request_prompts_text
]
cache_salt = request.cache_salt if (
hasattr(request, "cache_salt")
and request.cache_salt is not None) else None
if cache_salt:
for prompt_text in engine_prompts_text:
prompt_text["cache_salt"] = cache_salt
# This check is equivalent to simply checking if
# `request_prompts_embeds` is empty, but it's difficult to propagate
# overloads to the private helper functions to enable this check.
# This overload is needed because only TextPrompts are allowed for
# non-completion requests and if we don't add the overload here,
# everywhere this function is used outside of serving_completion will
# need logic asserting that only text prompts are in the request.
if not isinstance(request,
CompletionRequest) and input_or_inputs is not None:
return request_prompts_text, engine_prompts_text
engine_prompts_embeds = [
EngineEmbedsPrompt(
prompt_embeds=request_prompt_embeds["prompt_embeds"])
for request_prompt_embeds in request_prompts_embeds
]
if cache_salt:
for prompt_embed in engine_prompts_embeds:
prompt_embed["cache_salt"] = cache_salt
request_prompts = request_prompts_embeds + request_prompts_text
engine_prompts = engine_prompts_embeds + engine_prompts_text
return request_prompts, engine_prompts
async def _preprocess_chat(
self,
request: Union[ChatLikeRequest, ResponsesRequest],
tokenizer: AnyTokenizer,
messages: list[ChatCompletionMessageParam],
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tool_dicts: Optional[list[dict[str, Any]]] = None,
documents: Optional[list[dict[str, str]]] = None,
chat_template_kwargs: Optional[dict[str, Any]] = None,
tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
add_special_tokens: bool = False,
) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
list[EngineTokensPrompt]]:
model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format(
chat_template,
tool_dicts,
chat_template_content_format,
tokenizer,
model_config=model_config,
)
conversation, mm_data_future = parse_chat_messages_futures(
messages,
model_config,
tokenizer,
content_format=resolved_content_format,
)
_chat_template_kwargs: dict[str, Any] = dict(
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tool_dicts,
documents=documents,
)
_chat_template_kwargs.update(chat_template_kwargs or {})
request_prompt: Union[str, list[int]]
if tokenizer is None:
request_prompt = "placeholder"
elif isinstance(tokenizer, MistralTokenizer):
request_prompt = apply_mistral_chat_template(
tokenizer,
messages=messages,
**_chat_template_kwargs,
)
else:
request_prompt = apply_hf_chat_template(
tokenizer=tokenizer,
conversation=conversation,
model_config=model_config,
**_chat_template_kwargs,
)
mm_data = await mm_data_future
# tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
# is set, we want to prevent parsing a tool_call hallucinated by the LLM
should_parse_tools = tool_parser is not None and (hasattr(
request, "tool_choice") and request.tool_choice != "none")
if should_parse_tools:
if not isinstance(request, ChatCompletionRequest):
msg = "Tool usage is only supported for Chat Completions API"
raise NotImplementedError(msg)
request = tool_parser(tokenizer).adjust_request( # type: ignore
request=request)
if tokenizer is None:
assert isinstance(request_prompt, str), (
"Prompt has to be a string", \
"when the tokenizer is not initialised"
)
prompt_inputs = TextTokensPrompt(prompt=request_prompt,
prompt_token_ids=[1])
elif isinstance(request_prompt, str):
prompt_inputs = await self._tokenize_prompt_input_async(
request,
tokenizer,
request_prompt,
add_special_tokens=add_special_tokens,
)
else:
# For MistralTokenizer
assert is_list_of(request_prompt, int), (
"Prompt has to be either a string or a list of token ids")
prompt_inputs = TextTokensPrompt(
prompt=tokenizer.decode(request_prompt),
prompt_token_ids=request_prompt)
engine_prompt = EngineTokensPrompt(
prompt_token_ids=prompt_inputs["prompt_token_ids"])
if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data
if request.mm_processor_kwargs is not None:
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
if hasattr(request, "cache_salt") and request.cache_salt is not None:
engine_prompt["cache_salt"] = request.cache_salt
return conversation, [request_prompt], [engine_prompt]
async def _generate_with_builtin_tools(
self,
request_id: str,
request_prompt: RequestPrompt,
engine_prompt: EngineTokensPrompt,
sampling_params: SamplingParams,
context: ConversationContext,
lora_request: Optional[LoRARequest] = None,
priority: int = 0,
**kwargs,
):
orig_priority = priority
while True:
self._log_inputs(
request_id,
request_prompt,
params=sampling_params,
lora_request=lora_request,
)
generator = self.engine_client.generate(
engine_prompt,
sampling_params,
request_id,
lora_request=lora_request,
priority=priority,
**kwargs,
)
async for res in generator:
context.append_output(res)
# NOTE(woosuk): The stop condition is handled by the engine.
yield context
if not context.need_builtin_tool_call():
# The model did not ask for a tool call, so we're done.
break
# Call the tool and update the context with the result.
tool_output = await context.call_tool()
context.append_output(tool_output)
# TODO: uncomment this and enable tool output streaming
# yield context
# Create inputs for the next turn.
# Render the next prompt token ids.
prompt_token_ids = context.render_for_completion()
engine_prompt = EngineTokensPrompt(
prompt_token_ids=prompt_token_ids)
request_prompt = prompt_token_ids
# Update the sampling params.
sampling_params.max_tokens = (self.max_model_len -
len(prompt_token_ids))
# OPTIMIZATION
priority = orig_priority - 1
@staticmethod
def _load_prompt_embeds(
prompt_embeds: Optional[Union[bytes, list[bytes]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
) -> list[EmbedsPrompt]:
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
tensor = torch.load(io.BytesIO(
pybase64.b64decode(embed, validate=True)),
weights_only=True,
map_location=torch.device("cpu"))
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
torch.float32,
torch.bfloat16,
torch.float16,
)
tensor = tensor.to_dense()
if tensor.dim() > 2:
tensor = tensor.squeeze(0)
assert tensor.dim() == 2
if truncate_prompt_tokens is not None:
tensor = tensor[-truncate_prompt_tokens:]
return {"prompt_embeds": tensor}
if prompt_embeds:
if isinstance(prompt_embeds, list):
return [
_load_and_validate_embed(embed) for embed in prompt_embeds
]
else:
return [_load_and_validate_embed(prompt_embeds)]
else:
return []
def _log_inputs(
self,
request_id: str,
inputs: RequestPrompt,
params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest],
) -> None:
if self.request_logger is None:
return
prompt, prompt_token_ids, prompt_embeds = None, None, None
if isinstance(inputs, str):
prompt = inputs
elif isinstance(inputs, list):
prompt_token_ids = inputs
elif 'prompt_embeds' in inputs:
prompt_embeds = inputs.get("prompt_embeds")
else:
prompt = inputs["prompt"]
prompt_token_ids = inputs["prompt_token_ids"]
self.request_logger.log_inputs(
request_id,
prompt,
prompt_token_ids,
prompt_embeds,
params=params,
lora_request=lora_request,
)
async def _get_trace_headers(
self,
headers: Headers,
) -> Optional[Mapping[str, str]]:
is_tracing_enabled = await self.engine_client.is_tracing_enabled()
if is_tracing_enabled:
return extract_trace_headers(headers)
if contains_trace_headers(headers):
log_tracing_disabled_warning()
return None
@staticmethod
def _base_request_id(raw_request: Optional[Request],
default: Optional[str] = None) -> Optional[str]:
"""Pulls the request id to use from a header, if provided"""
default = default or random_uuid()
if raw_request is None:
return default
return raw_request.headers.get("X-Request-Id", default)
@staticmethod
def _get_decoded_token(logprob: Logprob,
token_id: int,
tokenizer: AnyTokenizer,
return_as_token_id: bool = False) -> str:
if return_as_token_id:
return f"token_id:{token_id}"
if logprob.decoded_token is not None:
return logprob.decoded_token
return tokenizer.decode(token_id)
def _is_model_supported(self, model_name: Optional[str]) -> bool:
if not model_name:
return True
return self.models.is_base_model(model_name)
def _get_model_name(self,
model_name: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> str:
if lora_request:
return lora_request.lora_name
if not model_name:
return self.models.base_model_paths[0].name
return model_name
def clamp_prompt_logprobs(
prompt_logprobs: Union[PromptLogprobs,
None]) -> Union[PromptLogprobs, None]:
if prompt_logprobs is None:
return prompt_logprobs
for logprob_dict in prompt_logprobs:
if logprob_dict is None:
continue
for logprob_values in logprob_dict.values():
if logprob_values.logprob == float('-inf'):
logprob_values.logprob = -9999.0
return prompt_logprobs