[Bugfix]fix Qwen3 xml tool parser (#26345)

Signed-off-by: Zhikaiiii <1658973216@qq.com>
This commit is contained in:
Zhikaiiii
2025-10-15 09:50:30 +08:00
committed by GitHub
parent 07ca70af8d
commit 9354660036
2 changed files with 179 additions and 26 deletions

View File

@ -40,7 +40,7 @@ def qwen3_xml_tool_parser(qwen3_tokenizer):
return Qwen3XMLToolParser(qwen3_tokenizer) return Qwen3XMLToolParser(qwen3_tokenizer)
@pytest.fixture(params=["original", "xml"]) @pytest.fixture(params=["xml"])
def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser, request): def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser, request):
"""Parameterized fixture that provides both parser types for testing""" """Parameterized fixture that provides both parser types for testing"""
if request.param == "original": if request.param == "original":
@ -664,6 +664,9 @@ def test_extract_tool_calls_streaming(
# Verify we got all expected tool calls # Verify we got all expected tool calls
assert len(tool_states) == len(expected_tool_calls) assert len(tool_states) == len(expected_tool_calls)
assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == len(
expected_tool_calls
)
# Verify each tool call # Verify each tool call
for idx, expected_tool in enumerate(expected_tool_calls): for idx, expected_tool in enumerate(expected_tool_calls):
@ -780,9 +783,10 @@ fahrenheit
# Verify content was streamed # Verify content was streamed
assert "Let me check the weather for you:" in other_content assert "Let me check the weather for you:" in other_content
# Verify we got the tool call # Verify we got the tool call
assert len(tool_states) == 1 assert len(tool_states) == 1
assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == 1
state = tool_states[0] state = tool_states[0]
assert state["id"] is not None assert state["id"] is not None
assert state["type"] == "function" assert state["type"] == "function"
@ -892,3 +896,83 @@ def test_extract_tool_calls_complex_type_with_single_quote(
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
assert args["obj_param"] == {"key": "value"} assert args["obj_param"] == {"key": "value"}
def test_extract_tool_calls_streaming_missing_opening_tag(
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools
):
"""Test streaming with missing opening <tool_call> tag
This tests that the streaming parser correctly handles
tool calls that start directly with <function=...>
"""
model_output = """I'll check the weather for you.
<function=get_current_weather>
<parameter=city>
Dallas
</parameter>
<parameter=state>
TX
</parameter>
<parameter=unit>
fahrenheit
</parameter>
</function>
</tool_call>"""
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
other_content = ""
tool_states = {}
for delta_message in stream_delta_message_generator(
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request
):
if delta_message.content:
other_content += delta_message.content
if delta_message.tool_calls:
for tool_call in delta_message.tool_calls:
idx = tool_call.index
if idx not in tool_states:
tool_states[idx] = {
"id": None,
"name": None,
"arguments": "",
"type": None,
}
if tool_call.id:
tool_states[idx]["id"] = tool_call.id
if tool_call.type:
assert tool_call.type == "function"
tool_states[idx]["type"] = tool_call.type
if tool_call.function:
if tool_call.function.name:
tool_states[idx]["name"] = tool_call.function.name
if tool_call.function.arguments is not None:
tool_states[idx]["arguments"] += tool_call.function.arguments
# Verify content was streamed
assert "I'll check the weather for you." in other_content
# Verify we got the tool call
assert len(tool_states) == 1
assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == 1
state = tool_states[0]
assert state["id"] is not None
assert state["type"] == "function"
assert state["name"] == "get_current_weather"
# Verify arguments were parsed correctly despite missing opening tag
assert state["arguments"] is not None
args = json.loads(state["arguments"])
assert args["city"] == "Dallas"
assert args["state"] == "TX"
assert args["unit"] == "fahrenheit"

View File

@ -2,13 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast import ast
import json import json
import uuid
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any from typing import Any
from xml.parsers.expat import ParserCreate from xml.parsers.expat import ParserCreate
import regex as re import regex as re
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import ( from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionToolsParam, ChatCompletionToolsParam,
@ -375,14 +375,21 @@ class StreamingXMLToolCallParser:
return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1 return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1
else: else:
# If currently not parsing tool calls (entering a tool_call), # If currently not parsing tool calls (entering a tool_call),
# check if starts with <tool_call> # check if starts with <tool_call> or <function=
if self.current_call_id is None: if self.current_call_id is None:
# Check if might be start of <tool_call> # Check if might be start of <tool_call>
if buffer == "<tool_call>"[: len(buffer)]: if buffer == "<tool_call>"[: len(buffer)]:
# Might be start of <tool_call>, wait for more data # Might be start of <tool_call>, wait for more data
return None, start_pos return None, start_pos
elif (
buffer.startswith("<function=")
or buffer == "<function="[: len(buffer)]
):
# Might be start of <function=, wait for more data
# to get the complete function tag
return None, start_pos
else: else:
# Not start of <tool_call>, treat as text # Not start of <tool_call> or <function=, treat as text
return buffer, start_pos + len(buffer) return buffer, start_pos + len(buffer)
else: else:
# When parsing tool calls, # When parsing tool calls,
@ -621,7 +628,7 @@ class StreamingXMLToolCallParser:
self._auto_close_open_parameter_if_needed("tool_call") self._auto_close_open_parameter_if_needed("tool_call")
self.parameters = {} self.parameters = {}
self.current_call_id = self._get_next_call_id() self.current_call_id = make_tool_call_id()
self.current_param_is_first = True self.current_param_is_first = True
self.tool_call_index += 1 self.tool_call_index += 1
elif name.startswith("function") or (name == "function"): elif name.startswith("function") or (name == "function"):
@ -957,10 +964,6 @@ class StreamingXMLToolCallParser:
"""Set tool configuration information""" """Set tool configuration information"""
self.tools = tools self.tools = tools
def _get_next_call_id(self):
"""Generate unique call ID"""
return f"call_{uuid.uuid4().hex[:24]}"
def _extract_function_name(self, name: str, attrs: dict[str, str]) -> str | None: def _extract_function_name(self, name: str, attrs: dict[str, str]) -> str | None:
"""Extract function name from various formats""" """Extract function name from various formats"""
if attrs and "name" in attrs: if attrs and "name" in attrs:
@ -1168,6 +1171,10 @@ class Qwen3XMLToolParser(ToolParser):
super().__init__(tokenizer) super().__init__(tokenizer)
self.parser = StreamingXMLToolCallParser() self.parser = StreamingXMLToolCallParser()
# Add missing attributes for compatibility with serving_chat.py
self.prev_tool_call_arr: list[dict] = []
self.streamed_args_for_tool: list[str] = []
logger.info( logger.info(
"vLLM Successfully import tool parser %s !", self.__class__.__name__ "vLLM Successfully import tool parser %s !", self.__class__.__name__
) )
@ -1178,6 +1185,9 @@ class Qwen3XMLToolParser(ToolParser):
request: ChatCompletionRequest, request: ChatCompletionRequest,
) -> ExtractedToolCallInformation: ) -> ExtractedToolCallInformation:
self.parser.reset_streaming_state() self.parser.reset_streaming_state()
# Reset tool call tracking arrays for new extraction
self.prev_tool_call_arr = []
self.streamed_args_for_tool = []
if request: if request:
self.parser.set_tools(request.tools) self.parser.set_tools(request.tools)
result = self.parser.parse_single_streaming_chunks(model_output) result = self.parser.parse_single_streaming_chunks(model_output)
@ -1201,6 +1211,34 @@ class Qwen3XMLToolParser(ToolParser):
), ),
) )
) )
# Update tool call tracking arrays for compatibility
tool_index = (
tool_call.index
if tool_call.index is not None
else len(self.prev_tool_call_arr) - 1
)
# Ensure we have enough entries in our tracking arrays
while len(self.prev_tool_call_arr) <= tool_index:
self.prev_tool_call_arr.append({"name": "", "arguments": ""})
while len(self.streamed_args_for_tool) <= tool_index:
self.streamed_args_for_tool.append("")
# Update tool call information
self.prev_tool_call_arr[tool_index]["name"] = (
tool_call.function.name
)
self.prev_tool_call_arr[tool_index]["arguments"] = (
tool_call.function.arguments
)
# Update streamed arguments
if tool_call.function.arguments:
self.streamed_args_for_tool[tool_index] = (
tool_call.function.arguments
)
return ExtractedToolCallInformation( return ExtractedToolCallInformation(
tool_calls=tool_calls, tool_calls=tool_calls,
tools_called=len(tool_calls) > 0, tools_called=len(tool_calls) > 0,
@ -1219,6 +1257,9 @@ class Qwen3XMLToolParser(ToolParser):
) -> DeltaMessage | None: ) -> DeltaMessage | None:
if not previous_text: if not previous_text:
self.parser.reset_streaming_state() self.parser.reset_streaming_state()
# Reset tool call tracking arrays for new streaming session
self.prev_tool_call_arr = []
self.streamed_args_for_tool = []
if request: if request:
self.parser.set_tools(request.tools) self.parser.set_tools(request.tools)
@ -1230,20 +1271,48 @@ class Qwen3XMLToolParser(ToolParser):
open_calls = current_text.count( open_calls = current_text.count(
self.parser.tool_call_start_token self.parser.tool_call_start_token
) - current_text.count(self.parser.tool_call_end_token) ) - current_text.count(self.parser.tool_call_end_token)
if open_calls == 0 and self.parser.tool_call_index > 0: if (
# If current_call_id is None, use last_completed_call_id open_calls == 0
call_id = ( and self.parser.tool_call_index > 0
self.parser.current_call_id or self.parser.last_completed_call_id or not self.parser.tool_call_index
) and current_text
return DeltaMessage( ):
tool_calls=[ return DeltaMessage(content="")
DeltaToolCall( return None
index=self.parser.tool_call_index - 1,
id=call_id, # Parse the delta text and get the result
function=DeltaFunctionCall(arguments=""), result = self.parser.parse_single_streaming_chunks(delta_text)
type="function",
) # Update tool call tracking arrays based on incremental parsing results
] if result and result.tool_calls:
for tool_call in result.tool_calls:
if tool_call.function:
tool_index = (
tool_call.index
if tool_call.index is not None
else len(self.prev_tool_call_arr) - 1
) )
return self.parser.parse_single_streaming_chunks(delta_text) # Ensure we have enough entries in our tracking arrays
while len(self.prev_tool_call_arr) <= tool_index:
self.prev_tool_call_arr.append({"name": "", "arguments": ""})
while len(self.streamed_args_for_tool) <= tool_index:
self.streamed_args_for_tool.append("")
# Update tool name if provided
if tool_call.function.name:
self.prev_tool_call_arr[tool_index]["name"] = (
tool_call.function.name
)
# Update arguments incrementally
if tool_call.function.arguments is not None:
# Concatenate the incremental arguments
# to the existing streamed arguments
self.prev_tool_call_arr[tool_index]["arguments"] += (
tool_call.function.arguments
)
self.streamed_args_for_tool[tool_index] += (
tool_call.function.arguments
)
return result