Consolidate rendering parameters into RenderConfig dataclass (#24543)

Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
Flora Feng
2025-09-10 01:44:47 -07:00
committed by GitHub
parent feaf202e93
commit 77f62613f9
8 changed files with 167 additions and 108 deletions

View File

@ -10,7 +10,7 @@ import pybase64
import pytest
import torch
from vllm.entrypoints.renderer import CompletionRenderer
from vllm.entrypoints.renderer import CompletionRenderer, RenderConfig
from vllm.inputs.data import is_embeds_prompt
@ -56,8 +56,8 @@ class TestRenderPrompt:
@pytest.mark.asyncio
async def test_token_input(self, renderer):
tokens = [101, 7592, 2088]
results = await renderer.render_prompt(prompt_or_prompts=tokens,
max_length=100)
results = await renderer.render_prompt(
prompt_or_prompts=tokens, config=RenderConfig(max_length=100))
assert len(results) == 1
assert results[0]["prompt_token_ids"] == tokens
@ -65,8 +65,8 @@ class TestRenderPrompt:
@pytest.mark.asyncio
async def test_token_list_input(self, renderer):
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
results = await renderer.render_prompt(prompt_or_prompts=token_lists,
max_length=100)
results = await renderer.render_prompt(
prompt_or_prompts=token_lists, config=RenderConfig(max_length=100))
assert len(results) == 3
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
@ -80,8 +80,9 @@ class TestRenderPrompt:
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
max_length=100)
results = await renderer.render_prompt(
prompt_or_prompts="Hello world",
config=RenderConfig(max_length=100))
assert len(results) == 1
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
@ -96,7 +97,8 @@ class TestRenderPrompt:
text_list_input = ["Hello world", "How are you?", "Good morning"]
results = await renderer.render_prompt(
prompt_or_prompts=text_list_input, max_length=100)
prompt_or_prompts=text_list_input,
config=RenderConfig(max_length=100))
assert len(results) == 3
for result in results:
@ -110,8 +112,9 @@ class TestRenderPrompt:
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
max_length=100)
results = await renderer.render_prompt(
prompt_or_prompts="Hello world",
config=RenderConfig(max_length=100))
assert len(results) == 1
call_args = mock_async_tokenizer.call_args
@ -126,8 +129,9 @@ class TestRenderPrompt:
renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
max_length=100,
truncate_prompt_tokens=50)
config=RenderConfig(
max_length=100,
truncate_prompt_tokens=50))
assert len(results) == 1
call_args = mock_async_tokenizer.call_args
@ -143,8 +147,9 @@ class TestRenderPrompt:
renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
max_length=200,
truncate_prompt_tokens=-1)
config=RenderConfig(
max_length=200,
truncate_prompt_tokens=-1))
assert len(results) == 1
call_args = mock_async_tokenizer.call_args
@ -157,8 +162,9 @@ class TestRenderPrompt:
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108,
109] # 10 tokens
results = await renderer.render_prompt(prompt_or_prompts=long_tokens,
max_length=100,
truncate_prompt_tokens=5)
config=RenderConfig(
max_length=100,
truncate_prompt_tokens=5))
assert len(results) == 1
# Should keep the last 5 tokens: [105, 106, 107, 108, 109]
@ -170,7 +176,7 @@ class TestRenderPrompt:
with pytest.raises(ValueError, match="maximum context length"):
await renderer.render_prompt(prompt_or_prompts=long_tokens,
max_length=100)
config=RenderConfig(max_length=100))
@pytest.mark.asyncio
async def test_no_tokenizer_for_text(self, mock_model_config):
@ -181,7 +187,8 @@ class TestRenderPrompt:
with pytest.raises(ValueError, match="No tokenizer available"):
await renderer_no_tokenizer.render_prompt(
prompt_or_prompts="Hello world", max_length=100)
prompt_or_prompts="Hello world",
config=RenderConfig(max_length=100))
@pytest.mark.asyncio
async def test_token_input_with_needs_detokenization(
@ -196,7 +203,7 @@ class TestRenderPrompt:
tokens = [1, 2, 3, 4]
results = await renderer.render_prompt(
prompt_or_prompts=tokens,
needs_detokenization=True,
config=RenderConfig(needs_detokenization=True),
)
assert len(results) == 1
@ -221,7 +228,9 @@ class TestRenderEmbedPrompt:
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes, cache_salt="test_salt")
prompt_embeds=embed_bytes,
config=RenderConfig(cache_salt="test_salt"),
)
assert len(results) == 1
assert is_embeds_prompt(results[0])
@ -240,7 +249,9 @@ class TestRenderEmbedPrompt:
]
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes_list)
prompt_embeds=embed_bytes_list,
config=RenderConfig(),
)
assert len(results) == 2
for i, result in enumerate(results):
@ -254,7 +265,9 @@ class TestRenderEmbedPrompt:
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes, truncate_prompt_tokens=10)
prompt_embeds=embed_bytes,
config=RenderConfig(truncate_prompt_tokens=10),
)
assert len(results) == 1
# Should keep last 10 tokens
@ -271,7 +284,9 @@ class TestRenderEmbedPrompt:
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes)
prompt_embeds=embed_bytes,
config=RenderConfig(),
)
assert len(results) == 1
assert results[0]["prompt_embeds"].dtype == dtype
@ -283,7 +298,9 @@ class TestRenderEmbedPrompt:
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_embeds=embed_bytes)
prompt_embeds=embed_bytes,
config=RenderConfig(),
)
assert len(results) == 1
# Should be squeezed to 2D
@ -303,7 +320,10 @@ class TestRenderEmbedPrompt:
embed_bytes = self._create_test_embed_bytes(test_tensor)
results = await renderer.render_prompt_and_embeds(
prompt_or_prompts="Hello world", prompt_embeds=embed_bytes)
prompt_or_prompts="Hello world",
prompt_embeds=embed_bytes,
config=RenderConfig(),
)
assert len(results) == 2
# First should be embed prompt

