[Frontend] Automatic detection of chat content format from AST (#9919)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-11-16 13:35:40 +08:00
committed by GitHub
parent 4f168f69a3
commit 32e46e000f
16 changed files with 788 additions and 350 deletions

View File

@ -172,12 +172,20 @@ completion = client.chat.completions.create(
]
)
```
Most chat templates for LLMs expect the `content` to be a `string` but there are some newer models like
`meta-llama/Llama-Guard-3-1B` that expect the content to be parsed with the new OpenAI spec. In order to choose which
format the content needs to be parsed in by vLLM, please use the `--chat-template-text-format` argument to specify
between `string` or `openai`. The default value is `string` and vLLM internally converts both spec formats to match
this, unless explicitly specified.
Most chat templates for LLMs expect the `content` field to be a string, but there are some newer models like
`meta-llama/Llama-Guard-3-1B` that expect the content to be formatted according to the OpenAI schema in the
request. vLLM provides best-effort support to detect this automatically, which is logged as a string like
*"Detected the chat template content format to be..."*, and internally converts incoming requests to match
the detected format, which can be one of:
- `"string"`: A string.
- Example: `"Hello world"`
- `"openai"`: A list of dictionaries, similar to OpenAI schema.
- Example: `[{"type": "text", "text": "Hello world!"}]`
If the result is not what you expect, you can set the `--chat-template-content-format` CLI argument
to override which format to use.
## Command line arguments for the server

View File

@ -26,7 +26,6 @@ class MockModelConfig:
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
chat_template_text_format = "string"
max_model_len = 100
tokenizer_revision = None
multimodal_config = MultiModalConfig()
@ -49,6 +48,7 @@ async def _async_serving_chat_init():
BASE_MODEL_PATHS,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
lora_modules=None,
prompt_adapters=None,
request_logger=None)
@ -70,6 +70,7 @@ def test_serving_chat_should_set_correct_max_tokens():
BASE_MODEL_PATHS,
response_role="assistant",
chat_template=CHAT_TEMPLATE,
chat_template_content_format="auto",
lora_modules=None,
prompt_adapters=None,
request_logger=None)

View File

@ -6,15 +6,24 @@ from PIL import Image
from vllm.assets.image import ImageAsset
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (parse_chat_messages,
parse_chat_messages_futures)
from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
parse_chat_messages,
parse_chat_messages_futures,
resolve_chat_template_content_format)
from vllm.entrypoints.llm import apply_hf_chat_template
from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import encode_image_base64
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from ..utils import VLLM_PATH
EXAMPLES_DIR = VLLM_PATH / "examples"
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_3"
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
@pytest.fixture(scope="function")
@ -26,7 +35,6 @@ def phi3v_model_config():
trust_remote_code=True,
dtype="bfloat16",
seed=0,
chat_template_text_format="string",
limit_mm_per_prompt={
"image": 2,
})
@ -94,19 +102,24 @@ def test_parse_chat_messages_single_image(
phi3v_tokenizer,
image_url,
):
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What's in the image?"
}]
}], phi3v_model_config, phi3v_tokenizer)
conversation, mm_data = parse_chat_messages(
[{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What's in the image?"
}]
}],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
assert conversation == [{
"role": "user",
@ -121,19 +134,24 @@ async def test_parse_chat_messages_single_image_async(
phi3v_tokenizer,
image_url,
):
conversation, mm_future = parse_chat_messages_futures([{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What's in the image?"
}]
}], phi3v_model_config, phi3v_tokenizer)
conversation, mm_future = parse_chat_messages_futures(
[{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What's in the image?"
}]
}],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
assert conversation == [{
"role": "user",
@ -147,24 +165,29 @@ def test_parse_chat_messages_multiple_images(
phi3v_tokenizer,
image_url,
):
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What's in these images?"
}]
}], phi3v_model_config, phi3v_tokenizer)
conversation, mm_data = parse_chat_messages(
[{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What's in these images?"
}]
}],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
assert conversation == [{
"role":
@ -181,24 +204,29 @@ async def test_parse_chat_messages_multiple_images_async(
phi3v_tokenizer,
image_url,
):
conversation, mm_future = parse_chat_messages_futures([{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What's in these images?"
}]
}], phi3v_model_config, phi3v_tokenizer)
conversation, mm_future = parse_chat_messages_futures(
[{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What's in these images?"
}]
}],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
assert conversation == [{
"role":
@ -214,27 +242,31 @@ def test_parse_chat_messages_placeholder_already_in_prompt(
phi3v_tokenizer,
image_url,
):
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type":
"text",
"text":
"What's in <|image_1|> and how does it compare to <|image_2|>?"
}]
}], phi3v_model_config, phi3v_tokenizer)
conversation, mm_data = parse_chat_messages(
[{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type":
"text",
"text":
"What's in <|image_1|> and how does it compare to <|image_2|>?"
}]
}],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
assert conversation == [{
"role":
"user",
@ -249,26 +281,35 @@ def test_parse_chat_messages_placeholder_one_already_in_prompt(
phi3v_tokenizer,
image_url,
):
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type":
"text",
"text":
"What's in <|image_1|> and how does it compare to the other one?"
}]
}], phi3v_model_config, phi3v_tokenizer)
conversation, mm_data = parse_chat_messages(
[{
"role":
"user",
"content": [
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type":
"text",
"text":
"What's in <|image_1|> and how does it compare to the other one?" # noqa: E501
}
]
}],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
assert conversation == [{
"role":
@ -285,34 +326,39 @@ def test_parse_chat_messages_multiple_images_across_messages(
phi3v_tokenizer,
image_url,
):
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
conversation, mm_data = parse_chat_messages(
[{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What's in this image?"
}]
}, {
"type": "text",
"text": "What's in this image?"
}]
}, {
"role": "assistant",
"content": "Some stuff."
}, {
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
"role": "assistant",
"content": "Some stuff."
}, {
"type": "text",
"text": "What about this one?"
}]
}], phi3v_model_config, phi3v_tokenizer)
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What about this one?"
}]
}],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
assert conversation == [
{
@ -335,7 +381,6 @@ def test_parse_chat_messages_context_text_format(
phi3v_model_config,
phi3v_tokenizer,
):
phi3v_model_config.chat_template_text_format = "openai"
conversation, mm_data = parse_chat_messages(
[{
"role": "user",
@ -349,7 +394,11 @@ def test_parse_chat_messages_context_text_format(
}, {
"role": "user",
"content": "What about this one?"
}], phi3v_model_config, phi3v_tokenizer)
}],
phi3v_model_config,
phi3v_tokenizer,
content_format="openai",
)
assert conversation == [
{
@ -389,29 +438,34 @@ def test_parse_chat_messages_rejects_too_many_images_in_one_message(
ValueError,
match="At most 2 image\\(s\\) may be provided in one request\\."
):
parse_chat_messages([{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What's in these images?"
}]
}], phi3v_model_config, phi3v_tokenizer)
parse_chat_messages(
[{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What's in these images?"
}]
}],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
def test_parse_chat_messages_rejects_too_many_images_across_messages(
@ -427,39 +481,44 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
ValueError,
match="At most 2 image\\(s\\) may be provided in one request\\."
):
parse_chat_messages([{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
parse_chat_messages(
[{
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What's in this image?"
}]
}, {
"type": "text",
"text": "What's in this image?"
}]
}, {
"role": "assistant",
"content": "Some stuff."
}, {
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
"role": "assistant",
"content": "Some stuff."
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What about these two?"
}]
}], phi3v_model_config, phi3v_tokenizer)
"role":
"user",
"content": [{
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}, {
"type": "text",
"text": "What about these two?"
}]
}],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
def test_parse_chat_messages_multiple_images_uncommon_input(
@ -467,17 +526,22 @@ def test_parse_chat_messages_multiple_images_uncommon_input(
phi3v_tokenizer,
image_url,
):
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [
"What's in these images?", {
"image_url": image_url
}, {
"image_url": image_url
}
]
}], phi3v_model_config, phi3v_tokenizer)
conversation, mm_data = parse_chat_messages(
[{
"role":
"user",
"content": [
"What's in these images?", {
"image_url": image_url
}, {
"image_url": image_url
}
]
}],
phi3v_model_config,
phi3v_tokenizer,
content_format="string",
)
assert conversation == [{
"role":
@ -495,16 +559,21 @@ def test_mllama_single_image(
image_url,
):
"""Ensures that a single image is parsed correctly mllama."""
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [{
'type': 'text',
'text': 'The content of this image is:'
}, {
"image_url": image_url
}]
}], mllama_model_config, mllama_tokenizer)
conversation, mm_data = parse_chat_messages(
[{
"role":
"user",
"content": [{
'type': 'text',
'text': 'The content of this image is:'
}, {
"image_url": image_url
}]
}],
mllama_model_config,
mllama_tokenizer,
content_format="openai",
)
_assert_mm_data_is_image_input(mm_data, 1)
assert conversation == [{
'role':
@ -524,26 +593,31 @@ def test_mllama_interleaved_images(
image_url,
):
"""Ensures that multiple image are parsed as interleaved dicts."""
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [
{
'type': 'text',
'text': 'The content of the first image is:'
},
{
"image_url": image_url
},
{
'type': 'text',
'text': 'The content of the second image is:'
},
{
"image_url": image_url
},
]
}], mllama_model_config, mllama_tokenizer)
conversation, mm_data = parse_chat_messages(
[{
"role":
"user",
"content": [
{
'type': 'text',
'text': 'The content of the first image is:'
},
{
"image_url": image_url
},
{
'type': 'text',
'text': 'The content of the second image is:'
},
{
"image_url": image_url
},
]
}],
mllama_model_config,
mllama_tokenizer,
content_format="openai",
)
_assert_mm_data_is_image_input(mm_data, 2)
assert conversation == [{
'role':
@ -626,6 +700,7 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
vllm_conversation,
model_config,
tokenizer_group,
content_format="openai",
)
vllm_result = apply_hf_chat_template(
@ -636,3 +711,89 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
)
assert hf_result == vllm_result
# yapf: disable
@pytest.mark.parametrize(
("model", "expected_format"),
[(PHI3V_MODEL_ID, "string"),
(QWEN2VL_MODEL_ID, "openai"),
(ULTRAVOX_MODEL_ID, "string"),
(MLLAMA_MODEL_ID, "openai"),
(LLAMA_GUARD_MODEL_ID, "openai")],
)
# yapf: enable
def test_resolve_content_format_hf_defined(model, expected_format):
tokenizer_group = TokenizerGroup(
model,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
)
tokenizer = tokenizer_group.tokenizer
chat_template = tokenizer.chat_template
assert isinstance(chat_template, str)
print("[TEXT]")
print(chat_template)
print("[AST]")
print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format(
None, # Test detecting the tokenizer's chat_template
"auto",
tokenizer,
)
assert resolved_format == expected_format
# yapf: disable
@pytest.mark.parametrize(
("template_path", "expected_format"),
[("template_alpaca.jinja", "string"),
("template_baichuan.jinja", "string"),
("template_blip2.jinja", "string"),
("template_chatglm.jinja", "string"),
("template_chatglm2.jinja", "string"),
("template_chatml.jinja", "string"),
("template_falcon_180b.jinja", "string"),
("template_falcon.jinja", "string"),
("template_inkbot.jinja", "string"),
("template_llava.jinja", "string"),
("template_vlm2vec.jinja", "openai"),
("tool_chat_template_granite_20b_fc.jinja", "string"),
("tool_chat_template_hermes.jinja", "string"),
("tool_chat_template_internlm2_tool.jinja", "string"),
("tool_chat_template_llama3.1_json.jinja", "string"),
("tool_chat_template_llama3.2_json.jinja", "string"),
("tool_chat_template_mistral_parallel.jinja", "string"),
("tool_chat_template_mistral.jinja", "string")],
)
# yapf: enable
def test_resolve_content_format_examples(template_path, expected_format):
tokenizer_group = TokenizerGroup(
PHI3V_MODEL_ID,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
)
dummy_tokenizer = tokenizer_group.tokenizer
dummy_tokenizer.chat_template = None
chat_template = load_chat_template(EXAMPLES_DIR / template_path)
assert isinstance(chat_template, str)
print("[TEXT]")
print(chat_template)
print("[AST]")
print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format(
chat_template,
"auto",
dummy_tokenizer,
)
assert resolved_format == expected_format

View File

@ -155,7 +155,6 @@ class ModelConfig:
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True,
config_format: ConfigFormat = ConfigFormat.AUTO,
chat_template_text_format: str = "string",
hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
override_neuron_config: Optional[Dict[str, Any]] = None,
@ -216,7 +215,6 @@ class ModelConfig:
self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.use_async_output_proc = use_async_output_proc
self.chat_template_text_format = chat_template_text_format
self.mm_processor_kwargs = mm_processor_kwargs
# Set enforce_eager to False if the value is unset.

View File

@ -90,7 +90,6 @@ class EngineArgs:
task: TaskOption = "auto"
skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto'
chat_template_text_format: str = 'string'
trust_remote_code: bool = False
allowed_local_media_path: str = ""
download_dir: Optional[str] = None
@ -258,14 +257,6 @@ class EngineArgs:
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer.')
parser.add_argument(
'--chat-template-text-format',
type=str,
default=EngineArgs.chat_template_text_format,
choices=['string', 'openai'],
help='The format to render text content within a chat template. '
'"string" will keep the content field as a string whereas '
'"openai" will parse content in the current OpenAI format.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='Trust remote code from huggingface.')
@ -894,7 +885,6 @@ class EngineArgs:
# We know this is not None because we set it in __post_init__
tokenizer=cast(str, self.tokenizer),
tokenizer_mode=self.tokenizer_mode,
chat_template_text_format=self.chat_template_text_format,
trust_remote_code=self.trust_remote_code,
allowed_local_media_path=self.allowed_local_media_path,
dtype=self.dtype,

View File

@ -262,8 +262,7 @@ class LLMEngine:
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, "
"chat_template_text_format=%s, mm_processor_kwargs=%s, "
"pooler_config=%r)",
"mm_processor_kwargs=%s, pooler_config=%r)",
VLLM_VERSION,
model_config.model,
speculative_config,
@ -296,7 +295,6 @@ class LLMEngine:
cache_config.enable_prefix_caching,
model_config.use_async_output_proc,
use_cached_outputs,
model_config.chat_template_text_format,
model_config.mm_processor_kwargs,
model_config.pooler_config,
)

View File

@ -2,12 +2,14 @@ import asyncio
import codecs
import json
from abc import ABC, abstractmethod
from collections import defaultdict
from collections import defaultdict, deque
from functools import lru_cache, partial
from pathlib import Path
from typing import (Any, Awaitable, Callable, Dict, Generic, Iterable, List,
Literal, Mapping, Optional, Tuple, TypeVar, Union, cast)
import jinja2.nodes
import transformers.utils.chat_template_utils as hf_chat_utils
# yapf conflicts with isort for this block
# yapf: disable
from openai.types.chat import (ChatCompletionAssistantMessageParam,
@ -153,6 +155,199 @@ class ConversationMessage(TypedDict, total=False):
"""The tool calls generated by the model, such as function calls."""
# Passed in by user
ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]
# Used internally
_ChatTemplateContentFormat = Literal["string", "openai"]
def _is_var_access(node: jinja2.nodes.Node, varname: str) -> bool:
if isinstance(node, jinja2.nodes.Name):
return node.ctx == "load" and node.name == varname
return False
def _is_attr_access(node: jinja2.nodes.Node, varname: str, key: str) -> bool:
if isinstance(node, jinja2.nodes.Getitem):
return (_is_var_access(node.node, varname)
and isinstance(node.arg, jinja2.nodes.Const)
and node.arg.value == key)
if isinstance(node, jinja2.nodes.Getattr):
return _is_var_access(node.node, varname) and node.attr == key
return False
def _is_var_or_elems_access(
node: jinja2.nodes.Node,
varname: str,
key: Optional[str] = None,
) -> bool:
if isinstance(node, jinja2.nodes.Filter):
return (node.node is not None
and _is_var_or_elems_access(node.node, varname, key))
if isinstance(node, jinja2.nodes.Test):
return _is_var_or_elems_access(node.node, varname, key)
if (isinstance(node, jinja2.nodes.Getitem)
and isinstance(node.arg, jinja2.nodes.Slice)):
return _is_var_or_elems_access(node.node, varname, key)
# yapf: disable
return (
_is_attr_access(node, varname, key) if key
else _is_var_access(node, varname)
) # yapf: enable
def _iter_nodes_assign_var_or_elems(root: jinja2.nodes.Node, varname: str):
# Global variable that is implicitly defined at the root
yield root, varname
# Iterative BFS
related_varnames = deque([varname])
while related_varnames:
related_varname = related_varnames.popleft()
for assign_ast in root.find_all(jinja2.nodes.Assign):
lhs = assign_ast.target
rhs = assign_ast.node
if _is_var_or_elems_access(rhs, related_varname):
assert isinstance(lhs, jinja2.nodes.Name)
yield assign_ast, lhs.name
# Avoid infinite looping for self-assignment
if lhs.name != related_varname:
related_varnames.append(lhs.name)
# NOTE: The proper way to handle this is to build a CFG so that we can handle
# the scope in which each variable is defined, but that is too complicated
def _iter_nodes_assign_messages_item(root: jinja2.nodes.Node):
messages_varnames = [
varname
for _, varname in _iter_nodes_assign_var_or_elems(root, "messages")
]
# Search for {%- for message in messages -%} loops
for loop_ast in root.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter
loop_target = loop_ast.target
for varname in messages_varnames:
if _is_var_or_elems_access(loop_iter, varname):
assert isinstance(loop_target, jinja2.nodes.Name)
yield loop_ast, loop_target.name
break
def _iter_nodes_assign_content_item(root: jinja2.nodes.Node):
message_varnames = [
varname for _, varname in _iter_nodes_assign_messages_item(root)
]
# Search for {%- for content in message['content'] -%} loops
for loop_ast in root.find_all(jinja2.nodes.For):
loop_iter = loop_ast.iter
loop_target = loop_ast.target
for varname in message_varnames:
if _is_var_or_elems_access(loop_iter, varname, "content"):
assert isinstance(loop_target, jinja2.nodes.Name)
yield loop_ast, loop_target.name
break
def _try_extract_ast(chat_template: str) -> Optional[jinja2.nodes.Template]:
try:
jinja_compiled = hf_chat_utils._compile_jinja_template(chat_template)
return jinja_compiled.environment.parse(chat_template)
except Exception:
logger.exception("Error when compiling Jinja template")
return None
def _detect_content_format(
chat_template: str,
*,
default: _ChatTemplateContentFormat,
) -> _ChatTemplateContentFormat:
jinja_ast = _try_extract_ast(chat_template)
if jinja_ast is None:
return default
try:
next(_iter_nodes_assign_content_item(jinja_ast))
except StopIteration:
return "string"
except Exception:
logger.exception("Error when parsing AST of Jinja template")
return default
else:
return "openai"
def _resolve_chat_template_content_format(
chat_template: Optional[str],
given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer,
) -> _ChatTemplateContentFormat:
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
tokenizer_chat_template = tokenizer.chat_template
else:
tokenizer_chat_template = None
jinja_text: Optional[str]
if isinstance(tokenizer_chat_template, str) and chat_template is None:
jinja_text = tokenizer_chat_template
elif (isinstance(tokenizer_chat_template, dict)
and chat_template in tokenizer_chat_template):
jinja_text = tokenizer_chat_template[chat_template]
else:
jinja_text = load_chat_template(chat_template, is_literal=True)
detected_format = ("string" if jinja_text is None else
_detect_content_format(jinja_text, default="string"))
return detected_format if given_format == "auto" else given_format
@lru_cache
def resolve_chat_template_content_format(
chat_template: Optional[str],
given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer,
) -> _ChatTemplateContentFormat:
detected_format = _resolve_chat_template_content_format(
chat_template,
given_format,
tokenizer,
)
logger.info(
"Detected the chat template content format to be '%s'. "
"You can set `--chat-template-content-format` to override this.",
detected_format,
)
if given_format != "auto" and given_format != detected_format:
logger.warning(
"You specified `--chat-template-content-format %s` "
"which is different from the detected format '%s'. "
"If our automatic detection is incorrect, please consider "
"opening a GitHub issue so that we can improve it: "
"https://github.com/vllm-project/vllm/issues/new/choose",
given_format,
detected_format,
)
return detected_format
ModalityStr = Literal["image", "audio", "video"]
_T = TypeVar("_T")
@ -407,12 +602,23 @@ def validate_chat_template(chat_template: Optional[Union[Path, str]]):
def load_chat_template(
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
chat_template: Optional[Union[Path, str]],
*,
is_literal: bool = False,
) -> Optional[str]:
if chat_template is None:
return None
if is_literal:
if isinstance(chat_template, Path):
raise TypeError("chat_template is expected to be read directly "
"from its value")
return codecs.decode(chat_template, "unicode_escape")
try:
with open(chat_template) as f:
resolved_chat_template = f.read()
return f.read()
except OSError as e:
if isinstance(chat_template, Path):
raise
@ -426,10 +632,7 @@ def load_chat_template(
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
resolved_chat_template = codecs.decode(chat_template, "unicode_escape")
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
return resolved_chat_template
return load_chat_template(chat_template, is_literal=True)
# TODO: Let user specify how to insert multimodal tokens into prompt
@ -464,7 +667,6 @@ _ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_AudioParser = partial(cast, ChatCompletionContentPartAudioParam)
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'}
# Define a mapping from part types to their corresponding parsing functions.
MM_PARSER_MAP: Dict[str, Callable[[ChatCompletionContentPartParam], str]] = {
@ -542,18 +744,12 @@ def _parse_chat_message_content_parts(
role: str,
parts: Iterable[ChatCompletionContentPartParam],
mm_tracker: BaseMultiModalItemTracker,
chat_template_text_format: str,
*,
wrap_dicts: bool,
) -> List[ConversationMessage]:
content: List[Union[str, Dict[str, str]]] = []
mm_parser = mm_tracker.create_parser()
model_config = mm_tracker.model_config
wrap_dicts = (chat_template_text_format == "openai"
or (model_config.task == "embedding"
and model_config.is_multimodal_model)
or (model_config.hf_config.model_type
in MODEL_KEEP_MULTI_MODAL_CONTENT))
for part in parts:
parse_res = _parse_chat_message_content_part(
@ -578,9 +774,11 @@ def _parse_chat_message_content_parts(
def _parse_chat_message_content_part(
part: ChatCompletionContentPartParam,
mm_parser: BaseMultiModalContentParser,
wrap_dicts: bool) -> Optional[Union[str, Dict[str, str]]]:
part: ChatCompletionContentPartParam,
mm_parser: BaseMultiModalContentParser,
*,
wrap_dicts: bool,
) -> Optional[Union[str, Dict[str, str]]]:
"""Parses a single part of a conversation. If wrap_dicts is True,
structured dictionary pieces for texts and images will be
wrapped in dictionaries, i.e., {"type": "text", "text", ...} and
@ -629,7 +827,7 @@ _ToolParser = partial(cast, ChatCompletionToolMessageParam)
def _parse_chat_message_content(
message: ChatCompletionMessageParam,
mm_tracker: BaseMultiModalItemTracker,
chat_template_text_format: str,
content_format: _ChatTemplateContentFormat,
) -> List[ConversationMessage]:
role = message["role"]
content = message.get("content")
@ -645,7 +843,7 @@ def _parse_chat_message_content(
role,
content, # type: ignore
mm_tracker,
chat_template_text_format,
wrap_dicts=(content_format == "openai"),
)
for result_msg in result:
@ -684,6 +882,7 @@ def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
content_format: _ChatTemplateContentFormat,
) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]:
conversation: List[ConversationMessage] = []
mm_tracker = MultiModalItemTracker(model_config, tokenizer)
@ -692,7 +891,7 @@ def parse_chat_messages(
sub_messages = _parse_chat_message_content(
msg,
mm_tracker,
model_config.chat_template_text_format,
content_format,
)
conversation.extend(sub_messages)
@ -706,6 +905,7 @@ def parse_chat_messages_futures(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: AnyTokenizer,
content_format: _ChatTemplateContentFormat,
) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = []
mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer)
@ -714,7 +914,7 @@ def parse_chat_messages_futures(
sub_messages = _parse_chat_message_content(
msg,
mm_tracker,
model_config.chat_template_text_format,
content_format,
)
conversation.extend(sub_messages)

View File

@ -13,9 +13,11 @@ from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
TaskOption)
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages)
parse_chat_messages,
resolve_chat_template_content_format)
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
@ -523,6 +525,7 @@ class LLM:
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
chat_template: Optional[str] = None,
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tools: Optional[List[Dict[str, Any]]] = None,
@ -539,9 +542,11 @@ class LLM:
to the OpenAI API.
Args:
messages: A list of conversations or a single conversation.
- Each conversation is represented as a list of messages.
- Each message is a dictionary with 'role' and 'content' keys.
messages: A list of conversations or a single conversation.
- Each conversation is represented as a list of messages.
- Each message is a dictionary with 'role' and 'content' keys.
sampling_params: The sampling parameters for text generation.
If None, we use the default sampling parameters. When it
is a single value, it is applied to every prompt. When it
@ -551,11 +556,19 @@ class LLM:
lora_request: LoRA request to use for generation, if any.
chat_template: The template to use for structuring the chat.
If not provided, the model's default chat template will be used.
chat_template_content_format: The format to render message content.
- "string" will render the content as a string.
Example: ``"Who are you?"``
- "openai" will render the content as a list of dictionaries,
similar to OpenAI schema.
Example: ``[{"type": "text", "text": "Who are you?"}]``
add_generation_prompt: If True, adds a generation template
to each message.
continue_final_message: If True, continues the final message in
the conversation instead of starting a new one. Cannot be `True`
if `add_generation_prompt` is also `True`.
the conversation instead of starting a new one. Cannot be
``True`` if ``add_generation_prompt`` is also ``True``.
mm_processor_kwargs: Multimodal processor kwarg overrides for this
chat request. Only used for offline requests.
@ -576,17 +589,26 @@ class LLM:
cast(List[ChatCompletionMessageParam], messages)
]
tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config()
resolved_content_format = resolve_chat_template_content_format(
chat_template,
chat_template_content_format,
tokenizer,
)
prompts: List[Union[TokensPrompt, TextPrompt]] = []
for msgs in list_of_messages:
tokenizer = self.get_tokenizer()
model_config = self.llm_engine.get_model_config()
# NOTE: _parse_chat_message_content_parts() currently doesn't
# handle mm_processor_kwargs, since there is no implementation in
# the chat message parsing for it.
conversation, mm_data = parse_chat_messages(
msgs, model_config, tokenizer)
msgs,
model_config,
tokenizer,
content_format=resolved_content_format,
)
prompt_data: Union[str, List[int]]
if isinstance(tokenizer, MistralTokenizer):
@ -737,7 +759,7 @@ class LLM:
generation, if any.
Returns:
A list of `EmbeddingRequestOutput` objects containing the
A list of ``EmbeddingRequestOutput`` objects containing the
generated embeddings in the same order as the input prompts.
Note:

View File

@ -29,6 +29,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
@ -529,6 +530,9 @@ def init_app_state(
state.engine_client = engine_client
state.log_stats = not args.disable_log_stats
resolved_chat_template = load_chat_template(args.chat_template)
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config,
@ -537,7 +541,8 @@ def init_app_state(
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
chat_template=args.chat_template,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser,
@ -557,7 +562,8 @@ def init_app_state(
model_config,
base_model_paths,
request_logger=request_logger,
chat_template=args.chat_template,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
) if model_config.task == "embedding" else None
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
@ -565,7 +571,8 @@ def init_app_state(
base_model_paths,
lora_modules=args.lora_modules,
request_logger=request_logger,
chat_template=args.chat_template,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
)

View File

@ -7,10 +7,11 @@ purposes.
import argparse
import json
import ssl
from typing import List, Optional, Sequence, Union
from typing import List, Optional, Sequence, Union, get_args
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.chat_utils import validate_chat_template
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
@ -132,6 +133,18 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help="The file path to the chat template, "
"or the template in single-line form "
"for the specified model")
parser.add_argument(
'--chat-template-content-format',
type=str,
default="auto",
choices=get_args(ChatTemplateContentFormatOption),
help='The format to render message content within a chat template.'
'\n\n'
'* "string" will render the content as a string. '
'Example: "Hello World"\n'
'* "openai" will render the content as a list of dictionaries, '
'similar to OpenAI schema. '
'Example: [{"type": "text", "text": "Hello world!"}]')
parser.add_argument("--response-role",
type=nullable_str,
default="assistant",

View File

@ -5,9 +5,8 @@ from argparse import Namespace
from typing import Any, Dict, List, Literal, Optional, Union
import torch
from openai.types.chat import ChatCompletionContentPartParam
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated, Required, TypedDict
from typing_extensions import Annotated
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.pooling_params import PoolingParams
@ -35,26 +34,6 @@ assert _LONG_INFO.min == _MOCK_LONG_INFO.min
assert _LONG_INFO.max == _MOCK_LONG_INFO.max
class CustomChatCompletionMessageParam(TypedDict, total=False):
"""Enables custom roles in the Chat Completion API."""
role: Required[str]
"""The role of the message's author."""
content: Union[str, List[ChatCompletionContentPartParam]]
"""The contents of the message."""
name: str
"""An optional name for the participant.
Provides the model information to differentiate between participants of the
same role.
"""
tool_call_id: Optional[str]
tool_calls: Optional[List[dict]]
class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields
model_config = ConfigDict(extra="forbid")
@ -1054,16 +1033,56 @@ class TokenizeCompletionRequest(OpenAIBaseModel):
model: str
prompt: str
add_special_tokens: bool = Field(default=True)
add_special_tokens: bool = Field(
default=True,
description=(
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."),
)
class TokenizeChatRequest(OpenAIBaseModel):
model: str
messages: List[ChatCompletionMessageParam]
add_generation_prompt: bool = Field(default=True)
continue_final_message: bool = Field(default=False)
add_special_tokens: bool = Field(default=False)
add_generation_prompt: bool = Field(
default=True,
description=
("If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."),
)
continue_final_message: bool = Field(
default=False,
description=
("If this is set, the chat will be formatted so that the final "
"message in the chat is open-ended, without any EOS tokens. The "
"model will continue this message rather than starting a new one. "
"This allows you to \"prefill\" part of the model's response for it. "
"Cannot be used at the same time as `add_generation_prompt`."),
)
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to false (as is the "
"default)."),
)
chat_template: Optional[str] = Field(
default=None,
description=(
"A Jinja template to use for this conversion. "
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one."),
)
chat_template_kwargs: Optional[Dict[str, Any]] = Field(
default=None,
description=("Additional kwargs to pass to the template renderer. "
"Will be accessible by the chat template."),
)
@model_validator(mode="before")
@classmethod

View File

@ -222,6 +222,7 @@ async def main(args):
prompt_adapters=None,
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
) if model_config.task == "generate" else None
openai_serving_embedding = OpenAIServingEmbedding(
@ -230,6 +231,7 @@ async def main(args):
base_model_paths,
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
) if model_config.task == "embedding" else None
tracker = BatchProgressTracker()

View File

@ -10,7 +10,8 @@ from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ConversationMessage, load_chat_template
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
ConversationMessage)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb, ChatCompletionLogProbs,
@ -38,20 +39,23 @@ logger = init_logger(__name__)
class OpenAIServingChat(OpenAIServing):
def __init__(self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
response_role: str,
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
return_tokens_as_token_ids: bool = False,
enable_auto_tools: bool = False,
tool_parser: Optional[str] = None,
enable_prompt_tokens_details: bool = False):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
response_role: str,
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
return_tokens_as_token_ids: bool = False,
enable_auto_tools: bool = False,
tool_parser: Optional[str] = None,
enable_prompt_tokens_details: bool = False,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
@ -61,8 +65,8 @@ class OpenAIServingChat(OpenAIServing):
return_tokens_as_token_ids=return_tokens_as_token_ids)
self.response_role = response_role
self.use_tool_use_model_template = False
self.chat_template = load_chat_template(chat_template)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
# set up tool use
self.enable_auto_tools: bool = enable_auto_tools
@ -120,6 +124,7 @@ class OpenAIServingChat(OpenAIServing):
) = self._maybe_get_adapters(request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)
tool_parser = self.tool_parser
# validation for OpenAI tools
@ -157,6 +162,7 @@ class OpenAIServingChat(OpenAIServing):
tokenizer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tool_dicts=tool_dicts,

View File

@ -1,7 +1,7 @@
import asyncio
import base64
import time
from typing import AsyncGenerator, List, Literal, Optional, Union, cast
from typing import AsyncGenerator, Final, List, Literal, Optional, Union, cast
import numpy as np
from fastapi import Request
@ -9,7 +9,7 @@ from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
EmbeddingRequest,
@ -77,7 +77,8 @@ class OpenAIServingEmbedding(OpenAIServing):
*,
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
):
chat_template_content_format: ChatTemplateContentFormatOption,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
@ -85,7 +86,8 @@ class OpenAIServingEmbedding(OpenAIServing):
prompt_adapters=None,
request_logger=request_logger)
self.chat_template = load_chat_template(chat_template)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
async def create_embedding(
self,
@ -144,6 +146,8 @@ class OpenAIServingEmbedding(OpenAIServing):
tokenizer,
request.messages,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.
chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
truncate_prompt_tokens=truncate_prompt_tokens,

View File

@ -11,14 +11,16 @@ from typing_extensions import Annotated
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
ConversationMessage,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages_futures)
parse_chat_messages_futures,
resolve_chat_template_content_format)
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest,
DetokenizeRequest,
@ -426,7 +428,8 @@ class OpenAIServing:
request: ChatLikeRequest,
tokenizer: AnyTokenizer,
messages: List[ChatCompletionMessageParam],
chat_template: Optional[str] = None,
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
add_generation_prompt: bool = True,
continue_final_message: bool = False,
tool_dicts: Optional[List[Dict[str, Any]]] = None,
@ -437,10 +440,16 @@ class OpenAIServing:
add_special_tokens: bool = False,
) -> Tuple[List[ConversationMessage], Sequence[RequestPrompt],
List[TokensPrompt]]:
resolved_content_format = resolve_chat_template_content_format(
chat_template,
chat_template_content_format,
tokenizer,
)
conversation, mm_data_future = parse_chat_messages_futures(
messages,
self.model_config,
tokenizer,
content_format=resolved_content_format,
)
_chat_template_kwargs: Dict[str, Any] = dict(

View File

@ -1,8 +1,8 @@
from typing import List, Optional, Union
from typing import Final, List, Optional, Union
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
@ -33,7 +33,8 @@ class OpenAIServingTokenization(OpenAIServing):
lora_modules: Optional[List[LoRAModulePath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
):
chat_template_content_format: ChatTemplateContentFormatOption,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
@ -41,12 +42,8 @@ class OpenAIServingTokenization(OpenAIServing):
prompt_adapters=None,
request_logger=request_logger)
# If this is None we use the tokenizer's default chat template
# the list of commonly-used chat template names for HF named templates
hf_chat_templates: List[str] = ['default', 'tool_use']
self.chat_template = chat_template \
if chat_template in hf_chat_templates \
else load_chat_template(chat_template)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
async def create_tokenize(
self,
@ -75,9 +72,12 @@ class OpenAIServingTokenization(OpenAIServing):
request,
tokenizer,
request.messages,
chat_template=self.chat_template,
chat_template=request.chat_template or self.chat_template,
chat_template_content_format=self.
chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
chat_template_kwargs=request.chat_template_kwargs,
add_special_tokens=request.add_special_tokens,
)
else: