mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix]fix Qwen3 xml tool parser (#26345)
Signed-off-by: Zhikaiiii <1658973216@qq.com>
This commit is contained in:
@ -40,7 +40,7 @@ def qwen3_xml_tool_parser(qwen3_tokenizer):
|
||||
return Qwen3XMLToolParser(qwen3_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture(params=["original", "xml"])
|
||||
@pytest.fixture(params=["xml"])
|
||||
def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser, request):
|
||||
"""Parameterized fixture that provides both parser types for testing"""
|
||||
if request.param == "original":
|
||||
@ -664,6 +664,9 @@ def test_extract_tool_calls_streaming(
|
||||
|
||||
# Verify we got all expected tool calls
|
||||
assert len(tool_states) == len(expected_tool_calls)
|
||||
assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == len(
|
||||
expected_tool_calls
|
||||
)
|
||||
|
||||
# Verify each tool call
|
||||
for idx, expected_tool in enumerate(expected_tool_calls):
|
||||
@ -780,9 +783,10 @@ fahrenheit
|
||||
|
||||
# Verify content was streamed
|
||||
assert "Let me check the weather for you:" in other_content
|
||||
|
||||
# Verify we got the tool call
|
||||
assert len(tool_states) == 1
|
||||
assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == 1
|
||||
|
||||
state = tool_states[0]
|
||||
assert state["id"] is not None
|
||||
assert state["type"] == "function"
|
||||
@ -892,3 +896,83 @@ def test_extract_tool_calls_complex_type_with_single_quote(
|
||||
|
||||
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
|
||||
assert args["obj_param"] == {"key": "value"}
|
||||
|
||||
|
||||
def test_extract_tool_calls_streaming_missing_opening_tag(
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools
|
||||
):
|
||||
"""Test streaming with missing opening <tool_call> tag
|
||||
|
||||
This tests that the streaming parser correctly handles
|
||||
tool calls that start directly with <function=...>
|
||||
"""
|
||||
model_output = """I'll check the weather for you.
|
||||
|
||||
<function=get_current_weather>
|
||||
<parameter=city>
|
||||
Dallas
|
||||
</parameter>
|
||||
<parameter=state>
|
||||
TX
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>"""
|
||||
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||
|
||||
other_content = ""
|
||||
tool_states = {}
|
||||
|
||||
for delta_message in stream_delta_message_generator(
|
||||
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request
|
||||
):
|
||||
if delta_message.content:
|
||||
other_content += delta_message.content
|
||||
|
||||
if delta_message.tool_calls:
|
||||
for tool_call in delta_message.tool_calls:
|
||||
idx = tool_call.index
|
||||
|
||||
if idx not in tool_states:
|
||||
tool_states[idx] = {
|
||||
"id": None,
|
||||
"name": None,
|
||||
"arguments": "",
|
||||
"type": None,
|
||||
}
|
||||
|
||||
if tool_call.id:
|
||||
tool_states[idx]["id"] = tool_call.id
|
||||
|
||||
if tool_call.type:
|
||||
assert tool_call.type == "function"
|
||||
tool_states[idx]["type"] = tool_call.type
|
||||
|
||||
if tool_call.function:
|
||||
if tool_call.function.name:
|
||||
tool_states[idx]["name"] = tool_call.function.name
|
||||
|
||||
if tool_call.function.arguments is not None:
|
||||
tool_states[idx]["arguments"] += tool_call.function.arguments
|
||||
|
||||
# Verify content was streamed
|
||||
assert "I'll check the weather for you." in other_content
|
||||
|
||||
# Verify we got the tool call
|
||||
assert len(tool_states) == 1
|
||||
assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == 1
|
||||
|
||||
state = tool_states[0]
|
||||
assert state["id"] is not None
|
||||
assert state["type"] == "function"
|
||||
assert state["name"] == "get_current_weather"
|
||||
|
||||
# Verify arguments were parsed correctly despite missing opening tag
|
||||
assert state["arguments"] is not None
|
||||
args = json.loads(state["arguments"])
|
||||
assert args["city"] == "Dallas"
|
||||
assert args["state"] == "TX"
|
||||
assert args["unit"] == "fahrenheit"
|
||||
|
@ -2,13 +2,13 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ast
|
||||
import json
|
||||
import uuid
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from xml.parsers.expat import ParserCreate
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionToolsParam,
|
||||
@ -375,14 +375,21 @@ class StreamingXMLToolCallParser:
|
||||
return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1
|
||||
else:
|
||||
# If currently not parsing tool calls (entering a tool_call),
|
||||
# check if starts with <tool_call>
|
||||
# check if starts with <tool_call> or <function=
|
||||
if self.current_call_id is None:
|
||||
# Check if might be start of <tool_call>
|
||||
if buffer == "<tool_call>"[: len(buffer)]:
|
||||
# Might be start of <tool_call>, wait for more data
|
||||
return None, start_pos
|
||||
elif (
|
||||
buffer.startswith("<function=")
|
||||
or buffer == "<function="[: len(buffer)]
|
||||
):
|
||||
# Might be start of <function=, wait for more data
|
||||
# to get the complete function tag
|
||||
return None, start_pos
|
||||
else:
|
||||
# Not start of <tool_call>, treat as text
|
||||
# Not start of <tool_call> or <function=, treat as text
|
||||
return buffer, start_pos + len(buffer)
|
||||
else:
|
||||
# When parsing tool calls,
|
||||
@ -621,7 +628,7 @@ class StreamingXMLToolCallParser:
|
||||
self._auto_close_open_parameter_if_needed("tool_call")
|
||||
|
||||
self.parameters = {}
|
||||
self.current_call_id = self._get_next_call_id()
|
||||
self.current_call_id = make_tool_call_id()
|
||||
self.current_param_is_first = True
|
||||
self.tool_call_index += 1
|
||||
elif name.startswith("function") or (name == "function"):
|
||||
@ -957,10 +964,6 @@ class StreamingXMLToolCallParser:
|
||||
"""Set tool configuration information"""
|
||||
self.tools = tools
|
||||
|
||||
def _get_next_call_id(self):
|
||||
"""Generate unique call ID"""
|
||||
return f"call_{uuid.uuid4().hex[:24]}"
|
||||
|
||||
def _extract_function_name(self, name: str, attrs: dict[str, str]) -> str | None:
|
||||
"""Extract function name from various formats"""
|
||||
if attrs and "name" in attrs:
|
||||
@ -1168,6 +1171,10 @@ class Qwen3XMLToolParser(ToolParser):
|
||||
super().__init__(tokenizer)
|
||||
self.parser = StreamingXMLToolCallParser()
|
||||
|
||||
# Add missing attributes for compatibility with serving_chat.py
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
|
||||
logger.info(
|
||||
"vLLM Successfully import tool parser %s !", self.__class__.__name__
|
||||
)
|
||||
@ -1178,6 +1185,9 @@ class Qwen3XMLToolParser(ToolParser):
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
self.parser.reset_streaming_state()
|
||||
# Reset tool call tracking arrays for new extraction
|
||||
self.prev_tool_call_arr = []
|
||||
self.streamed_args_for_tool = []
|
||||
if request:
|
||||
self.parser.set_tools(request.tools)
|
||||
result = self.parser.parse_single_streaming_chunks(model_output)
|
||||
@ -1201,6 +1211,34 @@ class Qwen3XMLToolParser(ToolParser):
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Update tool call tracking arrays for compatibility
|
||||
tool_index = (
|
||||
tool_call.index
|
||||
if tool_call.index is not None
|
||||
else len(self.prev_tool_call_arr) - 1
|
||||
)
|
||||
|
||||
# Ensure we have enough entries in our tracking arrays
|
||||
while len(self.prev_tool_call_arr) <= tool_index:
|
||||
self.prev_tool_call_arr.append({"name": "", "arguments": ""})
|
||||
while len(self.streamed_args_for_tool) <= tool_index:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
# Update tool call information
|
||||
self.prev_tool_call_arr[tool_index]["name"] = (
|
||||
tool_call.function.name
|
||||
)
|
||||
self.prev_tool_call_arr[tool_index]["arguments"] = (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
# Update streamed arguments
|
||||
if tool_call.function.arguments:
|
||||
self.streamed_args_for_tool[tool_index] = (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tool_calls=tool_calls,
|
||||
tools_called=len(tool_calls) > 0,
|
||||
@ -1219,6 +1257,9 @@ class Qwen3XMLToolParser(ToolParser):
|
||||
) -> DeltaMessage | None:
|
||||
if not previous_text:
|
||||
self.parser.reset_streaming_state()
|
||||
# Reset tool call tracking arrays for new streaming session
|
||||
self.prev_tool_call_arr = []
|
||||
self.streamed_args_for_tool = []
|
||||
if request:
|
||||
self.parser.set_tools(request.tools)
|
||||
|
||||
@ -1230,20 +1271,48 @@ class Qwen3XMLToolParser(ToolParser):
|
||||
open_calls = current_text.count(
|
||||
self.parser.tool_call_start_token
|
||||
) - current_text.count(self.parser.tool_call_end_token)
|
||||
if open_calls == 0 and self.parser.tool_call_index > 0:
|
||||
# If current_call_id is None, use last_completed_call_id
|
||||
call_id = (
|
||||
self.parser.current_call_id or self.parser.last_completed_call_id
|
||||
)
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.parser.tool_call_index - 1,
|
||||
id=call_id,
|
||||
function=DeltaFunctionCall(arguments=""),
|
||||
type="function",
|
||||
)
|
||||
]
|
||||
)
|
||||
if (
|
||||
open_calls == 0
|
||||
and self.parser.tool_call_index > 0
|
||||
or not self.parser.tool_call_index
|
||||
and current_text
|
||||
):
|
||||
return DeltaMessage(content="")
|
||||
return None
|
||||
|
||||
return self.parser.parse_single_streaming_chunks(delta_text)
|
||||
# Parse the delta text and get the result
|
||||
result = self.parser.parse_single_streaming_chunks(delta_text)
|
||||
|
||||
# Update tool call tracking arrays based on incremental parsing results
|
||||
if result and result.tool_calls:
|
||||
for tool_call in result.tool_calls:
|
||||
if tool_call.function:
|
||||
tool_index = (
|
||||
tool_call.index
|
||||
if tool_call.index is not None
|
||||
else len(self.prev_tool_call_arr) - 1
|
||||
)
|
||||
|
||||
# Ensure we have enough entries in our tracking arrays
|
||||
while len(self.prev_tool_call_arr) <= tool_index:
|
||||
self.prev_tool_call_arr.append({"name": "", "arguments": ""})
|
||||
while len(self.streamed_args_for_tool) <= tool_index:
|
||||
self.streamed_args_for_tool.append("")
|
||||
|
||||
# Update tool name if provided
|
||||
if tool_call.function.name:
|
||||
self.prev_tool_call_arr[tool_index]["name"] = (
|
||||
tool_call.function.name
|
||||
)
|
||||
|
||||
# Update arguments incrementally
|
||||
if tool_call.function.arguments is not None:
|
||||
# Concatenate the incremental arguments
|
||||
# to the existing streamed arguments
|
||||
self.prev_tool_call_arr[tool_index]["arguments"] += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
self.streamed_args_for_tool[tool_index] += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
return result
|
||||
|
Reference in New Issue
Block a user