View File

@ -20,6 +20,7 @@ from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext,
OpenAIServing,
ServeContext)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.logger import init_logger
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
from vllm.pooling_params import PoolingParams
@ -57,8 +58,7 @@ class ClassificationMixin(OpenAIServing):
renderer = self._get_renderer(ctx.tokenizer)
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input,
max_length=self.max_model_len,
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens)
config=self._build_render_config(ctx.request))
return None
@ -114,6 +114,12 @@ class ClassificationMixin(OpenAIServing):
usage=usage,
)
def _build_render_config(self,
request: ClassificationRequest) -> RenderConfig:
return RenderConfig(
max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens)
class ServingClassification(ClassificationMixin):
request_id_prefix = "classify"

View File

@ -30,6 +30,7 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
clamp_prompt_logprobs)
# yapf: enable
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import get_max_tokens
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
is_tokens_prompt)
@ -129,18 +130,11 @@ class OpenAIServingCompletion(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request
)
renderer = self._get_renderer(tokenizer)
max_input_tokens_len = self.max_model_len - (request.max_tokens
or 0)
engine_prompts = await renderer.render_prompt_and_embeds(
prompt_or_prompts=request.prompt,
prompt_embeds=request.prompt_embeds,
max_length=max_input_tokens_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
cache_salt=request.cache_salt,
needs_detokenization=bool(request.echo
and not request.return_token_ids),
config=self._build_render_config(request),
)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
@ -677,3 +671,18 @@ class OpenAIServingCompletion(OpenAIServing):
tokens=out_tokens,
top_logprobs=out_top_logprobs,
)
def _build_render_config(
self,
request: CompletionRequest,
max_input_length: Optional[int] = None,
) -> RenderConfig:
max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
return RenderConfig(
max_length=max_input_tokens_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
cache_salt=request.cache_salt,
needs_detokenization=bool(request.echo
and not request.return_token_ids),
)

View File

