[Fix] correct tool_id for kimi-k2 when use tool_choice=required (#21259)

Co-authored-by: wangzhengtao <wangzhengtao@msh.team>
This commit is contained in:
bigmoyan
2025-08-21 03:59:54 +08:00
committed by GitHub
parent 0cdbf5e61c
commit 582bbe6bd7
15 changed files with 283 additions and 166 deletions

View File

@ -13,6 +13,127 @@ from ...utils import RemoteOpenAIServer
# any model with a chat template should work here
MODEL_NAME = "Qwen/Qwen3-0.6B"
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description":
"The city to find the weather for, e.g. 'Vienna'",
"default": "Vienna",
},
"country": {
"type":
"string",
"description":
"The country that the city is in, e.g. 'Austria'",
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
"options": {
"$ref": "#/$defs/WeatherOptions",
"description": "Optional parameters for weather query",
},
},
"required": ["country", "unit"],
"$defs": {
"WeatherOptions": {
"title": "WeatherOptions",
"type": "object",
"additionalProperties": False,
"properties": {
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"default": "celsius",
"description": "Temperature unit",
"title": "Temperature Unit",
},
"include_forecast": {
"type": "boolean",
"default": False,
"description":
"Whether to include a 24-hour forecast",
"title": "Include Forecast",
},
"language": {
"type": "string",
"default": "zh-CN",
"description": "Language of the response",
"title": "Language",
"enum": ["zh-CN", "en-US", "ja-JP"],
},
},
},
},
},
},
},
{
"type": "function",
"function": {
"name": "get_forecast",
"description": "Get the weather forecast for a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description":
"The city to get the forecast for, e.g. 'Vienna'",
"default": "Vienna",
},
"country": {
"type":
"string",
"description":
"The country that the city is in, e.g. 'Austria'",
},
"days": {
"type":
"integer",
"description":
"Number of days to get the forecast for (1-7)",
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["country", "days", "unit"],
},
},
},
]
messages = [
{
"role": "user",
"content": "Hi! How are you doing today?"
},
{
"role": "assistant",
"content": "I'm doing well! How can I help you?"
},
{
"role":
"user",
"content":
"Can you tell me what the current weather is in Berlin and the "\
"forecast for the next 5 days, in fahrenheit?",
},
]
@pytest.fixture(scope="module")
def server(): # noqa: F811
@ -27,6 +148,8 @@ def server(): # noqa: F811
"hermes",
"--reasoning-parser",
"qwen3",
"--gpu-memory-utilization",
"0.4"
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@ -54,129 +177,6 @@ async def client(server):
async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
stream: bool, tool_choice: Union[str, dict],
enable_thinking: bool):
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description":
"The city to find the weather for, e.g. 'Vienna'",
"default": "Vienna",
},
"country": {
"type":
"string",
"description":
"The country that the city is in, e.g. 'Austria'",
},
"unit": {
"type": "string",
"description":
"The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
"options": {
"$ref": "#/$defs/WeatherOptions",
"description":
"Optional parameters for weather query",
},
},
"required": ["country", "unit"],
"$defs": {
"WeatherOptions": {
"title": "WeatherOptions",
"type": "object",
"additionalProperties": False,
"properties": {
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"default": "celsius",
"description": "Temperature unit",
"title": "Temperature Unit",
},
"include_forecast": {
"type": "boolean",
"default": False,
"description":
"Whether to include a 24-hour forecast",
"title": "Include Forecast",
},
"language": {
"type": "string",
"default": "zh-CN",
"description": "Language of the response",
"title": "Language",
"enum": ["zh-CN", "en-US", "ja-JP"],
},
},
},
},
},
},
},
{
"type": "function",
"function": {
"name": "get_forecast",
"description": "Get the weather forecast for a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description":
"The city to get the forecast for, e.g. 'Vienna'",
"default": "Vienna",
},
"country": {
"type":
"string",
"description":
"The country that the city is in, e.g. 'Austria'",
},
"days": {
"type":
"integer",
"description":
"Number of days to get the forecast for (1-7)",
},
"unit": {
"type": "string",
"description":
"The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"],
},
},
"required": ["country", "days", "unit"],
},
},
},
]
messages = [
{
"role": "user",
"content": "Hi! How are you doing today?"
},
{
"role": "assistant",
"content": "I'm doing well! How can I help you?"
},
{
"role":
"user",
"content":
"Can you tell me what the current weather is in Berlin and the "\
"forecast for the next 5 days, in fahrenheit?",
},
]
if not stream:
# Non-streaming test
chat_completion = await client.chat.completions.create(
@ -216,3 +216,71 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
output.extend(chunk.choices[0].delta.tool_calls)
assert len(output) > 0
@pytest.fixture(scope="module")
def k2_server(): # noqa: F811
args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"half",
"--enable-auto-tool-choice",
"--guided-decoding-backend",
"xgrammar",
"--tool-call-parser",
"hermes",
"--reasoning-parser",
"qwen3",
"--gpu-memory-utilization",
"0.4",
]
# hack to test kimi_k2 tool use tool_id format.
# avoid error in is_deepseek_mla check by setting kv_lora_rank=null
with RemoteOpenAIServer(MODEL_NAME,
args,
override_hf_configs={
"model_type": 'kimi_k2',
'kv_lora_rank': None
}) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def k2_client(k2_server):
async with k2_server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.parametrize("tool_choice", ["required"])
async def test_tool_id_kimi_k2(k2_client: openai.AsyncOpenAI, model_name: str,
stream: bool, tool_choice: str):
if not stream:
# Non-streaming test
chat_completion = await k2_client.chat.completions.create(
messages=messages,
model=model_name,
tools=tools,
tool_choice=tool_choice)
assert chat_completion.choices[0].message.tool_calls is not None
assert len(chat_completion.choices[0].message.tool_calls) > 0
assert chat_completion.choices[0].message.tool_calls[
0].id == 'functions.get_current_weather:0'
else:
# Streaming test
output_stream = await k2_client.chat.completions.create(
messages=messages,
model=model_name,
tools=tools,
tool_choice=tool_choice,
stream=True)
output = []
async for chunk in output_stream:
if chunk.choices and chunk.choices[0].delta.tool_calls:
output.extend(chunk.choices[0].delta.tool_calls)
for o in output:
assert o.id is None or o.id == 'functions.get_current_weather:0'

