849 lines
		
	
	
		
			32 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			849 lines
		
	
	
		
			32 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# SPDX-License-Identifier: Apache-2.0
 | 
						||
 | 
						||
import json
 | 
						||
import sys
 | 
						||
import time
 | 
						||
from collections.abc import (AsyncGenerator, Iterable, Iterator, Mapping,
 | 
						||
                             Sequence)
 | 
						||
from concurrent.futures.thread import ThreadPoolExecutor
 | 
						||
from http import HTTPStatus
 | 
						||
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
 | 
						||
                    TypeVar, Union)
 | 
						||
 | 
						||
from fastapi import Request
 | 
						||
from pydantic import BaseModel, ConfigDict, Field
 | 
						||
from starlette.datastructures import Headers
 | 
						||
 | 
						||
if sys.version_info >= (3, 12):
 | 
						||
    from typing import TypedDict
 | 
						||
else:
 | 
						||
    from typing_extensions import TypedDict
 | 
						||
 | 
						||
import vllm.envs as envs
 | 
						||
from vllm.config import ModelConfig
 | 
						||
from vllm.engine.protocol import EngineClient
 | 
						||
# yapf conflicts with isort for this block
 | 
						||
# yapf: disable
 | 
						||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
 | 
						||
                                         ChatTemplateContentFormatOption,
 | 
						||
                                         ConversationMessage,
 | 
						||
                                         apply_hf_chat_template,
 | 
						||
                                         apply_mistral_chat_template,
 | 
						||
                                         parse_chat_messages_futures,
 | 
						||
                                         resolve_chat_template_content_format)
 | 
						||
from vllm.entrypoints.logger import RequestLogger
 | 
						||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
 | 
						||
                                              ChatCompletionResponse,
 | 
						||
                                              ClassificationRequest,
 | 
						||
                                              ClassificationResponse,
 | 
						||
                                              CompletionRequest,
 | 
						||
                                              CompletionResponse,
 | 
						||
                                              DetokenizeRequest,
 | 
						||
                                              EmbeddingChatRequest,
 | 
						||
                                              EmbeddingCompletionRequest,
 | 
						||
                                              EmbeddingRequest,
 | 
						||
                                              EmbeddingResponse, ErrorResponse,
 | 
						||
                                              PoolingResponse, RerankRequest,
 | 
						||
                                              ScoreRequest, ScoreResponse,
 | 
						||
                                              TokenizeChatRequest,
 | 
						||
                                              TokenizeCompletionRequest,
 | 
						||
                                              TokenizeResponse,
 | 
						||
                                              TranscriptionRequest,
 | 
						||
                                              TranscriptionResponse)
 | 
						||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
 | 
						||
from vllm.entrypoints.openai.tool_parsers import ToolParser
 | 
						||
# yapf: enable
 | 
						||
from vllm.inputs import TokensPrompt
 | 
						||
from vllm.inputs.parse import parse_and_batch_prompt
 | 
						||
from vllm.logger import init_logger
 | 
						||
from vllm.lora.request import LoRARequest
 | 
						||
from vllm.multimodal import (  # noqa: F401 - Required to resolve Pydantic error in RequestProcessingMixin
 | 
						||
    MultiModalDataDict)
 | 
						||
from vllm.outputs import PoolingRequestOutput, RequestOutput
 | 
						||
from vllm.pooling_params import PoolingParams
 | 
						||
from vllm.prompt_adapter.request import PromptAdapterRequest
 | 
						||
from vllm.sampling_params import BeamSearchParams, SamplingParams
 | 
						||
from vllm.sequence import Logprob, PromptLogprobs
 | 
						||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
 | 
						||
                          log_tracing_disabled_warning)
 | 
						||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
 | 
						||
from vllm.utils import (is_list_of, make_async, merge_async_iterators,
 | 
						||
                        random_uuid)
 | 
						||
 | 
						||
logger = init_logger(__name__)
 | 
						||
 | 
						||
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
 | 
						||
                              EmbeddingCompletionRequest, RerankRequest,
 | 
						||
                              ClassificationRequest, ScoreRequest,
 | 
						||
                              TokenizeCompletionRequest]
 | 
						||
 | 
						||
ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
 | 
						||
                        TokenizeChatRequest]
 | 
						||
 | 
						||
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest,
 | 
						||
                   TranscriptionRequest]
 | 
						||
 | 
						||
AnyResponse = Union[
 | 
						||
    CompletionResponse,
 | 
						||
    ChatCompletionResponse,
 | 
						||
    EmbeddingResponse,
 | 
						||
    TranscriptionResponse,
 | 
						||
    TokenizeResponse,
 | 
						||
    PoolingResponse,
 | 
						||
    ClassificationResponse,
 | 
						||
    ScoreResponse,
 | 
						||
]
 | 
						||
 | 
						||
 | 
						||
class TextTokensPrompt(TypedDict):
 | 
						||
    prompt: str
 | 
						||
    prompt_token_ids: list[int]
 | 
						||
 | 
						||
 | 
						||
RequestPrompt = Union[list[int], str, TextTokensPrompt]
 | 
						||
 | 
						||
