Fix wrong truncate_prompt_tokens type hint (#22761)

Signed-off-by: Gabriel Marinho <gmarinho@ibm.com>
Signed-off-by: Gabriel Marinho <104592062+gmarinho2@users.noreply.github.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Gabriel Marinho
2025-08-30 17:39:38 -03:00
committed by GitHub
parent 038e9be4eb
commit 5b8077b8ac
14 changed files with 101 additions and 102 deletions

View File

@ -51,7 +51,7 @@ from vllm.tasks import PoolingTask
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, Device, is_list_of
from vllm.utils import Counter, Device, as_iter, is_list_of
from vllm.v1.sample.logits_processor import LogitsProcessor
if TYPE_CHECKING:
@ -364,14 +364,6 @@ class LLM:
# Use default sampling params.
sampling_params = self.get_default_sampling_params()
tokenization_kwargs: dict[str, Any] = {}
truncate_prompt_tokens = None
if isinstance(sampling_params, SamplingParams):
truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
_validate_truncation_size(model_config.max_model_len,
truncate_prompt_tokens, tokenization_kwargs)
# Add any modality specific loras to the corresponding prompts
lora_request = self._get_modality_specific_lora_reqs(
prompts, lora_request)
@ -381,7 +373,6 @@ class LLM:
params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
priority=priority,
)
@ -871,6 +862,8 @@ class LLM:
If `False`, no progress bar is created.
lora_request: LoRA request to use for generation, if any.
pooling_task: Override the pooling task to use.
tokenization_kwargs: overrides tokenization_kwargs set in
pooling_params
Returns:
A list of `PoolingRequestOutput` objects containing the
@ -916,24 +909,17 @@ class LLM:
# Use default pooling params.
pooling_params = PoolingParams()
if isinstance(pooling_params, PoolingParams):
pooling_params.verify(pooling_task, model_config)
else:
for pooling_param in pooling_params:
pooling_param.verify(pooling_task, model_config)
if tokenization_kwargs is None:
tokenization_kwargs = dict[str, Any]()
_validate_truncation_size(model_config.max_model_len,
truncate_prompt_tokens,
tokenization_kwargs)
for param in as_iter(pooling_params):
param.verify(pooling_task, model_config)
# for backwards compatibility
if truncate_prompt_tokens is not None:
param.truncate_prompt_tokens = truncate_prompt_tokens
self._validate_and_add_requests(
prompts=prompts,
params=pooling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
outputs = self._run_engine(use_tqdm=use_tqdm)
@ -1385,7 +1371,6 @@ class LLM:
*,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
tokenization_kwargs: Optional[dict[str, Any]] = None,
priority: Optional[list[int]] = None,
) -> None:
if isinstance(prompts, (str, dict)):
@ -1412,7 +1397,17 @@ class LLM:
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
it = tqdm_func(it, desc="Adding requests")
model_config = self.llm_engine.model_config
for i, prompt in enumerate(it):
param = params[i] if isinstance(params, Sequence) else params
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(model_config.max_model_len,
param.truncate_prompt_tokens,
tokenization_kwargs)
self._add_request(
prompt,
params[i] if isinstance(params, Sequence) else params,

View File

@ -452,7 +452,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
min_tokens: int = 0
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
prompt_logprobs: Optional[int] = None
allowed_token_ids: Optional[list[int]] = None
bad_words: list[str] = Field(default_factory=list)
@ -995,7 +995,7 @@ class CompletionRequest(OpenAIBaseModel):
min_tokens: int = 0
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
allowed_token_ids: Optional[list[int]] = None
prompt_logprobs: Optional[int] = None
# --8<-- [end:completion-sampling-params]
@ -1325,8 +1325,10 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
# --8<-- [end:embedding-extra-params]
def to_pooling_params(self):
return PoolingParams(dimensions=self.dimensions,
normalize=self.normalize)
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
dimensions=self.dimensions,
normalize=self.normalize)
class EmbeddingChatRequest(OpenAIBaseModel):
@ -1393,8 +1395,10 @@ class EmbeddingChatRequest(OpenAIBaseModel):
return data
def to_pooling_params(self):
return PoolingParams(dimensions=self.dimensions,
normalize=self.normalize)
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
dimensions=self.dimensions,
normalize=self.normalize)
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
@ -1430,7 +1434,9 @@ class ScoreRequest(OpenAIBaseModel):
# --8<-- [end:score-extra-params]
def to_pooling_params(self):
return PoolingParams(activation=self.activation)
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
activation=self.activation)
class RerankRequest(OpenAIBaseModel):
@ -1460,7 +1466,9 @@ class RerankRequest(OpenAIBaseModel):
# --8<-- [end:rerank-extra-params]
def to_pooling_params(self):
return PoolingParams(activation=self.activation)
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
activation=self.activation)
class RerankDocument(BaseModel):
@ -1618,7 +1626,9 @@ class ClassificationRequest(OpenAIBaseModel):
# --8<-- [end:classification-extra-params]
def to_pooling_params(self):
return PoolingParams(activation=self.activation)
return PoolingParams(
truncate_prompt_tokens=self.truncate_prompt_tokens,
activation=self.activation)
class ClassificationData(OpenAIBaseModel):

