[Bugfix]fix Qwen3 xml tool parser (#26345)

Signed-off-by: Zhikaiiii <1658973216@qq.com>
This commit is contained in:
Zhikaiiii
2025-10-15 09:50:30 +08:00
committed by GitHub
parent 07ca70af8d
commit 9354660036
2 changed files with 179 additions and 26 deletions

View File

@ -40,7 +40,7 @@ def qwen3_xml_tool_parser(qwen3_tokenizer):
return Qwen3XMLToolParser(qwen3_tokenizer)
@pytest.fixture(params=["original", "xml"])
@pytest.fixture(params=["xml"])
def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser, request):
"""Parameterized fixture that provides both parser types for testing"""
if request.param == "original":
@ -664,6 +664,9 @@ def test_extract_tool_calls_streaming(
# Verify we got all expected tool calls
assert len(tool_states) == len(expected_tool_calls)
assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == len(
expected_tool_calls
)
# Verify each tool call
for idx, expected_tool in enumerate(expected_tool_calls):
@ -780,9 +783,10 @@ fahrenheit
# Verify content was streamed
assert "Let me check the weather for you:" in other_content
# Verify we got the tool call
assert len(tool_states) == 1
assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == 1
state = tool_states[0]
assert state["id"] is not None
assert state["type"] == "function"
@ -892,3 +896,83 @@ def test_extract_tool_calls_complex_type_with_single_quote(
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
assert args["obj_param"] == {"key": "value"}
def test_extract_tool_calls_streaming_missing_opening_tag(
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools
):
"""Test streaming with missing opening <tool_call> tag
This tests that the streaming parser correctly handles
tool calls that start directly with <function=...>
"""
model_output = """I'll check the weather for you.
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>"""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
other_content = ""
tool_states = {}
for delta_message in stream_delta_message_generator(
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request
):
if delta_message.content:
other_content += delta_message.content
if delta_message.tool_calls:
for tool_call in delta_message.tool_calls:
idx = tool_call.index
if idx not in tool_states:
tool_states[idx] = {
"id": None,
"name": None,
"arguments": "",
"type": None,
}
if tool_call.id:
tool_states[idx]["id"] = tool_call.id
if tool_call.type:
assert tool_call.type == "function"
tool_states[idx]["type"] = tool_call.type
if tool_call.function:
if tool_call.function.name:
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
tool_states[idx]["arguments"] += tool_call.function.arguments
# Verify content was streamed
assert "I'll check the weather for you." in other_content
# Verify we got the tool call
assert len(tool_states) == 1
assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == 1
state = tool_states[0]
assert state["id"] is not None
assert state["type"] == "function"
assert state["name"] == "get_current_weather"
# Verify arguments were parsed correctly despite missing opening tag
assert state["arguments"] is not None
args = json.loads(state["arguments"])
assert args["city"] == "Dallas"
assert args["state"] == "TX"
assert args["unit"] == "fahrenheit"

View File

@ -2,13 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import json
import uuid
from collections.abc import Sequence
from typing import Any
from xml.parsers.expat import ParserCreate
import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionToolsParam,
@ -375,14 +375,21 @@ class StreamingXMLToolCallParser:
return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1
else:
# If currently not parsing tool calls (entering a tool_call),
# check if starts with <tool_call>
# check if starts with <tool_call> or <function=
if self.current_call_id is None:
# Check if might be start of <tool_call>
if buffer == "<tool_call>"[: len(buffer)]:
# Might be start of <tool_call>, wait for more data
return None, start_pos
elif (
buffer.startswith("<function=")
or buffer == "<function="[: len(buffer)]
):
# Might be start of <function=, wait for more data
# to get the complete function tag
return None, start_pos
else:
# Not start of <tool_call>, treat as text
# Not start of <tool_call> or <function=, treat as text
return buffer, start_pos + len(buffer)
else:
# When parsing tool calls,
@ -621,7 +628,7 @@ class StreamingXMLToolCallParser:
self._auto_close_open_parameter_if_needed("tool_call")
self.parameters = {}
self.current_call_id = self._get_next_call_id()
self.current_call_id = make_tool_call_id()
self.current_param_is_first = True
self.tool_call_index += 1
elif name.startswith("function") or (name == "function"):
@ -957,10 +964,6 @@ class StreamingXMLToolCallParser:
"""Set tool configuration information"""
self.tools = tools
def _get_next_call_id(self):
"""Generate unique call ID"""
return f"call_{uuid.uuid4().hex[:24]}"
def _extract_function_name(self, name: str, attrs: dict[str, str]) -> str | None:
"""Extract function name from various formats"""
if attrs and "name" in attrs:
@ -1168,6 +1171,10 @@ class Qwen3XMLToolParser(ToolParser):
super().__init__(tokenizer)
self.parser = StreamingXMLToolCallParser()
# Add missing attributes for compatibility with serving_chat.py
self.prev_tool_call_arr: list[dict] = []
self.streamed_args_for_tool: list[str] = []
logger.info(
"vLLM Successfully import tool parser %s !", self.__class__.__name__
)
@ -1178,6 +1185,9 @@ class Qwen3XMLToolParser(ToolParser):
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
self.parser.reset_streaming_state()
# Reset tool call tracking arrays for new extraction
self.prev_tool_call_arr = []
self.streamed_args_for_tool = []
if request:
self.parser.set_tools(request.tools)
result = self.parser.parse_single_streaming_chunks(model_output)
@ -1201,6 +1211,34 @@ class Qwen3XMLToolParser(ToolParser):
),
)
)
# Update tool call tracking arrays for compatibility
tool_index = (
tool_call.index
if tool_call.index is not None
else len(self.prev_tool_call_arr) - 1
)
# Ensure we have enough entries in our tracking arrays
while len(self.prev_tool_call_arr) <= tool_index:
self.prev_tool_call_arr.append({"name": "", "arguments": ""})
while len(self.streamed_args_for_tool) <= tool_index:
self.streamed_args_for_tool.append("")
# Update tool call information
self.prev_tool_call_arr[tool_index]["name"] = (
tool_call.function.name
)
self.prev_tool_call_arr[tool_index]["arguments"] = (
tool_call.function.arguments
)
# Update streamed arguments
if tool_call.function.arguments:
self.streamed_args_for_tool[tool_index] = (
tool_call.function.arguments
)
return ExtractedToolCallInformation(
tool_calls=tool_calls,
tools_called=len(tool_calls) > 0,
@ -1219,6 +1257,9 @@ class Qwen3XMLToolParser(ToolParser):
) -> DeltaMessage | None:
if not previous_text:
self.parser.reset_streaming_state()
# Reset tool call tracking arrays for new streaming session
self.prev_tool_call_arr = []
self.streamed_args_for_tool = []
if request:
self.parser.set_tools(request.tools)
@ -1230,20 +1271,48 @@ class Qwen3XMLToolParser(ToolParser):
open_calls = current_text.count(
self.parser.tool_call_start_token
) - current_text.count(self.parser.tool_call_end_token)
if open_calls == 0 and self.parser.tool_call_index > 0:
# If current_call_id is None, use last_completed_call_id
call_id = (
self.parser.current_call_id or self.parser.last_completed_call_id
)
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.parser.tool_call_index - 1,
id=call_id,
function=DeltaFunctionCall(arguments=""),
type="function",
)
]
if (
open_calls == 0
and self.parser.tool_call_index > 0
or not self.parser.tool_call_index
and current_text
):
return DeltaMessage(content="")
return None
# Parse the delta text and get the result
result = self.parser.parse_single_streaming_chunks(delta_text)
# Update tool call tracking arrays based on incremental parsing results
if result and result.tool_calls:
for tool_call in result.tool_calls:
if tool_call.function:
tool_index = (
tool_call.index
if tool_call.index is not None
else len(self.prev_tool_call_arr) - 1
)
return self.parser.parse_single_streaming_chunks(delta_text)
# Ensure we have enough entries in our tracking arrays
while len(self.prev_tool_call_arr) <= tool_index:
self.prev_tool_call_arr.append({"name": "", "arguments": ""})
while len(self.streamed_args_for_tool) <= tool_index:
self.streamed_args_for_tool.append("")
# Update tool name if provided
if tool_call.function.name:
self.prev_tool_call_arr[tool_index]["name"] = (
tool_call.function.name
)
# Update arguments incrementally
if tool_call.function.arguments is not None:
# Concatenate the incremental arguments
# to the existing streamed arguments
self.prev_tool_call_arr[tool_index]["arguments"] += (
tool_call.function.arguments
)
self.streamed_args_for_tool[tool_index] += (
tool_call.function.arguments
)
return result