Extend renderer with embedding support and integrate completion endpoint (#24405)

Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
Flora Feng
2025-09-09 10:46:46 -07:00
committed by GitHub
parent 9ad0688e43
commit 15cb047e25
9 changed files with 410 additions and 309 deletions

View File

@ -10,7 +10,7 @@ import pytest
import regex as re
import torch
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.renderer import BaseRenderer
from ...utils import RemoteOpenAIServer
@ -27,12 +27,16 @@ async def test_empty_prompt():
with RemoteOpenAIServer(model_name, server_args) as remote_server:
client = remote_server.get_async_client()
with pytest.raises(openai.BadRequestError,
match="decoder prompt cannot be empty"):
with pytest.raises(
openai.BadRequestError,
match=
"Either prompt or prompt_embeds must be provided and non-empty."
):
await client.completions.create(model=model_name,
prompt="",
max_tokens=5,
temperature=0.0)
temperature=0.0,
extra_body={"prompt_embeds": []})
@pytest.mark.asyncio
@ -83,7 +87,7 @@ def test_load_prompt_embeds(dtype: torch.dtype, layout: torch.layout,
buffer.seek(0)
encoded_tensor = pybase64.b64encode(buffer.getvalue())
loaded_prompt_embeds = OpenAIServing._load_prompt_embeds(encoded_tensor)
loaded_prompt_embeds = BaseRenderer.load_prompt_embeds(encoded_tensor)
assert len(loaded_prompt_embeds) == 1
loaded_tensor = loaded_prompt_embeds[0]["prompt_embeds"]
assert loaded_tensor.device.type == "cpu"

View File

@ -1,13 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import io
from dataclasses import dataclass
from typing import Optional
from unittest.mock import AsyncMock, MagicMock
import pybase64
import pytest
import torch
from vllm.entrypoints.renderer import CompletionRenderer
from vllm.inputs.data import is_embeds_prompt
@dataclass
@ -178,3 +182,132 @@ class TestRenderPrompt:
with pytest.raises(ValueError, match="No tokenizer available"):
await renderer_no_tokenizer.render_prompt(
prompt_or_prompts="Hello world", max_length=100)
@pytest.mark.asyncio
async def test_token_input_with_needs_detokenization(
self, renderer, mock_async_tokenizer):
# When needs_detokenization=True for token inputs, renderer should
# use the async tokenizer to decode and include the original text
# in the returned prompt object.
mock_async_tokenizer.decode = AsyncMock(return_value="decoded text")
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer
tokens = [1, 2, 3, 4]
results = await renderer.render_prompt(
prompt_or_prompts=tokens,
needs_detokenization=True,
)
assert len(results) == 1
assert results[0]["prompt_token_ids"] == tokens
assert results[0]["prompt"] == "decoded text"
mock_async_tokenizer.decode.assert_awaited_once()
class TestRenderEmbedPrompt:
def _create_test_embed_bytes(self, tensor: torch.Tensor) -> bytes:
"""Helper to create base64-encoded tensor bytes"""
buffer = io.BytesIO()
torch.save(tensor, buffer)
buffer.seek(0)
return pybase64.b64encode(buffer.read())
@pytest.mark.asyncio
async def test_single_prompt_embed(self, renderer):
# Create a test tensor
test_tensor = torch.randn(10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes, cache_salt="test_salt")
assert len(results) == 1
assert is_embeds_prompt(results[0])
assert torch.allclose(results[0]["prompt_embeds"], test_tensor)
assert results[0]["cache_salt"] == "test_salt"
@pytest.mark.asyncio
async def test_multiple_prompt_embeds(self, renderer):
# Create multiple test tensors
test_tensors = [
torch.randn(8, 512, dtype=torch.float32),
torch.randn(12, 512, dtype=torch.float32),
]
embed_bytes_list = [
self._create_test_embed_bytes(t) for t in test_tensors
]
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes_list)
assert len(results) == 2
for i, result in enumerate(results):
assert is_embeds_prompt(result)
assert torch.allclose(result["prompt_embeds"], test_tensors[i])
@pytest.mark.asyncio
async def test_prompt_embed_truncation(self, renderer):
# Create tensor with more tokens than truncation limit
test_tensor = torch.randn(20, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes, truncate_prompt_tokens=10)
assert len(results) == 1
# Should keep last 10 tokens
expected = test_tensor[-10:]
assert torch.allclose(results[0]["prompt_embeds"], expected)
@pytest.mark.asyncio
async def test_prompt_embed_different_dtypes(self, renderer):
# Test different supported dtypes
dtypes = [torch.float32, torch.float16, torch.bfloat16]
for dtype in dtypes:
test_tensor = torch.randn(5, 256, dtype=dtype)
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes)
assert len(results) == 1
assert results[0]["prompt_embeds"].dtype == dtype
@pytest.mark.asyncio
async def test_prompt_embed_squeeze_batch_dim(self, renderer):
# Test tensor with batch dimension gets squeezed
test_tensor = torch.randn(1, 10, 768, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes)
assert len(results) == 1
# Should be squeezed to 2D
assert results[0]["prompt_embeds"].shape == (10, 768)
@pytest.mark.asyncio
async def test_both_prompts_and_embeds(self, renderer,
mock_async_tokenizer):
# Set up text tokenization
mock_async_tokenizer.return_value = MockTokenizerResult(
[101, 102, 103])
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer
# Create embed
test_tensor = torch.randn(5, 256, dtype=torch.float32)
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_or_prompts="Hello world", prompt_embeds=embed_bytes)
assert len(results) == 2
# First should be embed prompt
assert is_embeds_prompt(results[0])
# Second should be tokens prompt
assert "prompt_token_ids" in results[1]
assert results[1]["prompt_token_ids"] == [101, 102, 103]

