mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Fix wrong truncate_prompt_tokens type hint (#22761)
Signed-off-by: Gabriel Marinho <gmarinho@ibm.com> Signed-off-by: Gabriel Marinho <104592062+gmarinho2@users.noreply.github.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
@ -51,7 +51,7 @@ from vllm.tasks import PoolingTask
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
get_cached_tokenizer)
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter, Device, is_list_of
|
||||
from vllm.utils import Counter, Device, as_iter, is_list_of
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -364,14 +364,6 @@ class LLM:
|
||||
# Use default sampling params.
|
||||
sampling_params = self.get_default_sampling_params()
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
truncate_prompt_tokens = None
|
||||
if isinstance(sampling_params, SamplingParams):
|
||||
truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
|
||||
|
||||
_validate_truncation_size(model_config.max_model_len,
|
||||
truncate_prompt_tokens, tokenization_kwargs)
|
||||
|
||||
# Add any modality specific loras to the corresponding prompts
|
||||
lora_request = self._get_modality_specific_lora_reqs(
|
||||
prompts, lora_request)
|
||||
@ -381,7 +373,6 @@ class LLM:
|
||||
params=sampling_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
@ -871,6 +862,8 @@ class LLM:
|
||||
If `False`, no progress bar is created.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
pooling_task: Override the pooling task to use.
|
||||
tokenization_kwargs: overrides tokenization_kwargs set in
|
||||
pooling_params
|
||||
|
||||
Returns:
|
||||
A list of `PoolingRequestOutput` objects containing the
|
||||
@ -916,24 +909,17 @@ class LLM:
|
||||
# Use default pooling params.
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
if isinstance(pooling_params, PoolingParams):
|
||||
pooling_params.verify(pooling_task, model_config)
|
||||
else:
|
||||
for pooling_param in pooling_params:
|
||||
pooling_param.verify(pooling_task, model_config)
|
||||
|
||||
if tokenization_kwargs is None:
|
||||
tokenization_kwargs = dict[str, Any]()
|
||||
_validate_truncation_size(model_config.max_model_len,
|
||||
truncate_prompt_tokens,
|
||||
tokenization_kwargs)
|
||||
for param in as_iter(pooling_params):
|
||||
param.verify(pooling_task, model_config)
|
||||
# for backwards compatibility
|
||||
if truncate_prompt_tokens is not None:
|
||||
param.truncate_prompt_tokens = truncate_prompt_tokens
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=prompts,
|
||||
params=pooling_params,
|
||||
use_tqdm=use_tqdm,
|
||||
lora_request=lora_request,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
@ -1385,7 +1371,6 @@ class LLM:
|
||||
*,
|
||||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||||
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None,
|
||||
priority: Optional[list[int]] = None,
|
||||
) -> None:
|
||||
if isinstance(prompts, (str, dict)):
|
||||
@ -1412,7 +1397,17 @@ class LLM:
|
||||
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
|
||||
it = tqdm_func(it, desc="Adding requests")
|
||||
|
||||
model_config = self.llm_engine.model_config
|
||||
|
||||
for i, prompt in enumerate(it):
|
||||
|
||||
param = params[i] if isinstance(params, Sequence) else params
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
_validate_truncation_size(model_config.max_model_len,
|
||||
param.truncate_prompt_tokens,
|
||||
tokenization_kwargs)
|
||||
|
||||
self._add_request(
|
||||
prompt,
|
||||
params[i] if isinstance(params, Sequence) else params,
|
||||
|
@ -452,7 +452,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
min_tokens: int = 0
|
||||
skip_special_tokens: bool = True
|
||||
spaces_between_special_tokens: bool = True
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
|
||||
prompt_logprobs: Optional[int] = None
|
||||
allowed_token_ids: Optional[list[int]] = None
|
||||
bad_words: list[str] = Field(default_factory=list)
|
||||
@ -995,7 +995,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
min_tokens: int = 0
|
||||
skip_special_tokens: bool = True
|
||||
spaces_between_special_tokens: bool = True
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None
|
||||
allowed_token_ids: Optional[list[int]] = None
|
||||
prompt_logprobs: Optional[int] = None
|
||||
# --8<-- [end:completion-sampling-params]
|
||||
@ -1325,8 +1325,10 @@ class EmbeddingCompletionRequest(OpenAIBaseModel):
|
||||
# --8<-- [end:embedding-extra-params]
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(dimensions=self.dimensions,
|
||||
normalize=self.normalize)
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
dimensions=self.dimensions,
|
||||
normalize=self.normalize)
|
||||
|
||||
|
||||
class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
@ -1393,8 +1395,10 @@ class EmbeddingChatRequest(OpenAIBaseModel):
|
||||
return data
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(dimensions=self.dimensions,
|
||||
normalize=self.normalize)
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
dimensions=self.dimensions,
|
||||
normalize=self.normalize)
|
||||
|
||||
|
||||
EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
|
||||
@ -1430,7 +1434,9 @@ class ScoreRequest(OpenAIBaseModel):
|
||||
# --8<-- [end:score-extra-params]
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(activation=self.activation)
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
activation=self.activation)
|
||||
|
||||
|
||||
class RerankRequest(OpenAIBaseModel):
|
||||
@ -1460,7 +1466,9 @@ class RerankRequest(OpenAIBaseModel):
|
||||
# --8<-- [end:rerank-extra-params]
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(activation=self.activation)
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
activation=self.activation)
|
||||
|
||||
|
||||
class RerankDocument(BaseModel):
|
||||
@ -1618,7 +1626,9 @@ class ClassificationRequest(OpenAIBaseModel):
|
||||
# --8<-- [end:classification-extra-params]
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(activation=self.activation)
|
||||
return PoolingParams(
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
activation=self.activation)
|
||||
|
||||
|
||||
class ClassificationData(OpenAIBaseModel):
|
||||
|
@ -237,7 +237,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
documents=request.documents,
|
||||
chat_template_kwargs=request.chat_template_kwargs,
|
||||
tool_parser=tool_parser,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
|
@ -61,7 +61,6 @@ class ClassificationMixin(OpenAIServing):
|
||||
ctx.request,
|
||||
ctx.tokenizer,
|
||||
ctx.request.input,
|
||||
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
return None
|
||||
@ -157,18 +156,6 @@ class ServingClassification(ClassificationMixin):
|
||||
|
||||
return await super().handle(ctx) # type: ignore
|
||||
|
||||
@override
|
||||
def _validate_request(
|
||||
self,
|
||||
ctx: ClassificationServeContext,
|
||||
) -> Optional[ErrorResponse]:
|
||||
if error := super()._validate_request(ctx):
|
||||
return error
|
||||
|
||||
ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens
|
||||
|
||||
return None
|
||||
|
||||
@override
|
||||
def _create_pooling_params(
|
||||
self,
|
||||
|
@ -137,7 +137,6 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
request,
|
||||
tokenizer,
|
||||
request.prompt,
|
||||
truncate_prompt_tokens=request.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
except ValueError as e:
|
||||
|
@ -97,7 +97,6 @@ class EmbeddingMixin(OpenAIServing):
|
||||
# so there is no need to append extra tokens to the input
|
||||
add_generation_prompt=False,
|
||||
continue_final_message=False,
|
||||
truncate_prompt_tokens=ctx.truncate_prompt_tokens,
|
||||
add_special_tokens=ctx.request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
@ -106,7 +105,6 @@ class EmbeddingMixin(OpenAIServing):
|
||||
ctx.request,
|
||||
tokenizer,
|
||||
ctx.request.input,
|
||||
truncate_prompt_tokens=ctx.truncate_prompt_tokens,
|
||||
add_special_tokens=ctx.request.add_special_tokens,
|
||||
)
|
||||
return None
|
||||
@ -631,18 +629,6 @@ class OpenAIServingEmbedding(EmbeddingMixin):
|
||||
|
||||
return await super().handle(ctx) # type: ignore
|
||||
|
||||
@override
|
||||
def _validate_request(
|
||||
self,
|
||||
ctx: ServeContext[EmbeddingRequest],
|
||||
) -> Optional[ErrorResponse]:
|
||||
if error := super()._validate_request(ctx):
|
||||
return error
|
||||
|
||||
ctx.truncate_prompt_tokens = ctx.request.truncate_prompt_tokens
|
||||
|
||||
return None
|
||||
|
||||
@override
|
||||
def _create_pooling_params(
|
||||
self,
|
||||
|
@ -165,7 +165,6 @@ class ServeContext(RequestProcessingMixin, ResponseGenerationMixin, BaseModel,
|
||||
|
||||
# Shared across most requests
|
||||
tokenizer: Optional[AnyTokenizer] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
|
||||
# `protected_namespaces` resolves Pydantic v2's warning
|
||||
# on conflict with protected namespace "model_"
|
||||
@ -297,14 +296,12 @@ class OpenAIServing:
|
||||
truncate_prompt_tokens = getattr(ctx.request, "truncate_prompt_tokens",
|
||||
None)
|
||||
|
||||
if truncate_prompt_tokens is not None:
|
||||
if truncate_prompt_tokens <= self.max_model_len:
|
||||
ctx.truncate_prompt_tokens = truncate_prompt_tokens
|
||||
else:
|
||||
return self.create_error_response(
|
||||
"truncate_prompt_tokens value is "
|
||||
"greater than max_model_len."
|
||||
" Please, select a smaller truncation size.")
|
||||
if truncate_prompt_tokens is not None and \
|
||||
truncate_prompt_tokens > self.max_model_len:
|
||||
return self.create_error_response(
|
||||
"truncate_prompt_tokens value is "
|
||||
"greater than max_model_len."
|
||||
" Please, select a smaller truncation size.")
|
||||
return None
|
||||
|
||||
def _create_pooling_params(
|
||||
@ -528,7 +525,6 @@ class OpenAIServing:
|
||||
request: AnyRequest,
|
||||
prompt: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]],
|
||||
add_special_tokens: bool,
|
||||
) -> TextTokensPrompt:
|
||||
async_tokenizer = self._get_async_tokenizer(tokenizer)
|
||||
@ -538,6 +534,9 @@ class OpenAIServing:
|
||||
"do_lower_case", False)):
|
||||
prompt = prompt.lower()
|
||||
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||
None)
|
||||
|
||||
if truncate_prompt_tokens is None:
|
||||
encoded = await async_tokenizer(
|
||||
prompt, add_special_tokens=add_special_tokens)
|
||||
@ -565,8 +564,10 @@ class OpenAIServing:
|
||||
request: AnyRequest,
|
||||
prompt_ids: list[int],
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
|
||||
) -> TextTokensPrompt:
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||
None)
|
||||
|
||||
if truncate_prompt_tokens is None:
|
||||
input_ids = prompt_ids
|
||||
elif truncate_prompt_tokens < 0:
|
||||
@ -652,7 +653,6 @@ class OpenAIServing:
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_input: Union[str, list[int]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> TextTokensPrompt:
|
||||
"""
|
||||
@ -664,7 +664,6 @@ class OpenAIServing:
|
||||
request,
|
||||
tokenizer,
|
||||
[prompt_input],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
):
|
||||
return result
|
||||
@ -675,7 +674,6 @@ class OpenAIServing:
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_inputs: Iterable[Union[str, list[int]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> AsyncGenerator[TextTokensPrompt, None]:
|
||||
"""
|
||||
@ -689,7 +687,6 @@ class OpenAIServing:
|
||||
request,
|
||||
prompt=prompt,
|
||||
tokenizer=tokenizer,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
@ -697,7 +694,6 @@ class OpenAIServing:
|
||||
request,
|
||||
prompt_ids=prompt,
|
||||
tokenizer=tokenizer,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
async def _tokenize_prompt_input_or_inputs_async(
|
||||
@ -706,7 +702,6 @@ class OpenAIServing:
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> tuple[list[TextTokensPrompt], list[EmbedsPrompt]]:
|
||||
"""
|
||||
@ -719,6 +714,12 @@ class OpenAIServing:
|
||||
inputs_embeds = list[EmbedsPrompt]()
|
||||
inputs_text = list[TextTokensPrompt]()
|
||||
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||
None)
|
||||
|
||||
if (truncate_prompt_tokens or 0) < 0:
|
||||
truncate_prompt_tokens = self.max_model_len
|
||||
|
||||
if (isinstance(request, CompletionRequest)
|
||||
and request.prompt_embeds is not None):
|
||||
inputs_embeds.extend(
|
||||
@ -748,14 +749,10 @@ class OpenAIServing:
|
||||
request,
|
||||
prompt_input["content"],
|
||||
tokenizer=tokenizer,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens)
|
||||
else:
|
||||
task = self._normalize_prompt_tokens_to_input(
|
||||
request,
|
||||
prompt_input["content"],
|
||||
tokenizer=tokenizer,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens)
|
||||
request, prompt_input["content"], tokenizer=tokenizer)
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for all tokenization tasks to complete
|
||||
@ -772,7 +769,6 @@ class OpenAIServing:
|
||||
TokenizeCompletionRequest],
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
input_or_inputs: Union[str, list[str], list[int], list[list[int]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
|
||||
add_special_tokens: bool = ...,
|
||||
) -> tuple[list[TextTokensPrompt], list[EngineTokensPrompt]]:
|
||||
...
|
||||
@ -784,7 +780,6 @@ class OpenAIServing:
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = ...,
|
||||
add_special_tokens: bool = ...,
|
||||
) -> tuple[list[Union[TextTokensPrompt, EmbedsPrompt]], list[Union[
|
||||
EngineTokensPrompt, EngineEmbedsPrompt]]]:
|
||||
@ -796,7 +791,6 @@ class OpenAIServing:
|
||||
tokenizer: Optional[AnyTokenizer],
|
||||
input_or_inputs: Optional[Union[str, list[str], list[int],
|
||||
list[list[int]]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=-1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> tuple[Union[list[TextTokensPrompt], list[Union[
|
||||
TextTokensPrompt, EmbedsPrompt]]], Union[
|
||||
@ -813,7 +807,6 @@ class OpenAIServing:
|
||||
request,
|
||||
tokenizer,
|
||||
input_or_inputs,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
|
||||
@ -866,7 +859,6 @@ class OpenAIServing:
|
||||
documents: Optional[list[dict[str, str]]] = None,
|
||||
chat_template_kwargs: Optional[dict[str, Any]] = None,
|
||||
tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
add_special_tokens: bool = False,
|
||||
) -> tuple[list[ConversationMessage], Sequence[RequestPrompt],
|
||||
list[EngineTokensPrompt]]:
|
||||
@ -941,7 +933,6 @@ class OpenAIServing:
|
||||
request,
|
||||
tokenizer,
|
||||
request_prompt,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
|
@ -120,7 +120,6 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
# so there is no need to append extra tokens to the input
|
||||
add_generation_prompt=False,
|
||||
continue_final_message=False,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
@ -129,7 +128,6 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
request,
|
||||
tokenizer,
|
||||
request.input,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
|
@ -266,12 +266,14 @@ class ServingScores(OpenAIServing):
|
||||
request: Union[ScoreRequest, RerankRequest],
|
||||
request_id: str,
|
||||
raw_request: Optional[Request] = None,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
) -> Union[list[PoolingRequestOutput], ErrorResponse]:
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request)
|
||||
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||
None)
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
_validate_truncation_size(self.max_model_len, truncate_prompt_tokens,
|
||||
tokenization_kwargs)
|
||||
@ -343,7 +345,6 @@ class ServingScores(OpenAIServing):
|
||||
request,
|
||||
request_id,
|
||||
raw_request,
|
||||
request.truncate_prompt_tokens,
|
||||
)
|
||||
if isinstance(final_res_batch, ErrorResponse):
|
||||
return final_res_batch
|
||||
@ -391,7 +392,6 @@ class ServingScores(OpenAIServing):
|
||||
request,
|
||||
request_id,
|
||||
raw_request,
|
||||
request.truncate_prompt_tokens,
|
||||
)
|
||||
if isinstance(final_res_batch, ErrorResponse):
|
||||
return final_res_batch
|
||||
|
@ -346,6 +346,22 @@ class InputPreprocessor:
|
||||
) -> EmbedsInputs:
|
||||
return self._process_embeds(parsed_content)
|
||||
|
||||
def _truncate_inputs(
|
||||
self,
|
||||
inputs: list[int],
|
||||
tokenization_kwargs: Optional[dict[str, Any]] = None) -> list[int]:
|
||||
|
||||
if not tokenization_kwargs or "truncation" not in \
|
||||
tokenization_kwargs or self.tokenizer is None:
|
||||
return inputs
|
||||
|
||||
max_length = tokenization_kwargs["max_length"]
|
||||
|
||||
if self.tokenizer.truncation_side == "left":
|
||||
return inputs[-max_length:]
|
||||
else:
|
||||
return inputs[:max_length]
|
||||
|
||||
def _process_tokens(
|
||||
self,
|
||||
parsed_content: TokensPrompt,
|
||||
@ -354,7 +370,8 @@ class InputPreprocessor:
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_token_ids = parsed_content["prompt_token_ids"]
|
||||
prompt_token_ids = self._truncate_inputs(
|
||||
parsed_content["prompt_token_ids"], tokenization_kwargs)
|
||||
|
||||
inputs: Union[TokenInputs, MultiModalInputs]
|
||||
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
||||
@ -382,7 +399,8 @@ class InputPreprocessor:
|
||||
*,
|
||||
mm_hash_overrides: Optional[dict[str, list[str]]] = None,
|
||||
) -> Union[TokenInputs, MultiModalInputs]:
|
||||
prompt_token_ids = parsed_content["prompt_token_ids"]
|
||||
prompt_token_ids = self._truncate_inputs(
|
||||
parsed_content["prompt_token_ids"], tokenization_kwargs)
|
||||
|
||||
inputs: Union[TokenInputs, MultiModalInputs]
|
||||
if multi_modal_data := parsed_content.get("multi_modal_data"):
|
||||
|
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Optional
|
||||
|
||||
import msgspec
|
||||
|
||||
@ -27,6 +27,11 @@ class PoolingParams(
|
||||
the classification outputs.
|
||||
softmax: Whether to apply softmax to the reward outputs.
|
||||
"""
|
||||
truncate_prompt_tokens: Optional[Annotated[int,
|
||||
msgspec.Meta(ge=-1)]] = None
|
||||
"""If set to -1, will use the truncation size supported by the model. If
|
||||
set to an integer k, will use only the last k tokens from the prompt
|
||||
(i.e., left truncation). If set to `None`, truncation is disabled."""
|
||||
|
||||
## for embeddings models
|
||||
dimensions: Optional[int] = None
|
||||
|
@ -182,7 +182,8 @@ class SamplingParams(
|
||||
optionally prompt tokens as a first argument."""
|
||||
include_stop_str_in_output: bool = False
|
||||
"""Whether to include the stop strings in output text."""
|
||||
truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None
|
||||
truncate_prompt_tokens: Optional[Annotated[int,
|
||||
msgspec.Meta(ge=-1)]] = None
|
||||
"""If set to -1, will use the truncation size supported by the model. If
|
||||
set to an integer k, will use only the last k tokens from the prompt
|
||||
(i.e., left truncation). If set to `None`, truncation is disabled."""
|
||||
@ -241,7 +242,8 @@ class SamplingParams(
|
||||
spaces_between_special_tokens: bool = True,
|
||||
logits_processors: Optional[list[LogitsProcessor]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int,
|
||||
msgspec.Meta(ge=1)]] = None,
|
||||
msgspec.Meta(
|
||||
ge=-1)]] = None,
|
||||
output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE,
|
||||
guided_decoding: Optional[GuidedDecodingParams] = None,
|
||||
logit_bias: Optional[Union[dict[int, float], dict[str, float]]] = None,
|
||||
@ -411,9 +413,11 @@ class SamplingParams(
|
||||
raise ValueError(f"prompt_logprobs must be non-negative, got "
|
||||
f"{self.prompt_logprobs}.")
|
||||
if (self.truncate_prompt_tokens is not None
|
||||
and self.truncate_prompt_tokens < 1):
|
||||
raise ValueError(f"truncate_prompt_tokens must be >= 1, "
|
||||
f"got {self.truncate_prompt_tokens}")
|
||||
and (self.truncate_prompt_tokens == 0
|
||||
or self.truncate_prompt_tokens < -1)):
|
||||
raise ValueError(
|
||||
f"truncate_prompt_tokens must be an integer >= 1 or -1, "
|
||||
f"got {self.truncate_prompt_tokens}")
|
||||
assert isinstance(self.stop_token_ids, list)
|
||||
if not all(isinstance(st_id, int) for st_id in self.stop_token_ids):
|
||||
raise ValueError(f"stop_token_ids must contain only integers, "
|
||||
|
@ -23,6 +23,7 @@ class TokenizerGroup:
|
||||
self.tokenizer_config = tokenizer_config
|
||||
self.enable_lora = enable_lora
|
||||
self.max_input_length = max_input_length
|
||||
self.truncation_side = tokenizer_config.get("truncation_side", "left")
|
||||
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
|
||||
max_loras = tokenizer_config.get("max_loras", 0)
|
||||
self.lora_tokenizers = LRUCache[int, AnyTokenizer](
|
||||
|
@ -1328,6 +1328,12 @@ def as_list(maybe_list: Iterable[T]) -> list[T]:
|
||||
return maybe_list if isinstance(maybe_list, list) else list(maybe_list)
|
||||
|
||||
|
||||
def as_iter(obj: Union[T, Iterable[T]]) -> Iterable[T]:
|
||||
if isinstance(obj, str) or not isinstance(obj, Iterable):
|
||||
obj = [obj]
|
||||
return obj
|
||||
|
||||
|
||||
# `collections` helpers
|
||||
def is_list_of(
|
||||
value: object,
|
||||
|
Reference in New Issue
Block a user