Add filtering for chat template kwargs (#25794)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Russell Bryant
2025-09-27 06:46:49 -04:00
committed by GitHub
parent 3f5d902d2a
commit 7977e5027c
5 changed files with 158 additions and 6 deletions

View File

@ -19,6 +19,7 @@ from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
parse_chat_messages,
parse_chat_messages_futures,
resolve_chat_template_content_format,
resolve_chat_template_kwargs,
resolve_hf_chat_template)
from vllm.multimodal import MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
@ -37,6 +38,7 @@ QWEN2AUDIO_MODEL_ID = "Qwen/Qwen2-Audio-7B-Instruct"
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
QWEN25OMNI_MODEL_ID = "Qwen/Qwen2.5-Omni-7B"
QWEN3_MODEL_ID = "Qwen/Qwen3-8B"
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
MISTRAL_MODEL_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
@ -2255,6 +2257,89 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
assert isinstance(chat_template, str)
@pytest.mark.parametrize(
"model, expected_kwargs",
[
(
QWEN2VL_MODEL_ID,
{
"add_vision_id", "add_generation_prompt",
"continue_final_message", "tools"
},
),
(
QWEN3_MODEL_ID,
{
"enable_thinking", "add_generation_prompt",
"continue_final_message", "tools"
},
),
],
)
def test_resolve_hf_chat_template_kwargs(sample_json_schema, model,
expected_kwargs):
"""checks that chat_template is a dict type for HF models."""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
tools = ([{
"type": "function",
"function": {
"name": "dummy_function_name",
"description": "This is a dummy function",
"parameters": sample_json_schema,
},
}])
chat_template_kwargs = {
# both unused
"unsed_kwargs_1": 123,
"unsed_kwargs_2": "abc",
# should not appear
"chat_template": "{% Hello world! %}",
# used by tokenizer
"continue_final_message": True,
"tools": tools,
# both used by Qwen2-VL and Qwen3
"add_generation_prompt": True,
# only used by Qwen2-VL
"add_vision_id": True,
# only used by Qwen3
"enable_thinking": True,
}
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype)
# Build the tokenizer
tokenizer = get_tokenizer(
model,
trust_remote_code=model_config.trust_remote_code,
)
# Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template(
tokenizer,
chat_template=None,
tools=tools,
model_config=model_config,
)
resolved_chat_template_kwargs = resolve_chat_template_kwargs(
tokenizer,
chat_template=chat_template,
chat_template_kwargs=chat_template_kwargs,
)
assert set(resolved_chat_template_kwargs.keys()) == expected_kwargs
# NOTE: Qwen2-Audio default chat template is specially defined inside
# processor class instead of using `tokenizer_config.json`
# yapf: disable

View File

@ -11,7 +11,12 @@ from pathlib import Path
from typing import (Any, Callable, Generic, Literal, Optional, TypeVar, Union,
cast)
import jinja2
import jinja2.ext
import jinja2.meta
import jinja2.nodes
import jinja2.parser
import jinja2.sandbox
import transformers.utils.chat_template_utils as hf_chat_utils
# yapf conflicts with isort for this block
# yapf: disable
@ -50,7 +55,7 @@ from vllm.transformers_utils.chat_templates import (
# yapf: enable
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import random_uuid
from vllm.utils import random_uuid, supports_kw
logger = init_logger(__name__)
@ -1554,6 +1559,46 @@ def parse_chat_messages_futures(
return conversation, mm_tracker.all_mm_data(), mm_tracker.all_mm_uuids()
# adapted from https://github.com/huggingface/transformers/blob/v4.56.2/src/transformers/utils/chat_template_utils.py#L398-L412
# only preserve the parse function used to resolve chat template kwargs
class AssistantTracker(jinja2.ext.Extension):
tags = {"generation"}
def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
call = self.call_method("_generation_support")
call_block = jinja2.nodes.CallBlock(call, [], [], body)
return call_block.set_lineno(lineno)
def resolve_chat_template_kwargs(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
chat_template: str,
chat_template_kwargs: dict[str, Any],
) -> dict[str, Any]:
fn_kw = {
k for k in chat_template_kwargs
if supports_kw(tokenizer.apply_chat_template, k, allow_var_kwargs=False)
}
env = jinja2.sandbox.ImmutableSandboxedEnvironment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[AssistantTracker, jinja2.ext.loopcontrols],
)
parsed_content = env.parse(chat_template)
template_vars = jinja2.meta.find_undeclared_variables(parsed_content)
# We exclude chat_template from kwargs here, because
# chat template has been already resolved at this stage
unexpected_vars = {"chat_template"}
accept_vars = (fn_kw | template_vars) - unexpected_vars
return {
k: v for k, v in chat_template_kwargs.items() if k in accept_vars
}
def apply_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: list[ConversationMessage],
@ -1579,12 +1624,17 @@ def apply_hf_chat_template(
)
try:
resolved_kwargs = resolve_chat_template_kwargs(
tokenizer=tokenizer,
chat_template=hf_chat_template,
chat_template_kwargs=kwargs,
)
return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type]
tools=tools, # type: ignore[arg-type]
chat_template=hf_chat_template,
tokenize=tokenize,
**kwargs,
**resolved_kwargs,
)
# External library exceptions can sometimes occur despite the framework's

View File

@ -1716,6 +1716,7 @@ async def init_app_state(
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
exclude_tools_when_tool_choice_none=args.

View File

@ -103,9 +103,13 @@ class FrontendArgs:
chat_template_content_format: ChatTemplateContentFormatOption = "auto"
"""The format to render message content within a chat template.
* "string" will render the content as a string. Example: `"Hello World"`
* "openai" will render the content as a list of dictionaries, similar to OpenAI
schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
* "string" will render the content as a string. Example: `"Hello World"`
* "openai" will render the content as a list of dictionaries, similar to
OpenAI schema. Example: `[{"type": "text", "text": "Hello world!"}]`"""
trust_request_chat_template: bool = False
"""Whether to trust the chat template provided in the request. If False,
the server will always use the chat template specified by `--chat-template`
or the ones from tokenizer."""
response_role: str = "assistant"
"""The role name to return if `request.add_generation_prompt=true`."""
ssl_keyfile: Optional[str] = None

View File

@ -68,6 +68,7 @@ class OpenAIServingChat(OpenAIServing):
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
trust_request_chat_template: bool = False,
return_tokens_as_token_ids: bool = False,
reasoning_parser: str = "",
enable_auto_tools: bool = False,
@ -89,6 +90,7 @@ class OpenAIServingChat(OpenAIServing):
self.response_role = response_role
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
self.trust_request_chat_template = trust_request_chat_template
self.enable_log_outputs = enable_log_outputs
# set up tool use
@ -220,6 +222,16 @@ class OpenAIServingChat(OpenAIServing):
if not self.use_harmony:
# Common case.
request_chat_template = request.chat_template
chat_template_kwargs = request.chat_template_kwargs
if not self.trust_request_chat_template and (
request_chat_template is not None or
(chat_template_kwargs and
chat_template_kwargs.get("chat_template") is not None)):
return self.create_error_response(
"Chat template is passed with request, but "
"--trust-request-chat-template is not set. "
"Refused request with untrusted chat template.")
(
conversation,
request_prompts,
@ -228,7 +240,7 @@ class OpenAIServingChat(OpenAIServing):
request,
tokenizer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template=request_chat_template or self.chat_template,
chat_template_content_format=self.
chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,