View File

@ -686,7 +686,7 @@ async def test_invalid_grammar(client: openai.AsyncOpenAI, model_name: str):
async def test_completion_with_empty_prompt_embeds(
client: openai.AsyncOpenAI) -> None:
"""Test completion with empty prompt embeds."""
payload: dict[str, list] = {"prompt_embeds": []}
payload: dict[str, object] = {"prompt": "Hello", "prompt_embeds": []}
headers: dict[str, str] = {"Content-Type": "application/json"}
# base_url = http://localhost:8000/v1/completions
response = requests.post(f"{client.base_url}completions",

View File

@ -1270,9 +1270,20 @@ class CompletionRequest(OpenAIBaseModel):
@model_validator(mode="before")
@classmethod
def validate_prompt_and_prompt_embeds(cls, data):
if data.get("prompt") is None and data.get("prompt_embeds") is None:
prompt = data.get("prompt")
prompt_embeds = data.get("prompt_embeds")
prompt_is_empty = (prompt is None
or (isinstance(prompt, str) and prompt == ""))
embeds_is_empty = (prompt_embeds is None
or (isinstance(prompt_embeds, list)
and len(prompt_embeds) == 0))
if prompt_is_empty and embeds_is_empty:
raise ValueError(
"At least one of `prompt` or `prompt_embeds` must be set.")
"Either prompt or prompt_embeds must be provided and non-empty."
)
return data
@model_validator(mode="before")

View File