RequestT = TypeVar("RequestT", bound=AnyRequest)
 | 
						||
 | 
						||
 | 
						||
class RequestProcessingMixin(BaseModel):
 | 
						||
    """
 | 
						||
    Mixin for request processing, 
 | 
						||
    handling prompt preparation and engine input.
 | 
						||
    """
 | 
						||
    request_prompts: Optional[Sequence[RequestPrompt]] = \
 | 
						||
                            Field(default_factory=list)
 | 
						||
    engine_prompts: Optional[list[TokensPrompt]] = \
 | 
						||
                            Field(default_factory=list)
 | 
						||
 | 
						||
    model_config = ConfigDict(arbitrary_types_allowed=True)
 | 
						||
 | 
						||
 | 
						||
class ResponseGenerationMixin(BaseModel):
 | 
						||
    """
 | 
						||
    Mixin for response generation, 
 | 
						||
    managing result generators and final batch results.
 | 
						||
    """
 | 
						||
    result_generator: Optional[AsyncGenerator[tuple[int, Union[
 | 
						||
        RequestOutput, PoolingRequestOutput]], None]] = None
 | 
						||
    final_res_batch: list[Union[RequestOutput, PoolingRequestOutput]] = Field(
 | 
						||
        default_factory=list)
 | 
						||
 | 
						||
    model_config = ConfigDict(arbitrary_types_allowed=True)
 | 
						||
 | 
						||
 | 
						||
class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel,
 | 
						||
                   Generic[RequestT]):
 | 
						||
    # Shared across all requests
 | 
						||
    request: RequestT
 | 
						||
    raw_request: Optional[Request] = None
 | 
						||
    model_name: str
 | 
						||
    request_id: str
 | 
						||
    created_time: int = Field(default_factory=lambda: int(time.time()))
 | 
						||
    lora_request: Optional[LoRARequest] = None
 | 
						||
    prompt_adapter_request: Optional[PromptAdapterRequest] = None
 | 
						||
 | 
						||
    # Shared across most requests
 | 
						||
    tokenizer: Optional[AnyTokenizer] = None
 | 
						||
    truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
 | 
						||
 | 
						||
    # `protected_namespaces` resolves Pydantic v2's warning
 | 
						||
    # on conflict with protected namespace "model_"
 | 
						||
    model_config = ConfigDict(
 | 
						||
        protected_namespaces=(),
 | 
						||
        arbitrary_types_allowed=True,
 | 
						||
    )
 | 
						||
 | 
						||
 | 
						||
ClassificationServeContext = ServeContext[ClassificationRequest]
 | 
						||
 | 
						||
 | 
						||
class EmbeddingServeContext(ServeContext[EmbeddingRequest]):
 | 
						||
    chat_template: Optional[str] = None
 | 
						||
    chat_template_content_format: ChatTemplateContentFormatOption
 | 
						||
 | 
						||
 | 
						||
# Used to resolve the Pydantic error related to
 | 
						||
# forward reference of MultiModalDataDict in TokensPrompt
 | 
						||
RequestProcessingMixin.model_rebuild()
 | 
						||
ServeContext.model_rebuild()
 | 
						||
ClassificationServeContext.model_rebuild()
 | 
						||
EmbeddingServeContext.model_rebuild()
 | 
						||
 | 
						||
 | 
						||