View File

@ -5,6 +5,7 @@ import asyncio
import copy
import functools
import importlib
import json
import os
import signal
import subprocess
@ -101,7 +102,8 @@ class RemoteOpenAIServer:
env_dict: Optional[dict[str, str]] = None,
seed: Optional[int] = 0,
auto_port: bool = True,
max_wait_seconds: Optional[float] = None) -> None:
max_wait_seconds: Optional[float] = None,
override_hf_configs: Optional[dict[str, Any]] = None) -> None:
if auto_port:
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
raise ValueError("You have manually specified the port "
@ -120,6 +122,12 @@ class RemoteOpenAIServer:
vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]
if override_hf_configs is not None:
vllm_serve_args = vllm_serve_args + [
"--hf-overrides",
json.dumps(override_hf_configs)
]
parser = FlexibleArgumentParser(
description="vLLM's remote OpenAI server.")
subparsers = parser.add_subparsers(required=False, dest="subparser")

View File

@ -1345,5 +1345,18 @@ def apply_mistral_chat_template(
"template")
raise ValueError(str(e)) from e
def random_tool_call_id() -> str:
return f"chatcmpl-tool-{random_uuid()}"
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
idx = 0
for msg in conversation:
if msg['role'] == 'assistant':
tool_calls = msg.get('tool_calls')
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
return idx
def make_tool_call_id(id_type:str='random', func_name=None, idx=None):
if id_type=='kimi_k2':
return f'functions.{func_name}:{idx}'
else:
# by default return random
return f"chatcmpl-tool-{random_uuid()}"

View File

@ -38,7 +38,7 @@ from typing_extensions import TypeAlias
from vllm import envs
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
random_tool_call_id)
make_tool_call_id)
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
ScoreMultiModalParam)
from vllm.logger import init_logger
@ -1634,7 +1634,7 @@ class FunctionCall(OpenAIBaseModel):
class ToolCall(OpenAIBaseModel):
id: str = Field(default_factory=random_tool_call_id)
id: str = Field(default_factory=make_tool_call_id)
type: Literal["function"] = "function"
function: FunctionCall

View File

