From c91fe7b1b9c4398c6d4c980fc480ada0da8a0b23 Mon Sep 17 00:00:00 2001 From: Kai Wu Date: Thu, 22 May 2025 16:44:08 -0700 Subject: [PATCH] [Frontend][Bug Fix] Update llama4 pythonic jinja template and llama4_pythonic parser (#17917) Signed-off-by: Kai Wu --- docs/source/features/tool_calling.md | 11 +- .../tool_chat_template_llama4_pythonic.jinja | 100 +++--- .../test_llama4_pythonic_tool_parser.py | 193 +++++++++++ tests/tool_use/utils.py | 2 +- .../openai/tool_parsers/__init__.py | 4 +- .../llama4_pythonic_tool_parser.py | 303 ++++++++++++++++++ 6 files changed, 541 insertions(+), 72 deletions(-) create mode 100644 tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py create mode 100644 vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py diff --git a/docs/source/features/tool_calling.md b/docs/source/features/tool_calling.md index 2795b76934..f76128406b 100644 --- a/docs/source/features/tool_calling.md +++ b/docs/source/features/tool_calling.md @@ -158,13 +158,13 @@ All Llama 3.1, 3.2 and 4 models should be supported. * `meta-llama/Llama-3.2-*` * `meta-llama/Llama-4-*` -The tool calling that is supported is the [JSON based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) introduced by the Llama-3.2 models, see the `pythonic` tool parser below. +The tool calling that is supported is the [JSON based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). For [pythonic tool calling](https://github.com/meta-llama/llama-models/blob/main/models/llama3_2/text_prompt_format.md#zero-shot-function-calling) introduced by the Llama-3.2 models, see the `pythonic` tool parser below. As for llama 4 models, it is recommended to use the `llama4_pythonic` tool parser. Other tool calling formats like the built in python tool calling or custom tool calling are not supported. Known issues: -1. Parallel tool calls are not supported. +1. Parallel tool calls are not supported for llama 3, but it is supported in llama 4 models. 2. The model can generate parameters with a wrong format, such as generating an array serialized as string instead of an array. @@ -177,11 +177,10 @@ images. Recommended flags: `--tool-call-parser llama3_json --chat-template {see_above}` -VLLM also provides a JSON based chat template for Llama 4: -* - this is based on the "official" chat template for the Llama 4 -models, but tweaked so that it works better with vLLM. +VLLM also provides a pythonic and JSON based chat template for Llama 4, but pythonic tool calling is recommended: +* - this is based on the [official chat template](https://www.llama.com/docs/model-cards-and-prompt-formats/llama4/) for the Llama 4 models. -For Llama 4 use `--tool-call-parser llama4_json examples/tool_chat_template_llama4_json.jinja`. +For Llama 4 model, use `--tool-call-parser llama4_pythonic --chat-template examples/tool_chat_template_llama4_pythonic.jinja`. #### IBM Granite diff --git a/examples/tool_chat_template_llama4_pythonic.jinja b/examples/tool_chat_template_llama4_pythonic.jinja index bd18a35bdd..bbed3d8205 100644 --- a/examples/tool_chat_template_llama4_pythonic.jinja +++ b/examples/tool_chat_template_llama4_pythonic.jinja @@ -1,16 +1,17 @@ {{- bos_token }} -{%- if custom_tools is defined %} +{%- if custom_tools is defined and custom_tools%} {%- set tools = custom_tools %} {%- endif %} -{%- if not tools_in_user_message is defined %} - {%- set tools_in_user_message = false %} -{%- endif %} -{%- if not tools is defined %} +{%- if tools is defined and tools %} + {%- set tool_definition = tool_definition ~ (tools | tojson(indent=4)) %} +{%- else %} {%- set tools = none %} {%- endif %} + {#- This block extracts the system message, so we can slot it into the right place. #} {%- if messages[0]['role'] == 'system' %} + {%- set user_provided_system_message = true %} {%- if messages[0]['content'] is string %} {%- set system_message = messages[0]['content']|trim %} {%- else %} @@ -18,68 +19,33 @@ {%- endif %} {%- set messages = messages[1:] %} {%- else %} - {%- if tools is not none %} - {#- Add default tool system message when tools are provided #} - {%- set system_message = "You are a helpful assistant with tool calling " - "capabilities. Only reply with a tool call if the function exists in the " - "library provided by the user. If it doesn't exist, just reply directly in " - "natural language. When you receive a tool call response, use the output to " - "format an answer to the original user question." %} + {%- if tools is not none %} + {#- Since not system_message was provided by user, if tool is provided, system_message is now default tool system message #} + {#- This system message is from llama website:https://www.llama.com/docs/model-cards-and-prompt-formats/llama4/ #} + {%- set system_message = "You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:\n\n1. FUNCTION CALLS:\n- ONLY use functions that are EXPLICITLY listed in the function list below\n- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\"\n- If a function is not in the list, respond ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\"\n- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)\n- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\nExamples:\nCORRECT: [get_weather(location=\"Vancouver\"), calculate_route(start=\"Boston\", end=\"New York\")] <- Only if get_weather and calculate_route are in function list\nINCORRECT: get_weather(location=\"New York\")\nINCORRECT: Let me check the weather: [get_weather(location=\"New York\")]\nINCORRECT: [get_events(location=\"Singapore\")] <- If function not in list\n\n2. RESPONSE RULES:\n- For pure function requests matching a listed function: ONLY output the function call(s)\n- For knowledge questions: ONLY output text\n- For missing parameters: ONLY request the specific missing parameters\n- For unavailable services (not in function list): output ONLY with internal knowledge or \"I don't have access to [Unavailable service] information\". Do NOT execute a function call.\n- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations\n- NEVER combine text and function calls in the same response\n- NEVER suggest alternative functions when the requested service is unavailable\n- NEVER create or invent new functions not listed below\n\n3. STRICT BOUNDARIES:\n- ONLY use functions from the list below - no exceptions\n- NEVER use a function as an alternative to unavailable information\n- NEVER call functions not present in the function list\n- NEVER add explanatory text to function calls\n- NEVER respond with empty brackets\n- Use proper Python/JSON syntax for function calls\n- Check the function list carefully before responding\n\n4. TOOL RESPONSE HANDLING:\n- When receiving tool responses: provide concise, natural language responses\n- Don't repeat tool response verbatim\n- Don't add supplementary information\n\nHere is a list of functions in JSON format that you can invoke:\n" %} {%- else %} {%- set system_message = "" %} {%- endif %} {%- endif %} - -{#- System message if the user supplied one, or if tools are used (default tool system message) #} +{#- Now writing the system message: use the user provided system message if user_provided_system_message, else default tool system message if tools presented #} {%- if system_message %} {#- always use user provided system message to override default tool system message #} {{- "<|header_start|>system<|header_end|>\n\n" }} {{- system_message }} - {%- if tools is not none and not tools_in_user_message %} - {{- "Tools: You have access to the following tools. You might need to use one " - "or more function/tool calls to fulfill the task. \n" - "If none are needed, then proceed to the response.\n\n" - "Tool Call Syntax: You can call tools using the following syntax:\n" - "[func_name1(params_name1=params_value1, params_name2=params_value2, ...), ...]\n" - "Do not include anything else when calling the tools with the syntax above.\n\n" - "Here is a list of functions in JSON format that you can invoke.\n " }} - {%- for t in tools %} - {{- t | tojson(indent=4) }} - {{- "\n\n" }} - {%- endfor %} + {%- if user_provided_system_message and tools %} + {{- "\nHere is a list of functions in JSON format that you can invoke. Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]\n" }} + {{- tool_definition -}} + {%- elif tool_definition %} + {{- tool_definition -}} {%- endif %} {{- "<|eot|>" }} {%- endif %} -{#- Custom tools are passed in a user message with some extra guidance #} -{%- if tools_in_user_message and tools is not none %} - {#- Extract the first user message so we can plug it in here #} - {%- if messages | length != 0 %} - {%- if messages[0]['content'] is string %} - {%- set first_user_message = messages[0]['content']|trim %} - {%- else %} - {%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\n') %} - {%- endif %} - {%- set messages = messages[1:] %} - {%- else %} - {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} - {%- endif %} - {{- '<|header_start|>user<|header_end|>\n\n' -}} - {{- first_user_message}} - {{- "\nHere is a list of functions in JSON format that you can invoke:"}} - {%- for t in tools %} - {{- t | tojson(indent=4) }} - {{- "\n\n" }} - {%- endfor %} - {{- "Should you decide to return the function call(s), put them in the format " - "of [func_name1(params_name1=params_value1, params_name2=params_value2, " - "...), ...]\nDo not include anything else when calling the tools with the " - "syntax above." }} -{%- endif %} - +{#- Now deal with all other messages #} {%- for message in messages %} - {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} - {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }} + {#- Base case: messages that are not from tool role and has empty tool_call list #} + {%- if not (message.role == 'ipython' or message.role == 'tool' or ('tool_calls' in message and message.tool_calls|length != 0 )) %} + {{- '<|header_start|>' + message['role'] + '<|header_end|>\n\n' }} {%- if message['content'] is string %} {{- message['content'] }} {%- else %} @@ -91,10 +57,12 @@ {%- endif %} {%- endfor %} {%- endif %} - {{- "<|eot|>" }} - {%- elif 'tool_calls' in message and message.tool_calls|length > 0 %} - {%- set tool_call = message.tool_calls[0].function %} - {{- '<|header_start|>assistant<|header_end|>\n\n' -}} + {{- "<|eot|>" }} + {#- Tool case: messages has non-empty tool_call list, must from assistant #} + {%- elif 'tool_calls' in message %} + {#- assume tool_calls are always coming from assistant #} + {%- if message.role == 'assistant' %} + {{- '<|header_start|>assistant<|header_end|>\n\n' -}} {%- if message['content'] is string %} {{- message['content'] }} {%- else %} @@ -106,32 +74,36 @@ {%- endif %} {%- endfor %} {%- endif %} + {{- "[" }} {%- for tool_call in message.tool_calls %} {%- if tool_call.function is defined %} {%- set tool_call = tool_call.function %} {%- endif %} - {{- tool_call.name + '(' -}} + {{- tool_call.name + '(' -}} {%- for param in tool_call.arguments %} - {{- param + '=' -}} + {{- param + '="' -}} {{- "%s" | format(tool_call.arguments[param]) -}} + {{- '"' -}} {% if not loop.last %}, {% endif %} {%- endfor %} {{- ')' -}} {% if not loop.last %}, {% endif %} {%- endfor %} - {{- "<|eom|>" }} + {{- "]<|eot|>" }} +{%- endif %} +{#- Tool_response case: messages are from tool_response #} {%- elif message.role == "tool" or message.role == "ipython" %} {{- "<|header_start|>ipython<|header_end|>\n\n" }} {%- if message.content is string %} - {{- message.content | tojson }} + {{- message.content | tojson }} {%- else %} {%- for content in message['content'] %} {%- if content['type'] == 'text' %} - {{- content['text'] | tojson }} + {{- content['text'] | tojson }} {%- endif %} {%- endfor %} {%- endif %} - {{- "<|eom|>" }} + {{- "<|eot|>" }} {%- endif %} {%- endfor %} {%- if add_generation_prompt %} diff --git a/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py new file mode 100644 index 0000000000..92ba1376e2 --- /dev/null +++ b/tests/entrypoints/openai/tool_parsers/test_llama4_pythonic_tool_parser.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import MagicMock + +import pytest + +from tests.entrypoints.openai.tool_parsers.utils import ( + run_tool_extraction, run_tool_extraction_streaming) +from vllm.entrypoints.openai.protocol import FunctionCall +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager + +# Test cases similar to pythonic parser but with Llama4 specific format +SIMPLE_FUNCTION_OUTPUT = "[get_weather(city='LA', metric='C')]" +SIMPLE_FUNCTION_CALL = FunctionCall( + name="get_weather", + arguments='{"city": "LA", "metric": "C"}', +) +MORE_TYPES_FUNCTION_OUTPUT = ("[register_user(name='Doe', " + "age=9, " + "address={'city': 'LA', 'state': 'CA'}, " + "role=None, " + "passed_test=True, " + "aliases=['John', 'Johnny'])]") +MORE_TYPES_FUNCTION_CALL = FunctionCall( + name="register_user", + arguments='{"name": "Doe", ' + '"age": 9, ' + '"address": {"city": "LA", "state": "CA"}, ' + '"role": null, ' + '"passed_test": true, ' + '"aliases": ["John", "Johnny"]}', +) +PARAMETERLESS_FUNCTION_OUTPUT = "[get_weather()]" +PARAMETERLESS_FUNCTION_CALL = FunctionCall( + name="get_weather", + arguments='{}', +) +EMPTY_DICT_FUNCTION_OUTPUT = "[do_something_cool(additional_data={})]" +EMPTY_DICT_FUNCTION_CALL = FunctionCall( + name="do_something_cool", + arguments='{"additional_data": {}}', +) +EMPTY_LIST_FUNCTION_OUTPUT = "[do_something_cool(steps=[])]" +EMPTY_LIST_FUNCTION_CALL = FunctionCall( + name="do_something_cool", + arguments='{"steps": []}', +) +ESCAPED_STRING_FUNCTION_OUTPUT = ( + r"[get_weather(city='Martha\'s Vineyard', metric='\"cool units\"')]") +ESCAPED_STRING_FUNCTION_CALL = FunctionCall( + name="get_weather", + arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}', +) +PYTHON_TAG_FUNCTION_OUTPUT = ( + "<|python_start|>[get_weather(city='LA', metric='C')]<|python_end|>") + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_no_tool_call(streaming: bool): + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser( + "llama4_pythonic")(mock_tokenizer) + model_output = "How can I help you today?" + + content, tool_calls = run_tool_extraction(tool_parser, + model_output, + streaming=streaming) + + assert content == model_output + assert len(tool_calls) == 0 + + +test_str = "<|python_start|>" +test_str += "[get_weather(city='LA', metric='C')," +test_str += "register_user(name='Doe', age=9)]" +TEST_CASES = [ + pytest.param(True, + ESCAPED_STRING_FUNCTION_OUTPUT, + [ESCAPED_STRING_FUNCTION_CALL], + id="simple_streaming"), + pytest.param(False, + SIMPLE_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], + id="simple_nonstreaming"), + pytest.param(True, + MORE_TYPES_FUNCTION_OUTPUT, [MORE_TYPES_FUNCTION_CALL], + id="more_types_streaming"), + pytest.param(False, + MORE_TYPES_FUNCTION_OUTPUT, [MORE_TYPES_FUNCTION_CALL], + id="more_types_nonstreaming"), + pytest.param(True, + PARAMETERLESS_FUNCTION_OUTPUT, [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_streaming"), + pytest.param(False, + PARAMETERLESS_FUNCTION_OUTPUT, [PARAMETERLESS_FUNCTION_CALL], + id="parameterless_nonstreaming"), + pytest.param(True, + EMPTY_DICT_FUNCTION_OUTPUT, [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_streaming"), + pytest.param(False, + EMPTY_DICT_FUNCTION_OUTPUT, [EMPTY_DICT_FUNCTION_CALL], + id="empty_dict_nonstreaming"), + pytest.param(True, + EMPTY_LIST_FUNCTION_OUTPUT, [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_streaming"), + pytest.param(False, + EMPTY_LIST_FUNCTION_OUTPUT, [EMPTY_LIST_FUNCTION_CALL], + id="empty_list_nonstreaming"), + pytest.param(True, + ESCAPED_STRING_FUNCTION_OUTPUT, + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_streaming"), + pytest.param(False, + ESCAPED_STRING_FUNCTION_OUTPUT, + [ESCAPED_STRING_FUNCTION_CALL], + id="escaped_string_nonstreaming"), + pytest.param( + True, + "[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]", + [ + SIMPLE_FUNCTION_CALL, + FunctionCall(name="register_user", + arguments='{"name": "Doe", "age": 9}') + ], + id="parallel_calls_streaming"), + pytest.param( + False, + "[get_weather(city='LA',metric='C'),register_user(name='Doe',age=9)]", + [ + SIMPLE_FUNCTION_CALL, + FunctionCall(name="register_user", + arguments='{"name": "Doe", "age": 9}') + ], + id="parallel_calls_nonstreaming"), + pytest.param(True, + PYTHON_TAG_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], + id="python_tag_streaming"), + pytest.param(False, + PYTHON_TAG_FUNCTION_OUTPUT, [SIMPLE_FUNCTION_CALL], + id="python_tag_nonstreaming"), + pytest.param(True, + test_str, [ + SIMPLE_FUNCTION_CALL, + FunctionCall(name="register_user", + arguments='{"name": "Doe", "age": 9}') + ], + id="parallel_calls_streaming"), + pytest.param(False, + "<|python_start|>[get_weather(city='LA', metric='C'), " + + "register_user(name='Doe', age=9)]", [ + SIMPLE_FUNCTION_CALL, + FunctionCall(name="register_user", + arguments='{"name": "Doe", "age": 9}') + ], + id="parallel_calls_nonstreaming"), +] + + +@pytest.mark.parametrize("streaming, model_output, expected_tool_calls", + TEST_CASES) +def test_tool_call(streaming: bool, model_output: str, + expected_tool_calls: list[FunctionCall]): + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser( + "llama4_pythonic")(mock_tokenizer) + + content, tool_calls = run_tool_extraction(tool_parser, + model_output, + streaming=streaming) + + assert len(tool_calls) == len(expected_tool_calls) + for actual, expected in zip(tool_calls, expected_tool_calls): + assert actual.type == "function" + assert actual.function == expected + + +def test_streaming_tool_call_with_large_steps(): + mock_tokenizer = MagicMock() + tool_parser: ToolParser = ToolParserManager.get_tool_parser( + "llama4_pythonic")(mock_tokenizer) + model_output_deltas = [ + "<|python_start|>[get_weather(city='LA', metric='C'), " + "get_weather(), " + "do_something_cool(steps=[])]<|python_end|>", + ] + + reconstructor = run_tool_extraction_streaming( + tool_parser, model_output_deltas, assert_one_tool_per_delta=False) + + assert reconstructor.other_content == "" + assert len(reconstructor.tool_calls) == 3 + assert reconstructor.tool_calls[0].function == SIMPLE_FUNCTION_CALL + assert reconstructor.tool_calls[1].function == PARAMETERLESS_FUNCTION_CALL + assert reconstructor.tool_calls[2].function == EMPTY_LIST_FUNCTION_CALL diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index c14eaf71e9..efa6455c41 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -88,7 +88,7 @@ CONFIGS: dict[str, ServerConfig] = { "meta-llama/Llama-4-Scout-17B-16E-Instruct", "arguments": [ "--enforce-eager", "--no-enable-prefix-caching", - "--tool-call-parser", "pythonic", "--chat-template", + "--tool-call-parser", "llama4_pythonic", "--chat-template", str(VLLM_PATH / "examples/tool_chat_template_llama4_pythonic.jinja"), "-tp", "4" diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index f7c7112b12..054c0b006b 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -7,6 +7,7 @@ from .granite_tool_parser import GraniteToolParser from .hermes_tool_parser import Hermes2ProToolParser from .internlm2_tool_parser import Internlm2ToolParser from .jamba_tool_parser import JambaToolParser +from .llama4_pythonic_tool_parser import Llama4PythonicToolParser from .llama_tool_parser import Llama3JsonToolParser from .mistral_tool_parser import MistralToolParser from .phi4mini_tool_parser import Phi4MiniJsonToolParser @@ -16,5 +17,6 @@ __all__ = [ "ToolParser", "ToolParserManager", "Granite20bFCToolParser", "GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser", - "PythonicToolParser", "Phi4MiniJsonToolParser", "DeepSeekV3ToolParser" + "Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser", + "DeepSeekV3ToolParser" ] diff --git a/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py new file mode 100644 index 0000000000..f483ac4eee --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py @@ -0,0 +1,303 @@ +# SPDX-License-Identifier: Apache-2.0 + +import ast +import json +import re +from collections.abc import Sequence +from typing import Any, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class _UnexpectedAstError(Exception): + pass + + +@ToolParserManager.register_module("llama4_pythonic") +class Llama4PythonicToolParser(ToolParser): + """ + Toolcall parser for Llama4 that produce tool calls in a pythonic style + Use --enable-auto-tool-choice --tool-call-parser llama4_pythonic + """ + # TODO(mdepinet): Possible future improvements: + # 1. Support text + tools separated by either <|python_tag|> or \n\n + # 2. Support tools outside of a list (or separated by a semicolon). + # This depends on item 1 for consistent streaming. + # Neither of these are necessary for e.g. ToolACE, but both would help make + # Llama3.2 models more reliable. + + TOOL_CALL_REGEX = re.compile( + r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]", + re.DOTALL) + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + + # Rename for readability. This is NOT a tool id. + @property + def current_tool_index(self) -> int: + return self.current_tool_id + + @current_tool_index.setter + def current_tool_index(self, value: int) -> None: + self.current_tool_id = value + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. + """ + + # remove <|python_start|> and <|python_end|> + # as Llama 4 model sometime will output those tokens + if model_output.startswith("<|python_start|>"): + model_output = model_output[len("<|python_start|>"):] + model_output = model_output.replace("<|python_end|>", "") + if not (self.TOOL_CALL_REGEX.match(model_output)): + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + module = ast.parse(model_output) + parsed = getattr(module.body[0], "value", None) + if isinstance(parsed, ast.List) and all( + isinstance(e, ast.Call) for e in parsed.elts): + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=[ + _handle_single_tool(e) # type: ignore + for e in parsed.elts + ], + content=None) + else: + raise _UnexpectedAstError( + "Tool output must be a list of function calls") + except Exception: + logger.exception("Error in extracting tool call from response.") + # Treat as regular text + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + if not current_text.startswith("[") and not current_text.startswith( + "<|python_start|>"): + return DeltaMessage(content=delta_text) + + try: + # remove <|python_start|> and <|python_end|> + if current_text.startswith("<|python_start|>"): + current_text = current_text[len("<|python_start|>"):] + if current_text.endswith("<|python_end|>"): + current_text = current_text[:current_text. + rfind("<|python_end|>")] + valid_and_added_text = _make_valid_python(current_text) + if valid_and_added_text is None: + return None + valid_text, added_text = valid_and_added_text + + module = ast.parse(valid_text) + parsed = getattr(module.body[0], "value", None) + if not isinstance(parsed, ast.List) or not all( + isinstance(e, ast.Call) for e in parsed.elts): + raise _UnexpectedAstError( + "Tool output must be a list of function calls") + tool_calls = [ + _handle_single_tool(e) # type: ignore + for e in parsed.elts + ] + + tool_deltas = [] + for index, new_call in enumerate(tool_calls): + if index < self.current_tool_index: + continue + + self.current_tool_index = index + if len(self.streamed_args_for_tool) == index: + self.streamed_args_for_tool.append("") + + new_call_complete = index < len( + tool_calls) - 1 or ")]" not in added_text + if new_call_complete: + self.current_tool_index += 1 + + withheld_suffix = (added_text[:-2] + if not new_call_complete else "") + if not new_call_complete and added_text[-2] == ")": + # Function call is incomplete. Withhold the closing bracket. + withheld_suffix = withheld_suffix + "}" + # Strings get single quotes in the model-produced string. + # JSON requires double quotes. + withheld_suffix = withheld_suffix.replace("'", '"') + delta = _compute_tool_delta(self.streamed_args_for_tool[index], + new_call, index, withheld_suffix) + + if delta is not None: + tool_deltas.append(delta) + if (delta.function is not None + and delta.function.arguments is not None): + self.streamed_args_for_tool[ + index] += delta.function.arguments + + # HACK: serving_chat.py inspects the internal state of tool parsers + # when determining it's final streaming delta, automatically + # adding autocompleted JSON. + # These two lines avoid that nonsense while ensuring finish_reason + # is set to tool_calls when at least one tool is called. + if tool_deltas and not self.prev_tool_call_arr: + self.prev_tool_call_arr = [{"arguments": {}}] + + if tool_deltas: + return DeltaMessage(tool_calls=tool_deltas) + elif not added_text and self.current_tool_id > 0: + # Return an empty DeltaMessage once the tool calls are all done + # so that finish_reason gets set. + return DeltaMessage(content='') + else: + return None + except Exception: + logger.exception("Error trying to handle streaming tool call.") + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None + + +def _get_parameter_value(val: ast.expr) -> Any: + if isinstance(val, ast.Constant): + return val.value + elif isinstance(val, ast.Dict): + if not all(isinstance(k, ast.Constant) for k in val.keys): + raise _UnexpectedAstError( + "Dict tool call arguments must have literal keys") + return { + k.value: _get_parameter_value(v) # type: ignore + for k, v in zip(val.keys, val.values) + } + elif isinstance(val, ast.List): + return [_get_parameter_value(v) for v in val.elts] + else: + raise _UnexpectedAstError("Tool call arguments must be literals") + + +def _handle_single_tool(call: ast.Call) -> ToolCall: + if not isinstance(call.func, ast.Name): + raise _UnexpectedAstError("Invalid tool call name") + function_name = call.func.id + arguments = {} + for keyword in call.keywords: + arguments[keyword.arg] = _get_parameter_value(keyword.value) + return ToolCall(type="function", + function=FunctionCall(name=function_name, + arguments=json.dumps(arguments))) + + +def _make_valid_python(text: str) -> Union[tuple[str, str], None]: + bracket_stack = [] + for index, char in enumerate(text): + if char in {"[", "(", "{"}: + bracket_stack.append(char) + elif char == "]": + if not bracket_stack or bracket_stack.pop() != "[": + raise _UnexpectedAstError("Mismatched square brackets") + elif char == ")": + if not bracket_stack or bracket_stack.pop() != "(": + raise _UnexpectedAstError("Mismatched parentheses") + elif char == "}": + if not bracket_stack or bracket_stack.pop() != "{": + raise _UnexpectedAstError("Mismatched curly braces") + elif char in {"'", '"'}: + if bracket_stack and bracket_stack[-1] == char: + if index > 0 and text[index - 1] == "\\": + # Treat an escaped quote as a regular character + pass + else: + bracket_stack.pop() + elif bracket_stack and bracket_stack[-1] in {"'", '"'}: + # Double quote within a single quote string or vice versa. + pass + else: + bracket_stack.append(char) + + text = text.rstrip() + if text.endswith("=") or text.endswith(":"): + # Since we have no type information for this property/parameter value, + # we can't fill in a valid value. + return None + if bracket_stack and bracket_stack[-1] == "{": + trailing_dict_text = text[:text.rfind("{")] + num_keys = trailing_dict_text.count(":") + num_values = trailing_dict_text.count(",") + if num_keys <= num_values: + return None # Incomplete property name within parameter value + if bracket_stack and bracket_stack[-1] == "(": + trailing_params_text = text[:text.rfind("(")] + num_full_param_names = trailing_params_text.count("=") + num_full_param_values = trailing_params_text.count(",") + if num_full_param_names <= num_full_param_values: + return None # Incomplete parameter name + if text.endswith(","): + text = text[:-1] + if bracket_stack and bracket_stack[-1] == "[" and not text.endswith( + "[") and not text.endswith(")"): + return None # Incomplete function name + + added_text = "" + for char in reversed(bracket_stack): + if char == "[": + added_text += "]" + elif char == "(": + added_text += ")" + elif char == "{": + added_text += "}" + elif char == "'": + added_text += "'" + elif char == '"': + added_text += '"' + + return text + added_text, added_text + + +def _compute_tool_delta(previously_sent_args: str, new_call: ToolCall, + index: int, + withheld_suffix: str) -> Union[DeltaToolCall, None]: + new_call_args = new_call.function.arguments + if withheld_suffix: + assert new_call_args.endswith(withheld_suffix) + new_call_args = new_call_args[:-len(withheld_suffix)] + if not previously_sent_args: + return DeltaToolCall(id=new_call.id, + type="function", + index=index, + function=DeltaFunctionCall( + name=new_call.function.name, + arguments=new_call_args, + )) + + arg_diff = new_call_args[len(previously_sent_args):] + return DeltaToolCall( + id=None, index=index, function=DeltaFunctionCall( + arguments=arg_diff)) if arg_diff else None