@ -26,12 +26,8 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
PromptTokenUsageInfo,
RequestResponseMetadata,
UsageInfo)
from vllm.entrypoints.openai.serving_engine import (
EmbedsPrompt as ServingEngineEmbedsPrompt)
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
TextTokensPrompt,
clamp_prompt_logprobs,
is_text_tokens_prompt)
clamp_prompt_logprobs)
# yapf: enable
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.utils import get_max_tokens
@ -132,12 +128,19 @@ class OpenAIServingCompletion(OpenAIServing):
else:
tokenizer = await self.engine_client.get_tokenizer(lora_request
)
renderer = self._get_renderer(tokenizer)
max_input_tokens_len = self.max_model_len - (request.max_tokens
or 0)
request_prompts, engine_prompts = await self._preprocess_completion(
request,
tokenizer,
request.prompt,
engine_prompts = await renderer.render_prompt_and_embeds(
prompt_or_prompts=request.prompt,
prompt_embeds=request.prompt_embeds,
max_length=max_input_tokens_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
cache_salt=request.cache_salt,
needs_detokenization=bool(request.echo
and not request.return_token_ids),
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
@ -198,7 +201,7 @@ class OpenAIServingCompletion(OpenAIServing):
self._log_inputs(
request_id_item,
request_prompts[i],
engine_prompt,
params=sampling_params,
lora_request=lora_request,
)
@ -249,7 +252,7 @@ class OpenAIServingCompletion(OpenAIServing):
if stream:
return self.completion_stream_generator(
request,
request_prompts,
engine_prompts,
result_generator,
request_id,
created_time,
@ -273,11 +276,9 @@ class OpenAIServingCompletion(OpenAIServing):
# We did not pass it into vLLM engine to avoid being redundant
# with the inputs token IDs
if final_res.prompt is None:
request_prompt = request_prompts[i]
if is_text_tokens_prompt(request_prompt):
final_res.prompt = request_prompt["prompt"]
else:
final_res.prompt = None
engine_prompt = engine_prompts[i]
final_res.prompt = None if is_embeds_prompt(
engine_prompt) else engine_prompt.get("prompt")
final_res_batch_checked = cast(list[RequestOutput],
final_res_batch)
@ -313,8 +314,7 @@ class OpenAIServingCompletion(OpenAIServing):
async def completion_stream_generator(
self,
request: CompletionRequest,
request_prompts: list[Union[TextTokensPrompt,
ServingEngineEmbedsPrompt]],
engine_prompts: list[Union[TokensPrompt, EmbedsPrompt]],
result_generator: AsyncIterator[tuple[int, RequestOutput]],
request_id: str,
created_time: int,
@ -350,14 +350,11 @@ class OpenAIServingCompletion(OpenAIServing):
num_cached_tokens = res.num_cached_tokens
first_iteration = False
if res.prompt is not None:
prompt_text = res.prompt
else:
request_prompt = request_prompts[prompt_idx]
if is_text_tokens_prompt(request_prompt):
prompt_text = request_prompt["prompt"]
else:
prompt_text = None
prompt_text = res.prompt
if prompt_text is None:
engine_prompt = engine_prompts[prompt_idx]
prompt_text = None if is_embeds_prompt(
engine_prompt) else engine_prompt.get("prompt")
# Prompt details are excluded from later streamed outputs
if prompt_token_ids is not None:
@ -378,6 +375,8 @@ class OpenAIServingCompletion(OpenAIServing):
assert request.max_tokens is not None
if request.echo and not has_echoed[i]:
assert prompt_token_ids is not None
if request.return_token_ids:
prompt_text = ""
assert prompt_text is not None
if request.max_tokens == 0:
# only return the prompt
@ -525,6 +524,8 @@ class OpenAIServingCompletion(OpenAIServing):
for output in final_res.outputs:
assert request.max_tokens is not None
if request.echo:
if request.return_token_ids:
prompt_text = ""
assert prompt_text is not None
if request.max_tokens == 0:
token_ids = prompt_token_ids

View File

@ -28,7 +28,6 @@ from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
TextTokensPrompt)
# yapf: enable
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
@ -290,7 +289,7 @@ class EmbeddingMixin(OpenAIServing):
async def _create_single_prompt_generator(
self,
ctx: EmbeddingServeContext,
engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt],
engine_prompt: EngineTokensPrompt,
pooling_params: PoolingParams,
trace_headers: Optional[Mapping[str, str]],
prompt_index: int,
@ -303,12 +302,6 @@ class EmbeddingMixin(OpenAIServing):
params=pooling_params,
lora_request=ctx.lora_request)
# Mypy has an existing bug related to inferring the variance
# of TypedDicts with `builtins.enumerate`:
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
engine_prompt = cast(Union[EngineTokensPrompt, EngineEmbedsPrompt],
engine_prompt)
# Return the original generator without wrapping
return self.engine_client.encode(
engine_prompt,
@ -375,12 +368,8 @@ class EmbeddingMixin(OpenAIServing):
continue
# Normal processing for short prompts or non-token prompts
# Cast engine_prompt to the expected type for mypy
engine_prompt_typed = cast(
Union[EngineTokensPrompt, EngineEmbedsPrompt],
engine_prompt)
generator = await self._create_single_prompt_generator(
ctx, engine_prompt_typed, pooling_params, trace_headers, i)
ctx, engine_prompt, pooling_params, trace_headers, i)
generators.append(generator)
from vllm.utils import merge_async_iterators

View File

@ -1,7 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import io
import json
import sys
import time
@ -9,10 +7,8 @@ import traceback
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from http import HTTPStatus
from typing import (Annotated, Any, Callable, ClassVar, Generic, Optional,
TypeVar, Union, cast, overload)
from typing import Any, Callable, ClassVar, Generic, Optional, TypeVar, Union
import pybase64
import torch
from fastapi import Request
from pydantic import BaseModel, ConfigDict, Field
@ -64,10 +60,8 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer
# yapf: enable
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import PromptType
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob, PromptLogprobs
from vllm.lora.request import LoRARequest
@ -149,8 +143,7 @@ class RequestProcessingMixin(BaseModel):
"""
request_prompts: Optional[Sequence[RequestPrompt]] = []
engine_prompts: Optional[Union[list[EngineTokensPrompt],
list[EngineEmbedsPrompt]]] = []
engine_prompts: Optional[list[EngineTokensPrompt]] = []
model_config = ConfigDict(arbitrary_types_allowed=True)
@ -368,13 +361,6 @@ class OpenAIServing:
for i, engine_prompt in enumerate(ctx.engine_prompts):
request_id_item = f"{ctx.request_id}-{i}"
# Mypy has an existing bug related to inferring the variance of
# TypedDicts with `builtins.enumerate`:
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
engine_prompt = cast(
Union[EngineTokensPrompt, EngineEmbedsPrompt],
engine_prompt)
self._log_inputs(
request_id_item,
engine_prompt,
@ -737,170 +723,6 @@ class OpenAIServing:
tokenizer=tokenizer,
)
async def _tokenize_prompt_input_or_inputs_async(
self,
request: AnyRequest,
tokenizer: Optional[AnyTokenizer],
input_or_inputs: Optional[Union[str, list[str], list[int],
list[list[int]]]],
add_special_tokens: bool = True,
) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]:
"""
Tokenize/detokenize depending on the input format.
According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
, each input can be a string or array of tokens. Note that each request
can pass one or more inputs.
"""
inputs_embeds = list[EmbedsPrompt]()
inputs_text = list[TextTokensPrompt]()
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None)
if (truncate_prompt_tokens or 0) < 0:
truncate_prompt_tokens = self.max_model_len
if (isinstance(request, CompletionRequest)
and request.prompt_embeds is not None):
inputs_embeds.extend(
self._load_prompt_embeds(request.prompt_embeds,
truncate_prompt_tokens))
# Empty prompts are okay as long as there are prompt embeddings
if input_or_inputs is None or (inputs_embeds
and input_or_inputs == ""):
return [], inputs_embeds
# Although our type checking is based on mypy,
# VSCode Pyright extension should still work properly
# "is False" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
# Parse and batch the input prompts
batch_inputs = parse_and_batch_prompt(input_or_inputs)
# Process each input in the batch concurrently
tasks = []
for prompt_input in batch_inputs:
if prompt_input["is_tokens"] is False:
assert tokenizer is not None, (
"Tokenizer is required for text prompts")
task = self._normalize_prompt_text_to_input(
request,
prompt_input["content"],
tokenizer=tokenizer,
add_special_tokens=add_special_tokens,
)
else:
task = self._normalize_prompt_tokens_to_input(
request, prompt_input["content"], tokenizer=tokenizer)
tasks.append(task)
# Wait for all tokenization tasks to complete
results = await asyncio.gather(*tasks)
inputs_text.extend(results)
return inputs_text, inputs_embeds
@overload
async def _preprocess_completion(
self,
request: Union[
DetokenizeRequest,
EmbeddingCompletionRequest,
RerankRequest,
ClassificationRequest,
ScoreRequest,
TokenizeCompletionRequest,
],
tokenizer: Optional[AnyTokenizer],
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
add_special_tokens: bool = ...,
) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]:
...
@overload
async def _preprocess_completion(
self,
request: CompletionRequest,
tokenizer: Optional[AnyTokenizer],
input_or_inputs: Optional[Union[str, list[str], list[int],
list[list[int]]]],
add_special_tokens: bool = ...,
) -> tuple[
list[Union[TextTokensPrompt, EmbedsPrompt]],
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]],
]:
...
async def _preprocess_completion(
self,
request: CompletionLikeRequest,
tokenizer: Optional[AnyTokenizer],
input_or_inputs: Optional[Union[str, list[str], list[int],
list[list[int]]]],
add_special_tokens: bool = True,
) -> tuple[
Union[list[TextTokensPrompt], list[Union[TextTokensPrompt,
EmbedsPrompt]]],
Union[
list[EngineTokensPrompt],
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]],
],
]:
if (not isinstance(request, CompletionRequest)
and input_or_inputs is None):
raise ValueError(
"Prompt embeds with non-completion requests is not"
" currently supported.")
(
request_prompts_text,
request_prompts_embeds,
) = await self._tokenize_prompt_input_or_inputs_async(
request,
tokenizer,
input_or_inputs,
add_special_tokens=add_special_tokens,
)
engine_prompts_text = [
EngineTokensPrompt(
prompt_token_ids=request_prompt_text["prompt_token_ids"])
for request_prompt_text in request_prompts_text
]
cache_salt = (request.cache_salt if
(hasattr(request, "cache_salt")
and request.cache_salt is not None) else None)
if cache_salt:
for prompt_text in engine_prompts_text:
prompt_text["cache_salt"] = cache_salt
# This check is equivalent to simply checking if
# `request_prompts_embeds` is empty, but it's difficult to propagate
# overloads to the private helper functions to enable this check.
# This overload is needed because only TextPrompts are allowed for
# non-completion requests and if we don't add the overload here,
# everywhere this function is used outside of serving_completion will
# need logic asserting that only text prompts are in the request.
if (not isinstance(request, CompletionRequest)
and input_or_inputs is not None):
return request_prompts_text, engine_prompts_text
engine_prompts_embeds = [
EngineEmbedsPrompt(
prompt_embeds=request_prompt_embeds["prompt_embeds"])
for request_prompt_embeds in request_prompts_embeds
]
if cache_salt:
for prompt_embed in engine_prompts_embeds:
prompt_embed["cache_salt"] = cache_salt
request_prompts = request_prompts_embeds + request_prompts_text
engine_prompts = engine_prompts_embeds + engine_prompts_text
return request_prompts, engine_prompts
async def _preprocess_chat(
self,
request: Union[ChatLikeRequest, ResponsesRequest],
@ -1073,41 +895,6 @@ class OpenAIServing:
# OPTIMIZATION
priority = orig_priority - 1
@staticmethod
def _load_prompt_embeds(
prompt_embeds: Optional[Union[bytes, list[bytes]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
) -> list[EmbedsPrompt]:
def _load_and_validate_embed(embed: bytes) -> EmbedsPrompt:
tensor = torch.load(
io.BytesIO(pybase64.b64decode(embed, validate=True)),
weights_only=True,
map_location=torch.device("cpu"),
)
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
torch.float32,
torch.bfloat16,
torch.float16,
)
tensor = tensor.to_dense()
if tensor.dim() > 2:
tensor = tensor.squeeze(0)
assert tensor.dim() == 2
if truncate_prompt_tokens is not None:
tensor = tensor[-truncate_prompt_tokens:]
return {"prompt_embeds": tensor}
if prompt_embeds:
if isinstance(prompt_embeds, list):
return [
_load_and_validate_embed(embed) for embed in prompt_embeds
]
else:
return [_load_and_validate_embed(prompt_embeds)]
else:
return []
def _log_inputs(
self,
request_id: str,

View File

@ -2,12 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import io
from abc import ABC, abstractmethod
from typing import Annotated, Optional, Union
import pybase64
import torch
from pydantic import Field
from vllm.config import ModelConfig
from vllm.inputs.data import EmbedsPrompt as EngineEmbedsPrompt
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.transformers_utils.tokenizer import AnyTokenizer
@ -49,37 +53,121 @@ class BaseRenderer(ABC):
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: Optional[bool] = True,
cache_salt: Optional[str] = None,
needs_detokenization: Optional[bool] = False,
) -> list[EngineTokensPrompt]:
"""
Convert input prompts into tokenized format for engine processing.
This is the core method that transforms various input formats into
standardized TokensPrompt objects. Implementations should handle
tokenization, special token insertion, truncation, and validation
according to model requirements.
Convert text or token inputs into engine-ready TokensPrompt objects.
This method accepts text or token inputs and produces a
list of [`TokensPrompt`][vllm.inputs.data.TokensPrompt] objects
for the engine.
Args:
prompt_or_prompts: Input data in various formats:
- str: Single text prompt
- list[str]: Batch of text prompts
- list[int]: Pre-tokenized sequence
- list[list[int]]: Batch of pre-tokenized sequences
max_length: Maximum sequence length (endpoint-specific behavior)
truncate_prompt_tokens: Truncate to last N tokens
(None=no truncation, 0=empty)
add_special_tokens: Add model-specific tokens (e.g., [CLS], [SEP])
to text inputs
cache_salt: Optional string to disambiguate cached prompts
prompt_or_prompts: One of:
- ``str``: Single text prompt.
- ``list[str]``: Batch of text prompts.
- ``list[int]``: Single pre-tokenized sequence.
- ``list[list[int]]``: Batch of pre-tokenized sequences.
max_length: Maximum allowable total input token length. If provided,
token inputs longer than this raise ``ValueError``.
truncate_prompt_tokens: Number of tokens to keep. ``None`` means no
truncation. ``0`` yields an empty list (and skips embeds).
``-1`` maps to ``model_config.max_model_len``.
add_special_tokens: Whether to add model-specific special tokens
during text tokenization.
cache_salt: Optional string to disambiguate prefix cache entries.
needs_detokenization: If True and ``prompt_or_prompts`` is token
input, detokenize IDs back to text for inclusion in outputs.
Returns:
list[EngineTokensPrompt]: Tokenized prompts ready for engine
consumption
list[EngineTokensPrompt]: Engine-ready token prompts.
Raises:
ValueError: If input format is invalid or length limits exceeded
ValueError: If input formats are invalid or length limits exceeded.
"""
raise NotImplementedError
@abstractmethod
async def render_prompt_and_embeds(
self,
prompt_or_prompts: Optional[Union[str, list[str], list[int],
list[list[int]]]] = None,
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
max_length: Optional[int] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: Optional[bool] = True,
cache_salt: Optional[str] = None,
needs_detokenization: Optional[bool] = False,
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
"""
Convert text/token and/or base64-encoded embeddings inputs into
engine-ready prompt objects.
At least one of ``prompt_or_prompts`` or ``prompt_embeds`` must be
provided and non-empty. If both are omitted or empty (e.g., empty
string and empty list), a ``ValueError`` is raised.
Args:
prompt_or_prompts: Text or token inputs to include.
prompt_embeds: Base64-encoded bytes (or list thereof) containing a
torch-saved tensor to be used as prompt embeddings.
max_length: Maximum allowable total input token length. If provided,
inputs longer than this raise ``ValueError``.
truncate_prompt_tokens: Number of tokens/rows to keep from the end
of the sequence. ``-1`` maps to ``model_config.max_model_len``.
add_special_tokens: Whether to add model-specific special tokens
during text tokenization.
cache_salt: Optional string to disambiguate prefix cache entries.
needs_detokenization: If True and ``prompt_or_prompts`` is token
input, detokenize IDs back to text for inclusion in outputs.
Returns:
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
Engine-ready prompt objects.
Raises:
ValueError: If both ``prompt_or_prompts`` and ``prompt_embeds``
are omitted or empty (decoder prompt cannot be empty), or if
length limits are exceeded.
"""
raise NotImplementedError
@classmethod
def load_prompt_embeds(
cls,
prompt_embeds: Union[bytes, list[bytes]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=0)]] = None,
cache_salt: Optional[str] = None,
) -> list[EngineEmbedsPrompt]:
"""Load and validate base64-encoded embeddings into prompt objects."""
def _load_and_validate_embed(embed: bytes) -> EngineEmbedsPrompt:
tensor = torch.load(
io.BytesIO(pybase64.b64decode(embed, validate=True)),
weights_only=True,
map_location=torch.device("cpu"),
)
assert isinstance(tensor, torch.Tensor) and tensor.dtype in (
torch.float32,
torch.bfloat16,
torch.float16,
)
tensor = tensor.to_dense()
if tensor.dim() > 2:
tensor = tensor.squeeze(0)
assert tensor.dim() == 2
if truncate_prompt_tokens is not None:
tensor = tensor[-truncate_prompt_tokens:]
embeds_prompt = EngineEmbedsPrompt(prompt_embeds=tensor)
if cache_salt is not None:
embeds_prompt["cache_salt"] = cache_salt
return embeds_prompt
if isinstance(prompt_embeds, list):
return [_load_and_validate_embed(embed) for embed in prompt_embeds]
else:
return [_load_and_validate_embed(prompt_embeds)]
class CompletionRenderer(BaseRenderer):
@ -101,50 +189,110 @@ class CompletionRenderer(BaseRenderer):
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: Optional[bool] = True,
cache_salt: Optional[str] = None,
needs_detokenization: Optional[bool] = False,
) -> list[EngineTokensPrompt]:
"""Implementation of prompt rendering for completion-style requests.
Uses async tokenizer pooling for improved performance. See base class
for detailed parameter documentation.
"""
if truncate_prompt_tokens is not None:
if truncate_prompt_tokens == 0:
return []
if truncate_prompt_tokens < 0:
truncate_prompt_tokens = self.model_config.max_model_len
if max_length is not None and truncate_prompt_tokens > max_length:
raise ValueError(
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
f"cannot be greater than max_length ({max_length}). "
f"Please select a smaller truncation size.")
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
truncate_prompt_tokens, max_length)
if truncate_prompt_tokens == 0:
return []
# Parse and batch the input prompts
batch_inputs = parse_and_batch_prompt(prompt_or_prompts)
rendered_prompts: list[EngineTokensPrompt] = []
tokenize_tasks = []
tasks = []
for prompt_input in batch_inputs:
if prompt_input["is_tokens"] is True:
# Token input
token_ids = self._maybe_apply_truncation(
prompt_input["content"], truncate_prompt_tokens)
rendered_prompts.append(
self._create_tokens_prompt(token_ids, max_length,
cache_salt))
detokenize_task = asyncio.create_task(
# Note: detokenization is needed when echo is enabled,
# where the input token IDs are decoded back to text.
self._maybe_detokenize(prompt_input["content"], max_length,
truncate_prompt_tokens, cache_salt,
needs_detokenization))
tasks.append(detokenize_task)
else:
# Text input
tokenize_task = asyncio.create_task(
self._tokenize(prompt_input["content"], max_length,
truncate_prompt_tokens, add_special_tokens,
cache_salt))
tokenize_tasks.append(tokenize_task)
tasks.append(tokenize_task)
# Wait for all text tokenization to finish
if tokenize_tasks:
tokenized_text_prompts = await asyncio.gather(*tokenize_tasks)
rendered_prompts.extend(tokenized_text_prompts)
if tasks:
tokenized_text_prompts = await asyncio.gather(*tasks)
return tokenized_text_prompts
return rendered_prompts
return []
async def render_prompt_and_embeds(
self,
prompt_or_prompts: Optional[Union[str, list[str], list[int],
list[list[int]]]] = None,
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
max_length: Optional[int] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: Optional[bool] = True,
cache_salt: Optional[str] = None,
needs_detokenization: Optional[bool] = False,
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
"""
Render text/token prompts and/or precomputed embedding prompts. At
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
"""
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
truncate_prompt_tokens, max_length)
if truncate_prompt_tokens == 0:
return []
rendered: list[Union[EngineTokensPrompt, EngineEmbedsPrompt]] = []
if prompt_embeds is not None:
rendered.extend(
self.load_prompt_embeds(prompt_embeds, truncate_prompt_tokens,
cache_salt))
if prompt_or_prompts is None or prompt_or_prompts == "":
return rendered
token_prompts = await self.render_prompt(
prompt_or_prompts=prompt_or_prompts,
max_length=max_length,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
cache_salt=cache_salt,
needs_detokenization=needs_detokenization,
)
rendered.extend(token_prompts)
return rendered
def _validate_and_normalize_truncate_tokens(
self,
truncate_prompt_tokens: Optional[int],
max_length: Optional[int],
) -> Optional[int]:
"""Validate and normalize truncate_prompt_tokens parameter."""
if truncate_prompt_tokens is None:
return None
if truncate_prompt_tokens == 0:
return 0
if truncate_prompt_tokens < 0:
truncate_prompt_tokens = self.model_config.max_model_len
if max_length is not None and truncate_prompt_tokens > max_length:
raise ValueError(
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
f"cannot be greater than max_length ({max_length}). "
f"Please select a smaller truncation size.")
return truncate_prompt_tokens
def _maybe_apply_truncation(
self, token_ids: list[int],
@ -186,7 +334,29 @@ class CompletionRenderer(BaseRenderer):
max_length=truncate_prompt_tokens)
return self._create_tokens_prompt(encoded.input_ids, max_length,
cache_salt)
cache_salt, text)
async def _maybe_detokenize(
self,
token_ids: list[int],
max_length: Optional[int],
truncate_prompt_tokens: Optional[int],
cache_salt: Optional[str],
needs_detokenization: Optional[bool] = False,
) -> EngineTokensPrompt:
"""Optionally detokenize token IDs and build a tokens prompt."""
token_ids = self._maybe_apply_truncation(token_ids,
truncate_prompt_tokens)
prompt = None
if needs_detokenization is True:
async_tokenizer = self._get_async_tokenizer()
prompt = await async_tokenizer.decode(token_ids)
return self._create_tokens_prompt(token_ids=token_ids,
max_length=max_length,
cache_salt=cache_salt,
prompt=prompt)
def _get_async_tokenizer(self) -> AsyncMicrobatchTokenizer:
"""Get or create async tokenizer using shared pool."""
@ -210,6 +380,7 @@ class CompletionRenderer(BaseRenderer):
token_ids: list[int],
max_length: Optional[int] = None,
cache_salt: Optional[str] = None,
prompt: Optional[str] = None,
) -> EngineTokensPrompt:
"""Create validated EngineTokensPrompt."""
if max_length is not None and len(token_ids) > max_length:
@ -221,4 +392,6 @@ class CompletionRenderer(BaseRenderer):
tokens_prompt = EngineTokensPrompt(prompt_token_ids=token_ids)
if cache_salt is not None:
tokens_prompt["cache_salt"] = cache_salt
return tokens_prompt
if prompt is not None:
tokens_prompt["prompt"] = prompt
return tokens_prompt

View File

@ -52,6 +52,9 @@ class TokensPrompt(TypedDict):
prompt_token_ids: list[int]
"""A list of token IDs to pass to the model."""
prompt: NotRequired[str]
"""The prompt text corresponding to the token IDs, if available."""
token_type_ids: NotRequired[list[int]]
"""A list of token type IDs to pass to the cross encoder model."""