@ -19,7 +19,8 @@ from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
ConversationMessage,
random_tool_call_id)
get_history_tool_calls_cnt,
make_tool_call_id)
from vllm.entrypoints.harmony_utils import (
get_developer_message, get_stop_tokens_for_assistant_actions,
get_streamable_parser_for_assistant, get_system_message, parse_chat_input,
@ -133,6 +134,10 @@ class OpenAIServingChat(OpenAIServing):
source = "model" if source == "auto" else source
logger.info("Using default chat sampling params from %s: %s",
source, self.default_sampling_params)
if self.model_config.hf_config.model_type == 'kimi_k2':
self.tool_call_id_type = 'kimi_k2'
else:
self.tool_call_id_type = 'random'
self.use_harmony = model_config.hf_config.model_type == "gpt_oss"
if self.use_harmony:
@ -379,6 +384,7 @@ class OpenAIServingChat(OpenAIServing):
current_text: Optional[str],
delta_text: str,
function_name_returned: bool,
tool_call_idx: Optional[int] = None
) -> tuple[Optional[DeltaMessage], bool]:
if current_text is None or current_text == "":
# if the current text is empty, we cannot parse it
@ -424,8 +430,12 @@ class OpenAIServingChat(OpenAIServing):
current_tool_call = obj[-2]
function_name_returned = True
tool_call_id = make_tool_call_id(
id_type=self.tool_call_id_type,
func_name=current_tool_call["name"],
idx=tool_call_idx)
delta_message = DeltaMessage(tool_calls=[
DeltaToolCall(id=random_tool_call_id(),
DeltaToolCall(id=tool_call_id,
function=DeltaFunctionCall(
name=current_tool_call["name"],
arguments=arguments),
@ -491,6 +501,10 @@ class OpenAIServingChat(OpenAIServing):
all_previous_token_ids: Optional[list[list[int]]]
function_name_returned = [False] * num_choices
if self.tool_call_id_type == 'kimi_k2':
history_tool_call_cnt = get_history_tool_calls_cnt(conversation)
else:
history_tool_call_cnt = 0
# Always track previous_texts for comprehensive output logging
previous_texts = [""] * num_choices
@ -673,7 +687,6 @@ class OpenAIServingChat(OpenAIServing):
previous_text = previous_texts[i]
previous_token_ids = all_previous_token_ids[i]
current_text = previous_text + delta_text
# avoid the None + list error.
if previous_token_ids:
current_token_ids = previous_token_ids + as_list(
@ -733,7 +746,7 @@ class OpenAIServingChat(OpenAIServing):
index=i)
else:
delta_tool_call = DeltaToolCall(
id=random_tool_call_id(),
id=make_tool_call_id(),
type="function",
function=DeltaFunctionCall(
name=tool_choice_function_name,
@ -764,7 +777,11 @@ class OpenAIServingChat(OpenAIServing):
previous_text=previous_text,
current_text=content,
delta_text=delta_text,
function_name_returned=fn_name_returned))
function_name_returned=fn_name_returned,
tool_call_idx=history_tool_call_cnt))
if (delta_message and delta_message.tool_calls and
delta_message.tool_calls[0].id is not None):
history_tool_call_cnt += 1
# update the previous values for the next iteration
previous_texts[i] = current_text
@ -1089,6 +1106,10 @@ class OpenAIServingChat(OpenAIServing):
assert final_res is not None
choices: list[ChatCompletionResponseChoice] = []
if self.tool_call_id_type == 'kimi_k2':
history_tool_call_cnt = get_history_tool_calls_cnt(conversation)
else:
history_tool_call_cnt = 0
role = self.get_chat_request_role(request)
for output in final_res.outputs:
@ -1194,17 +1215,26 @@ class OpenAIServingChat(OpenAIServing):
assert content is not None
tool_calls = TypeAdapter(
list[FunctionDefinition]).validate_json(content)
tool_call_ids = []
for tool_call in tool_calls:
tool_call_ids.append(
make_tool_call_id(id_type=self.tool_call_id_type,
func_name=tool_call.name,
idx=history_tool_call_cnt))
history_tool_call_cnt += 1
message = ChatMessage(
role=role,
content="",
reasoning_content=reasoning_content,
tool_calls=[
tool_call_class(function=FunctionCall(
name=tool_call.name,
arguments=json.dumps(tool_call.parameters,
ensure_ascii=False)))
for tool_call in tool_calls
])
tool_call_class(id=tool_call_ids[i],
function=FunctionCall(
name=tool_call.name,
arguments=json.dumps(
tool_call.parameters,
ensure_ascii=False)))
for i, tool_call in enumerate(tool_calls)
],
reasoning_content=reasoning_content)
# if the request doesn't use tool choice
# OR specifies to not use a tool
@ -1248,7 +1278,6 @@ class OpenAIServingChat(OpenAIServing):
if (tool_call_info.content
and len(tool_call_info.content) > 0):
ret_content = tool_call_info.content
message = ChatMessage(role=role,
reasoning_content=reasoning_content,
content=ret_content)
@ -1327,12 +1356,11 @@ class OpenAIServingChat(OpenAIServing):
elif choice.message.tool_calls:
# For tool calls, log the function name and arguments
tool_call_descriptions = []
for tool_call in choice.message.tool_calls:
if hasattr(tool_call.function, "name") and hasattr(
tool_call.function, "arguments"):
for tc in choice.message.tool_calls:
if hasattr(tc.function, "name") and hasattr(
tc.function, "arguments"):
tool_call_descriptions.append(
f"{tool_call.function.name}({tool_call.function.arguments})"
)
f"{tc.function.name}({tc.function.arguments})")
tool_calls_str = ", ".join(tool_call_descriptions)
output_text = f"[tool_calls: {tool_calls_str}]"