@ -28,6 +28,7 @@ from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
TextTokensPrompt)
# yapf: enable
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
@ -97,23 +98,28 @@ class EmbeddingMixin(OpenAIServing):
add_special_tokens=ctx.request.add_special_tokens,
)
else:
# Set max_length based on chunked processing capability
if self._should_use_chunked_processing(ctx.request):
max_length = None
else:
max_length = self.max_embed_len or self.max_model_len
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input,
max_length=max_length,
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens,
add_special_tokens=ctx.request.add_special_tokens,
config=self._build_render_config(ctx.request),
)
return None
except (ValueError, TypeError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
def _build_render_config(
self, request: EmbeddingCompletionRequest) -> RenderConfig:
# Set max_length based on chunked processing capability
if self._should_use_chunked_processing(request):
max_length = None
else:
max_length = self.max_embed_len or self.max_model_len
return RenderConfig(
max_length=max_length,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens)
@override
def _build_response(
self,

View File

@ -58,7 +58,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TranslationRequest)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer
from vllm.entrypoints.renderer import (BaseRenderer, CompletionRenderer,
RenderConfig)
# yapf: enable
from vllm.inputs.data import PromptType
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
@ -248,6 +249,19 @@ class OpenAIServing:
tokenizer=tokenizer,
async_tokenizer_pool=self._async_tokenizer_pool)
def _build_render_config(
self,
request: Any,
) -> RenderConfig:
"""
Build and return a `RenderConfig` for an endpoint.
Used by the renderer to control how prompts are prepared
(e.g., tokenization and length handling). Endpoints should
implement this with logic appropriate to their request type.
"""
raise NotImplementedError
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
"""
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the

View File

@ -28,6 +28,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
# yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
@ -149,10 +150,7 @@ class OpenAIServingPooling(OpenAIServing):
elif isinstance(request, PoolingCompletionRequest):
engine_prompts = await renderer.render_prompt(
prompt_or_prompts=request.input,
max_length=self.max_model_len,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
cache_salt=getattr(request, 'cache_salt', None),
config=self._build_render_config(request),
)
else:
raise ValueError(
@ -270,3 +268,10 @@ class OpenAIServingPooling(OpenAIServing):
data=items,
usage=usage,
)
def _build_render_config(
self, request: PoolingCompletionRequest) -> RenderConfig:
return RenderConfig(
max_length=self.max_model_len,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens)

View File

@ -22,6 +22,7 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
# yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.renderer import RenderConfig
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
@ -72,7 +73,7 @@ class OpenAIServingTokenization(OpenAIServing):
[tool.model_dump() for tool in request.tools])
(
_,
request_prompts,
_,
engine_prompts,
) = await self._preprocess_chat(
request,
@ -90,15 +91,14 @@ class OpenAIServingTokenization(OpenAIServing):
else:
engine_prompts = await renderer.render_prompt(
prompt_or_prompts=request.prompt,
add_special_tokens=request.add_special_tokens,
cache_salt=getattr(request, 'cache_salt', None),
config=self._build_render_config(request),
)
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(f"{e} {e.__cause__}")
input_ids: list[int] = []
for i, engine_prompt in enumerate(engine_prompts):
for engine_prompt in engine_prompts:
self._log_inputs(request_id,
engine_prompt,
params=None,
@ -157,6 +157,9 @@ class OpenAIServingTokenization(OpenAIServing):
return self.create_error_response(
f"Failed to get tokenizer info: {str(e)}")
def _build_render_config(self, request: TokenizeRequest) -> RenderConfig:
return RenderConfig(add_special_tokens=request.add_special_tokens)
@dataclass
class TokenizerInfo:

View File

@ -4,6 +4,7 @@
import asyncio
import io
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Annotated, Optional, Union
import pybase64
@ -18,6 +19,29 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AsyncMicrobatchTokenizer
@dataclass(frozen=True)
class RenderConfig:
"""Configuration to control how prompts are prepared."""
max_length: Optional[int] = None
"""Maximum allowable total input token length. If provided,
token inputs longer than this raise ``ValueError``."""
truncate_prompt_tokens: Optional[int] = None
"""Number of tokens to keep. ``None`` means no truncation.
``0`` yields an empty list (and skips embeds).
``-1`` maps to ``model_config.max_model_len``."""
add_special_tokens: Optional[bool] = True
"""Whether to add model-specific special tokens during tokenization."""
cache_salt: Optional[str] = None
"""String to disambiguate prefix cache entries."""
needs_detokenization: Optional[bool] = False
"""If True, detokenize IDs back to text for inclusion in outputs."""
class BaseRenderer(ABC):
"""
Base class for unified input processing and rendering.
@ -48,12 +72,9 @@ class BaseRenderer(ABC):
@abstractmethod
async def render_prompt(
self,
*,
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
max_length: Optional[int] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: Optional[bool] = True,
cache_salt: Optional[str] = None,
needs_detokenization: Optional[bool] = False,
config: "RenderConfig",
) -> list[EngineTokensPrompt]:
"""
Convert text or token inputs into engine-ready TokensPrompt objects.
@ -68,16 +89,8 @@ class BaseRenderer(ABC):
- ``list[str]``: Batch of text prompts.
- ``list[int]``: Single pre-tokenized sequence.
- ``list[list[int]]``: Batch of pre-tokenized sequences.
max_length: Maximum allowable total input token length. If provided,
token inputs longer than this raise ``ValueError``.
truncate_prompt_tokens: Number of tokens to keep. ``None`` means no
truncation. ``0`` yields an empty list (and skips embeds).
``-1`` maps to ``model_config.max_model_len``.
add_special_tokens: Whether to add model-specific special tokens
during text tokenization.
cache_salt: Optional string to disambiguate prefix cache entries.
needs_detokenization: If True and ``prompt_or_prompts`` is token
input, detokenize IDs back to text for inclusion in outputs.
config: Render configuration controlling how prompts are prepared
(e.g., tokenization and length handling).
Returns:
list[EngineTokensPrompt]: Engine-ready token prompts.
@ -90,18 +103,15 @@ class BaseRenderer(ABC):
@abstractmethod
async def render_prompt_and_embeds(
self,
*,
prompt_or_prompts: Optional[Union[str, list[str], list[int],
list[list[int]]]] = None,
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
max_length: Optional[int] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: Optional[bool] = True,
cache_salt: Optional[str] = None,
needs_detokenization: Optional[bool] = False,
config: "RenderConfig",
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
"""
Convert text/token and/or base64-encoded embeddings inputs into
engine-ready prompt objects.
engine-ready prompt objects using a unified RenderConfig.
At least one of ``prompt_or_prompts`` or ``prompt_embeds`` must be
provided and non-empty. If both are omitted or empty (e.g., empty
@ -111,15 +121,8 @@ class BaseRenderer(ABC):
prompt_or_prompts: Text or token inputs to include.
prompt_embeds: Base64-encoded bytes (or list thereof) containing a
torch-saved tensor to be used as prompt embeddings.
max_length: Maximum allowable total input token length. If provided,
inputs longer than this raise ``ValueError``.
truncate_prompt_tokens: Number of tokens/rows to keep from the end
of the sequence. ``-1`` maps to ``model_config.max_model_len``.
add_special_tokens: Whether to add model-specific special tokens
during text tokenization.
cache_salt: Optional string to disambiguate prefix cache entries.
needs_detokenization: If True and ``prompt_or_prompts`` is token
input, detokenize IDs back to text for inclusion in outputs.
config: Render configuration controlling how prompts are prepared
(e.g., tokenization and length handling).
Returns:
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
@ -184,12 +187,9 @@ class CompletionRenderer(BaseRenderer):
async def render_prompt(
self,
*,
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
max_length: Optional[int] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: Optional[bool] = True,
cache_salt: Optional[str] = None,
needs_detokenization: Optional[bool] = False,
config: "RenderConfig",
) -> list[EngineTokensPrompt]:
"""Implementation of prompt rendering for completion-style requests.
@ -197,7 +197,7 @@ class CompletionRenderer(BaseRenderer):
for detailed parameter documentation.
"""
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
truncate_prompt_tokens, max_length)
config.truncate_prompt_tokens, config.max_length)
if truncate_prompt_tokens == 0:
return []
@ -211,16 +211,19 @@ class CompletionRenderer(BaseRenderer):
detokenize_task = asyncio.create_task(
# Note: detokenization is needed when echo is enabled,
# where the input token IDs are decoded back to text.
self._maybe_detokenize(prompt_input["content"], max_length,
truncate_prompt_tokens, cache_salt,
needs_detokenization))
self._maybe_detokenize(prompt_input["content"],
config.max_length,
truncate_prompt_tokens,
config.cache_salt,
config.needs_detokenization))
tasks.append(detokenize_task)
else:
# Text input
tokenize_task = asyncio.create_task(
self._tokenize(prompt_input["content"], max_length,
truncate_prompt_tokens, add_special_tokens,
cache_salt))
self._tokenize(prompt_input["content"], config.max_length,
truncate_prompt_tokens,
config.add_special_tokens,
config.cache_salt))
tasks.append(tokenize_task)
# Wait for all text tokenization to finish
@ -232,21 +235,18 @@ class CompletionRenderer(BaseRenderer):
async def render_prompt_and_embeds(
self,
*,
prompt_or_prompts: Optional[Union[str, list[str], list[int],
list[list[int]]]] = None,
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
max_length: Optional[int] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: Optional[bool] = True,
cache_salt: Optional[str] = None,
needs_detokenization: Optional[bool] = False,
config: "RenderConfig",
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
"""
Render text/token prompts and/or precomputed embedding prompts. At
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
"""
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
truncate_prompt_tokens, max_length)
config.truncate_prompt_tokens, config.max_length)
if truncate_prompt_tokens == 0:
return []
@ -255,17 +255,13 @@ class CompletionRenderer(BaseRenderer):
if prompt_embeds is not None:
rendered.extend(
self.load_prompt_embeds(prompt_embeds, truncate_prompt_tokens,
cache_salt))
config.cache_salt))
if prompt_or_prompts is None or prompt_or_prompts == "":
return rendered
token_prompts = await self.render_prompt(
prompt_or_prompts=prompt_or_prompts,
max_length=max_length,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
cache_salt=cache_salt,
needs_detokenization=needs_detokenization,
config=config,
)
rendered.extend(token_prompts)