mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Consolidate rendering parameters into RenderConfig dataclass (#24543)
Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
@ -10,7 +10,7 @@ import pybase64
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.entrypoints.renderer import CompletionRenderer
|
||||
from vllm.entrypoints.renderer import CompletionRenderer, RenderConfig
|
||||
from vllm.inputs.data import is_embeds_prompt
|
||||
|
||||
|
||||
@ -56,8 +56,8 @@ class TestRenderPrompt:
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_input(self, renderer):
|
||||
tokens = [101, 7592, 2088]
|
||||
results = await renderer.render_prompt(prompt_or_prompts=tokens,
|
||||
max_length=100)
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts=tokens, config=RenderConfig(max_length=100))
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["prompt_token_ids"] == tokens
|
||||
@ -65,8 +65,8 @@ class TestRenderPrompt:
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_list_input(self, renderer):
|
||||
token_lists = [[101, 7592, 2088], [102, 1234, 5678, 9012], [103, 4567]]
|
||||
results = await renderer.render_prompt(prompt_or_prompts=token_lists,
|
||||
max_length=100)
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts=token_lists, config=RenderConfig(max_length=100))
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
|
||||
@ -80,8 +80,9 @@ class TestRenderPrompt:
|
||||
renderer.async_tokenizer_pool[
|
||||
renderer.tokenizer] = mock_async_tokenizer
|
||||
|
||||
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
|
||||
max_length=100)
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts="Hello world",
|
||||
config=RenderConfig(max_length=100))
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["prompt_token_ids"] == [101, 7592, 2088]
|
||||
@ -96,7 +97,8 @@ class TestRenderPrompt:
|
||||
|
||||
text_list_input = ["Hello world", "How are you?", "Good morning"]
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts=text_list_input, max_length=100)
|
||||
prompt_or_prompts=text_list_input,
|
||||
config=RenderConfig(max_length=100))
|
||||
|
||||
assert len(results) == 3
|
||||
for result in results:
|
||||
@ -110,8 +112,9 @@ class TestRenderPrompt:
|
||||
renderer.async_tokenizer_pool[
|
||||
renderer.tokenizer] = mock_async_tokenizer
|
||||
|
||||
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
|
||||
max_length=100)
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts="Hello world",
|
||||
config=RenderConfig(max_length=100))
|
||||
|
||||
assert len(results) == 1
|
||||
call_args = mock_async_tokenizer.call_args
|
||||
@ -126,8 +129,9 @@ class TestRenderPrompt:
|
||||
renderer.tokenizer] = mock_async_tokenizer
|
||||
|
||||
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
|
||||
max_length=100,
|
||||
truncate_prompt_tokens=50)
|
||||
config=RenderConfig(
|
||||
max_length=100,
|
||||
truncate_prompt_tokens=50))
|
||||
|
||||
assert len(results) == 1
|
||||
call_args = mock_async_tokenizer.call_args
|
||||
@ -143,8 +147,9 @@ class TestRenderPrompt:
|
||||
renderer.tokenizer] = mock_async_tokenizer
|
||||
|
||||
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
|
||||
max_length=200,
|
||||
truncate_prompt_tokens=-1)
|
||||
config=RenderConfig(
|
||||
max_length=200,
|
||||
truncate_prompt_tokens=-1))
|
||||
|
||||
assert len(results) == 1
|
||||
call_args = mock_async_tokenizer.call_args
|
||||
@ -157,8 +162,9 @@ class TestRenderPrompt:
|
||||
long_tokens = [100, 101, 102, 103, 104, 105, 106, 107, 108,
|
||||
109] # 10 tokens
|
||||
results = await renderer.render_prompt(prompt_or_prompts=long_tokens,
|
||||
max_length=100,
|
||||
truncate_prompt_tokens=5)
|
||||
config=RenderConfig(
|
||||
max_length=100,
|
||||
truncate_prompt_tokens=5))
|
||||
|
||||
assert len(results) == 1
|
||||
# Should keep the last 5 tokens: [105, 106, 107, 108, 109]
|
||||
@ -170,7 +176,7 @@ class TestRenderPrompt:
|
||||
|
||||
with pytest.raises(ValueError, match="maximum context length"):
|
||||
await renderer.render_prompt(prompt_or_prompts=long_tokens,
|
||||
max_length=100)
|
||||
config=RenderConfig(max_length=100))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tokenizer_for_text(self, mock_model_config):
|
||||
@ -181,7 +187,8 @@ class TestRenderPrompt:
|
||||
|
||||
with pytest.raises(ValueError, match="No tokenizer available"):
|
||||
await renderer_no_tokenizer.render_prompt(
|
||||
prompt_or_prompts="Hello world", max_length=100)
|
||||
prompt_or_prompts="Hello world",
|
||||
config=RenderConfig(max_length=100))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_input_with_needs_detokenization(
|
||||
@ -196,7 +203,7 @@ class TestRenderPrompt:
|
||||
tokens = [1, 2, 3, 4]
|
||||
results = await renderer.render_prompt(
|
||||
prompt_or_prompts=tokens,
|
||||
needs_detokenization=True,
|
||||
config=RenderConfig(needs_detokenization=True),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
@ -221,7 +228,9 @@ class TestRenderEmbedPrompt:
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_embeds=embed_bytes, cache_salt="test_salt")
|
||||
prompt_embeds=embed_bytes,
|
||||
config=RenderConfig(cache_salt="test_salt"),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert is_embeds_prompt(results[0])
|
||||
@ -240,7 +249,9 @@ class TestRenderEmbedPrompt:
|
||||
]
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_embeds=embed_bytes_list)
|
||||
prompt_embeds=embed_bytes_list,
|
||||
config=RenderConfig(),
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
for i, result in enumerate(results):
|
||||
@ -254,7 +265,9 @@ class TestRenderEmbedPrompt:
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_embeds=embed_bytes, truncate_prompt_tokens=10)
|
||||
prompt_embeds=embed_bytes,
|
||||
config=RenderConfig(truncate_prompt_tokens=10),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
# Should keep last 10 tokens
|
||||
@ -271,7 +284,9 @@ class TestRenderEmbedPrompt:
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_embeds=embed_bytes)
|
||||
prompt_embeds=embed_bytes,
|
||||
config=RenderConfig(),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["prompt_embeds"].dtype == dtype
|
||||
@ -283,7 +298,9 @@ class TestRenderEmbedPrompt:
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_embeds=embed_bytes)
|
||||
prompt_embeds=embed_bytes,
|
||||
config=RenderConfig(),
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
# Should be squeezed to 2D
|
||||
@ -303,7 +320,10 @@ class TestRenderEmbedPrompt:
|
||||
embed_bytes = self._create_test_embed_bytes(test_tensor)
|
||||
|
||||
results = await renderer.render_prompt_and_embeds(
|
||||
prompt_or_prompts="Hello world", prompt_embeds=embed_bytes)
|
||||
prompt_or_prompts="Hello world",
|
||||
prompt_embeds=embed_bytes,
|
||||
config=RenderConfig(),
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
# First should be embed prompt
|
||||
|
@ -20,6 +20,7 @@ from vllm.entrypoints.openai.serving_engine import (ClassificationServeContext,
|
||||
OpenAIServing,
|
||||
ServeContext)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import ClassificationOutput, PoolingRequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@ -57,8 +58,7 @@ class ClassificationMixin(OpenAIServing):
|
||||
renderer = self._get_renderer(ctx.tokenizer)
|
||||
ctx.engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=ctx.request.input,
|
||||
max_length=self.max_model_len,
|
||||
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens)
|
||||
config=self._build_render_config(ctx.request))
|
||||
|
||||
return None
|
||||
|
||||
@ -114,6 +114,12 @@ class ClassificationMixin(OpenAIServing):
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _build_render_config(self,
|
||||
request: ClassificationRequest) -> RenderConfig:
|
||||
return RenderConfig(
|
||||
max_length=self.max_model_len,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens)
|
||||
|
||||
|
||||
class ServingClassification(ClassificationMixin):
|
||||
request_id_prefix = "classify"
|
||||
|
@ -30,6 +30,7 @@ from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
clamp_prompt_logprobs)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.entrypoints.utils import get_max_tokens
|
||||
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
|
||||
is_tokens_prompt)
|
||||
@ -129,18 +130,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request
|
||||
)
|
||||
renderer = self._get_renderer(tokenizer)
|
||||
max_input_tokens_len = self.max_model_len - (request.max_tokens
|
||||
or 0)
|
||||
|
||||
engine_prompts = await renderer.render_prompt_and_embeds(
|
||||
prompt_or_prompts=request.prompt,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
max_length=max_input_tokens_len,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
cache_salt=request.cache_salt,
|
||||
needs_detokenization=bool(request.echo
|
||||
and not request.return_token_ids),
|
||||
config=self._build_render_config(request),
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
@ -677,3 +671,18 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
tokens=out_tokens,
|
||||
top_logprobs=out_top_logprobs,
|
||||
)
|
||||
|
||||
def _build_render_config(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
max_input_length: Optional[int] = None,
|
||||
) -> RenderConfig:
|
||||
max_input_tokens_len = self.max_model_len - (request.max_tokens or 0)
|
||||
return RenderConfig(
|
||||
max_length=max_input_tokens_len,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
cache_salt=request.cache_salt,
|
||||
needs_detokenization=bool(request.echo
|
||||
and not request.return_token_ids),
|
||||
)
|
||||
|
@ -28,6 +28,7 @@ from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
|
||||
TextTokensPrompt)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
|
||||
@ -97,23 +98,28 @@ class EmbeddingMixin(OpenAIServing):
|
||||
add_special_tokens=ctx.request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
# Set max_length based on chunked processing capability
|
||||
if self._should_use_chunked_processing(ctx.request):
|
||||
max_length = None
|
||||
else:
|
||||
max_length = self.max_embed_len or self.max_model_len
|
||||
|
||||
ctx.engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=ctx.request.input,
|
||||
max_length=max_length,
|
||||
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens,
|
||||
add_special_tokens=ctx.request.add_special_tokens,
|
||||
config=self._build_render_config(ctx.request),
|
||||
)
|
||||
return None
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
def _build_render_config(
|
||||
self, request: EmbeddingCompletionRequest) -> RenderConfig:
|
||||
# Set max_length based on chunked processing capability
|
||||
if self._should_use_chunked_processing(request):
|
||||
max_length = None
|
||||
else:
|
||||
max_length = self.max_embed_len or self.max_model_len
|
||||
|
||||
return RenderConfig(
|
||||
max_length=max_length,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens)
|
||||
|
||||
@override
|
||||
def _build_response(
|
||||
self,
|
||||
|
@ -58,7 +58,8 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
TranslationRequest)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser
|
||||
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer
|
||||
from vllm.entrypoints.renderer import (BaseRenderer, CompletionRenderer,
|
||||
RenderConfig)
|
||||
# yapf: enable
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
|
||||
@ -248,6 +249,19 @@ class OpenAIServing:
|
||||
tokenizer=tokenizer,
|
||||
async_tokenizer_pool=self._async_tokenizer_pool)
|
||||
|
||||
def _build_render_config(
|
||||
self,
|
||||
request: Any,
|
||||
) -> RenderConfig:
|
||||
"""
|
||||
Build and return a `RenderConfig` for an endpoint.
|
||||
|
||||
Used by the renderer to control how prompts are prepared
|
||||
(e.g., tokenization and length handling). Endpoints should
|
||||
implement this with logic appropriate to their request type.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_async_tokenizer(self, tokenizer) -> AsyncMicrobatchTokenizer:
|
||||
"""
|
||||
Return (and cache) an `AsyncMicrobatchTokenizer` bound to the
|
||||
|
@ -28,6 +28,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
@ -149,10 +150,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
elif isinstance(request, PoolingCompletionRequest):
|
||||
engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=request.input,
|
||||
max_length=self.max_model_len,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
cache_salt=getattr(request, 'cache_salt', None),
|
||||
config=self._build_render_config(request),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -270,3 +268,10 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
data=items,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _build_render_config(
|
||||
self, request: PoolingCompletionRequest) -> RenderConfig:
|
||||
return RenderConfig(
|
||||
max_length=self.max_model_len,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens)
|
||||
|
@ -22,6 +22,7 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
@ -72,7 +73,7 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
[tool.model_dump() for tool in request.tools])
|
||||
(
|
||||
_,
|
||||
request_prompts,
|
||||
_,
|
||||
engine_prompts,
|
||||
) = await self._preprocess_chat(
|
||||
request,
|
||||
@ -90,15 +91,14 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
else:
|
||||
engine_prompts = await renderer.render_prompt(
|
||||
prompt_or_prompts=request.prompt,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
cache_salt=getattr(request, 'cache_salt', None),
|
||||
config=self._build_render_config(request),
|
||||
)
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(f"{e} {e.__cause__}")
|
||||
|
||||
input_ids: list[int] = []
|
||||
for i, engine_prompt in enumerate(engine_prompts):
|
||||
for engine_prompt in engine_prompts:
|
||||
self._log_inputs(request_id,
|
||||
engine_prompt,
|
||||
params=None,
|
||||
@ -157,6 +157,9 @@ class OpenAIServingTokenization(OpenAIServing):
|
||||
return self.create_error_response(
|
||||
f"Failed to get tokenizer info: {str(e)}")
|
||||
|
||||
def _build_render_config(self, request: TokenizeRequest) -> RenderConfig:
|
||||
return RenderConfig(add_special_tokens=request.add_special_tokens)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenizerInfo:
|
||||
|
@ -4,6 +4,7 @@
|
||||
import asyncio
|
||||
import io
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Optional, Union
|
||||
|
||||
import pybase64
|
||||
@ -18,6 +19,29 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import AsyncMicrobatchTokenizer
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RenderConfig:
|
||||
"""Configuration to control how prompts are prepared."""
|
||||
|
||||
max_length: Optional[int] = None
|
||||
"""Maximum allowable total input token length. If provided,
|
||||
token inputs longer than this raise ``ValueError``."""
|
||||
|
||||
truncate_prompt_tokens: Optional[int] = None
|
||||
"""Number of tokens to keep. ``None`` means no truncation.
|
||||
``0`` yields an empty list (and skips embeds).
|
||||
``-1`` maps to ``model_config.max_model_len``."""
|
||||
|
||||
add_special_tokens: Optional[bool] = True
|
||||
"""Whether to add model-specific special tokens during tokenization."""
|
||||
|
||||
cache_salt: Optional[str] = None
|
||||
"""String to disambiguate prefix cache entries."""
|
||||
|
||||
needs_detokenization: Optional[bool] = False
|
||||
"""If True, detokenize IDs back to text for inclusion in outputs."""
|
||||
|
||||
|
||||
class BaseRenderer(ABC):
|
||||
"""
|
||||
Base class for unified input processing and rendering.
|
||||
@ -48,12 +72,9 @@ class BaseRenderer(ABC):
|
||||
@abstractmethod
|
||||
async def render_prompt(
|
||||
self,
|
||||
*,
|
||||
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
|
||||
max_length: Optional[int] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: Optional[bool] = True,
|
||||
cache_salt: Optional[str] = None,
|
||||
needs_detokenization: Optional[bool] = False,
|
||||
config: "RenderConfig",
|
||||
) -> list[EngineTokensPrompt]:
|
||||
"""
|
||||
Convert text or token inputs into engine-ready TokensPrompt objects.
|
||||
@ -68,16 +89,8 @@ class BaseRenderer(ABC):
|
||||
- ``list[str]``: Batch of text prompts.
|
||||
- ``list[int]``: Single pre-tokenized sequence.
|
||||
- ``list[list[int]]``: Batch of pre-tokenized sequences.
|
||||
max_length: Maximum allowable total input token length. If provided,
|
||||
token inputs longer than this raise ``ValueError``.
|
||||
truncate_prompt_tokens: Number of tokens to keep. ``None`` means no
|
||||
truncation. ``0`` yields an empty list (and skips embeds).
|
||||
``-1`` maps to ``model_config.max_model_len``.
|
||||
add_special_tokens: Whether to add model-specific special tokens
|
||||
during text tokenization.
|
||||
cache_salt: Optional string to disambiguate prefix cache entries.
|
||||
needs_detokenization: If True and ``prompt_or_prompts`` is token
|
||||
input, detokenize IDs back to text for inclusion in outputs.
|
||||
config: Render configuration controlling how prompts are prepared
|
||||
(e.g., tokenization and length handling).
|
||||
|
||||
Returns:
|
||||
list[EngineTokensPrompt]: Engine-ready token prompts.
|
||||
@ -90,18 +103,15 @@ class BaseRenderer(ABC):
|
||||
@abstractmethod
|
||||
async def render_prompt_and_embeds(
|
||||
self,
|
||||
*,
|
||||
prompt_or_prompts: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]] = None,
|
||||
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: Optional[bool] = True,
|
||||
cache_salt: Optional[str] = None,
|
||||
needs_detokenization: Optional[bool] = False,
|
||||
config: "RenderConfig",
|
||||
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
||||
"""
|
||||
Convert text/token and/or base64-encoded embeddings inputs into
|
||||
engine-ready prompt objects.
|
||||
engine-ready prompt objects using a unified RenderConfig.
|
||||
|
||||
At least one of ``prompt_or_prompts`` or ``prompt_embeds`` must be
|
||||
provided and non-empty. If both are omitted or empty (e.g., empty
|
||||
@ -111,15 +121,8 @@ class BaseRenderer(ABC):
|
||||
prompt_or_prompts: Text or token inputs to include.
|
||||
prompt_embeds: Base64-encoded bytes (or list thereof) containing a
|
||||
torch-saved tensor to be used as prompt embeddings.
|
||||
max_length: Maximum allowable total input token length. If provided,
|
||||
inputs longer than this raise ``ValueError``.
|
||||
truncate_prompt_tokens: Number of tokens/rows to keep from the end
|
||||
of the sequence. ``-1`` maps to ``model_config.max_model_len``.
|
||||
add_special_tokens: Whether to add model-specific special tokens
|
||||
during text tokenization.
|
||||
cache_salt: Optional string to disambiguate prefix cache entries.
|
||||
needs_detokenization: If True and ``prompt_or_prompts`` is token
|
||||
input, detokenize IDs back to text for inclusion in outputs.
|
||||
config: Render configuration controlling how prompts are prepared
|
||||
(e.g., tokenization and length handling).
|
||||
|
||||
Returns:
|
||||
list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
||||
@ -184,12 +187,9 @@ class CompletionRenderer(BaseRenderer):
|
||||
|
||||
async def render_prompt(
|
||||
self,
|
||||
*,
|
||||
prompt_or_prompts: Union[str, list[str], list[int], list[list[int]]],
|
||||
max_length: Optional[int] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: Optional[bool] = True,
|
||||
cache_salt: Optional[str] = None,
|
||||
needs_detokenization: Optional[bool] = False,
|
||||
config: "RenderConfig",
|
||||
) -> list[EngineTokensPrompt]:
|
||||
"""Implementation of prompt rendering for completion-style requests.
|
||||
|
||||
@ -197,7 +197,7 @@ class CompletionRenderer(BaseRenderer):
|
||||
for detailed parameter documentation.
|
||||
"""
|
||||
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
|
||||
truncate_prompt_tokens, max_length)
|
||||
config.truncate_prompt_tokens, config.max_length)
|
||||
if truncate_prompt_tokens == 0:
|
||||
return []
|
||||
|
||||
@ -211,16 +211,19 @@ class CompletionRenderer(BaseRenderer):
|
||||
detokenize_task = asyncio.create_task(
|
||||
# Note: detokenization is needed when echo is enabled,
|
||||
# where the input token IDs are decoded back to text.
|
||||
self._maybe_detokenize(prompt_input["content"], max_length,
|
||||
truncate_prompt_tokens, cache_salt,
|
||||
needs_detokenization))
|
||||
self._maybe_detokenize(prompt_input["content"],
|
||||
config.max_length,
|
||||
truncate_prompt_tokens,
|
||||
config.cache_salt,
|
||||
config.needs_detokenization))
|
||||
tasks.append(detokenize_task)
|
||||
else:
|
||||
# Text input
|
||||
tokenize_task = asyncio.create_task(
|
||||
self._tokenize(prompt_input["content"], max_length,
|
||||
truncate_prompt_tokens, add_special_tokens,
|
||||
cache_salt))
|
||||
self._tokenize(prompt_input["content"], config.max_length,
|
||||
truncate_prompt_tokens,
|
||||
config.add_special_tokens,
|
||||
config.cache_salt))
|
||||
tasks.append(tokenize_task)
|
||||
|
||||
# Wait for all text tokenization to finish
|
||||
@ -232,21 +235,18 @@ class CompletionRenderer(BaseRenderer):
|
||||
|
||||
async def render_prompt_and_embeds(
|
||||
self,
|
||||
*,
|
||||
prompt_or_prompts: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]] = None,
|
||||
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: Optional[bool] = True,
|
||||
cache_salt: Optional[str] = None,
|
||||
needs_detokenization: Optional[bool] = False,
|
||||
config: "RenderConfig",
|
||||
) -> list[Union[EngineTokensPrompt, EngineEmbedsPrompt]]:
|
||||
"""
|
||||
Render text/token prompts and/or precomputed embedding prompts. At
|
||||
least one of `prompt_or_prompts` or `prompt_embeds` must be provided.
|
||||
"""
|
||||
truncate_prompt_tokens = self._validate_and_normalize_truncate_tokens(
|
||||
truncate_prompt_tokens, max_length)
|
||||
config.truncate_prompt_tokens, config.max_length)
|
||||
if truncate_prompt_tokens == 0:
|
||||
return []
|
||||
|
||||
@ -255,17 +255,13 @@ class CompletionRenderer(BaseRenderer):
|
||||
if prompt_embeds is not None:
|
||||
rendered.extend(
|
||||
self.load_prompt_embeds(prompt_embeds, truncate_prompt_tokens,
|
||||
cache_salt))
|
||||
config.cache_salt))
|
||||
if prompt_or_prompts is None or prompt_or_prompts == "":
|
||||
return rendered
|
||||
|
||||
token_prompts = await self.render_prompt(
|
||||
prompt_or_prompts=prompt_or_prompts,
|
||||
max_length=max_length,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
cache_salt=cache_salt,
|
||||
needs_detokenization=needs_detokenization,
|
||||
config=config,
|
||||
)
|
||||
rendered.extend(token_prompts)
|
||||
|
||||
@ -394,4 +390,4 @@ class CompletionRenderer(BaseRenderer):
|
||||
tokens_prompt["cache_salt"] = cache_salt
|
||||
if prompt is not None:
|
||||
tokens_prompt["prompt"] = prompt
|
||||
return tokens_prompt
|
||||
return tokens_prompt
|
||||
|
Reference in New Issue
Block a user