View File

@ -6,7 +6,7 @@ from typing import Union
import regex as re
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
@ -267,7 +267,7 @@ class DeepSeekV3ToolParser(ToolParser):
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=random_tool_call_id(),
id=make_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True),

View File

@ -10,7 +10,7 @@ import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
@ -203,7 +203,7 @@ class Granite20bFCToolParser(ToolParser):
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=random_tool_call_id(),
id=make_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))

View File

@ -8,7 +8,7 @@ from typing import Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
@ -185,7 +185,7 @@ class GraniteToolParser(ToolParser):
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=random_tool_call_id(),
id=make_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))

View File

@ -9,7 +9,7 @@ import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
@ -307,7 +307,7 @@ class Hermes2ProToolParser(ToolParser):
return DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=random_tool_call_id(),
id=make_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))

View File

@ -8,7 +8,7 @@ from typing import Union
import partial_json_parser
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
@ -107,7 +107,7 @@ class Internlm2ToolParser(ToolParser):
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=random_tool_call_id(),
id=make_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))

View File

@ -9,7 +9,7 @@ import partial_json_parser
import regex as re
from partial_json_parser.core.options import Allow
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
@ -222,7 +222,7 @@ class JambaToolParser(ToolParser):
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=random_tool_call_id(),
id=make_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))

View File

@ -10,7 +10,7 @@ import regex as re
from partial_json_parser.core.options import Allow
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
@ -213,7 +213,7 @@ class Llama3JsonToolParser(ToolParser):
delta = DeltaMessage(tool_calls=[
DeltaToolCall(index=self.current_tool_id,
type="function",
id=random_tool_call_id(),
id=make_tool_call_id(),
function=DeltaFunctionCall(
name=function_name).model_dump(
exclude_none=True))

View File

@ -7,7 +7,7 @@ from typing import Any, Optional, Union
import regex as re
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
@ -394,7 +394,7 @@ class MinimaxToolParser(ToolParser):
sent_tools.append({
"sent_name": False,
"sent_arguments": "",
"id": random_tool_call_id(),
"id": make_tool_call_id(),
})
while len(tool_ids) < tool_count:

View File

@ -8,7 +8,7 @@ from typing import Any, Optional
import regex as re
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage,
ExtractedToolCallInformation,
@ -74,7 +74,7 @@ class Phi4MiniJsonToolParser(ToolParser):
tool_calls: list[ToolCall] = [
ToolCall(
id=random_tool_call_id(),
id=make_tool_call_id(),
type="function",
function=FunctionCall(
name=raw_function_call["name"],

View File

@ -7,7 +7,7 @@ from typing import Any, Optional, Union
import regex as re
from vllm.entrypoints.chat_utils import random_tool_call_id
from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaFunctionCall, DeltaMessage,
DeltaToolCall,
@ -226,7 +226,7 @@ class xLAMToolParser(ToolParser):
function_name = name_match.group(1)
# The test expects us to send just the name first
tool_id = random_tool_call_id()
tool_id = make_tool_call_id()
delta = DeltaMessage(tool_calls=[
DeltaToolCall(
index=0,