mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Bugfix]fix Qwen3 xml tool parser (#26345)
Signed-off-by: Zhikaiiii <1658973216@qq.com>
This commit is contained in:
@ -40,7 +40,7 @@ def qwen3_xml_tool_parser(qwen3_tokenizer):
|
|||||||
return Qwen3XMLToolParser(qwen3_tokenizer)
|
return Qwen3XMLToolParser(qwen3_tokenizer)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=["original", "xml"])
|
@pytest.fixture(params=["xml"])
|
||||||
def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser, request):
|
def qwen3_tool_parser_parametrized(qwen3_tool_parser, qwen3_xml_tool_parser, request):
|
||||||
"""Parameterized fixture that provides both parser types for testing"""
|
"""Parameterized fixture that provides both parser types for testing"""
|
||||||
if request.param == "original":
|
if request.param == "original":
|
||||||
@ -664,6 +664,9 @@ def test_extract_tool_calls_streaming(
|
|||||||
|
|
||||||
# Verify we got all expected tool calls
|
# Verify we got all expected tool calls
|
||||||
assert len(tool_states) == len(expected_tool_calls)
|
assert len(tool_states) == len(expected_tool_calls)
|
||||||
|
assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == len(
|
||||||
|
expected_tool_calls
|
||||||
|
)
|
||||||
|
|
||||||
# Verify each tool call
|
# Verify each tool call
|
||||||
for idx, expected_tool in enumerate(expected_tool_calls):
|
for idx, expected_tool in enumerate(expected_tool_calls):
|
||||||
@ -780,9 +783,10 @@ fahrenheit
|
|||||||
|
|
||||||
# Verify content was streamed
|
# Verify content was streamed
|
||||||
assert "Let me check the weather for you:" in other_content
|
assert "Let me check the weather for you:" in other_content
|
||||||
|
|
||||||
# Verify we got the tool call
|
# Verify we got the tool call
|
||||||
assert len(tool_states) == 1
|
assert len(tool_states) == 1
|
||||||
|
assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == 1
|
||||||
|
|
||||||
state = tool_states[0]
|
state = tool_states[0]
|
||||||
assert state["id"] is not None
|
assert state["id"] is not None
|
||||||
assert state["type"] == "function"
|
assert state["type"] == "function"
|
||||||
@ -892,3 +896,83 @@ def test_extract_tool_calls_complex_type_with_single_quote(
|
|||||||
|
|
||||||
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
|
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
|
||||||
assert args["obj_param"] == {"key": "value"}
|
assert args["obj_param"] == {"key": "value"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_tool_calls_streaming_missing_opening_tag(
|
||||||
|
qwen3_tool_parser_parametrized, qwen3_tokenizer, sample_tools
|
||||||
|
):
|
||||||
|
"""Test streaming with missing opening <tool_call> tag
|
||||||
|
|
||||||
|
This tests that the streaming parser correctly handles
|
||||||
|
tool calls that start directly with <function=...>
|
||||||
|
"""
|
||||||
|
model_output = """I'll check the weather for you.
|
||||||
|
|
||||||
|
<function=get_current_weather>
|
||||||
|
<parameter=city>
|
||||||
|
Dallas
|
||||||
|
</parameter>
|
||||||
|
<parameter=state>
|
||||||
|
TX
|
||||||
|
</parameter>
|
||||||
|
<parameter=unit>
|
||||||
|
fahrenheit
|
||||||
|
</parameter>
|
||||||
|
</function>
|
||||||
|
</tool_call>"""
|
||||||
|
|
||||||
|
request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools)
|
||||||
|
|
||||||
|
other_content = ""
|
||||||
|
tool_states = {}
|
||||||
|
|
||||||
|
for delta_message in stream_delta_message_generator(
|
||||||
|
qwen3_tool_parser_parametrized, qwen3_tokenizer, model_output, request
|
||||||
|
):
|
||||||
|
if delta_message.content:
|
||||||
|
other_content += delta_message.content
|
||||||
|
|
||||||
|
if delta_message.tool_calls:
|
||||||
|
for tool_call in delta_message.tool_calls:
|
||||||
|
idx = tool_call.index
|
||||||
|
|
||||||
|
if idx not in tool_states:
|
||||||
|
tool_states[idx] = {
|
||||||
|
"id": None,
|
||||||
|
"name": None,
|
||||||
|
"arguments": "",
|
||||||
|
"type": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tool_call.id:
|
||||||
|
tool_states[idx]["id"] = tool_call.id
|
||||||
|
|
||||||
|
if tool_call.type:
|
||||||
|
assert tool_call.type == "function"
|
||||||
|
tool_states[idx]["type"] = tool_call.type
|
||||||
|
|
||||||
|
if tool_call.function:
|
||||||
|
if tool_call.function.name:
|
||||||
|
tool_states[idx]["name"] = tool_call.function.name
|
||||||
|
|
||||||
|
if tool_call.function.arguments is not None:
|
||||||
|
tool_states[idx]["arguments"] += tool_call.function.arguments
|
||||||
|
|
||||||
|
# Verify content was streamed
|
||||||
|
assert "I'll check the weather for you." in other_content
|
||||||
|
|
||||||
|
# Verify we got the tool call
|
||||||
|
assert len(tool_states) == 1
|
||||||
|
assert len(qwen3_tool_parser_parametrized.prev_tool_call_arr) == 1
|
||||||
|
|
||||||
|
state = tool_states[0]
|
||||||
|
assert state["id"] is not None
|
||||||
|
assert state["type"] == "function"
|
||||||
|
assert state["name"] == "get_current_weather"
|
||||||
|
|
||||||
|
# Verify arguments were parsed correctly despite missing opening tag
|
||||||
|
assert state["arguments"] is not None
|
||||||
|
args = json.loads(state["arguments"])
|
||||||
|
assert args["city"] == "Dallas"
|
||||||
|
assert args["state"] == "TX"
|
||||||
|
assert args["unit"] == "fahrenheit"
|
||||||
|
@ -2,13 +2,13 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import ast
|
import ast
|
||||||
import json
|
import json
|
||||||
import uuid
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from xml.parsers.expat import ParserCreate
|
from xml.parsers.expat import ParserCreate
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
|
|
||||||
|
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionToolsParam,
|
ChatCompletionToolsParam,
|
||||||
@ -375,14 +375,21 @@ class StreamingXMLToolCallParser:
|
|||||||
return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1
|
return buffer[: tag_end2 + 1], start_pos + tag_end2 + 1
|
||||||
else:
|
else:
|
||||||
# If currently not parsing tool calls (entering a tool_call),
|
# If currently not parsing tool calls (entering a tool_call),
|
||||||
# check if starts with <tool_call>
|
# check if starts with <tool_call> or <function=
|
||||||
if self.current_call_id is None:
|
if self.current_call_id is None:
|
||||||
# Check if might be start of <tool_call>
|
# Check if might be start of <tool_call>
|
||||||
if buffer == "<tool_call>"[: len(buffer)]:
|
if buffer == "<tool_call>"[: len(buffer)]:
|
||||||
# Might be start of <tool_call>, wait for more data
|
# Might be start of <tool_call>, wait for more data
|
||||||
return None, start_pos
|
return None, start_pos
|
||||||
|
elif (
|
||||||
|
buffer.startswith("<function=")
|
||||||
|
or buffer == "<function="[: len(buffer)]
|
||||||
|
):
|
||||||
|
# Might be start of <function=, wait for more data
|
||||||
|
# to get the complete function tag
|
||||||
|
return None, start_pos
|
||||||
else:
|
else:
|
||||||
# Not start of <tool_call>, treat as text
|
# Not start of <tool_call> or <function=, treat as text
|
||||||
return buffer, start_pos + len(buffer)
|
return buffer, start_pos + len(buffer)
|
||||||
else:
|
else:
|
||||||
# When parsing tool calls,
|
# When parsing tool calls,
|
||||||
@ -621,7 +628,7 @@ class StreamingXMLToolCallParser:
|
|||||||
self._auto_close_open_parameter_if_needed("tool_call")
|
self._auto_close_open_parameter_if_needed("tool_call")
|
||||||
|
|
||||||
self.parameters = {}
|
self.parameters = {}
|
||||||
self.current_call_id = self._get_next_call_id()
|
self.current_call_id = make_tool_call_id()
|
||||||
self.current_param_is_first = True
|
self.current_param_is_first = True
|
||||||
self.tool_call_index += 1
|
self.tool_call_index += 1
|
||||||
elif name.startswith("function") or (name == "function"):
|
elif name.startswith("function") or (name == "function"):
|
||||||
@ -957,10 +964,6 @@ class StreamingXMLToolCallParser:
|
|||||||
"""Set tool configuration information"""
|
"""Set tool configuration information"""
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
|
|
||||||
def _get_next_call_id(self):
|
|
||||||
"""Generate unique call ID"""
|
|
||||||
return f"call_{uuid.uuid4().hex[:24]}"
|
|
||||||
|
|
||||||
def _extract_function_name(self, name: str, attrs: dict[str, str]) -> str | None:
|
def _extract_function_name(self, name: str, attrs: dict[str, str]) -> str | None:
|
||||||
"""Extract function name from various formats"""
|
"""Extract function name from various formats"""
|
||||||
if attrs and "name" in attrs:
|
if attrs and "name" in attrs:
|
||||||
@ -1168,6 +1171,10 @@ class Qwen3XMLToolParser(ToolParser):
|
|||||||
super().__init__(tokenizer)
|
super().__init__(tokenizer)
|
||||||
self.parser = StreamingXMLToolCallParser()
|
self.parser = StreamingXMLToolCallParser()
|
||||||
|
|
||||||
|
# Add missing attributes for compatibility with serving_chat.py
|
||||||
|
self.prev_tool_call_arr: list[dict] = []
|
||||||
|
self.streamed_args_for_tool: list[str] = []
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"vLLM Successfully import tool parser %s !", self.__class__.__name__
|
"vLLM Successfully import tool parser %s !", self.__class__.__name__
|
||||||
)
|
)
|
||||||
@ -1178,6 +1185,9 @@ class Qwen3XMLToolParser(ToolParser):
|
|||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> ExtractedToolCallInformation:
|
) -> ExtractedToolCallInformation:
|
||||||
self.parser.reset_streaming_state()
|
self.parser.reset_streaming_state()
|
||||||
|
# Reset tool call tracking arrays for new extraction
|
||||||
|
self.prev_tool_call_arr = []
|
||||||
|
self.streamed_args_for_tool = []
|
||||||
if request:
|
if request:
|
||||||
self.parser.set_tools(request.tools)
|
self.parser.set_tools(request.tools)
|
||||||
result = self.parser.parse_single_streaming_chunks(model_output)
|
result = self.parser.parse_single_streaming_chunks(model_output)
|
||||||
@ -1201,6 +1211,34 @@ class Qwen3XMLToolParser(ToolParser):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Update tool call tracking arrays for compatibility
|
||||||
|
tool_index = (
|
||||||
|
tool_call.index
|
||||||
|
if tool_call.index is not None
|
||||||
|
else len(self.prev_tool_call_arr) - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure we have enough entries in our tracking arrays
|
||||||
|
while len(self.prev_tool_call_arr) <= tool_index:
|
||||||
|
self.prev_tool_call_arr.append({"name": "", "arguments": ""})
|
||||||
|
while len(self.streamed_args_for_tool) <= tool_index:
|
||||||
|
self.streamed_args_for_tool.append("")
|
||||||
|
|
||||||
|
# Update tool call information
|
||||||
|
self.prev_tool_call_arr[tool_index]["name"] = (
|
||||||
|
tool_call.function.name
|
||||||
|
)
|
||||||
|
self.prev_tool_call_arr[tool_index]["arguments"] = (
|
||||||
|
tool_call.function.arguments
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update streamed arguments
|
||||||
|
if tool_call.function.arguments:
|
||||||
|
self.streamed_args_for_tool[tool_index] = (
|
||||||
|
tool_call.function.arguments
|
||||||
|
)
|
||||||
|
|
||||||
return ExtractedToolCallInformation(
|
return ExtractedToolCallInformation(
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
tools_called=len(tool_calls) > 0,
|
tools_called=len(tool_calls) > 0,
|
||||||
@ -1219,6 +1257,9 @@ class Qwen3XMLToolParser(ToolParser):
|
|||||||
) -> DeltaMessage | None:
|
) -> DeltaMessage | None:
|
||||||
if not previous_text:
|
if not previous_text:
|
||||||
self.parser.reset_streaming_state()
|
self.parser.reset_streaming_state()
|
||||||
|
# Reset tool call tracking arrays for new streaming session
|
||||||
|
self.prev_tool_call_arr = []
|
||||||
|
self.streamed_args_for_tool = []
|
||||||
if request:
|
if request:
|
||||||
self.parser.set_tools(request.tools)
|
self.parser.set_tools(request.tools)
|
||||||
|
|
||||||
@ -1230,20 +1271,48 @@ class Qwen3XMLToolParser(ToolParser):
|
|||||||
open_calls = current_text.count(
|
open_calls = current_text.count(
|
||||||
self.parser.tool_call_start_token
|
self.parser.tool_call_start_token
|
||||||
) - current_text.count(self.parser.tool_call_end_token)
|
) - current_text.count(self.parser.tool_call_end_token)
|
||||||
if open_calls == 0 and self.parser.tool_call_index > 0:
|
if (
|
||||||
# If current_call_id is None, use last_completed_call_id
|
open_calls == 0
|
||||||
call_id = (
|
and self.parser.tool_call_index > 0
|
||||||
self.parser.current_call_id or self.parser.last_completed_call_id
|
or not self.parser.tool_call_index
|
||||||
)
|
and current_text
|
||||||
return DeltaMessage(
|
):
|
||||||
tool_calls=[
|
return DeltaMessage(content="")
|
||||||
DeltaToolCall(
|
return None
|
||||||
index=self.parser.tool_call_index - 1,
|
|
||||||
id=call_id,
|
# Parse the delta text and get the result
|
||||||
function=DeltaFunctionCall(arguments=""),
|
result = self.parser.parse_single_streaming_chunks(delta_text)
|
||||||
type="function",
|
|
||||||
)
|
# Update tool call tracking arrays based on incremental parsing results
|
||||||
]
|
if result and result.tool_calls:
|
||||||
|
for tool_call in result.tool_calls:
|
||||||
|
if tool_call.function:
|
||||||
|
tool_index = (
|
||||||
|
tool_call.index
|
||||||
|
if tool_call.index is not None
|
||||||
|
else len(self.prev_tool_call_arr) - 1
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.parser.parse_single_streaming_chunks(delta_text)
|
# Ensure we have enough entries in our tracking arrays
|
||||||
|
while len(self.prev_tool_call_arr) <= tool_index:
|
||||||
|
self.prev_tool_call_arr.append({"name": "", "arguments": ""})
|
||||||
|
while len(self.streamed_args_for_tool) <= tool_index:
|
||||||
|
self.streamed_args_for_tool.append("")
|
||||||
|
|
||||||
|
# Update tool name if provided
|
||||||
|
if tool_call.function.name:
|
||||||
|
self.prev_tool_call_arr[tool_index]["name"] = (
|
||||||
|
tool_call.function.name
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update arguments incrementally
|
||||||
|
if tool_call.function.arguments is not None:
|
||||||
|
# Concatenate the incremental arguments
|
||||||
|
# to the existing streamed arguments
|
||||||
|
self.prev_tool_call_arr[tool_index]["arguments"] += (
|
||||||
|
tool_call.function.arguments
|
||||||
|
)
|
||||||
|
self.streamed_args_for_tool[tool_index] += (
|
||||||
|
tool_call.function.arguments
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
Reference in New Issue
Block a user