[Feature] Support MiniMax-M1 function calls features (#20297)

Signed-off-by: QscQ <qscqesze@gmail.com>
Signed-off-by: qingjun <qingjun@minimaxi.com>
This commit is contained in:
qscqesze
2025-07-03 14:48:27 +08:00
committed by GitHub
parent 4ff61ababa
commit 363528de27
5 changed files with 842 additions and 1 deletions

View File

@ -264,6 +264,15 @@ For Qwen2.5, the chat template in tokenizer_config.json has already included sup
Flags: `--tool-call-parser hermes`
### MiniMax Models (`minimax_m1`)
Supported models:
* `MiniMaxAi/MiniMax-M1-40k` (use with <gh-file:examples/tool_chat_template_minimax.jinja>)
* `MiniMaxAi/MiniMax-M1-80k` (use with <gh-file:examples/tool_chat_template_minimax.jinja>)
Flags: `--tool-call-parser minimax --chat-template examples/tool_chat_template_minimax.jinja`
### DeepSeek-V3 Models (`deepseek_v3`)
Supported models:

View File

@ -0,0 +1,91 @@
{{ '<begin_of_document>' -}}
{%- if custom_tools is defined %}
{%- set tools = custom_tools %}
{%- endif %}
{%- if not tools is defined %}
{%- set tools = none %}
{%- endif %}
{#- Extract system message #}
{% set ns = namespace(system_prompt='') -%}
{%- if messages[0]['role'] == 'system' %}
{%- if messages[0]['content'] is string %}
{%- set ns.system_prompt = messages[0]['content']|trim %}
{%- else %}
{%- set ns.system_prompt = messages[0]['content'][0]['text']|trim %}
{%- endif %}
{%- set messages = messages[1:] %}
{%- else %}
{%- if tools is not none %}
{%- set ns.system_prompt = "You are a helpful assistant created by Minimax based on MiniMax-M1 model." %}
{%- else %}
{%- set ns.system_prompt = "You are a helpful assistant created by Minimax based on MiniMax-M1 model." %}
{%- endif %}
{%- endif %}
{#- System message #}
{%- if ns.system_prompt != '' %}
{{ '<beginning_of_sentence>system ai_setting=assistant\n' + ns.system_prompt + '<end_of_sentence>\n' -}}
{%- endif %}
{#- Tools configuration #}
{%- if tools is not none %}
{{ '<beginning_of_sentence>system tool_setting=tools\nYou are provided with these tools:\n<tools>\n' -}}
{%- for tool in tools %}
{{ tool | tojson ~ '\n' -}}
{%- endfor %}
{{ '</tools>\n\nIf you need to call tools, please respond with <tool_calls></tool_calls> XML tags, and provide tool-name and json-object of arguments, following the format below:\n<tool_calls>\n{"name": <tool-name>, "arguments": <args-json-object>}\n...\n</tool_calls><end_of_sentence>\n' -}}
{%- endif %}
{#- Process messages #}
{%- for message in messages %}
{%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
{%- if message['role'] == 'user' %}
{{ '<beginning_of_sentence>user name=user\n' -}}
{%- if message['content'] is string %}
{{ message['content']|trim -}}
{%- else %}
{%- for content in message['content'] %}
{%- if content['type'] == 'text' %}
{{ content['text']|trim -}}
{%- endif %}
{%- endfor %}
{%- endif %}
{{ '<end_of_sentence>\n' -}}
{%- elif message['role'] == 'assistant' %}
{{ '<beginning_of_sentence>ai name=assistant\n' -}}
{%- if message['content'] is string %}
{{ message['content']|trim -}}
{%- else %}
{%- for content in message['content'] | selectattr('type', 'equalto', 'text') %}
{{ content['text']|trim -}}
{%- endfor %}
{%- endif %}
{{ '<end_of_sentence>\n' -}}
{%- endif %}
{%- elif 'tool_calls' in message %}
{{ '<beginning_of_sentence>ai name=assistant\n<tool_calls>\n' -}}
{%- for tool_call in message.tool_calls %}
{{ '{"name": "' + tool_call.function.name + '", "arguments": ' + tool_call.function.arguments | tojson + '}\n' -}}
{%- endfor %}
{{ '</tool_calls><end_of_sentence>\n' -}}
{%- elif message.role == "tool" or message.role == "ipython" %}
{{ '<beginning_of_sentence>tool name=tools\n' -}}
{%- if message.content is string %}
{{ 'tool result: ' + message.content + '\n\n' -}}
{%- else %}
{%- for content in message['content'] %}
{%- if content['type'] == 'text' %}
{{ 'tool result: ' + content['text'] + '\n\n' -}}
{%- elif content.get('name') %}
{{ 'tool name: ' + content['name'] + '\ntool result: ' + content['text'] + '\n\n' -}}
{%- endif %}
{%- endfor %}
{%- endif %}
{{ '<end_of_sentence>\n' -}}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{ '<beginning_of_sentence>ai name=assistant\n' -}}
{%- endif %}

View File

@ -0,0 +1,371 @@
# SPDX-License-Identifier: Apache-2.0
# ruff: noqa: E501
import json
import pytest
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.tool_parsers import MinimaxToolParser
from vllm.transformers_utils.tokenizer import get_tokenizer
# Use a common model that is likely to be available
MODEL = "MiniMaxAi/MiniMax-M1-40k"
@pytest.fixture(scope="module")
def minimax_tokenizer():
return get_tokenizer(tokenizer_name=MODEL)
@pytest.fixture
def minimax_tool_parser(minimax_tokenizer):
return MinimaxToolParser(minimax_tokenizer)
def assert_tool_calls(actual_tool_calls: list[ToolCall],
expected_tool_calls: list[ToolCall]):
assert len(actual_tool_calls) == len(expected_tool_calls)
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
expected_tool_calls):
assert isinstance(actual_tool_call.id, str)
assert len(actual_tool_call.id) > 16
assert actual_tool_call.type == "function"
assert actual_tool_call.function == expected_tool_call.function
def test_extract_tool_calls_no_tools(minimax_tool_parser):
model_output = "This is a test"
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content == model_output
@pytest.mark.parametrize(
ids=[
"single_tool_call",
"multiple_tool_calls",
"tool_call_with_content_before",
"tool_call_with_single_line_json",
"tool_call_incomplete_tag",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""<tool_calls>
{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}
</tool_calls>""",
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
None,
),
(
"""<tool_calls>
{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}
{"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}
</tool_calls>""",
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
)),
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit",
}),
)),
],
None,
),
(
"""I'll help you check the weather. <tool_calls>
{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA", "unit": "celsius"}}
</tool_calls>""",
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Seattle",
"state": "WA",
"unit": "celsius",
}),
))
],
"I'll help you check the weather.",
),
(
"""<tool_calls>
{"name": "get_current_weather", "arguments": {"city": "New York", "state": "NY", "unit": "celsius"}}
</tool_calls>""",
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "New York",
"state": "NY",
"unit": "celsius",
}),
))
],
None,
),
(
"""<tool_calls>
{"name": "get_current_weather", "arguments": {"city": "Boston", "state": "MA"}}""",
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Boston",
"state": "MA",
}),
))
],
None,
),
],
)
def test_extract_tool_calls(minimax_tool_parser, model_output,
expected_tool_calls, expected_content):
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
assert extracted_tool_calls.content == expected_content
def test_preprocess_model_output_with_thinking_tags(minimax_tool_parser):
"""Test that tool calls within thinking tags are removed during preprocessing."""
model_output = """<think>Let me think about this. <tool_calls>
{"name": "fake_tool", "arguments": {"param": "value"}}
</tool_calls> This should be removed.</think>
I'll help you with that. <tool_calls>
{"name": "get_current_weather", "arguments": {"city": "Seattle", "state": "WA"}}
</tool_calls>"""
processed_output = minimax_tool_parser.preprocess_model_output(
model_output)
# The tool call within thinking tags should be removed
assert "fake_tool" not in processed_output
# But the thinking tag itself should remain
assert "<think>" in processed_output
assert "</think>" in processed_output
# The actual tool call outside thinking tags should remain
assert "get_current_weather" in processed_output
def test_extract_tool_calls_with_thinking_tags(minimax_tool_parser):
"""Test tool extraction when thinking tags contain tool calls that should be ignored."""
model_output = """<think>I should use a tool. <tool_calls>
{"name": "ignored_tool", "arguments": {"should": "ignore"}}
</tool_calls></think>
Let me help you with the weather. <tool_calls>
{"name": "get_current_weather", "arguments": {"city": "Miami", "state": "FL", "unit": "fahrenheit"}}
</tool_calls>"""
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
assert len(extracted_tool_calls.tool_calls) == 1
assert extracted_tool_calls.tool_calls[
0].function.name == "get_current_weather"
# Content extraction is based on the position of the first <tool_calls> in the original model_output
# Since preprocessing removes tool calls within thinking tags, the actual first <tool_calls> is the external one
expected_content = """<think>I should use a tool. <tool_calls>
{"name": "ignored_tool", "arguments": {"should": "ignore"}}
</tool_calls></think>
Let me help you with the weather."""
assert extracted_tool_calls.content == expected_content
def test_extract_tool_calls_invalid_json(minimax_tool_parser):
"""Test that invalid JSON in tool calls is handled gracefully."""
model_output = """<tool_calls>
{"name": "valid_tool", "arguments": {"city": "Seattle"}}
{invalid json here}
{"name": "another_valid_tool", "arguments": {"param": "value"}}
</tool_calls>"""
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
# Should extract only the valid JSON tool calls
assert len(extracted_tool_calls.tool_calls) == 2
assert extracted_tool_calls.tool_calls[0].function.name == "valid_tool"
assert extracted_tool_calls.tool_calls[
1].function.name == "another_valid_tool"
def test_extract_tool_calls_missing_name_or_arguments(minimax_tool_parser):
"""Test that tool calls missing name or arguments are filtered out."""
model_output = """<tool_calls>
{"name": "valid_tool", "arguments": {"city": "Seattle"}}
{"name": "missing_args"}
{"arguments": {"city": "Portland"}}
{"name": "another_valid_tool", "arguments": {"param": "value"}}
</tool_calls>"""
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
assert extracted_tool_calls.tools_called
# Should extract only the valid tool calls with both name and arguments
assert len(extracted_tool_calls.tool_calls) == 2
assert extracted_tool_calls.tool_calls[0].function.name == "valid_tool"
assert extracted_tool_calls.tool_calls[
1].function.name == "another_valid_tool"
def test_streaming_basic_functionality(minimax_tool_parser):
"""Test basic streaming functionality."""
# Reset streaming state
minimax_tool_parser.current_tool_name_sent = False
minimax_tool_parser.prev_tool_call_arr = []
minimax_tool_parser.current_tool_id = -1
minimax_tool_parser.streamed_args_for_tool = []
# Test with a simple tool call
current_text = """<tool_calls>
{"name": "get_current_weather", "arguments": {"city": "Seattle"}}
</tool_calls>"""
# First call should handle the initial setup
result = minimax_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text=current_text,
delta_text="</tool_calls>",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# The result might be None or contain tool call information
# This depends on the internal state management
if result is not None and hasattr(result,
'tool_calls') and result.tool_calls:
assert len(result.tool_calls) >= 0
def test_streaming_with_content_before_tool_calls(minimax_tool_parser):
"""Test streaming when there's content before tool calls."""
# Reset streaming state
minimax_tool_parser.current_tool_name_sent = False
minimax_tool_parser.prev_tool_call_arr = []
minimax_tool_parser.current_tool_id = -1
minimax_tool_parser.streamed_args_for_tool = []
current_text = "I'll help you with that. <tool_calls>"
# When there's content before tool calls, it should be returned as content
result = minimax_tool_parser.extract_tool_calls_streaming(
previous_text="I'll help you",
current_text=current_text,
delta_text=" with that. <tool_calls>",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
if result is not None and hasattr(result, 'content'):
# Should contain some content
assert result.content is not None
def test_streaming_no_tool_calls(minimax_tool_parser):
"""Test streaming when there are no tool calls."""
current_text = "This is just regular text without any tool calls."
result = minimax_tool_parser.extract_tool_calls_streaming(
previous_text="This is just regular text",
current_text=current_text,
delta_text=" without any tool calls.",
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# Should return the delta text as content
assert result is not None
assert hasattr(result, 'content')
assert result.content == " without any tool calls."
def test_streaming_with_thinking_tags(minimax_tool_parser):
"""Test streaming with thinking tags that contain tool calls."""
# Reset streaming state
minimax_tool_parser.current_tool_name_sent = False
minimax_tool_parser.prev_tool_call_arr = []
minimax_tool_parser.current_tool_id = -1
minimax_tool_parser.streamed_args_for_tool = []
current_text = """<think><tool_calls>{"name": "ignored", "arguments": {}}</tool_calls></think><tool_calls>{"name": "real_tool", "arguments": {"param": "value"}}</tool_calls>"""
result = minimax_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text=current_text,
delta_text=current_text,
previous_token_ids=[],
current_token_ids=[],
delta_token_ids=[],
request=None,
)
# The preprocessing should remove tool calls from thinking tags
# and only process the real tool call
if result is not None and hasattr(result,
'tool_calls') and result.tool_calls:
for tool_call in result.tool_calls:
assert tool_call.function.name != "ignored"
def test_extract_tool_calls_multiline_json_not_supported(minimax_tool_parser):
"""Test that multiline JSON in tool calls is not currently supported."""
model_output = """<tool_calls>
{
"name": "get_current_weather",
"arguments": {
"city": "New York",
"state": "NY",
"unit": "celsius"
}
}
</tool_calls>"""
extracted_tool_calls = minimax_tool_parser.extract_tool_calls(
model_output, request=None) # type: ignore[arg-type]
# Multiline JSON is currently not supported, should return no tools called
assert not extracted_tool_calls.tools_called
assert extracted_tool_calls.tool_calls == []
assert extracted_tool_calls.content is None

View File

@ -10,6 +10,7 @@ from .internlm2_tool_parser import Internlm2ToolParser
from .jamba_tool_parser import JambaToolParser
from .llama4_pythonic_tool_parser import Llama4PythonicToolParser
from .llama_tool_parser import Llama3JsonToolParser
from .minimax_tool_parser import MinimaxToolParser
from .mistral_tool_parser import MistralToolParser
from .phi4mini_tool_parser import Phi4MiniJsonToolParser
from .pythonic_tool_parser import PythonicToolParser
@ -20,5 +21,5 @@ __all__ = [
"GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
"Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser",
"DeepSeekV3ToolParser", "xLAMToolParser"
"DeepSeekV3ToolParser", "xLAMToolParser", "MinimaxToolParser"
]

View File

@ -0,0 +1,369 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Sequence
from typing import Union
import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall, ToolCall)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser, ToolParserManager)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import AnyTokenizer
logger = init_logger(__name__)
@ToolParserManager.register_module("minimax")
class MinimaxToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
self.current_tool_name_sent: bool = False
self.prev_tool_call_arr: list[dict] = []
self.current_tool_id: int = -1
self.streamed_args_for_tool: list[str] = []
self.tool_call_start_token: str = "<tool_calls>"
self.tool_call_end_token: str = "</tool_calls>"
self.tool_call_regex = re.compile(
r"<tool_calls>(.*?)</tool_calls>|<tool_calls>(.*)", re.DOTALL)
# Add regex pattern for thinking tag
self.thinking_tag_pattern = r"<think>(.*?)</think>"
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ToolParser "
"constructor during construction.")
self.tool_call_start_token_id = self.vocab.get(
self.tool_call_start_token)
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
if (self.tool_call_start_token_id is None
or self.tool_call_end_token_id is None):
logger.warning(
"Minimax Tool parser could not locate tool call start/end "
"tokens in the tokenizer. Falling back to string matching.")
def preprocess_model_output(self, model_output: str) -> str:
"""
Remove tool calls from within thinking tags to avoid processing them.
"""
def remove_tool_calls_from_think(match):
think_content = match.group(1)
# Remove tool_calls from within the think tag
cleaned_content = re.sub(r"<tool_calls>.*?</tool_calls>",
"",
think_content,
flags=re.DOTALL)
return f"<think>{cleaned_content}</think>"
# Process thinking tags and remove tool_calls from within them
processed_output = re.sub(self.thinking_tag_pattern,
remove_tool_calls_from_think,
model_output,
flags=re.DOTALL)
return processed_output
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
# Preprocess to remove tool calls from thinking tags
processed_output = self.preprocess_model_output(model_output)
if self.tool_call_start_token not in processed_output:
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
try:
function_call_tuples = (
self.tool_call_regex.findall(processed_output))
raw_function_calls = []
for match in function_call_tuples:
tool_call_content = match[0] if match[0] else match[1]
if tool_call_content.strip():
lines = tool_call_content.strip().split('\n')
for line in lines:
line = line.strip()
if line and line.startswith('{') and line.endswith(
'}'):
try:
parsed_call = json.loads(line)
raw_function_calls.append(parsed_call)
except json.JSONDecodeError:
continue
tool_calls = []
for function_call in raw_function_calls:
if "name" in function_call and "arguments" in function_call:
tool_calls.append(
ToolCall(type="function",
function=FunctionCall(
name=function_call["name"],
arguments=json.dumps(
function_call["arguments"],
ensure_ascii=False))))
# Extract content before the first valid tool call
# Find the position in processed output, then map back to original
processed_pos = processed_output.find(self.tool_call_start_token)
if processed_pos != -1:
# Get the content before tool calls in processed output
processed_content = processed_output[:processed_pos].strip()
if processed_content:
# Find the end of this content in the original output
# Look for the last non-empty line of processed content
lines = processed_content.split('\n')
for line in reversed(lines):
line = line.strip()
if line:
# Find this line in original output
pos = model_output.find(line)
if pos != -1:
content = model_output[:pos + len(line)]
break
else:
content = ""
else:
content = ""
else:
content = model_output
return ExtractedToolCallInformation(
tools_called=len(tool_calls) > 0,
tool_calls=tool_calls,
content=content.strip() if content.strip() else None)
except Exception:
logger.exception(
"An unexpected error occurred during tool call extraction.")
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
logger.debug("delta_text: %s", delta_text)
logger.debug("delta_token_ids: %s", delta_token_ids)
# Preprocess to remove tool calls from thinking tags
processed_current_text = self.preprocess_model_output(current_text)
if self.tool_call_start_token not in processed_current_text:
return DeltaMessage(content=delta_text)
if (self.tool_call_start_token_id is not None
and self.tool_call_start_token_id in delta_token_ids
and len(delta_token_ids) == 1):
return None
original_tool_call_start_pos = current_text.find(
self.tool_call_start_token)
if original_tool_call_start_pos > 0:
delta_start_pos = len(current_text) - len(delta_text)
if delta_start_pos < original_tool_call_start_pos:
content_part = delta_text
if delta_start_pos + len(
delta_text) > original_tool_call_start_pos:
content_part = delta_text[:original_tool_call_start_pos -
delta_start_pos]
if content_part:
return DeltaMessage(content=content_part)
flags = Allow.ALL if self.current_tool_name_sent \
else Allow.ALL & ~Allow.STR
try:
parsable_content = processed_current_text.split(
self.tool_call_start_token)[-1].split(
self.tool_call_end_token)[0]
tool_call_arr = []
if parsable_content.strip():
lines = parsable_content.strip().split('\n')
for line in lines:
line = line.strip()
if line and (line.startswith('{') or '"name"' in line):
try:
if line.endswith('}'):
parsed_call = json.loads(line)
tool_call_arr.append(parsed_call)
else:
parsed_call = partial_json_parser.loads(
line, flags)
if parsed_call and isinstance(
parsed_call, dict):
tool_call_arr.append(parsed_call)
except (json.JSONDecodeError, partial_json_parser.core.
exceptions.MalformedJSON):
continue
current_tool_call: dict = tool_call_arr[self.current_tool_id] \
if len(tool_call_arr) > self.current_tool_id >= 0 else {}
if len(tool_call_arr) == 0:
return None
# Starting a new tool in the array
elif (len(tool_call_arr) > 0
and len(tool_call_arr) > self.current_tool_id + 1):
# Handle any missed arguments from previous tool
if self.current_tool_id >= 0 and self.current_tool_id < len(
self.prev_tool_call_arr):
prev_tool_call = self.prev_tool_call_arr[
self.current_tool_id]
diff_arguments = prev_tool_call.get("arguments")
if diff_arguments:
diff_arguments_json = json.dumps(diff_arguments,
ensure_ascii=False)
already_streamed = self.streamed_args_for_tool[
self.
current_tool_id] if self.current_tool_id < len(
self.streamed_args_for_tool) else ""
if diff_arguments_json != already_streamed:
diff = diff_arguments_json[len(already_streamed):]
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=diff).model_dump(
exclude_none=True))
])
if self.current_tool_id < len(
self.streamed_args_for_tool):
self.streamed_args_for_tool[
self.current_tool_id] = diff_arguments_json
else:
delta = None
else:
delta = None
else:
delta = None
self.current_tool_id = len(tool_call_arr) - 1
self.current_tool_name_sent = False
self.streamed_args_for_tool.append("")
logger.debug("starting on new tool %d", self.current_tool_id)
return delta
# Send tool name if not sent yet
if not self.current_tool_name_sent:
function_name = current_tool_call.get("name")
if function_name:
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=random_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))
])
self.current_tool_name_sent = True
else:
delta = None
# Stream arguments
else:
prev_arguments = None
if (self.current_tool_id < len(self.prev_tool_call_arr)
and self.prev_tool_call_arr[self.current_tool_id]):
prev_arguments = self.prev_tool_call_arr[
self.current_tool_id].get("arguments")
cur_arguments = current_tool_call.get("arguments")
if not cur_arguments and not prev_arguments:
delta = None
elif not cur_arguments and prev_arguments:
logger.error(
"Arguments reset mid-call, skipping streaming")
delta = None
elif cur_arguments and not prev_arguments:
cur_arguments_json = json.dumps(cur_arguments,
ensure_ascii=False)
logger.debug("First tokens in arguments received: %s",
cur_arguments_json)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=cur_arguments_json).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] = cur_arguments_json
elif cur_arguments and prev_arguments:
cur_args_json = json.dumps(cur_arguments,
ensure_ascii=False)
prev_args_json = json.dumps(prev_arguments,
ensure_ascii=False)
logger.debug("Searching for diff between \n%s\n%s",
cur_args_json, prev_args_json)
already_streamed = self.streamed_args_for_tool[
self.current_tool_id] if self.current_tool_id < len(
self.streamed_args_for_tool) else ""
if cur_args_json.startswith(already_streamed):
argument_diff = cur_args_json[len(already_streamed):]
elif cur_args_json != already_streamed:
argument_diff = cur_args_json
self.streamed_args_for_tool[self.current_tool_id] = ""
else:
argument_diff = ""
if argument_diff:
logger.debug("got arguments diff: %s", argument_diff)
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
function=DeltaFunctionCall(
arguments=argument_diff).
model_dump(exclude_none=True))
])
self.streamed_args_for_tool[
self.current_tool_id] += argument_diff
else:
delta = None
else:
delta = None
self.prev_tool_call_arr = tool_call_arr
return delta
except Exception:
logger.exception("An unexpected error occurred",
"during streaming tool call handling.")
return None