View File

@ -237,7 +237,6 @@ class OpenAIServingChat(OpenAIServing):
documents=request.documents,
chat_template_kwargs=request.chat_template_kwargs,
tool_parser=tool_parser,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
else:

View File

@ -61,7 +61,6 @@ class ClassificationMixin(OpenAIServing):
ctx.request,
ctx.tokenizer,
ctx.request.input,
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens,
)
return None
@ -157,18 +156,6 @@ class ServingClassification(ClassificationMixin):
return await super().handle(ctx) # type: ignore
@override
def _validate_request(
self,
ctx: ClassificationServeContext,
) -> Optional[ErrorResponse]:
if error := super()._validate_request(ctx):
return error
ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens
return None
@override
def _create_pooling_params(
self,

View File

@ -137,7 +137,6 @@ class OpenAIServingCompletion(OpenAIServing):
request,
tokenizer,
request.prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
except ValueError as e:

View File

@ -97,7 +97,6 @@ class EmbeddingMixin(OpenAIServing):
# so there is no need to append extra tokens to the input
add_generation_prompt=False,
continue_final_message=False,
truncate_prompt_tokens=ctx.truncate_prompt_tokens,
add_special_tokens=ctx.request.add_special_tokens,
)
else:
@ -106,7 +105,6 @@ class EmbeddingMixin(OpenAIServing):
ctx.request,
tokenizer,
ctx.request.input,
truncate_prompt_tokens=ctx.truncate_prompt_tokens,
add_special_tokens=ctx.request.add_special_tokens,
)
return None
@ -631,18 +629,6 @@ class OpenAIServingEmbedding(EmbeddingMixin):
return await super().handle(ctx) # type: ignore
@override
def _validate_request(
self,
ctx: ServeContext[EmbeddingRequest],
) -> Optional[ErrorResponse]:
if error := super()._validate_request(ctx):
return error
ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens
return None
@override
def _create_pooling_params(
self,

View File

@ -165,7 +165,6 @@ class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel,
# Shared across most requests
tokenizer: Optional[AnyTokenizer] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# `protected_namespaces` resolves Pydantic v2's warning
# on conflict with protected namespace "model_"
@ -297,14 +296,12 @@ class OpenAIServing:
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
None)
if truncate_prompt_tokens is not None:
if truncate_prompt_tokens <= self.max_model_len:
ctx.truncate_prompt_tokens = truncate_prompt_tokens
else:
return self.create_error_response(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size.")
if truncate_prompt_tokens is not None and \
truncate_prompt_tokens > self.max_model_len:
return self.create_error_response(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size.")
return None
def _create_pooling_params(
@ -528,7 +525,6 @@ class OpenAIServing:
request: AnyRequest,
prompt: str,
tokenizer: AnyTokenizer,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]],
add_special_tokens: bool,
) -> TextTokensPrompt:
async_tokenizer = self._get_async_tokenizer(tokenizer)
@ -538,6 +534,9 @@ class OpenAIServing:
"do_lower_case", False)):
prompt = prompt.lower()
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None)
if truncate_prompt_tokens is None:
encoded = await async_tokenizer(
prompt, add_special_tokens=add_special_tokens)
@ -565,8 +564,10 @@ class OpenAIServing:
request: AnyRequest,
prompt_ids: list[int],
tokenizer: Optional[AnyTokenizer],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
) -> TextTokensPrompt:
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None)
if truncate_prompt_tokens is None:
input_ids = prompt_ids
elif truncate_prompt_tokens < 0:
@ -652,7 +653,6 @@ class OpenAIServing:
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_input: Union[str, list[int]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: bool = True,
) -> TextTokensPrompt:
"""
@ -664,7 +664,6 @@ class OpenAIServing:
request,
tokenizer,
[prompt_input],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
):
return result
@ -675,7 +674,6 @@ class OpenAIServing:
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_inputs: Iterable[Union[str, list[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: bool = True,
) -> AsyncGenerator[TextTokensPrompt, None]:
"""
@ -689,7 +687,6 @@ class OpenAIServing:
request,
prompt=prompt,
tokenizer=tokenizer,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
@ -697,7 +694,6 @@ class OpenAIServing:
request,
prompt_ids=prompt,
tokenizer=tokenizer,
truncate_prompt_tokens=truncate_prompt_tokens,
)
async def _tokenize_prompt_input_or_inputs_async(
@ -706,7 +702,6 @@ class OpenAIServing:
tokenizer: Optional[AnyTokenizer],
input_or_inputs: Optional[Union[str, list[str], list[int],
list[list[int]]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: bool = True,
) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]:
"""
@ -719,6 +714,12 @@ class OpenAIServing:
inputs_embeds = list[EmbedsPrompt]()
inputs_text = list[TextTokensPrompt]()
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None)
if (truncate_prompt_tokens or 0) < 0:
truncate_prompt_tokens = self.max_model_len
if (isinstance(request, CompletionRequest)
and request.prompt_embeds is not None):
inputs_embeds.extend(
@ -748,14 +749,10 @@ class OpenAIServing:
request,
prompt_input["content"],
tokenizer=tokenizer,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens)
else:
task = self._normalize_prompt_tokens_to_input(
request,
prompt_input["content"],
tokenizer=tokenizer,
truncate_prompt_tokens=truncate_prompt_tokens)
request, prompt_input["content"], tokenizer=tokenizer)
tasks.append(task)
# Wait for all tokenization tasks to complete
@ -772,7 +769,6 @@ class OpenAIServing:
TokenizeCompletionRequest],
tokenizer: Optional[AnyTokenizer],
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
add_special_tokens: bool = ...,
) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]:
...
@ -784,7 +780,6 @@ class OpenAIServing:
tokenizer: Optional[AnyTokenizer],
input_or_inputs: Optional[Union[str, list[str], list[int],
list[list[int]]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
add_special_tokens: bool = ...,
) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[
EngineTokensPrompt, EngineEmbedsPrompt]]]:
@ -796,7 +791,6 @@ class OpenAIServing:
tokenizer: Optional[AnyTokenizer],
input_or_inputs: Optional[Union[str, list[str], list[int],
list[list[int]]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
add_special_tokens: bool = True,
) -> tuple[Union[list[TextTokensPrompt], list[Union[
TextTokensPrompt, EmbedsPrompt]]], Union[
@ -813,7 +807,6 @@ class OpenAIServing:
request,
tokenizer,
input_or_inputs,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
@ -866,7 +859,6 @@ class OpenAIServing:
documents: Optional[list[dict[str, str]]] = None,
chat_template_kwargs: Optional[dict[str, Any]] = None,
tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = False,
) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
list[EngineTokensPrompt]]:
@ -941,7 +933,6 @@ class OpenAIServing:
request,
tokenizer,
request_prompt,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:

View File

@ -120,7 +120,6 @@ class OpenAIServingPooling(OpenAIServing):
# so there is no need to append extra tokens to the input
add_generation_prompt=False,
continue_final_message=False,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
else:
@ -129,7 +128,6 @@ class OpenAIServingPooling(OpenAIServing):
request,
tokenizer,
request.input,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
except (ValueError, TypeError, jinja2.TemplateError) as e:

View File

@ -266,12 +266,14 @@ class ServingScores(OpenAIServing):
request: Union[ScoreRequest, RerankRequest],
request_id: str,
raw_request: Optional[Request] = None,
truncate_prompt_tokens: Optional[int] = None,
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
lora_request = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None)
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(self.max_model_len, truncate_prompt_tokens,
tokenization_kwargs)
@ -343,7 +345,6 @@ class ServingScores(OpenAIServing):
request,
request_id,
raw_request,
request.truncate_prompt_tokens,
)
if isinstance(final_res_batch, ErrorResponse):
return final_res_batch
@ -391,7 +392,6 @@ class ServingScores(OpenAIServing):
request,
request_id,
raw_request,
request.truncate_prompt_tokens,
)
if isinstance(final_res_batch, ErrorResponse):
return final_res_batch

View File

@ -346,6 +346,22 @@ class InputPreprocessor:
) -> EmbedsInputs:
return self._process_embeds(parsed_content)
def _truncate_inputs(
self,
inputs: list[int],
tokenization_kwargs: Optional[dict[str, Any]] = None) -> list[int]:
if not tokenization_kwargs or "truncation" not in \
tokenization_kwargs or self.tokenizer is None:
return inputs
max_length = tokenization_kwargs["max_length"]
if self.tokenizer.truncation_side == "left":
return inputs[-max_length:]
else:
return inputs[:max_length]
def _process_tokens(
self,
parsed_content: TokensPrompt,
@ -354,7 +370,8 @@ class InputPreprocessor:
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"]
prompt_token_ids = self._truncate_inputs(
parsed_content["prompt_token_ids"], tokenization_kwargs)
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
@ -382,7 +399,8 @@ class InputPreprocessor:
*,
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"]
prompt_token_ids = self._truncate_inputs(
parsed_content["prompt_token_ids"], tokenization_kwargs)
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Annotated, Any, Optional
import msgspec
@ -27,6 +27,11 @@ class PoolingParams(
the classification outputs.
softmax: Whether to apply softmax to the reward outputs.
"""
truncate_prompt_tokens: Optional[Annotated[int,
msgspec.Meta(ge=-1)]] = None
"""If set to -1, will use the truncation size supported by the model. If
set to an integer k, will use only the last k tokens from the prompt
(i.e., left truncation). If set to `None`, truncation is disabled."""
## for embeddings models
dimensions: Optional[int] = None

View File

@ -182,7 +182,8 @@ class SamplingParams(
optionally prompt tokens as a first argument."""
include_stop_str_in_output: bool = False
"""Whether to include the stop strings in output text."""
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
truncate_prompt_tokens: Optional[Annotated[int,
msgspec.Meta(ge=-1)]] = None
"""If set to -1, will use the truncation size supported by the model. If
set to an integer k, will use only the last k tokens from the prompt
(i.e., left truncation). If set to `None`, truncation is disabled."""
@ -241,7 +242,8 @@ class SamplingParams(
spaces_between_special_tokens: bool = True,
logits_processors: Optional[list[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[Annotated[int,
msgspec.Meta(ge=1)]] = None,
msgspec.Meta(
ge=-1)]] = None,
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
guided_decoding: Optional[GuidedDecodingParams] = None,
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
@ -411,9 +413,11 @@ class SamplingParams(
raise ValueError(f"prompt_logprobs must be non-negative, got "
f"{self.prompt_logprobs}.")
if (self.truncate_prompt_tokens is not None
and self.truncate_prompt_tokens < 1):
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
f"got {self.truncate_prompt_tokens}")
and (self.truncate_prompt_tokens == 0
or self.truncate_prompt_tokens < -1)):
raise ValueError(
f"truncate_prompt_tokens must be an integer >= 1 or -1, "
f"got {self.truncate_prompt_tokens}")
assert isinstance(self.stop_token_ids, list)
if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
raise ValueError(f"stop_token_ids must contain only integers, "

View File

@ -23,6 +23,7 @@ class TokenizerGroup:
self.tokenizer_config = tokenizer_config
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.truncation_side = tokenizer_config.get("truncation_side", "left")
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
max_loras = tokenizer_config.get("max_loras", 0)
self.lora_tokenizers = LRUCache[int, AnyTokenizer](

View File

@ -1328,6 +1328,12 @@ def as_list(maybe_list: Iterable[T]) -> list[T]:
return maybe_list if isinstance(maybe_list, list) else list(maybe_list)
def as_iter(obj: Union[T, Iterable[T]]) -> Iterable[T]:
if isinstance(obj, str) or not isinstance(obj, Iterable):
obj = [obj]
return obj
# `collections` helpers
def is_list_of(
value: object,