mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Frontend][Bug Fix] Update llama4 pythonic jinja template and llama4_pythonic parser (#17917)
Signed-off-by: Kai Wu <kaiwu@meta.com>
This commit is contained in:
@ -158,13 +158,13 @@ All Llama 3.1, 3.2 and 4 models should be supported.
|
||||
* `meta-llama/Llama-3.2-*`
|
||||
* `meta-llama/Llama-4-*`
|
||||
|
||||
The tool calling that is supported is the [JSON based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) introduced by the Llama-3.2 models, see the `pythonic` tool parser below.
|
||||
The tool calling that is supported is the [JSON based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) introduced by the Llama-3.2 models, see the `pythonic` tool parser below. As for llama 4 models, it is recommended to use the `llama4_pythonic` tool parser.
|
||||
|
||||
Other tool calling formats like the built in python tool calling or custom tool calling are not supported.
|
||||
|
||||
Known issues:
|
||||
|
||||
1. Parallel tool calls are not supported.
|
||||
1. Parallel tool calls are not supported for llama 3, but it is supported in llama 4 models.
|
||||
2. The model can generate parameters with a wrong format, such as generating
|
||||
an array serialized as string instead of an array.
|
||||
|
||||
@ -177,11 +177,10 @@ images.
|
||||
|
||||
Recommended flags: `--tool-call-parser llama3_json --chat-template {see_above}`
|
||||
|
||||
VLLM also provides a JSON based chat template for Llama 4:
|
||||
* <gh-file:examples/tool_chat_template_llama4_json.jinja> - this is based on the "official" chat template for the Llama 4
|
||||
models, but tweaked so that it works better with vLLM.
|
||||
VLLM also provides a pythonic and JSON based chat template for Llama 4, but pythonic tool calling is recommended:
|
||||
* <gh-file:examples/tool_chat_template_llama4_pythonic.jinja> - this is based on the [official chat template](https://www.llama.com/docs/model-cards-and-prompt-formats/llama4/) for the Llama 4 models.
|
||||
|
||||
For Llama 4 use `--tool-call-parser llama4_json examples/tool_chat_template_llama4_json.jinja`.
|
||||
For Llama 4 model, use `--tool-call-parser llama4_pythonic --chat-template examples/tool_chat_template_llama4_pythonic.jinja`.
|
||||
|
||||
#### IBM Granite
|
||||
|
||||
|
@ -1,16 +1,17 @@
|
||||
{{- bos_token }}
|
||||
{%- if custom_tools is defined %}
|
||||
{%- if custom_tools is defined and custom_tools%}
|
||||
{%- set tools = custom_tools %}
|
||||
{%- endif %}
|
||||
{%- if not tools_in_user_message is defined %}
|
||||
{%- set tools_in_user_message = false %}
|
||||
{%- endif %}
|
||||
{%- if not tools is defined %}
|
||||
{%- if tools is defined and tools %}
|
||||
{%- set tool_definition = tool_definition ~ (tools | tojson(indent=4)) %}
|
||||
{%- else %}
|
||||
{%- set tools = none %}
|
||||
{%- endif %}
|
||||
|
||||
|
||||
{#- This block extracts the system message, so we can slot it into the right place. #}
|
||||
{%- if messages[0]['role'] == 'system' %}
|
||||
{%- set user_provided_system_message = true %}
|
||||
{%- if messages[0]['content'] is string %}
|
||||
{%- set system_message = messages[0]['content']|trim %}
|
||||
{%- else %}
|
||||
@ -18,68 +19,33 @@
|
||||
{%- endif %}
|
||||
{%- set messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- if tools is not none %}
|
||||
{#- Add default tool system message when tools are provided #}
|
||||
{%- set system_message = "You are a helpful assistant with tool calling "
|
||||
"capabilities. Only reply with a tool call if the function exists in the "
|
||||
"library provided by the user. If it doesn't exist, just reply directly in "
|
||||
"natural language. When you receive a tool call response, use the output to "
|
||||
"format an answer to the original user question." %}
|
||||
{%- if tools is not none %}
|
||||
{#- Since not system_message was provided by user, if tool is provided, system_message is now default tool system message #}
|
||||
{#- This system message is from llama website:https://www.llama.com/docs/model-cards-and-prompt-formats/llama4/ #}
|
||||
{%- set system_message = "You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:\n\n1. FUNCTION CALLS:\n- ONLY use functions that are EXPLICITLY listed in the function list below\n- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\"\n- If a function is not in the list, respond ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\"\n- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)\n- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\nExamples:\nCORRECT: [get_weather(location=\"Vancouver\"), calculate_route(start=\"Boston\", end=\"New York\")] <- Only if get_weather and calculate_route are in function list\nINCORRECT: get_weather(location=\"New York\")\nINCORRECT: Let me check the weather: [get_weather(location=\"New York\")]\nINCORRECT: [get_events(location=\"Singapore\")] <- If function not in list\n\n2. RESPONSE RULES:\n- For pure function requests matching a listed function: ONLY output the function call(s)\n- For knowledge questions: ONLY output text\n- For missing parameters: ONLY request the specific missing parameters\n- For unavailable services (not in function list): output ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\". Do NOT execute a function call.\n- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations\n- NEVER combine text and function calls in the same response\n- NEVER suggest alternative functions when the requested service is unavailable\n- NEVER create or invent new functions not listed below\n\n3. STRICT BOUNDARIES:\n- ONLY use functions from the list below - no exceptions\n- NEVER use a function as an alternative to unavailable information\n- NEVER call functions not present in the function list\n- NEVER add explanatory text to function calls\n- NEVER respond with empty brackets\n- Use proper Python/JSON syntax for function calls\n- Check the function list carefully before responding\n\n4. TOOL RESPONSE HANDLING:\n- When receiving tool responses: provide concise, natural language responses\n- Don't repeat tool response verbatim\n- Don't add supplementary information\n\nHere is a list of functions in JSON format that you can invoke:\n" %}
|
||||
{%- else %}
|
||||
{%- set system_message = "" %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
|
||||
{#- System message if the user supplied one, or if tools are used (default tool system message) #}
|
||||
{#- Now writing the system message: use the user provided system message if user_provided_system_message, else default tool system message if tools presented #}
|
||||
{%- if system_message %}
|
||||
{#- always use user provided system message to override default tool system message #}
|
||||
{{- "<|header_start|>system<|header_end|>\n\n" }}
|
||||
{{- system_message }}
|
||||
{%- if tools is not none and not tools_in_user_message %}
|
||||
{{- "Tools: You have access to the following tools. You might need to use one "
|
||||
"or more function/tool calls to fulfill the task. \n"
|
||||
"If none are needed, then proceed to the response.\n\n"
|
||||
"Tool Call Syntax: You can call tools using the following syntax:\n"
|
||||
"[func_name1(params_name1=params_value1, params_name2=params_value2, ...), ...]\n"
|
||||
"Do not include anything else when calling the tools with the syntax above.\n\n"
|
||||
"Here is a list of functions in JSON format that you can invoke.\n " }}
|
||||
{%- for t in tools %}
|
||||
{{- t | tojson(indent=4) }}
|
||||
{{- "\n\n" }}
|
||||
{%- endfor %}
|
||||
{%- if user_provided_system_message and tools %}
|
||||
{{- "\nHere is a list of functions in JSON format that you can invoke. Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\n" }}
|
||||
{{- tool_definition -}}
|
||||
{%- elif tool_definition %}
|
||||
{{- tool_definition -}}
|
||||
{%- endif %}
|
||||
{{- "<|eot|>" }}
|
||||
{%- endif %}
|
||||
|
||||
{#- Custom tools are passed in a user message with some extra guidance #}
|
||||
{%- if tools_in_user_message and tools is not none %}
|
||||
{#- Extract the first user message so we can plug it in here #}
|
||||
{%- if messages | length != 0 %}
|
||||
{%- if messages[0]['content'] is string %}
|
||||
{%- set first_user_message = messages[0]['content']|trim %}
|
||||
{%- else %}
|
||||
{%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\n') %}
|
||||
{%- endif %}
|
||||
{%- set messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
|
||||
{%- endif %}
|
||||
{{- '<|header_start|>user<|header_end|>\n\n' -}}
|
||||
{{- first_user_message}}
|
||||
{{- "\nHere is a list of functions in JSON format that you can invoke:"}}
|
||||
{%- for t in tools %}
|
||||
{{- t | tojson(indent=4) }}
|
||||
{{- "\n\n" }}
|
||||
{%- endfor %}
|
||||
{{- "Should you decide to return the function call(s), put them in the format "
|
||||
"of [func_name1(params_name1=params_value1, params_name2=params_value2, "
|
||||
"...), ...]\nDo not include anything else when calling the tools with the "
|
||||
"syntax above." }}
|
||||
{%- endif %}
|
||||
|
||||
{#- Now deal with all other messages #}
|
||||
{%- for message in messages %}
|
||||
{%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
|
||||
{{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}
|
||||
{#- Base case: messages that are not from tool role and has empty tool_call list #}
|
||||
{%- if not (message.role == 'ipython' or message.role == 'tool' or ('tool_calls' in message and message.tool_calls|length != 0 )) %}
|
||||
{{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }}
|
||||
{%- if message['content'] is string %}
|
||||
{{- message['content'] }}
|
||||
{%- else %}
|
||||
@ -91,10 +57,12 @@
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- "<|eot|>" }}
|
||||
{%- elif 'tool_calls' in message and message.tool_calls|length > 0 %}
|
||||
{%- set tool_call = message.tool_calls[0].function %}
|
||||
{{- '<|header_start|>assistant<|header_end|>\n\n' -}}
|
||||
{{- "<|eot|>" }}
|
||||
{#- Tool case: messages has non-empty tool_call list, must from assistant #}
|
||||
{%- elif 'tool_calls' in message %}
|
||||
{#- assume tool_calls are always coming from assistant #}
|
||||
{%- if message.role == 'assistant' %}
|
||||
{{- '<|header_start|>assistant<|header_end|>\n\n' -}}
|
||||
{%- if message['content'] is string %}
|
||||
{{- message['content'] }}
|
||||
{%- else %}
|
||||
@ -106,32 +74,36 @@
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- "[" }}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if tool_call.function is defined %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- tool_call.name + '(' -}}
|
||||
{{- tool_call.name + '(' -}}
|
||||
{%- for param in tool_call.arguments %}
|
||||
{{- param + '=' -}}
|
||||
{{- param + '="' -}}
|
||||
{{- "%s" | format(tool_call.arguments[param]) -}}
|
||||
{{- '"' -}}
|
||||
{% if not loop.last %}, {% endif %}
|
||||
{%- endfor %}
|
||||
{{- ')' -}}
|
||||
{% if not loop.last %}, {% endif %}
|
||||
{%- endfor %}
|
||||
{{- "<|eom|>" }}
|
||||
{{- "]<|eot|>" }}
|
||||
{%- endif %}
|
||||
{#- Tool_response case: messages are from tool_response #}
|
||||
{%- elif message.role == "tool" or message.role == "ipython" %}
|
||||
{{- "<|header_start|>ipython<|header_end|>\n\n" }}
|
||||
{%- if message.content is string %}
|
||||
{{- message.content | tojson }}
|
||||
{{- message.content | tojson }}
|
||||
{%- else %}
|
||||
{%- for content in message['content'] %}
|
||||
{%- if content['type'] == 'text' %}
|
||||
{{- content['text'] | tojson }}
|
||||
{{- content['text'] | tojson }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- "<|eom|>" }}
|
||||
{{- "<|eot|>" }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
|
@ -0,0 +1,193 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.entrypoints.openai.tool_parsers.utils import (
|
||||
run_tool_extraction, run_tool_extraction_streaming)
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
# Test cases similar to pythonic parser but with Llama4 specific format
|
||||
SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]"
|
||||
SIMPLE_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "LA", "metric": "C"}',
|
||||
)
|
||||
MORE_TYPES_FUNCTION_OUTPUT = ("[register_user(name='Doe', "
|
||||
"age=9, "
|
||||
"address={'city': 'LA', 'state': 'CA'}, "
|
||||
"role=None, "
|
||||
"passed_test=True, "
|
||||
"aliases=['John', 'Johnny'])]")
|
||||
MORE_TYPES_FUNCTION_CALL = FunctionCall(
|
||||
name="register_user",
|
||||
arguments='{"name": "Doe", '
|
||||
'"age": 9, '
|
||||
'"address": {"city": "LA", "state": "CA"}, '
|
||||
'"role": null, '
|
||||
'"passed_test": true, '
|
||||
'"aliases": ["John", "Johnny"]}',
|
||||
)
|
||||
PARAMETERLESS_FUNCTION_OUTPUT = "[get_weather()]"
|
||||
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{}',
|
||||
)
|
||||
EMPTY_DICT_FUNCTION_OUTPUT = "[do_something_cool(additional_data={})]"
|
||||
EMPTY_DICT_FUNCTION_CALL = FunctionCall(
|
||||
name="do_something_cool",
|
||||
arguments='{"additional_data": {}}',
|
||||
)
|
||||
EMPTY_LIST_FUNCTION_OUTPUT = "[do_something_cool(steps=[])]"
|
||||
EMPTY_LIST_FUNCTION_CALL = FunctionCall(
|
||||
name="do_something_cool",
|
||||
arguments='{"steps": []}',
|
||||
)
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT = (
|
||||
r"[get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')]")
|
||||
ESCAPED_STRING_FUNCTION_CALL = FunctionCall(
|
||||
name="get_weather",
|
||||
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}',
|
||||
)
|
||||
PYTHON_TAG_FUNCTION_OUTPUT = (
|
||||
"<|python_start|>[get_weather(city='LA', metric='C')]<|python_end|>")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool):
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
|
||||
"llama4_pythonic")(mock_tokenizer)
|
||||
model_output = "How can I help you today?"
|
||||
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
model_output,
|
||||
streaming=streaming)
|
||||
|
||||
assert content == model_output
|
||||
assert len(tool_calls) == 0
|
||||
|
||||
|
||||
test_str = "<|python_start|>"
|
||||
test_str += "[get_weather(city='LA', metric='C'),"
|
||||
test_str += "register_user(name='Doe', age=9)]"
|
||||
TEST_CASES = [
|
||||
pytest.param(True,
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT,
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="simple_streaming"),
|
||||
pytest.param(False,
|
||||
SIMPLE_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL],
|
||||
id="simple_nonstreaming"),
|
||||
pytest.param(True,
|
||||
MORE_TYPES_FUNCTION_OUTPUT, [MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_streaming"),
|
||||
pytest.param(False,
|
||||
MORE_TYPES_FUNCTION_OUTPUT, [MORE_TYPES_FUNCTION_CALL],
|
||||
id="more_types_nonstreaming"),
|
||||
pytest.param(True,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT, [PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_streaming"),
|
||||
pytest.param(False,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT, [PARAMETERLESS_FUNCTION_CALL],
|
||||
id="parameterless_nonstreaming"),
|
||||
pytest.param(True,
|
||||
EMPTY_DICT_FUNCTION_OUTPUT, [EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_streaming"),
|
||||
pytest.param(False,
|
||||
EMPTY_DICT_FUNCTION_OUTPUT, [EMPTY_DICT_FUNCTION_CALL],
|
||||
id="empty_dict_nonstreaming"),
|
||||
pytest.param(True,
|
||||
EMPTY_LIST_FUNCTION_OUTPUT, [EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_streaming"),
|
||||
pytest.param(False,
|
||||
EMPTY_LIST_FUNCTION_OUTPUT, [EMPTY_LIST_FUNCTION_CALL],
|
||||
id="empty_list_nonstreaming"),
|
||||
pytest.param(True,
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT,
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_streaming"),
|
||||
pytest.param(False,
|
||||
ESCAPED_STRING_FUNCTION_OUTPUT,
|
||||
[ESCAPED_STRING_FUNCTION_CALL],
|
||||
id="escaped_string_nonstreaming"),
|
||||
pytest.param(
|
||||
True,
|
||||
"[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]",
|
||||
[
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user",
|
||||
arguments='{"name": "Doe", "age": 9}')
|
||||
],
|
||||
id="parallel_calls_streaming"),
|
||||
pytest.param(
|
||||
False,
|
||||
"[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]",
|
||||
[
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user",
|
||||
arguments='{"name": "Doe", "age": 9}')
|
||||
],
|
||||
id="parallel_calls_nonstreaming"),
|
||||
pytest.param(True,
|
||||
PYTHON_TAG_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL],
|
||||
id="python_tag_streaming"),
|
||||
pytest.param(False,
|
||||
PYTHON_TAG_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL],
|
||||
id="python_tag_nonstreaming"),
|
||||
pytest.param(True,
|
||||
test_str, [
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user",
|
||||
arguments='{"name": "Doe", "age": 9}')
|
||||
],
|
||||
id="parallel_calls_streaming"),
|
||||
pytest.param(False,
|
||||
"<|python_start|>[get_weather(city='LA', metric='C'), " +
|
||||
"register_user(name='Doe', age=9)]", [
|
||||
SIMPLE_FUNCTION_CALL,
|
||||
FunctionCall(name="register_user",
|
||||
arguments='{"name": "Doe", "age": 9}')
|
||||
],
|
||||
id="parallel_calls_nonstreaming"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, model_output, expected_tool_calls",
|
||||
TEST_CASES)
|
||||
def test_tool_call(streaming: bool, model_output: str,
|
||||
expected_tool_calls: list[FunctionCall]):
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
|
||||
"llama4_pythonic")(mock_tokenizer)
|
||||
|
||||
content, tool_calls = run_tool_extraction(tool_parser,
|
||||
model_output,
|
||||
streaming=streaming)
|
||||
|
||||
assert len(tool_calls) == len(expected_tool_calls)
|
||||
for actual, expected in zip(tool_calls, expected_tool_calls):
|
||||
assert actual.type == "function"
|
||||
assert actual.function == expected
|
||||
|
||||
|
||||
def test_streaming_tool_call_with_large_steps():
|
||||
mock_tokenizer = MagicMock()
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
|
||||
"llama4_pythonic")(mock_tokenizer)
|
||||
model_output_deltas = [
|
||||
"<|python_start|>[get_weather(city='LA', metric='C'), "
|
||||
"get_weather(), "
|
||||
"do_something_cool(steps=[])]<|python_end|>",
|
||||
]
|
||||
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser, model_output_deltas, assert_one_tool_per_delta=False)
|
||||
|
||||
assert reconstructor.other_content == ""
|
||||
assert len(reconstructor.tool_calls) == 3
|
||||
assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL
|
||||
assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL
|
@ -88,7 +88,7 @@ CONFIGS: dict[str, ServerConfig] = {
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
"arguments": [
|
||||
"--enforce-eager", "--no-enable-prefix-caching",
|
||||
"--tool-call-parser", "pythonic", "--chat-template",
|
||||
"--tool-call-parser", "llama4_pythonic", "--chat-template",
|
||||
str(VLLM_PATH /
|
||||
"examples/tool_chat_template_llama4_pythonic.jinja"), "-tp",
|
||||
"4"
|
||||
|
@ -7,6 +7,7 @@ from .granite_tool_parser import GraniteToolParser
|
||||
from .hermes_tool_parser import Hermes2ProToolParser
|
||||
from .internlm2_tool_parser import Internlm2ToolParser
|
||||
from .jamba_tool_parser import JambaToolParser
|
||||
from .llama4_pythonic_tool_parser import Llama4PythonicToolParser
|
||||
from .llama_tool_parser import Llama3JsonToolParser
|
||||
from .mistral_tool_parser import MistralToolParser
|
||||
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
|
||||
@ -16,5 +17,6 @@ __all__ = [
|
||||
"ToolParser", "ToolParserManager", "Granite20bFCToolParser",
|
||||
"GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
|
||||
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
|
||||
"PythonicToolParser", "Phi4MiniJsonToolParser", "DeepSeekV3ToolParser"
|
||||
"Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser",
|
||||
"DeepSeekV3ToolParser"
|
||||
]
|
||||
|
@ -0,0 +1,303 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class _UnexpectedAstError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@ToolParserManager.register_module("llama4_pythonic")
|
||||
class Llama4PythonicToolParser(ToolParser):
|
||||
"""
|
||||
Toolcall parser for Llama4 that produce tool calls in a pythonic style
|
||||
Use --enable-auto-tool-choice --tool-call-parser llama4_pythonic
|
||||
"""
|
||||
# TODO(mdepinet): Possible future improvements:
|
||||
# 1. Support text + tools separated by either <|python_tag|> or \n\n
|
||||
# 2. Support tools outside of a list (or separated by a semicolon).
|
||||
# This depends on item 1 for consistent streaming.
|
||||
# Neither of these are necessary for e.g. ToolACE, but both would help make
|
||||
# Llama3.2 models more reliable.
|
||||
|
||||
TOOL_CALL_REGEX = re.compile(
|
||||
r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
|
||||
re.DOTALL)
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# Rename for readability. This is NOT a tool id.
|
||||
@property
|
||||
def current_tool_index(self) -> int:
|
||||
return self.current_tool_id
|
||||
|
||||
@current_tool_index.setter
|
||||
def current_tool_index(self, value: int) -> None:
|
||||
self.current_tool_id = value
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
"""
|
||||
|
||||
# remove <|python_start|> and <|python_end|>
|
||||
# as Llama 4 model sometime will output those tokens
|
||||
if model_output.startswith("<|python_start|>"):
|
||||
model_output = model_output[len("<|python_start|>"):]
|
||||
model_output = model_output.replace("<|python_end|>", "")
|
||||
if not (self.TOOL_CALL_REGEX.match(model_output)):
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
try:
|
||||
module = ast.parse(model_output)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if isinstance(parsed, ast.List) and all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts):
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=[
|
||||
_handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
],
|
||||
content=None)
|
||||
else:
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls")
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
# Treat as regular text
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
if not current_text.startswith("[") and not current_text.startswith(
|
||||
"<|python_start|>"):
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
# remove <|python_start|> and <|python_end|>
|
||||
if current_text.startswith("<|python_start|>"):
|
||||
current_text = current_text[len("<|python_start|>"):]
|
||||
if current_text.endswith("<|python_end|>"):
|
||||
current_text = current_text[:current_text.
|
||||
rfind("<|python_end|>")]
|
||||
valid_and_added_text = _make_valid_python(current_text)
|
||||
if valid_and_added_text is None:
|
||||
return None
|
||||
valid_text, added_text = valid_and_added_text
|
||||
|
||||
module = ast.parse(valid_text)
|
||||
parsed = getattr(module.body[0], "value", None)
|
||||
if not isinstance(parsed, ast.List) or not all(
|
||||
isinstance(e, ast.Call) for e in parsed.elts):
|
||||
raise _UnexpectedAstError(
|
||||
"Tool output must be a list of function calls")
|
||||
tool_calls = [
|
||||
_handle_single_tool(e) # type: ignore
|
||||
for e in parsed.elts
|
||||
]
|
||||
|
||||
tool_deltas = []
|
||||
for index, new_call in enumerate(tool_calls):
|
||||
if index < self.current_tool_index:
|
||||
continue
|
||||
|
||||
self.current_tool_index = index
|
||||
if len(self.streamed_args_for_tool) == index:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
new_call_complete = index < len(
|
||||
tool_calls) - 1 or ")]" not in added_text
|
||||
if new_call_complete:
|
||||
self.current_tool_index += 1
|
||||
|
||||
withheld_suffix = (added_text[:-2]
|
||||
if not new_call_complete else "")
|
||||
if not new_call_complete and added_text[-2] == ")":
|
||||
# Function call is incomplete. Withhold the closing bracket.
|
||||
withheld_suffix = withheld_suffix + "}"
|
||||
# Strings get single quotes in the model-produced string.
|
||||
# JSON requires double quotes.
|
||||
withheld_suffix = withheld_suffix.replace("'", '"')
|
||||
delta = _compute_tool_delta(self.streamed_args_for_tool[index],
|
||||
new_call, index, withheld_suffix)
|
||||
|
||||
if delta is not None:
|
||||
tool_deltas.append(delta)
|
||||
if (delta.function is not None
|
||||
and delta.function.arguments is not None):
|
||||
self.streamed_args_for_tool[
|
||||
index] += delta.function.arguments
|
||||
|
||||
# HACK: serving_chat.py inspects the internal state of tool parsers
|
||||
# when determining it's final streaming delta, automatically
|
||||
# adding autocompleted JSON.
|
||||
# These two lines avoid that nonsense while ensuring finish_reason
|
||||
# is set to tool_calls when at least one tool is called.
|
||||
if tool_deltas and not self.prev_tool_call_arr:
|
||||
self.prev_tool_call_arr = [{"arguments": {}}]
|
||||
|
||||
if tool_deltas:
|
||||
return DeltaMessage(tool_calls=tool_deltas)
|
||||
elif not added_text and self.current_tool_id > 0:
|
||||
# Return an empty DeltaMessage once the tool calls are all done
|
||||
# so that finish_reason gets set.
|
||||
return DeltaMessage(content='')
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
|
||||
|
||||
def _get_parameter_value(val: ast.expr) -> Any:
|
||||
if isinstance(val, ast.Constant):
|
||||
return val.value
|
||||
elif isinstance(val, ast.Dict):
|
||||
if not all(isinstance(k, ast.Constant) for k in val.keys):
|
||||
raise _UnexpectedAstError(
|
||||
"Dict tool call arguments must have literal keys")
|
||||
return {
|
||||
k.value: _get_parameter_value(v) # type: ignore
|
||||
for k, v in zip(val.keys, val.values)
|
||||
}
|
||||
elif isinstance(val, ast.List):
|
||||
return [_get_parameter_value(v) for v in val.elts]
|
||||
else:
|
||||
raise _UnexpectedAstError("Tool call arguments must be literals")
|
||||
|
||||
|
||||
def _handle_single_tool(call: ast.Call) -> ToolCall:
|
||||
if not isinstance(call.func, ast.Name):
|
||||
raise _UnexpectedAstError("Invalid tool call name")
|
||||
function_name = call.func.id
|
||||
arguments = {}
|
||||
for keyword in call.keywords:
|
||||
arguments[keyword.arg] = _get_parameter_value(keyword.value)
|
||||
return ToolCall(type="function",
|
||||
function=FunctionCall(name=function_name,
|
||||
arguments=json.dumps(arguments)))
|
||||
|
||||
|
||||
def _make_valid_python(text: str) -> Union[tuple[str, str], None]:
|
||||
bracket_stack = []
|
||||
for index, char in enumerate(text):
|
||||
if char in {"[", "(", "{"}:
|
||||
bracket_stack.append(char)
|
||||
elif char == "]":
|
||||
if not bracket_stack or bracket_stack.pop() != "[":
|
||||
raise _UnexpectedAstError("Mismatched square brackets")
|
||||
elif char == ")":
|
||||
if not bracket_stack or bracket_stack.pop() != "(":
|
||||
raise _UnexpectedAstError("Mismatched parentheses")
|
||||
elif char == "}":
|
||||
if not bracket_stack or bracket_stack.pop() != "{":
|
||||
raise _UnexpectedAstError("Mismatched curly braces")
|
||||
elif char in {"'", '"'}:
|
||||
if bracket_stack and bracket_stack[-1] == char:
|
||||
if index > 0 and text[index - 1] == "\\":
|
||||
# Treat an escaped quote as a regular character
|
||||
pass
|
||||
else:
|
||||
bracket_stack.pop()
|
||||
elif bracket_stack and bracket_stack[-1] in {"'", '"'}:
|
||||
# Double quote within a single quote string or vice versa.
|
||||
pass
|
||||
else:
|
||||
bracket_stack.append(char)
|
||||
|
||||
text = text.rstrip()
|
||||
if text.endswith("=") or text.endswith(":"):
|
||||
# Since we have no type information for this property/parameter value,
|
||||
# we can't fill in a valid value.
|
||||
return None
|
||||
if bracket_stack and bracket_stack[-1] == "{":
|
||||
trailing_dict_text = text[:text.rfind("{")]
|
||||
num_keys = trailing_dict_text.count(":")
|
||||
num_values = trailing_dict_text.count(",")
|
||||
if num_keys <= num_values:
|
||||
return None # Incomplete property name within parameter value
|
||||
if bracket_stack and bracket_stack[-1] == "(":
|
||||
trailing_params_text = text[:text.rfind("(")]
|
||||
num_full_param_names = trailing_params_text.count("=")
|
||||
num_full_param_values = trailing_params_text.count(",")
|
||||
if num_full_param_names <= num_full_param_values:
|
||||
return None # Incomplete parameter name
|
||||
if text.endswith(","):
|
||||
text = text[:-1]
|
||||
if bracket_stack and bracket_stack[-1] == "[" and not text.endswith(
|
||||
"[") and not text.endswith(")"):
|
||||
return None # Incomplete function name
|
||||
|
||||
added_text = ""
|
||||
for char in reversed(bracket_stack):
|
||||
if char == "[":
|
||||
added_text += "]"
|
||||
elif char == "(":
|
||||
added_text += ")"
|
||||
elif char == "{":
|
||||
added_text += "}"
|
||||
elif char == "'":
|
||||
added_text += "'"
|
||||
elif char == '"':
|
||||
added_text += '"'
|
||||
|
||||
return text + added_text, added_text
|
||||
|
||||
|
||||
def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall,
|
||||
index: int,
|
||||
withheld_suffix: str) -> Union[DeltaToolCall, None]:
|
||||
new_call_args = new_call.function.arguments
|
||||
if withheld_suffix:
|
||||
assert new_call_args.endswith(withheld_suffix)
|
||||
new_call_args = new_call_args[:-len(withheld_suffix)]
|
||||
if not previously_sent_args:
|
||||
return DeltaToolCall(id=new_call.id,
|
||||
type="function",
|
||||
index=index,
|
||||
function=DeltaFunctionCall(
|
||||
name=new_call.function.name,
|
||||
arguments=new_call_args,
|
||||
))
|
||||
|
||||
arg_diff = new_call_args[len(previously_sent_args):]
|
||||
return DeltaToolCall(
|
||||
id=None, index=index, function=DeltaFunctionCall(
|
||||
arguments=arg_diff)) if arg_diff else None
|
Reference in New Issue
Block a user