Files
vllm-dev/vllm/entrypoints/openai/serving_engine.py
2025-02-22 00:31:26 -08:00

539 lines
21 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import json
from concurrent.futures.thread import ThreadPoolExecutor
from http import HTTPStatus
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
Optional, Sequence, Tuple, TypedDict, Union)
from fastapi import Request
from pydantic import Field
from starlette.datastructures import Headers
from typing_extensions import Annotated
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
ConversationMessage,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages_futures,
resolve_chat_template_content_format)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest,
DetokenizeRequest,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
ErrorResponse, RerankRequest,
ScoreRequest,
TokenizeChatRequest,
TokenizeCompletionRequest,
TranscriptionRequest)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser
# yapf: enable
from vllm.inputs import TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import is_list_of, make_async, random_uuid
logger = init_logger(__name__)
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
EmbeddingCompletionRequest, RerankRequest,
ScoreRequest, TokenizeCompletionRequest]
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
TokenizeChatRequest]
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest,
TranscriptionRequest]
class TextTokensPrompt(TypedDict):
prompt: str
prompt_token_ids: List[int]
RequestPrompt = Union[List[int], str, TextTokensPrompt]
class OpenAIServing:
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
super().__init__()
self.engine_client = engine_client
self.model_config = model_config
self.max_model_len = model_config.max_model_len
self.models = models
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
self._tokenize_prompt_input_async = make_async(
self._tokenize_prompt_input, executor=self._tokenizer_executor)
self._tokenize_prompt_input_or_inputs_async = make_async(
self._tokenize_prompt_input_or_inputs,
executor=self._tokenizer_executor)
def create_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
return ErrorResponse(message=message,
type=err_type,
code=status_code.value)
def create_streaming_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
json_str = json.dumps({
"error":
self.create_error_response(message=message,
err_type=err_type,
status_code=status_code).model_dump()
})
return json_str
async def _check_model(
self,
request: AnyRequest,
) -> Optional[ErrorResponse]:
if self._is_model_supported(request.model):
return None
if request.model in [
lora.lora_name for lora in self.models.lora_requests
]:
return None
if request.model in [
prompt_adapter.prompt_adapter_name
for prompt_adapter in self.models.prompt_adapter_requests
]:
return None
return self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_adapters(
self, request: AnyRequest
) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
None, PromptAdapterRequest]]:
if self._is_model_supported(request.model):
return None, None
for lora in self.models.lora_requests:
if request.model == lora.lora_name:
return lora, None
for prompt_adapter in self.models.prompt_adapter_requests:
if request.model == prompt_adapter.prompt_adapter_name:
return None, prompt_adapter
# if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.")
def _normalize_prompt_text_to_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt: str,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
add_special_tokens: bool,
) -> TextTokensPrompt:
if (self.model_config.encoder_config is not None
and self.model_config.encoder_config.get(
"do_lower_case", False)):
prompt = prompt.lower()
if truncate_prompt_tokens is None:
encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
else:
encoded = tokenizer(prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=truncate_prompt_tokens)
input_ids = encoded.input_ids
input_text = prompt
return self._validate_input(request, input_ids, input_text)
def _normalize_prompt_tokens_to_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_ids: List[int],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
) -> TextTokensPrompt:
if truncate_prompt_tokens is None:
input_ids = prompt_ids
else:
input_ids = prompt_ids[-truncate_prompt_tokens:]
input_text = tokenizer.decode(input_ids)
return self._validate_input(request, input_ids, input_text)
def _validate_input(
self,
request: AnyRequest,
input_ids: List[int],
input_text: str,
) -> TextTokensPrompt:
token_num = len(input_ids)
# Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
if isinstance(request,
(EmbeddingChatRequest, EmbeddingCompletionRequest,
ScoreRequest, RerankRequest)):
operation = "score" if isinstance(request, ScoreRequest) \
else "embedding generation"
if token_num > self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for {operation}. "
f"Please reduce the length of the input.")
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
DetokenizeRequest)):
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
# chat completion endpoint supports max_completion_tokens
if isinstance(request, ChatCompletionRequest):
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
max_tokens = request.max_completion_tokens or request.max_tokens
else:
max_tokens = request.max_tokens
if max_tokens is None:
if token_num >= self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the messages, "
f"Please reduce the length of the messages.")
elif token_num + max_tokens > self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.")
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
def _tokenize_prompt_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_input: Union[str, List[int]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> TextTokensPrompt:
"""
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
that assumes single input.
"""
return next(
self._tokenize_prompt_inputs(
request,
tokenizer,
[prompt_input],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
))
def _tokenize_prompt_inputs(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_inputs: Iterable[Union[str, List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]:
"""
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
that assumes multiple inputs.
"""
for text in prompt_inputs:
if isinstance(text, str):
yield self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=text,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
yield self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=text,
truncate_prompt_tokens=truncate_prompt_tokens,
)
def _tokenize_prompt_input_or_inputs(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> List[TextTokensPrompt]:
"""
Tokenize/detokenize depending on the input format.
According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
, each input can be a string or array of tokens. Note that each request
can pass one or more inputs.
"""
# Although our type checking is based on mypy,
# VSCode Pyright extension should still work properly
# "is True" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
return [
self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens)
if prompt_input["is_tokens"] is False else
self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens)
for prompt_input in parse_and_batch_prompt(input_or_inputs)
]
async def _preprocess_completion(
self,
request: CompletionLikeRequest,
tokenizer: AnyTokenizer,
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Tuple[List[TextTokensPrompt], List[TokensPrompt]]:
request_prompts = await self._tokenize_prompt_input_or_inputs_async(
request,
tokenizer,
input_or_inputs,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
engine_prompts = [
TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
for request_prompt in request_prompts
]
return request_prompts, engine_prompts
async def _preprocess_chat(
self,
request: ChatLikeRequest,
tokenizer: AnyTokenizer,
messages: List[ChatCompletionMessageParam],
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tool_dicts: Optional[List[Dict[str, Any]]] = None,
documents: Optional[List[Dict[str, str]]] = None,
chat_template_kwargs: Optional[Dict[str, Any]] = None,
tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = False,
) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt],
List[TokensPrompt]]:
resolved_content_format = resolve_chat_template_content_format(
chat_template,
chat_template_content_format,
tokenizer,
)
conversation, mm_data_future = parse_chat_messages_futures(
messages,
self.model_config,
tokenizer,
content_format=resolved_content_format,
)
_chat_template_kwargs: Dict[str, Any] = dict(
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tool_dicts,
documents=documents,
)
_chat_template_kwargs.update(chat_template_kwargs or {})
request_prompt: Union[str, List[int]]
if isinstance(tokenizer, MistralTokenizer):
request_prompt = apply_mistral_chat_template(
tokenizer,
messages=messages,
**_chat_template_kwargs,
)
else:
request_prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
**_chat_template_kwargs,
)
mm_data = await mm_data_future
# tool parsing is done only if a tool_parser has been set and if
# tool_choice is not "none" (if tool_choice is "none" but a tool_parser
# is set, we want to prevent parsing a tool_call hallucinated by the LLM
should_parse_tools = tool_parser is not None and (hasattr(
request, "tool_choice") and request.tool_choice != "none")
if should_parse_tools:
if not isinstance(request, ChatCompletionRequest):
msg = "Tool usage is only supported for Chat Completions API"
raise NotImplementedError(msg)
request = tool_parser(tokenizer).adjust_request( # type: ignore
request=request)
if isinstance(request_prompt, str):
prompt_inputs = await self._tokenize_prompt_input_async(
request,
tokenizer,
request_prompt,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
# For MistralTokenizer
assert is_list_of(request_prompt, int), (
"Prompt has to be either a string or a list of token ids")
prompt_inputs = TextTokensPrompt(
prompt=tokenizer.decode(request_prompt),
prompt_token_ids=request_prompt)
engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["prompt_token_ids"])
if mm_data is not None:
engine_prompt["multi_modal_data"] = mm_data
if request.mm_processor_kwargs is not None:
engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
return conversation, [request_prompt], [engine_prompt]
def _log_inputs(
self,
request_id: str,
inputs: RequestPrompt,
params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None:
if self.request_logger is None:
return
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = None
elif isinstance(inputs, list):
prompt = None
prompt_token_ids = inputs
else:
prompt = inputs["prompt"]
prompt_token_ids = inputs["prompt_token_ids"]
self.request_logger.log_inputs(
request_id,
prompt,
prompt_token_ids,
params=params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
async def _get_trace_headers(
self,
headers: Headers,
) -> Optional[Mapping[str, str]]:
is_tracing_enabled = await self.engine_client.is_tracing_enabled()
if is_tracing_enabled:
return extract_trace_headers(headers)
if contains_trace_headers(headers):
log_tracing_disabled_warning()
return None
@staticmethod
def _base_request_id(raw_request: Optional[Request],
default: Optional[str] = None) -> Optional[str]:
"""Pulls the request id to use from a header, if provided"""
default = default or random_uuid()
if raw_request is None:
return default
return raw_request.headers.get("X-Request-Id", default)
@staticmethod
def _get_decoded_token(logprob: Logprob,
token_id: int,
tokenizer: AnyTokenizer,
return_as_token_id: bool = False) -> str:
if return_as_token_id:
return f"token_id:{token_id}"
if logprob.decoded_token is not None:
return logprob.decoded_token
return tokenizer.decode(token_id)
def _is_model_supported(self, model_name: Optional[str]) -> bool:
if not model_name:
return True
return self.models.is_base_model(model_name)
def _get_model_name(self,
model_name: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> str:
if lora_request:
return lora_request.lora_name
if model_name is None:
return self.models.base_model_paths[0].name
return model_name