diff --git a/examples/tool_chat_template_qwen3coder.jinja b/examples/tool_chat_template_qwen3coder.jinja new file mode 100644 index 0000000000..49b0e8d0ee --- /dev/null +++ b/examples/tool_chat_template_qwen3coder.jinja @@ -0,0 +1,117 @@ +{% macro render_extra_keys(json_dict, handled_keys) %} + {%- if json_dict is mapping %} + {%- for json_key in json_dict if json_key not in handled_keys %} + {%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %} + {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '' }} + {%- else %} + {{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '' }} + {%- endif %} + {%- endfor %} + {%- endif %} +{% endmacro %} + +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + +{%- if not tools is defined %} + {%- set tools = [] %} +{%- endif %} + +{%- if system_message is defined %} + {{- "<|im_start|>system\n" + system_message }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- "<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." }} + {%- endif %} +{%- endif %} +{%- if tools is iterable and tools | length > 0 %} + {{- "\n\n# Tools\n\nYou have access to the following functions:\n\n" }} + {{- "" }} + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- "\n\n" ~ tool.name ~ "" }} + {%- if tool.description is defined %} + {{- '\n' ~ (tool.description | trim) ~ '' }} + {%- endif %} + {{- '\n' }} + {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- '\n' }} + {{- '\n' ~ param_name ~ '' }} + {%- if param_fields.type is defined %} + {{- '\n' ~ (param_fields.type | string) ~ '' }} + {%- endif %} + {%- if param_fields.description is defined %} + {{- '\n' ~ (param_fields.description | trim) ~ '' }} + {%- endif %} + {%- set handled_keys = ['name', 'type', 'description'] %} + {{- render_extra_keys(param_fields, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {% set handled_keys = ['type', 'properties'] %} + {{- render_extra_keys(tool.parameters, handled_keys) }} + {{- '\n' }} + {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %} + {{- render_extra_keys(tool, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {{- "\n" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} +{%- endif %} +{%- if system_message is defined %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in loop_messages %} + {%- if message.role == "assistant" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %} + {{- '<|im_start|>' + message.role }} + {%- if message.content is defined and message.content is string and message.content | trim | length > 0 %} + {{- '\n' + message.content | trim + '\n' }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n\n' }} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' }} + {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "user" or message.role == "system" or message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user\n' }} + {%- endif %} + {{- '\n' }} + {{- message.content }} + {{- '\n\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/tests/tool_use/test_qwen3coder_tool_parser.py b/tests/tool_use/test_qwen3coder_tool_parser.py index 40c3158e9e..ccb2acf512 100644 --- a/tests/tool_use/test_qwen3coder_tool_parser.py +++ b/tests/tool_use/test_qwen3coder_tool_parser.py @@ -16,7 +16,7 @@ from vllm.entrypoints.openai.tool_parsers.qwen3coder_tool_parser import ( from vllm.transformers_utils.detokenizer import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer -MODEL = "Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8" +MODEL = "Qwen/Qwen3-Coder-30B-A3B-Instruct-FP8" @pytest.fixture(scope="module") @@ -397,7 +397,9 @@ hello world "no_tools", "single_tool", "single_tool_with_content", + "single_tool_multiline_param", "parallel_tools", + "tool_with_typed_params", # Added this test case ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ @@ -422,7 +424,7 @@ fahrenheit "state": "TX", "unit": "fahrenheit" }))) - ], ""), + ], None), ('''Sure! Let me check the weather for you. @@ -445,6 +447,30 @@ fahrenheit }))) ], "Sure! Let me check the weather for you."), (''' + + +rectangle + + +{"width": 10, + "height": 20} + + +2 + + +''', [ + ToolCall(function=FunctionCall(name="calculate_area", + arguments=json.dumps({ + "shape": "rectangle", + "dimensions": { + "width": 10, + "height": 20 + }, + "precision": 2 + }))) + ], None), + (''' Dallas @@ -484,13 +510,36 @@ celsius "state": "FL", "unit": "celsius" }))) - ], ""), + ], None), + # Added tool_with_typed_params test case + ('''Let me calculate that area for you. + + +circle + + +{"radius": 15.5} + + +3 + + +''', [ + ToolCall(function=FunctionCall(name="calculate_area", + arguments=json.dumps({ + "shape": "circle", + "dimensions": { + "radius": 15.5 + }, + "precision": 3 + }))) + ], "Let me calculate that area for you."), ], ) def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, sample_tools, model_output, expected_tool_calls, expected_content): - """Test incremental streaming behavior""" + """Test incremental streaming behavior including typed parameters""" request = ChatCompletionRequest(model=MODEL, messages=[], tools=sample_tools) @@ -539,7 +588,7 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, "arguments"] += tool_call.function.arguments # Verify final content - assert other_content == expected_content + assert other_content == (expected_content or "") # Handle None case # Verify we got all expected tool calls assert len(tool_states) == len(expected_tool_calls) @@ -559,6 +608,125 @@ def test_extract_tool_calls_streaming(qwen3_tool_parser, qwen3_tokenizer, assert actual_args == expected_args +def test_extract_tool_calls_missing_closing_parameter_tag( + qwen3_tool_parser, sample_tools): + """Test handling of missing closing tag""" + # Using get_current_weather from sample_tools but with malformed XML + model_output = '''Let me check the weather for you: + + + +Dallas + +TX + + +fahrenheit + + +''' + + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + extracted_tool_calls = qwen3_tool_parser.extract_tool_calls( + model_output, request=request) + + # The parser should handle the malformed XML gracefully + assert extracted_tool_calls.tools_called + assert len(extracted_tool_calls.tool_calls) == 1 + + # Verify the function name is correct + assert extracted_tool_calls.tool_calls[ + 0].function.name == "get_current_weather" + + # Verify the arguments are parsed despite the missing closing tag + args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments) + assert "city" in args + assert args["city"] == "Dallas" + assert args["state"] == "TX" + assert args["unit"] == "fahrenheit" + + # Check that content before the tool call is preserved + assert "Let me check the weather for you:" in extracted_tool_calls.content + + +def test_extract_tool_calls_streaming_missing_closing_tag( + qwen3_tool_parser, qwen3_tokenizer, sample_tools): + """Test streaming with missing closing tag""" + # Using get_current_weather from sample_tools but with malformed XML + model_output = '''Let me check the weather for you: + + + +Dallas + +TX + + +fahrenheit + + +''' + + request = ChatCompletionRequest(model=MODEL, + messages=[], + tools=sample_tools) + + other_content = '' + tool_states = {} + + for delta_message in stream_delta_message_generator( + qwen3_tool_parser, qwen3_tokenizer, model_output, request): + + if delta_message.content: + other_content += delta_message.content + + if delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + idx = tool_call.index + + if idx not in tool_states: + tool_states[idx] = { + "id": None, + "name": None, + "arguments": "", + "type": None + } + + if tool_call.id: + tool_states[idx]["id"] = tool_call.id + + if tool_call.type: + assert tool_call.type == "function" + tool_states[idx]["type"] = tool_call.type + + if tool_call.function: + if tool_call.function.name: + tool_states[idx]["name"] = tool_call.function.name + + if tool_call.function.arguments is not None: + tool_states[idx][ + "arguments"] += tool_call.function.arguments + + # Verify content was streamed + assert "Let me check the weather for you:" in other_content + + # Verify we got the tool call + assert len(tool_states) == 1 + state = tool_states[0] + assert state["id"] is not None + assert state["type"] == "function" + assert state["name"] == "get_current_weather" + + # Verify arguments were parsed correctly despite missing closing tag + assert state["arguments"] is not None + args = json.loads(state["arguments"]) + assert args["city"] == "Dallas" + assert args["state"] == "TX" + assert args["unit"] == "fahrenheit" + + def test_extract_tool_calls_streaming_incremental(qwen3_tool_parser, qwen3_tokenizer, sample_tools): diff --git a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py index 2501d6739e..955813ddd3 100644 --- a/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +import ast import json import uuid from collections.abc import Sequence @@ -22,7 +22,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) -@ToolParserManager.register_module(["qwen3_coder"]) +@ToolParserManager.register_module("qwen3_coder") class Qwen3CoderToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): @@ -30,6 +30,8 @@ class Qwen3CoderToolParser(ToolParser): self.current_tool_name_sent: bool = False self.prev_tool_call_arr: list[dict] = [] + # Override base class type - we use string IDs for tool calls + self.current_tool_id: Optional[str] = None # type: ignore self.streamed_args_for_tool: list[str] = [] # Sentinel tokens for streaming mode @@ -42,20 +44,6 @@ class Qwen3CoderToolParser(ToolParser): self.is_tool_call_started: bool = False self.failed_count: int = 0 - # Streaming state variables - self.current_tool_index: int = 0 - self.header_sent: bool = False - self.current_tool_string_id: Optional[str] = None - self.current_function_name: Optional[str] = None - self.current_param_name: Optional[str] = None - self.current_param_value: str = "" - self.param_count: int = 0 - self.in_param: bool = False - self.in_function: bool = False - self.accumulated_text: str = "" - self.json_started: bool = False - self.json_closed: bool = False - # Enhanced streaming state - reset for each new message self._reset_streaming_state() @@ -67,7 +55,8 @@ class Qwen3CoderToolParser(ToolParser): self.tool_call_function_regex = re.compile( r"|||(?=)|$)", + re.DOTALL) if not self.model_tokenizer: raise ValueError( @@ -84,8 +73,8 @@ class Qwen3CoderToolParser(ToolParser): "Qwen3 XML Tool parser could not locate tool call start/end " "tokens in the tokenizer!") - logger.debug("vLLM Successfully import tool parser %s !", - self.__class__.__name__) + logger.info("vLLM Successfully import tool parser %s !", + self.__class__.__name__) def _generate_tool_call_id(self) -> str: """Generate a unique tool call ID.""" @@ -96,7 +85,7 @@ class Qwen3CoderToolParser(ToolParser): self.current_tool_index = 0 self.is_tool_call_started = False self.header_sent = False - self.current_tool_string_id = None + self.current_tool_id = None self.current_function_name = None self.current_param_name = None self.current_param_value = "" @@ -106,122 +95,122 @@ class Qwen3CoderToolParser(ToolParser): self.accumulated_text = "" self.json_started = False self.json_closed = False + # Store accumulated parameters for type conversion + self.accumulated_params = {} + self.streaming_request = None + + def _get_arguments_config( + self, func_name: str, + tools: Optional[list[ChatCompletionToolsParam]]) -> dict: + """Extract argument configuration for a function.""" + if tools is None: + return {} + for config in tools: + if not hasattr(config, "type") or not (hasattr( + config, "function") and hasattr(config.function, "name")): + continue + if config.type == "function" and config.function.name == func_name: + if not hasattr(config.function, "parameters"): + return {} + params = config.function.parameters + if isinstance(params, dict) and "properties" in params: + return params["properties"] + elif isinstance(params, dict): + return params + else: + return {} + logger.warning("Tool '%s' is not defined in the tools list.", + func_name) + return {} + + def _convert_param_value(self, param_value: str, param_name: str, + param_config: dict, func_name: str) -> Any: + """Convert parameter value based on its type in the schema.""" + # Handle null value for any type + if param_value.lower() == "null": + return None + + if param_name not in param_config: + if param_config != {}: + logger.warning( + "Parsed parameter '%s' is not defined in the tool " + "parameters for tool '%s', directly returning the " + "string value.", param_name, func_name) + return param_value + + if isinstance(param_config[param_name], + dict) and "type" in param_config[param_name]: + param_type = str(param_config[param_name]["type"]).strip().lower() + else: + param_type = "string" + if param_type in ["string", "str", "text", "varchar", "char", "enum"]: + return param_value + elif param_type.startswith("int") or param_type.startswith( + "uint") or param_type.startswith( + "long") or param_type.startswith( + "short") or param_type.startswith("unsigned"): + try: + return int(param_value) + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not an " + "integer in tool '%s', degenerating to string.", + param_value, param_name, func_name) + return param_value + elif param_type.startswith("num") or param_type.startswith("float"): + try: + float_param_value = float(param_value) + return float_param_value if float_param_value - int( + float_param_value) != 0 else int(float_param_value) + except (ValueError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' is not a float " + "in tool '%s', degenerating to string.", param_value, + param_name, func_name) + return param_value + elif param_type in ["boolean", "bool", "binary"]: + param_value = param_value.lower() + if param_value not in ["true", "false"]: + logger.warning( + "Parsed value '%s' of parameter '%s' is not a boolean " + "(`true` or `false`) in tool '%s', degenerating to " + "false.", param_value, param_name, func_name) + return param_value == "true" + else: + if param_type in ["object", "array", "arr" + ] or param_type.startswith( + "dict") or param_type.startswith("list"): + try: + param_value = json.loads(param_value) + return param_value + except (json.JSONDecodeError, TypeError, ValueError): + logger.warning( + "Parsed value '%s' of parameter '%s' cannot be " + "parsed with json.loads in tool '%s', will try " + "other methods to parse it.", param_value, param_name, + func_name) + try: + param_value = ast.literal_eval(param_value) # safer + except (ValueError, SyntaxError, TypeError): + logger.warning( + "Parsed value '%s' of parameter '%s' cannot be " + "converted via Python `ast.literal_eval()` in tool " + "'%s', degenerating to string.", param_value, param_name, + func_name) + return param_value def _parse_xml_function_call( self, function_call_str: str, tools: Optional[list[ChatCompletionToolsParam]] ) -> Optional[ToolCall]: - def get_arguments_config(func_name: str) -> dict: - if tools is None: - return {} - for config in tools: - if not hasattr(config, "type") or not ( - hasattr(config, "function") - and hasattr(config.function, "name")): - continue - if (config.type == "function" - and config.function.name == func_name): - if not hasattr(config.function, "parameters"): - return {} - params = config.function.parameters - if isinstance(params, dict) and "properties" in params: - return params["properties"] - elif isinstance(params, dict): - return params - else: - return {} - logger.warning("Tool '%s' is not defined in the tools list.", - func_name) - return {} - - def convert_param_value(param_value: str, param_name: str, - param_config: dict, func_name: str) -> Any: - # Handle null value for any type - if param_value.lower() == "null": - return None - - converted_value: Any - - if param_name not in param_config: - if param_config != {}: - logger.warning( - "Parsed parameter '%s' is not defined in the tool " - "parameters for tool '%s', directly returning the " - "string value.", param_name, func_name) - return param_value - - if (isinstance(param_config[param_name], dict) - and "type" in param_config[param_name]): - param_type = str( - param_config[param_name]["type"]).strip().lower() - else: - param_type = "string" - if param_type in [ - "string", "str", "text", "varchar", "char", "enum" - ]: - return param_value - elif (param_type.startswith("int") or param_type.startswith("uint") - or param_type.startswith("long") - or param_type.startswith("short") - or param_type.startswith("unsigned")): - try: - converted_value = int(param_value) - return converted_value - except ValueError: - logger.warning( - "Parsed value '%s' of parameter '%s' is not an " - "integer in tool '%s', degenerating to string.", - param_value, param_name, func_name) - return param_value - elif (param_type.startswith("num") - or param_type.startswith("float")): - try: - float_param_value = float(param_value) - converted_value = (float_param_value if float_param_value - - int(float_param_value) != 0 else - int(float_param_value)) - return converted_value - except ValueError: - logger.warning( - "Parsed value '%s' of parameter '%s' is not a float " - "in tool '%s', degenerating to string.", param_value, - param_name, func_name) - return param_value - elif param_type in ["boolean", "bool", "binary"]: - param_value = param_value.lower() - if param_value not in ["true", "false"]: - logger.warning( - "Parsed value '%s' of parameter '%s' is not a " - "boolean (`true` of `false`) in tool '%s', " - "degenerating to false.", param_value, param_name, - func_name) - return param_value == "true" - else: - if param_type == "object" or param_type.startswith("dict"): - try: - converted_value = json.loads(param_value) - return converted_value - except json.JSONDecodeError: - logger.warning( - "Parsed value '%s' of parameter '%s' is not a " - "valid JSON object in tool '%s', will try other " - "methods to parse it.", param_value, param_name, - func_name) - logger.warning( - "Parameter '%s' has unknown type '%s'. " - "The value will be treated as a string.", param_name, - param_type) - return param_value - # Extract function name end_index = function_call_str.index(">") function_name = function_call_str[:end_index] - param_config = get_arguments_config(function_name) + param_config = self._get_arguments_config(function_name, tools) parameters = function_call_str[end_index + 1:] param_dict = {} - for match in self.tool_call_parameter_regex.findall(parameters): - match_text = match[0] if match[0] else match[1] + for match_text in self.tool_call_parameter_regex.findall(parameters): idx = match_text.index(">") param_name = match_text[:idx] param_value = str(match_text[idx + 1:]) @@ -231,7 +220,7 @@ class Qwen3CoderToolParser(ToolParser): if param_value.endswith("\n"): param_value = param_value[:-1] - param_dict[param_name] = convert_param_value( + param_dict[param_name] = self._convert_param_value( param_value, param_name, param_config, function_name) return ToolCall( type="function", @@ -284,8 +273,7 @@ class Qwen3CoderToolParser(ToolParser): for function_call_str in function_calls ] - # Populate prev_tool_call_arr for serving layer to set - # finish_reason + # Populate prev_tool_call_arr for serving layer to set finish_reason self.prev_tool_call_arr.clear() # Clear previous calls for tool_call in tool_calls: if tool_call: @@ -298,8 +286,8 @@ class Qwen3CoderToolParser(ToolParser): # Extract content before tool calls content_index = model_output.find(self.tool_call_start_token) - content_index = (content_index if content_index >= 0 else - model_output.find(self.tool_call_prefix)) + idx = model_output.find(self.tool_call_prefix) + content_index = content_index if content_index >= 0 else idx content = model_output[:content_index] # .rstrip() return ExtractedToolCallInformation( @@ -324,13 +312,16 @@ class Qwen3CoderToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - # If no delta text, return None unless it's an EOS token after tool - # calls + # Store request for type conversion + if not previous_text: + self._reset_streaming_state() + self.streaming_request = request + + # If no delta text, return None unless it's an EOS token after tools if not delta_text: # Check if this is an EOS token after all tool calls are complete - # We check for tool calls in the text even if is_tool_call_started - # is False because it might have been reset after processing all - # tools + # Check for tool calls in text even if is_tool_call_started + # is False (might have been reset after processing all tools) if (delta_token_ids and self.tool_call_end_token_id not in delta_token_ids): # Count complete tool calls @@ -339,24 +330,19 @@ class Qwen3CoderToolParser(ToolParser): # If we have completed tool calls and populated # prev_tool_call_arr - if (complete_calls > 0 and len(self.prev_tool_call_arr) > 0): + if complete_calls > 0 and len(self.prev_tool_call_arr) > 0: # Check if all tool calls are closed - open_calls = ( - current_text.count(self.tool_call_start_token) - - current_text.count(self.tool_call_end_token)) + open_calls = current_text.count( + self.tool_call_start_token) - current_text.count( + self.tool_call_end_token) if open_calls == 0: - # Return empty delta message to allow finish_reason - # processing + # Return empty delta for finish_reason processing return DeltaMessage(content="") elif not self.is_tool_call_started and current_text: # This is a regular content response that's now complete return DeltaMessage(content="") return None - # Check if this is the first call (reset state if needed) - if not previous_text: - self._reset_streaming_state() - # Update accumulated text self.accumulated_text = current_text @@ -371,11 +357,11 @@ class Qwen3CoderToolParser(ToolParser): self.param_count = 0 self.json_started = False self.json_closed = False + self.accumulated_params = {} # Check if there are more tool calls - tool_starts_count = current_text.count( - self.tool_call_start_token) - if self.current_tool_index >= tool_starts_count: + tool_starts = current_text.count(self.tool_call_start_token) + if self.current_tool_index >= tool_starts: # No more tool calls self.is_tool_call_started = False # Continue processing next tool @@ -412,20 +398,20 @@ class Qwen3CoderToolParser(ToolParser): # We're in a tool call, find the current tool call portion # Need to find the correct tool call based on current_tool_index - tool_starts: list[int] = [] + tool_start_positions: list[int] = [] idx = 0 while True: idx = current_text.find(self.tool_call_start_token, idx) if idx == -1: break - tool_starts.append(idx) + tool_start_positions.append(idx) idx += len(self.tool_call_start_token) - if self.current_tool_index >= len(tool_starts): + if self.current_tool_index >= len(tool_start_positions): # No more tool calls to process yet return None - tool_start_idx = tool_starts[self.current_tool_index] + tool_start_idx = tool_start_positions[self.current_tool_index] # Find where this tool call ends (or current position if not ended yet) tool_end_idx = current_text.find(self.tool_call_end_token, tool_start_idx) @@ -438,19 +424,19 @@ class Qwen3CoderToolParser(ToolParser): # Looking for function header if not self.header_sent: if self.tool_call_prefix in tool_text: - func_start = (tool_text.find(self.tool_call_prefix) + - len(self.tool_call_prefix)) + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix) func_end = tool_text.find(">", func_start) if func_end != -1: # Found complete function name self.current_function_name = tool_text[func_start:func_end] - self.current_tool_string_id = self._generate_tool_call_id() + self.current_tool_id = self._generate_tool_call_id() self.header_sent = True self.in_function = True - # IMPORTANT: Add to prev_tool_call_arr immediately when we - # detect a tool call. This ensures + # IMPORTANT: Add to prev_tool_call_arr immediately when + # we detect a tool call. This ensures # finish_reason="tool_calls" even if parsing isn't complete already_added = any( tool.get("name") == self.current_function_name @@ -466,7 +452,7 @@ class Qwen3CoderToolParser(ToolParser): return DeltaMessage(tool_calls=[ DeltaToolCall( index=self.current_tool_index, - id=self.current_tool_string_id, + id=self.current_tool_id, function=DeltaFunctionCall( name=self.current_function_name, arguments=""), type="function", @@ -496,10 +482,11 @@ class Qwen3CoderToolParser(ToolParser): # Close JSON self.json_closed = True - # Extract the complete tool call to update prev_tool_call_arr - # with final arguments. Find the function content - func_start = (tool_text.find(self.tool_call_prefix) + - len(self.tool_call_prefix)) + # Extract complete tool call to update + # prev_tool_call_arr with final arguments + # Find the function content + func_start = tool_text.find(self.tool_call_prefix) + len( + self.tool_call_prefix) func_content_end = tool_text.find(self.function_end_token, func_start) if func_content_end != -1: @@ -507,15 +494,17 @@ class Qwen3CoderToolParser(ToolParser): # Parse to get the complete arguments try: parsed_tool = self._parse_xml_function_call( - func_content, request.tools if request else None) + func_content, self.streaming_request.tools + if self.streaming_request else None) if parsed_tool: - # Update existing entry in prev_tool_call_arr with - # complete arguments + # Update existing entry in + # prev_tool_call_arr with complete args for i, tool in enumerate(self.prev_tool_call_arr): - if (tool.get("name") == - parsed_tool.function.name): - self.prev_tool_call_arr[i]["arguments"] = ( - parsed_tool.function.arguments) + if tool.get( + "name") == parsed_tool.function.name: + args = parsed_tool.function.arguments + self.prev_tool_call_arr[i][ + "arguments"] = args break except Exception: pass # Ignore parsing errors during streaming @@ -530,73 +519,110 @@ class Qwen3CoderToolParser(ToolParser): # Reset state for next tool self.in_function = False self.json_closed = True + self.accumulated_params = {} return result # Look for parameters - # Count how many complete parameters we have processed - complete_params = tool_text.count(self.parameter_end_token) + # Find all parameter starts + param_starts = [] + idx = 0 + while True: + idx = tool_text.find(self.parameter_prefix, idx) + if idx == -1: + break + param_starts.append(idx) + idx += len(self.parameter_prefix) # Check if we should start a new parameter - if not self.in_param and self.param_count < complete_params: - # Find the unprocessed parameter - # Count parameter starts - param_starts = [] - idx = 0 - while True: - idx = tool_text.find(self.parameter_prefix, idx) - if idx == -1: - break - param_starts.append(idx) - idx += len(self.parameter_prefix) + if (not self.in_param and self.param_count < len(param_starts) + and len(param_starts) > self.param_count): + # Process the next parameter + param_idx = param_starts[self.param_count] + param_start = param_idx + len(self.parameter_prefix) + remaining = tool_text[param_start:] - if len(param_starts) > self.param_count: - # Process the next parameter - param_idx = param_starts[self.param_count] - param_start = param_idx + len(self.parameter_prefix) - remaining = tool_text[param_start:] + if ">" in remaining: + # We have the complete parameter name + name_end = remaining.find(">") + self.current_param_name = remaining[:name_end] - if ">" in remaining: - # We have the complete parameter name - name_end = remaining.find(">") - self.current_param_name = remaining[:name_end] + # Find the parameter value + value_start = param_start + name_end + 1 + value_text = tool_text[value_start:] + if value_text.startswith("\n"): + value_text = value_text[1:] - # Find the parameter value - value_start = param_start + name_end + 1 - value_text = tool_text[value_start:] - if value_text.startswith("\n"): - value_text = value_text[1:] + # Find where this parameter ends + param_end_idx = value_text.find(self.parameter_end_token) + if param_end_idx == -1: + # No closing tag, look for next parameter or + # function end + next_param_idx = value_text.find(self.parameter_prefix) + func_end_idx = value_text.find(self.function_end_token) - # Find where this parameter ends - param_end_idx = value_text.find( - self.parameter_end_token) - if param_end_idx != -1: - # Complete parameter found - param_value = value_text[:param_end_idx] - if param_value.endswith("\n"): - param_value = param_value[:-1] - - # Build complete JSON fragment for this parameter - if self.param_count == 0: - json_fragment = ( - '"' + self.current_param_name + '": "' + - json.dumps(param_value)[1:-1] + '"') + if next_param_idx != -1 and (func_end_idx == -1 + or next_param_idx + < func_end_idx): + param_end_idx = next_param_idx + elif func_end_idx != -1: + param_end_idx = func_end_idx + else: + # Neither found, check if tool call is complete + if self.tool_call_end_token in tool_text: + # Tool call is complete, so parameter + # must be complete too. Use all + # remaining text before function end + param_end_idx = len(value_text) else: - json_fragment = ( - ', "' + self.current_param_name + '": "' + - json.dumps(param_value)[1:-1] + '"') + # Still streaming, wait for more content + return None - self.param_count += 1 + if param_end_idx != -1: + # Complete parameter found + param_value = value_text[:param_end_idx] + if param_value.endswith("\n"): + param_value = param_value[:-1] - return DeltaMessage(tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall( - arguments=json_fragment), - ) - ]) + # Store raw value for later processing + self.accumulated_params[ + self.current_param_name] = param_value - # Continue parameter value + # Get parameter configuration for type conversion + param_config = self._get_arguments_config( + self.current_function_name or "", + self.streaming_request.tools + if self.streaming_request else None) + + # Convert param value to appropriate type + converted_value = self._convert_param_value( + param_value, self.current_param_name, param_config, + self.current_function_name or "") + + # Build JSON fragment based on the converted type + # Use json.dumps to properly serialize the value + serialized_value = json.dumps(converted_value, + ensure_ascii=False) + + if self.param_count == 0: + json_fragment = (f'"{self.current_param_name}": ' + f'{serialized_value}') + else: + json_fragment = (f', "{self.current_param_name}": ' + f'{serialized_value}') + + self.param_count += 1 + + return DeltaMessage(tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall( + arguments=json_fragment), + ) + ]) + + # Continue parameter value - Not used in the current implementation + # since we process complete parameters above if self.in_param: if self.parameter_end_token in delta_text: # End of parameter @@ -608,25 +634,42 @@ class Qwen3CoderToolParser(ToolParser): gt_idx = value_chunk.find(">") value_chunk = value_chunk[gt_idx + 1:] - if (not self.current_param_value - and value_chunk.startswith("\n")): + if not self.current_param_value and value_chunk.startswith( + "\n"): value_chunk = value_chunk[1:] - # Calculate incremental JSON + # Store complete value full_value = self.current_param_value + value_chunk - prev_escaped = (json.dumps(self.current_param_value)[1:-1] - if self.current_param_value else "") - full_escaped = json.dumps(full_value)[1:-1] - delta_escaped = full_escaped[len(prev_escaped):] + self.accumulated_params[ + self.current_param_name] = full_value + # Get parameter configuration for type conversion + param_config = self._get_arguments_config( + self.current_function_name or "", + self.streaming_request.tools + if self.streaming_request else None) + + # Convert the parameter value to the appropriate type + converted_value = self._convert_param_value( + full_value, self.current_param_name or "", + param_config, self.current_function_name or "") + + # Serialize the converted value + serialized_value = json.dumps(converted_value, + ensure_ascii=False) + + # Since we've been streaming the quoted version, + # we need to close it properly + # This is complex - for now just complete the value self.in_param = False self.current_param_value = "" + # Just close the current parameter string return DeltaMessage(tool_calls=[ DeltaToolCall( index=self.current_tool_index, function=DeltaFunctionCall( - arguments=delta_escaped + '"'), + arguments='"'), # Close the string quote ) ]) else: @@ -638,18 +681,18 @@ class Qwen3CoderToolParser(ToolParser): gt_idx = value_chunk.find(">") value_chunk = value_chunk[gt_idx + 1:] - if (not self.current_param_value - and value_chunk.startswith("\n")): + if not self.current_param_value and value_chunk.startswith( + "\n"): value_chunk = value_chunk[1:] if value_chunk: # Stream the escaped delta - prev_escaped = (json.dumps( - self.current_param_value)[1:-1] - if self.current_param_value else "") + prev_escaped = json.dumps( + self.current_param_value, ensure_ascii=False + )[1:-1] if self.current_param_value else "" self.current_param_value += value_chunk - full_escaped = json.dumps( - self.current_param_value)[1:-1] + full_escaped = json.dumps(self.current_param_value, + ensure_ascii=False)[1:-1] delta_escaped = full_escaped[len(prev_escaped):] if delta_escaped: @@ -661,4 +704,4 @@ class Qwen3CoderToolParser(ToolParser): ) ]) - return None + return None \ No newline at end of file