class OpenAIServing:
 | 
						||
    request_id_prefix: ClassVar[str] = """
 | 
						||
    A short string prepended to every request’s ID (e.g. "embd", "classify")
 | 
						||
    so you can easily tell “this ID came from Embedding vs Classification.”
 | 
						||
    """
 | 
						||
 | 
						||
    def __init__(
 | 
						||
        self,
 | 
						||
        engine_client: EngineClient,
 | 
						||
        model_config: ModelConfig,
 | 
						||
        models: OpenAIServingModels,
 | 
						||
        *,
 | 
						||
        request_logger: Optional[RequestLogger],
 | 
						||
        return_tokens_as_token_ids: bool = False,
 | 
						||
    ):
 | 
						||
        super().__init__()
 | 
						||
 | 
						||
        self.engine_client = engine_client
 | 
						||
        self.model_config = model_config
 | 
						||
        self.max_model_len = model_config.max_model_len
 | 
						||
 | 
						||
        self.models = models
 | 
						||
 | 
						||
        self.request_logger = request_logger
 | 
						||
        self.return_tokens_as_token_ids = return_tokens_as_token_ids
 | 
						||
 | 
						||
        self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)
 | 
						||
 | 
						||
        self._tokenize_prompt_input_async = make_async(
 | 
						||
            self._tokenize_prompt_input, executor=self._tokenizer_executor)
 | 
						||
        self._tokenize_prompt_input_or_inputs_async = make_async(
 | 
						||
            self._tokenize_prompt_input_or_inputs,
 | 
						||
            executor=self._tokenizer_executor)
 | 
						||
 | 
						||
    async def _preprocess(
 | 
						||
        self,
 | 
						||
        ctx: ServeContext,
 | 
						||
    ) -> Optional[ErrorResponse]:
 | 
						||
        """
 | 
						||
        Default preprocessing hook. Subclasses may override
 | 
						||
        to prepare `ctx` (classification, embedding, etc.).
 | 
						||
        """
 | 
						||
        return None
 | 
						||
 | 
						||
    def _build_response(
 | 
						||
        self,
 | 
						||
        ctx: ServeContext,
 | 
						||
    ) -> Union[AnyResponse, ErrorResponse]:
 | 
						||
        """
 | 
						||
        Default response builder. Subclass may override this method
 | 
						||
        to return the appropriate response object.
 | 
						||
        """
 | 
						||
        return self.create_error_response("unimplemented endpoint")
 | 
						||
 | 
						||
    async def handle(
 | 
						||
        self,
 | 
						||
        ctx: ServeContext,
 | 
						||
    ) -> Union[AnyResponse, ErrorResponse]:
 | 
						||
        generation: AsyncGenerator[Union[AnyResponse, ErrorResponse], None]
 | 
						||
        generation = self._pipeline(ctx)
 | 
						||
 | 
						||
        async for response in generation:
 | 
						||
            return response
 | 
						||
 | 
						||
        return self.create_error_response("No response yielded from pipeline")
 | 
						||
 | 
						||
    async def _pipeline(
 | 
						||
        self,
 | 
						||
        ctx: ServeContext,
 | 
						||
    ) -> AsyncGenerator[Union[AnyResponse, ErrorResponse], None]:
 | 
						||
        """Execute the request processing pipeline yielding responses."""
 | 
						||
        if error := await self._check_model(ctx.request):
 | 
						||
            yield error
 | 
						||
        if error := self._validate_request(ctx):
 | 
						||
            yield error
 | 
						||
 | 
						||
        preprocess_ret = await self._preprocess(ctx)
 | 
						||
        if isinstance(preprocess_ret, ErrorResponse):
 | 
						||
            yield preprocess_ret
 | 
						||
 | 
						||
        generators_ret = await self._prepare_generators(ctx)
 | 
						||
        if isinstance(generators_ret, ErrorResponse):
 | 
						||
            yield generators_ret
 | 
						||
 | 
						||
        collect_ret = await self._collect_batch(ctx)
 | 
						||
        if isinstance(collect_ret, ErrorResponse):
 | 
						||
            yield collect_ret
 | 
						||
 | 
						||
        yield self._build_response(ctx)
 | 
						||
 | 
						||
    def _validate_request(self, ctx: ServeContext) -> Optional[ErrorResponse]:
 | 
						||
        truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
 | 
						||
                                         None)
 | 
						||
 | 
						||
        if truncate_prompt_tokens is not None:
 | 
						||
            if truncate_prompt_tokens <= self.max_model_len:
 | 
						||
                ctx.truncate_prompt_tokens = truncate_prompt_tokens
 | 
						||
            else:
 | 
						||
                return self.create_error_response(
 | 
						||
                    "truncate_prompt_tokens value is "
 | 
						||
                    "greater than max_model_len."
 | 
						||
                    " Please, select a smaller truncation size.")
 | 
						||
        return None
 | 
						||
 | 
						||
    async def _prepare_generators(
 | 
						||
        self,
 | 
						||
        ctx: ServeContext,
 | 
						||
    ) -> Optional[ErrorResponse]:
 | 
						||
        """Schedule the request and get the result generator."""
 | 
						||
        generators: list[AsyncGenerator[Union[RequestOutput,
 | 
						||
                                              PoolingRequestOutput],
 | 
						||
                                        None]] = []
 | 
						||
 | 
						||
        try:
 | 
						||
            trace_headers = (None if ctx.raw_request is None else await
 | 
						||
                             self._get_trace_headers(ctx.raw_request.headers))
 | 
						||
 | 
						||
            if not hasattr(ctx.request, "to_pooling_params"):
 | 
						||
                return self.create_error_response(
 | 
						||
                    "Request type does not support pooling parameters")
 | 
						||
 | 
						||
            pooling_params = ctx.request.to_pooling_params()
 | 
						||
 | 
						||
            if ctx.engine_prompts is None:
 | 
						||
                return self.create_error_response(
 | 
						||
                    "Engine prompts not available")
 | 
						||
 | 
						||
            for i, engine_prompt in enumerate(ctx.engine_prompts):
 | 
						||
                request_id_item = f"{ctx.request_id}-{i}"
 | 
						||
 | 
						||
                if ctx.request_prompts is None:
 | 
						||
                    return self.create_error_response(
 | 
						||
                        "Request prompts not available")
 | 
						||
 | 
						||
                self._log_inputs(
 | 
						||
                    request_id_item,
 | 
						||
                    ctx.request_prompts[i],
 | 
						||
                    params=pooling_params,
 | 
						||
                    lora_request=ctx.lora_request,
 | 
						||
                    prompt_adapter_request=ctx.prompt_adapter_request)
 | 
						||
 | 
						||
                generator = self.engine_client.encode(
 | 
						||
                    engine_prompt,
 | 
						||
                    pooling_params,
 | 
						||
                    request_id_item,
 | 
						||
                    lora_request=ctx.lora_request,
 | 
						||
                    trace_headers=trace_headers,
 | 
						||
                    priority=getattr(ctx.request, "priority", 0),
 | 
						||
                )
 | 
						||
 | 
						||
                generators.append(generator)
 | 
						||
 | 
						||
            ctx.result_generator = merge_async_iterators(*generators)
 | 
						||
 | 
						||
            return None
 | 
						||
 | 
						||
        except Exception as e:
 | 
						||
            # TODO: Use a vllm-specific Validation Error
 | 
						||
            return self.create_error_response(str(e))
 | 
						||
 | 
						||
    async def _collect_batch(
 | 
						||
        self,
 | 
						||
        ctx: ServeContext,
 | 
						||
    ) -> Optional[ErrorResponse]:
 | 
						||
        """Collect batch results from the result generator."""
 | 
						||
        try:
 | 
						||
            if ctx.engine_prompts is None:
 | 
						||
                return self.create_error_response(
 | 
						||
                    "Engine prompts not available")
 | 
						||
 | 
						||
            num_prompts = len(ctx.engine_prompts)
 | 
						||
            final_res_batch: list[Optional[Union[RequestOutput,
 | 
						||
                                                 PoolingRequestOutput]]]
 | 
						||
            final_res_batch = [None] * num_prompts
 | 
						||
 | 
						||
            if ctx.result_generator is None:
 | 
						||
                return self.create_error_response(
 | 
						||
                    "Result generator not available")
 | 
						||
 | 
						||
            async for i, res in ctx.result_generator:
 | 
						||
                final_res_batch[i] = res
 | 
						||
 | 
						||
            if None in final_res_batch:
 | 
						||
                return self.create_error_response(
 | 
						||
                    "Failed to generate results for all prompts")
 | 
						||
 | 
						||
            ctx.final_res_batch = [
 | 
						||
                res for res in final_res_batch if res is not None
 | 
						||
            ]
 | 
						||
 | 
						||
            return None
 | 
						||
 | 
						||
        except Exception as e:
 | 
						||
            return self.create_error_response(str(e))
 | 
						||
 | 
						||
    def create_error_response(
 | 
						||
            self,
 | 
						||
            message: str,
 | 
						||
            err_type: str = "BadRequestError",
 | 
						||
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
 | 
						||
        return ErrorResponse(message=message,
 | 
						||
                             type=err_type,
 | 
						||
                             code=status_code.value)
 | 
						||
 | 
						||
    def create_streaming_error_response(
 | 
						||
            self,
 | 
						||
            message: str,
 | 
						||
            err_type: str = "BadRequestError",
 | 
						||
            status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
 | 
						||
        json_str = json.dumps({
 | 
						||
            "error":
 | 
						||
            self.create_error_response(message=message,
 | 
						||
                                       err_type=err_type,
 | 
						||
                                       status_code=status_code).model_dump()
 | 
						||
        })
 | 
						||
        return json_str
 | 
						||
 | 
						||
    async def _check_model(
 | 
						||
        self,
 | 
						||
        request: AnyRequest,
 | 
						||
    ) -> Optional[ErrorResponse]:
 | 
						||
 | 
						||
        error_response = None
 | 
						||
 | 
						||
        if self._is_model_supported(request.model):
 | 
						||
            return None
 | 
						||
        if request.model in [
 | 
						||
                lora.lora_name for lora in self.models.lora_requests
 | 
						||
        ]:
 | 
						||
            return None
 | 
						||
        if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and (
 | 
						||
                load_result := await self.models.resolve_lora(request.model)):
 | 
						||
            if isinstance(load_result, LoRARequest):
 | 
						||
                return None
 | 
						||
            if isinstance(load_result, ErrorResponse) and \
 | 
						||
                load_result.code == HTTPStatus.BAD_REQUEST.value:
 | 
						||
                error_response = load_result
 | 
						||
        if request.model in [
 | 
						||
                prompt_adapter.prompt_adapter_name
 | 
						||
                for prompt_adapter in self.models.prompt_adapter_requests
 | 
						||
        ]:
 | 
						||
            return None
 | 
						||
 | 
						||
        return error_response or self.create_error_response(
 | 
						||
            message=f"The model `{request.model}` does not exist.",
 | 
						||
            err_type="NotFoundError",
 | 
						||
            status_code=HTTPStatus.NOT_FOUND)
 | 
						||
 | 
						||
    def _maybe_get_adapters(
 | 
						||
        self, request: AnyRequest
 | 
						||
    ) -> Union[tuple[None, None], tuple[LoRARequest, None], tuple[
 | 
						||
            None, PromptAdapterRequest]]:
 | 
						||
        if self._is_model_supported(request.model):
 | 
						||
            return None, None
 | 
						||
        for lora in self.models.lora_requests:
 | 
						||
            if request.model == lora.lora_name:
 | 
						||
                return lora, None
 | 
						||
        for prompt_adapter in self.models.prompt_adapter_requests:
 | 
						||
            if request.model == prompt_adapter.prompt_adapter_name:
 | 
						||
                return None, prompt_adapter
 | 
						||
        # if _check_model has been called earlier, this will be unreachable
 | 
						||
        raise ValueError(f"The model `{request.model}` does not exist.")
 | 
						||
 | 
						||
    def _normalize_prompt_text_to_input(
 | 
						||
        self,
 | 
						||
        request: AnyRequest,
 | 
						||
        tokenizer: AnyTokenizer,
 | 
						||
        prompt: str,
 | 
						||
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]],
 | 
						||
        add_special_tokens: bool,
 | 
						||
    ) -> TextTokensPrompt:
 | 
						||
        if (self.model_config.encoder_config is not None
 | 
						||
                and self.model_config.encoder_config.get(
 | 
						||
                    "do_lower_case", False)):
 | 
						||
            prompt = prompt.lower()
 | 
						||
 | 
						||
        if truncate_prompt_tokens is None:
 | 
						||
            encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
 | 
						||
        elif truncate_prompt_tokens < 0:
 | 
						||
            # Negative means we cap at the model's max length
 | 
						||
            encoded = tokenizer(prompt,
 | 
						||
                                add_special_tokens=add_special_tokens,
 | 
						||
                                truncation=True,
 | 
						||
                                max_length=self.max_model_len)
 | 
						||
        else:
 | 
						||
            encoded = tokenizer(prompt,
 | 
						||
                                add_special_tokens=add_special_tokens,
 | 
						||
                                truncation=True,
 | 
						||
                                max_length=truncate_prompt_tokens)
 | 
						||
 | 
						||
        input_ids = encoded.input_ids
 | 
						||
 | 
						||
        input_text = prompt
 | 
						||
 | 
						||
        return self._validate_input(request, input_ids, input_text)
 | 
						||
 | 
						||
    def _normalize_prompt_tokens_to_input(
 | 
						||
        self,
 | 
						||
        request: AnyRequest,
 | 
						||
        tokenizer: AnyTokenizer,
 | 
						||
        prompt_ids: list[int],
 | 
						||
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
 | 
						||
    ) -> TextTokensPrompt:
 | 
						||
        if truncate_prompt_tokens is None:
 | 
						||
            input_ids = prompt_ids
 | 
						||
        elif truncate_prompt_tokens < 0:
 | 
						||
            input_ids = prompt_ids[-self.max_model_len:]
 | 
						||
        else:
 | 
						||
            input_ids = prompt_ids[-truncate_prompt_tokens:]
 | 
						||
 | 
						||
        input_text = tokenizer.decode(input_ids)
 | 
						||
 | 
						||
        return self._validate_input(request, input_ids, input_text)
 | 
						||
 | 
						||
    def _validate_input(
 | 
						||
        self,
 | 
						||
        request: AnyRequest,
 | 
						||
        input_ids: list[int],
 | 
						||
        input_text: str,
 | 
						||
    ) -> TextTokensPrompt:
 | 
						||
        token_num = len(input_ids)
 | 
						||
 | 
						||
        # Note: EmbeddingRequest, ClassificationRequest,
 | 
						||
        # and ScoreRequest doesn't have max_tokens
 | 
						||
        if isinstance(request,
 | 
						||
                      (EmbeddingChatRequest, EmbeddingCompletionRequest,
 | 
						||
                       ScoreRequest, RerankRequest, ClassificationRequest)):
 | 
						||
            operation = {
 | 
						||
                ScoreRequest: "score",
 | 
						||
                ClassificationRequest: "classification"
 | 
						||
            }.get(type(request), "embedding generation")
 | 
						||
 | 
						||
            if token_num > self.max_model_len:
 | 
						||
                raise ValueError(
 | 
						||
                    f"This model's maximum context length is "
 | 
						||
                    f"{self.max_model_len} tokens. However, you requested "
 | 
						||
                    f"{token_num} tokens in the input for {operation}. "
 | 
						||
                    f"Please reduce the length of the input.")
 | 
						||
            return TextTokensPrompt(prompt=input_text,
 | 
						||
                                    prompt_token_ids=input_ids)
 | 
						||
 | 
						||
        # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
 | 
						||
        # and does not require model context length validation
 | 
						||
        if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
 | 
						||
                                DetokenizeRequest)):
 | 
						||
            return TextTokensPrompt(prompt=input_text,
 | 
						||
                                    prompt_token_ids=input_ids)
 | 
						||
 | 
						||
        # chat completion endpoint supports max_completion_tokens
 | 
						||
        if isinstance(request, ChatCompletionRequest):
 | 
						||
            # TODO(#9845): remove max_tokens when field dropped from OpenAI API
 | 
						||
            max_tokens = request.max_completion_tokens or request.max_tokens
 | 
						||
        else:
 | 
						||
            max_tokens = getattr(request, "max_tokens", None)
 | 
						||
        if max_tokens is None:
 | 
						||
            if token_num >= self.max_model_len:
 | 
						||
                raise ValueError(
 | 
						||
                    f"This model's maximum context length is "
 | 
						||
                    f"{self.max_model_len} tokens. However, you requested "
 | 
						||
                    f"{token_num} tokens in the messages, "
 | 
						||
                    f"Please reduce the length of the messages.")
 | 
						||
        elif token_num + max_tokens > self.max_model_len:
 | 
						||
            raise ValueError(
 | 
						||
                f"This model's maximum context length is "
 | 
						||
                f"{self.max_model_len} tokens. However, you requested "
 | 
						||
                f"{max_tokens + token_num} tokens "
 | 
						||
                f"({token_num} in the messages, "
 | 
						||
                f"{max_tokens} in the completion). "
 | 
						||
                f"Please reduce the length of the messages or completion.")
 | 
						||
 | 
						||
        return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
 | 
						||
 | 
						||
    def _tokenize_prompt_input(
 | 
						||
        self,
 | 
						||
        request: AnyRequest,
 | 
						||
        tokenizer: AnyTokenizer,
 | 
						||
        prompt_input: Union[str, list[int]],
 | 
						||
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
 | 
						||
        add_special_tokens: bool = True,
 | 
						||
    ) -> TextTokensPrompt:
 | 
						||
        """
 | 
						||
        A simpler implementation of {meth}`_tokenize_prompt_input_or_inputs`
 | 
						||
        that assumes single input.
 | 
						||
        """
 | 
						||
        return next(
 | 
						||
            self._tokenize_prompt_inputs(
 | 
						||
                request,
 | 
						||
                tokenizer,
 | 
						||
                [prompt_input],
 | 
						||
                truncate_prompt_tokens=truncate_prompt_tokens,
 | 
						||
                add_special_tokens=add_special_tokens,
 | 
						||
            ))
 | 
						||
 | 
						||
    def _tokenize_prompt_inputs(
 | 
						||
        self,
 | 
						||
        request: AnyRequest,
 | 
						||
        tokenizer: AnyTokenizer,
 | 
						||
        prompt_inputs: Iterable[Union[str, list[int]]],
 | 
						||
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
 | 
						||
        add_special_tokens: bool = True,
 | 
						||
    ) -> Iterator[TextTokensPrompt]:
 | 
						||
        """
 | 
						||
        A simpler implementation of {meth}`_tokenize_prompt_input_or_inputs`
 | 
						||
        that assumes multiple inputs.
 | 
						||
        """
 | 
						||
        for text in prompt_inputs:
 | 
						||
            if isinstance(text, str):
 | 
						||
                yield self._normalize_prompt_text_to_input(
 | 
						||
                    request,
 | 
						||
                    tokenizer,
 | 
						||
                    prompt=text,
 | 
						||
                    truncate_prompt_tokens=truncate_prompt_tokens,
 | 
						||
                    add_special_tokens=add_special_tokens,
 | 
						||
                )
 | 
						||
            else:
 | 
						||
                yield self._normalize_prompt_tokens_to_input(
 | 
						||
                    request,
 | 
						||
                    tokenizer,
 | 
						||
                    prompt_ids=text,
 | 
						||
                    truncate_prompt_tokens=truncate_prompt_tokens,
 | 
						||
                )
 | 
						||
 | 
						||
    def _tokenize_prompt_input_or_inputs(
 | 
						||
        self,
 | 
						||
        request: AnyRequest,
 | 
						||
        tokenizer: AnyTokenizer,
 | 
						||
        input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
 | 
						||
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
 | 
						||
        add_special_tokens: bool = True,
 | 
						||
    ) -> list[TextTokensPrompt]:
 | 
						||
        """
 | 
						||
        Tokenize/detokenize depending on the input format.
 | 
						||
 | 
						||
        According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
 | 
						||
        , each input can be a string or array of tokens. Note that each request
 | 
						||
        can pass one or more inputs.
 | 
						||
        """
 | 
						||
        # Although our type checking is based on mypy,
 | 
						||
        # VSCode Pyright extension should still work properly
 | 
						||
        # "is True" is required for Pyright to perform type narrowing
 | 
						||
        # See: https://github.com/microsoft/pyright/issues/7672
 | 
						||
        return [
 | 
						||
            self._normalize_prompt_text_to_input(
 | 
						||
                request,
 | 
						||
                tokenizer,
 | 
						||
                prompt=prompt_input["content"],
 | 
						||
                truncate_prompt_tokens=truncate_prompt_tokens,
 | 
						||
                add_special_tokens=add_special_tokens)
 | 
						||
            if prompt_input["is_tokens"] is False else
 | 
						||
            self._normalize_prompt_tokens_to_input(
 | 
						||
                request,
 | 
						||
                tokenizer,
 | 
						||
                prompt_ids=prompt_input["content"],
 | 
						||
                truncate_prompt_tokens=truncate_prompt_tokens)
 | 
						||
            for prompt_input in parse_and_batch_prompt(input_or_inputs)
 | 
						||
        ]
 | 
						||
 | 
						||
    async def _preprocess_completion(
 | 
						||
        self,
 | 
						||
        request: CompletionLikeRequest,
 | 
						||
        tokenizer: AnyTokenizer,
 | 
						||
        input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
 | 
						||
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
 | 
						||
        add_special_tokens: bool = True,
 | 
						||
    ) -> tuple[list[TextTokensPrompt], list[TokensPrompt]]:
 | 
						||
        request_prompts = await self._tokenize_prompt_input_or_inputs_async(
 | 
						||
            request,
 | 
						||
            tokenizer,
 | 
						||
            input_or_inputs,
 | 
						||
            truncate_prompt_tokens=truncate_prompt_tokens,
 | 
						||
            add_special_tokens=add_special_tokens,
 | 
						||
        )
 | 
						||
 | 
						||
        engine_prompts = [
 | 
						||
            TokensPrompt(prompt_token_ids=request_prompt["prompt_token_ids"])
 | 
						||
            for request_prompt in request_prompts
 | 
						||
        ]
 | 
						||
 | 
						||
        return request_prompts, engine_prompts
 | 
						||
 | 
						||
    async def _preprocess_chat(
 | 
						||
        self,
 | 
						||
        request: ChatLikeRequest,
 | 
						||
        tokenizer: AnyTokenizer,
 | 
						||
        messages: list[ChatCompletionMessageParam],
 | 
						||
        chat_template: Optional[str],
 | 
						||
        chat_template_content_format: ChatTemplateContentFormatOption,
 | 
						||
        add_generation_prompt: bool = True,
 | 
						||
        continue_final_message: bool = False,
 | 
						||
        tool_dicts: Optional[list[dict[str, Any]]] = None,
 | 
						||
        documents: Optional[list[dict[str, str]]] = None,
 | 
						||
        chat_template_kwargs: Optional[dict[str, Any]] = None,
 | 
						||
        tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
 | 
						||
        truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
 | 
						||
        add_special_tokens: bool = False,
 | 
						||
    ) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
 | 
						||
               list[TokensPrompt]]:
 | 
						||
        model_config = self.model_config
 | 
						||
 | 
						||
        resolved_content_format = resolve_chat_template_content_format(
 | 
						||
            chat_template,
 | 
						||
            tool_dicts,
 | 
						||
            chat_template_content_format,
 | 
						||
            tokenizer,
 | 
						||
            model_config=model_config,
 | 
						||
        )
 | 
						||
        conversation, mm_data_future = parse_chat_messages_futures(
 | 
						||
            messages,
 | 
						||
            model_config,
 | 
						||
            tokenizer,
 | 
						||
            content_format=resolved_content_format,
 | 
						||
        )
 | 
						||
 | 
						||
        _chat_template_kwargs: dict[str, Any] = dict(
 | 
						||
            chat_template=chat_template,
 | 
						||
            add_generation_prompt=add_generation_prompt,
 | 
						||
            continue_final_message=continue_final_message,
 | 
						||
            tools=tool_dicts,
 | 
						||
            documents=documents,
 | 
						||
        )
 | 
						||
        _chat_template_kwargs.update(chat_template_kwargs or {})
 | 
						||
 | 
						||
        request_prompt: Union[str, list[int]]
 | 
						||
        if isinstance(tokenizer, MistralTokenizer):
 | 
						||
            request_prompt = apply_mistral_chat_template(
 | 
						||
                tokenizer,
 | 
						||
                messages=messages,
 | 
						||
                **_chat_template_kwargs,
 | 
						||
            )
 | 
						||
        else:
 | 
						||
            request_prompt = apply_hf_chat_template(
 | 
						||
                tokenizer=tokenizer,
 | 
						||
                conversation=conversation,
 | 
						||
                model_config=model_config,
 | 
						||
                **_chat_template_kwargs,
 | 
						||
            )
 | 
						||
 | 
						||
        mm_data = await mm_data_future
 | 
						||
 | 
						||
        # tool parsing is done only if a tool_parser has been set and if
 | 
						||
        # tool_choice is not "none" (if tool_choice is "none" but a tool_parser
 | 
						||
        # is set, we want to prevent parsing a tool_call hallucinated by the LLM
 | 
						||
        should_parse_tools = tool_parser is not None and (hasattr(
 | 
						||
            request, "tool_choice") and request.tool_choice != "none")
 | 
						||
 | 
						||
        if should_parse_tools:
 | 
						||
            if not isinstance(request, ChatCompletionRequest):
 | 
						||
                msg = "Tool usage is only supported for Chat Completions API"
 | 
						||
                raise NotImplementedError(msg)
 | 
						||
 | 
						||
            request = tool_parser(tokenizer).adjust_request(  # type: ignore
 | 
						||
                request=request)
 | 
						||
 | 
						||
        if isinstance(request_prompt, str):
 | 
						||
            prompt_inputs = await self._tokenize_prompt_input_async(
 | 
						||
                request,
 | 
						||
                tokenizer,
 | 
						||
                request_prompt,
 | 
						||
                truncate_prompt_tokens=truncate_prompt_tokens,
 | 
						||
                add_special_tokens=add_special_tokens,
 | 
						||
            )
 | 
						||
        else:
 | 
						||
            # For MistralTokenizer
 | 
						||
            assert is_list_of(request_prompt, int), (
 | 
						||
                "Prompt has to be either a string or a list of token ids")
 | 
						||
            prompt_inputs = TextTokensPrompt(
 | 
						||
                prompt=tokenizer.decode(request_prompt),
 | 
						||
                prompt_token_ids=request_prompt)
 | 
						||
 | 
						||
        engine_prompt = TokensPrompt(
 | 
						||
            prompt_token_ids=prompt_inputs["prompt_token_ids"])
 | 
						||
        if mm_data is not None:
 | 
						||
            engine_prompt["multi_modal_data"] = mm_data
 | 
						||
        if request.mm_processor_kwargs is not None:
 | 
						||
            engine_prompt["mm_processor_kwargs"] = request.mm_processor_kwargs
 | 
						||
 | 
						||
        if hasattr(request, "cache_salt") and request.cache_salt is not None:
 | 
						||
            engine_prompt["cache_salt"] = request.cache_salt
 | 
						||
 | 
						||
        return conversation, [request_prompt], [engine_prompt]
 | 
						||
 | 
						||
    def _log_inputs(
 | 
						||
        self,
 | 
						||
        request_id: str,
 | 
						||
        inputs: RequestPrompt,
 | 
						||
        params: Optional[Union[SamplingParams, PoolingParams,
 | 
						||
                               BeamSearchParams]],
 | 
						||
        lora_request: Optional[LoRARequest],
 | 
						||
        prompt_adapter_request: Optional[PromptAdapterRequest],
 | 
						||
    ) -> None:
 | 
						||
        if self.request_logger is None:
 | 
						||
            return
 | 
						||
 | 
						||
        if isinstance(inputs, str):
 | 
						||
            prompt = inputs
 | 
						||
            prompt_token_ids = None
 | 
						||
        elif isinstance(inputs, list):
 | 
						||
            prompt = None
 | 
						||
            prompt_token_ids = inputs
 | 
						||
        else:
 | 
						||
            prompt = inputs["prompt"]
 | 
						||
            prompt_token_ids = inputs["prompt_token_ids"]
 | 
						||
 | 
						||
        self.request_logger.log_inputs(
 | 
						||
            request_id,
 | 
						||
            prompt,
 | 
						||
            prompt_token_ids,
 | 
						||
            params=params,
 | 
						||
            lora_request=lora_request,
 | 
						||
            prompt_adapter_request=prompt_adapter_request,
 | 
						||
        )
 | 
						||
 | 
						||
    async def _get_trace_headers(
 | 
						||
        self,
 | 
						||
        headers: Headers,
 | 
						||
    ) -> Optional[Mapping[str, str]]:
 | 
						||
        is_tracing_enabled = await self.engine_client.is_tracing_enabled()
 | 
						||
 | 
						||
        if is_tracing_enabled:
 | 
						||
            return extract_trace_headers(headers)
 | 
						||
 | 
						||
        if contains_trace_headers(headers):
 | 
						||
            log_tracing_disabled_warning()
 | 
						||
 | 
						||
        return None
 | 
						||
 | 
						||
    @staticmethod
 | 
						||
    def _base_request_id(raw_request: Optional[Request],
 | 
						||
                         default: Optional[str] = None) -> Optional[str]:
 | 
						||
        """Pulls the request id to use from a header, if provided"""
 | 
						||
        default = default or random_uuid()
 | 
						||
        if raw_request is None:
 | 
						||
            return default
 | 
						||
 | 
						||
        return raw_request.headers.get("X-Request-Id", default)
 | 
						||
 | 
						||
    @staticmethod
 | 
						||
    def _get_decoded_token(logprob: Logprob,
 | 
						||
                           token_id: int,
 | 
						||
                           tokenizer: AnyTokenizer,
 | 
						||
                           return_as_token_id: bool = False) -> str:
 | 
						||
        if return_as_token_id:
 | 
						||
            return f"token_id:{token_id}"
 | 
						||
 | 
						||
        if logprob.decoded_token is not None:
 | 
						||
            return logprob.decoded_token
 | 
						||
        return tokenizer.decode(token_id)
 | 
						||
 | 
						||
    def _is_model_supported(self, model_name: Optional[str]) -> bool:
 | 
						||
        if not model_name:
 | 
						||
            return True
 | 
						||
        return self.models.is_base_model(model_name)
 | 
						||
 | 
						||
    def _get_model_name(self,
 | 
						||
                        model_name: Optional[str] = None,
 | 
						||
                        lora_request: Optional[LoRARequest] = None) -> str:
 | 
						||
        if lora_request:
 | 
						||
            return lora_request.lora_name
 | 
						||
        if not model_name:
 | 
						||
            return self.models.base_model_paths[0].name
 | 
						||
        return model_name
 | 
						||
 | 
						||
 | 
						||
def clamp_prompt_logprobs(
 | 
						||
    prompt_logprobs: Union[PromptLogprobs,
 | 
						||
                           None]) -> Union[PromptLogprobs, None]:
 | 
						||
    if prompt_logprobs is None:
 | 
						||
        return prompt_logprobs
 | 
						||
 | 
						||
    for logprob_dict in prompt_logprobs:
 | 
						||
        if logprob_dict is None:
 | 
						||
            continue
 | 
						||
        for logprob_values in logprob_dict.values():
 | 
						||
            if logprob_values.logprob == float('-inf'):
 | 
						||
                logprob_values.logprob = -9999.0
 | 
						||
    return prompt_logprobs
 |