mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Fix] correct tool_id for kimi-k2 when use tool_choice=required (#21259)
Co-authored-by: wangzhengtao <wangzhengtao@msh.team>
This commit is contained in:
@ -13,6 +13,127 @@ from ...utils import RemoteOpenAIServer
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "Qwen/Qwen3-0.6B"
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city to find the weather for, e.g. 'Vienna'",
|
||||
"default": "Vienna",
|
||||
},
|
||||
"country": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The country that the city is in, e.g. 'Austria'",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
"options": {
|
||||
"$ref": "#/$defs/WeatherOptions",
|
||||
"description": "Optional parameters for weather query",
|
||||
},
|
||||
},
|
||||
"required": ["country", "unit"],
|
||||
"$defs": {
|
||||
"WeatherOptions": {
|
||||
"title": "WeatherOptions",
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
"properties": {
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"default": "celsius",
|
||||
"description": "Temperature unit",
|
||||
"title": "Temperature Unit",
|
||||
},
|
||||
"include_forecast": {
|
||||
"type": "boolean",
|
||||
"default": False,
|
||||
"description":
|
||||
"Whether to include a 24-hour forecast",
|
||||
"title": "Include Forecast",
|
||||
},
|
||||
"language": {
|
||||
"type": "string",
|
||||
"default": "zh-CN",
|
||||
"description": "Language of the response",
|
||||
"title": "Language",
|
||||
"enum": ["zh-CN", "en-US", "ja-JP"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_forecast",
|
||||
"description": "Get the weather forecast for a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city to get the forecast for, e.g. 'Vienna'",
|
||||
"default": "Vienna",
|
||||
},
|
||||
"country": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The country that the city is in, e.g. 'Austria'",
|
||||
},
|
||||
"days": {
|
||||
"type":
|
||||
"integer",
|
||||
"description":
|
||||
"Number of days to get the forecast for (1-7)",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description": "The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["country", "days", "unit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi! How are you doing today?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well! How can I help you?"
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Can you tell me what the current weather is in Berlin and the "\
|
||||
"forecast for the next 5 days, in fahrenheit?",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server(): # noqa: F811
|
||||
@ -27,6 +148,8 @@ def server(): # noqa: F811
|
||||
"hermes",
|
||||
"--reasoning-parser",
|
||||
"qwen3",
|
||||
"--gpu-memory-utilization",
|
||||
"0.4"
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
@ -54,129 +177,6 @@ async def client(server):
|
||||
async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
|
||||
stream: bool, tool_choice: Union[str, dict],
|
||||
enable_thinking: bool):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city to find the weather for, e.g. 'Vienna'",
|
||||
"default": "Vienna",
|
||||
},
|
||||
"country": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The country that the city is in, e.g. 'Austria'",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
"options": {
|
||||
"$ref": "#/$defs/WeatherOptions",
|
||||
"description":
|
||||
"Optional parameters for weather query",
|
||||
},
|
||||
},
|
||||
"required": ["country", "unit"],
|
||||
"$defs": {
|
||||
"WeatherOptions": {
|
||||
"title": "WeatherOptions",
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
"properties": {
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"default": "celsius",
|
||||
"description": "Temperature unit",
|
||||
"title": "Temperature Unit",
|
||||
},
|
||||
"include_forecast": {
|
||||
"type": "boolean",
|
||||
"default": False,
|
||||
"description":
|
||||
"Whether to include a 24-hour forecast",
|
||||
"title": "Include Forecast",
|
||||
},
|
||||
"language": {
|
||||
"type": "string",
|
||||
"default": "zh-CN",
|
||||
"description": "Language of the response",
|
||||
"title": "Language",
|
||||
"enum": ["zh-CN", "en-US", "ja-JP"],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_forecast",
|
||||
"description": "Get the weather forecast for a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The city to get the forecast for, e.g. 'Vienna'",
|
||||
"default": "Vienna",
|
||||
},
|
||||
"country": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The country that the city is in, e.g. 'Austria'",
|
||||
},
|
||||
"days": {
|
||||
"type":
|
||||
"integer",
|
||||
"description":
|
||||
"Number of days to get the forecast for (1-7)",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"description":
|
||||
"The unit to fetch the temperature in",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
},
|
||||
},
|
||||
"required": ["country", "days", "unit"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hi! How are you doing today?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'm doing well! How can I help you?"
|
||||
},
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content":
|
||||
"Can you tell me what the current weather is in Berlin and the "\
|
||||
"forecast for the next 5 days, in fahrenheit?",
|
||||
},
|
||||
]
|
||||
if not stream:
|
||||
# Non-streaming test
|
||||
chat_completion = await client.chat.completions.create(
|
||||
@ -216,3 +216,71 @@ async def test_function_tool_use(client: openai.AsyncOpenAI, model_name: str,
|
||||
output.extend(chunk.choices[0].delta.tool_calls)
|
||||
|
||||
assert len(output) > 0
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def k2_server(): # noqa: F811
|
||||
args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"half",
|
||||
"--enable-auto-tool-choice",
|
||||
"--guided-decoding-backend",
|
||||
"xgrammar",
|
||||
"--tool-call-parser",
|
||||
"hermes",
|
||||
"--reasoning-parser",
|
||||
"qwen3",
|
||||
"--gpu-memory-utilization",
|
||||
"0.4",
|
||||
]
|
||||
# hack to test kimi_k2 tool use tool_id format.
|
||||
# avoid error in is_deepseek_mla check by setting kv_lora_rank=null
|
||||
with RemoteOpenAIServer(MODEL_NAME,
|
||||
args,
|
||||
override_hf_configs={
|
||||
"model_type": 'kimi_k2',
|
||||
'kv_lora_rank': None
|
||||
}) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def k2_client(k2_server):
|
||||
async with k2_server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
@pytest.mark.parametrize("tool_choice", ["required"])
|
||||
async def test_tool_id_kimi_k2(k2_client: openai.AsyncOpenAI, model_name: str,
|
||||
stream: bool, tool_choice: str):
|
||||
|
||||
if not stream:
|
||||
# Non-streaming test
|
||||
chat_completion = await k2_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice)
|
||||
assert chat_completion.choices[0].message.tool_calls is not None
|
||||
assert len(chat_completion.choices[0].message.tool_calls) > 0
|
||||
assert chat_completion.choices[0].message.tool_calls[
|
||||
0].id == 'functions.get_current_weather:0'
|
||||
else:
|
||||
# Streaming test
|
||||
output_stream = await k2_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model_name,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=True)
|
||||
|
||||
output = []
|
||||
async for chunk in output_stream:
|
||||
if chunk.choices and chunk.choices[0].delta.tool_calls:
|
||||
output.extend(chunk.choices[0].delta.tool_calls)
|
||||
for o in output:
|
||||
assert o.id is None or o.id == 'functions.get_current_weather:0'
|
||||
|
@ -5,6 +5,7 @@ import asyncio
|
||||
import copy
|
||||
import functools
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
@ -101,7 +102,8 @@ class RemoteOpenAIServer:
|
||||
env_dict: Optional[dict[str, str]] = None,
|
||||
seed: Optional[int] = 0,
|
||||
auto_port: bool = True,
|
||||
max_wait_seconds: Optional[float] = None) -> None:
|
||||
max_wait_seconds: Optional[float] = None,
|
||||
override_hf_configs: Optional[dict[str, Any]] = None) -> None:
|
||||
if auto_port:
|
||||
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
|
||||
raise ValueError("You have manually specified the port "
|
||||
@ -120,6 +122,12 @@ class RemoteOpenAIServer:
|
||||
|
||||
vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]
|
||||
|
||||
if override_hf_configs is not None:
|
||||
vllm_serve_args = vllm_serve_args + [
|
||||
"--hf-overrides",
|
||||
json.dumps(override_hf_configs)
|
||||
]
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM's remote OpenAI server.")
|
||||
subparsers = parser.add_subparsers(required=False, dest="subparser")
|
||||
|
@ -1345,5 +1345,18 @@ def apply_mistral_chat_template(
|
||||
"template")
|
||||
raise ValueError(str(e)) from e
|
||||
|
||||
def random_tool_call_id() -> str:
|
||||
return f"chatcmpl-tool-{random_uuid()}"
|
||||
def get_history_tool_calls_cnt(conversation: list[ConversationMessage]):
|
||||
idx = 0
|
||||
for msg in conversation:
|
||||
if msg['role'] == 'assistant':
|
||||
tool_calls = msg.get('tool_calls')
|
||||
idx += len(list(tool_calls)) if tool_calls is not None else 0 # noqa
|
||||
return idx
|
||||
|
||||
def make_tool_call_id(id_type:str='random', func_name=None, idx=None):
|
||||
|
||||
if id_type=='kimi_k2':
|
||||
return f'functions.{func_name}:{idx}'
|
||||
else:
|
||||
# by default return random
|
||||
return f"chatcmpl-tool-{random_uuid()}"
|
||||
|
@ -38,7 +38,7 @@ from typing_extensions import TypeAlias
|
||||
|
||||
from vllm import envs
|
||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
random_tool_call_id)
|
||||
make_tool_call_id)
|
||||
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
|
||||
ScoreMultiModalParam)
|
||||
from vllm.logger import init_logger
|
||||
@ -1634,7 +1634,7 @@ class FunctionCall(OpenAIBaseModel):
|
||||
|
||||
|
||||
class ToolCall(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=random_tool_call_id)
|
||||
id: str = Field(default_factory=make_tool_call_id)
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionCall
|
||||
|
||||
|
@ -19,7 +19,8 @@ from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
|
||||
ConversationMessage,
|
||||
random_tool_call_id)
|
||||
get_history_tool_calls_cnt,
|
||||
make_tool_call_id)
|
||||
from vllm.entrypoints.harmony_utils import (
|
||||
get_developer_message, get_stop_tokens_for_assistant_actions,
|
||||
get_streamable_parser_for_assistant, get_system_message, parse_chat_input,
|
||||
@ -133,6 +134,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
source = "model" if source == "auto" else source
|
||||
logger.info("Using default chat sampling params from %s: %s",
|
||||
source, self.default_sampling_params)
|
||||
if self.model_config.hf_config.model_type == 'kimi_k2':
|
||||
self.tool_call_id_type = 'kimi_k2'
|
||||
else:
|
||||
self.tool_call_id_type = 'random'
|
||||
|
||||
self.use_harmony = model_config.hf_config.model_type == "gpt_oss"
|
||||
if self.use_harmony:
|
||||
@ -379,6 +384,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
current_text: Optional[str],
|
||||
delta_text: str,
|
||||
function_name_returned: bool,
|
||||
tool_call_idx: Optional[int] = None
|
||||
) -> tuple[Optional[DeltaMessage], bool]:
|
||||
if current_text is None or current_text == "":
|
||||
# if the current text is empty, we cannot parse it
|
||||
@ -424,8 +430,12 @@ class OpenAIServingChat(OpenAIServing):
|
||||
current_tool_call = obj[-2]
|
||||
|
||||
function_name_returned = True
|
||||
tool_call_id = make_tool_call_id(
|
||||
id_type=self.tool_call_id_type,
|
||||
func_name=current_tool_call["name"],
|
||||
idx=tool_call_idx)
|
||||
delta_message = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(id=random_tool_call_id(),
|
||||
DeltaToolCall(id=tool_call_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=current_tool_call["name"],
|
||||
arguments=arguments),
|
||||
@ -491,6 +501,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
all_previous_token_ids: Optional[list[list[int]]]
|
||||
function_name_returned = [False] * num_choices
|
||||
if self.tool_call_id_type == 'kimi_k2':
|
||||
history_tool_call_cnt = get_history_tool_calls_cnt(conversation)
|
||||
else:
|
||||
history_tool_call_cnt = 0
|
||||
|
||||
# Always track previous_texts for comprehensive output logging
|
||||
previous_texts = [""] * num_choices
|
||||
@ -673,7 +687,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
previous_text = previous_texts[i]
|
||||
previous_token_ids = all_previous_token_ids[i]
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
# avoid the None + list error.
|
||||
if previous_token_ids:
|
||||
current_token_ids = previous_token_ids + as_list(
|
||||
@ -733,7 +746,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
index=i)
|
||||
else:
|
||||
delta_tool_call = DeltaToolCall(
|
||||
id=random_tool_call_id(),
|
||||
id=make_tool_call_id(),
|
||||
type="function",
|
||||
function=DeltaFunctionCall(
|
||||
name=tool_choice_function_name,
|
||||
@ -764,7 +777,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
previous_text=previous_text,
|
||||
current_text=content,
|
||||
delta_text=delta_text,
|
||||
function_name_returned=fn_name_returned))
|
||||
function_name_returned=fn_name_returned,
|
||||
tool_call_idx=history_tool_call_cnt))
|
||||
if (delta_message and delta_message.tool_calls and
|
||||
delta_message.tool_calls[0].id is not None):
|
||||
history_tool_call_cnt += 1
|
||||
|
||||
# update the previous values for the next iteration
|
||||
previous_texts[i] = current_text
|
||||
@ -1089,6 +1106,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
assert final_res is not None
|
||||
|
||||
choices: list[ChatCompletionResponseChoice] = []
|
||||
if self.tool_call_id_type == 'kimi_k2':
|
||||
history_tool_call_cnt = get_history_tool_calls_cnt(conversation)
|
||||
else:
|
||||
history_tool_call_cnt = 0
|
||||
|
||||
role = self.get_chat_request_role(request)
|
||||
for output in final_res.outputs:
|
||||
@ -1194,17 +1215,26 @@ class OpenAIServingChat(OpenAIServing):
|
||||
assert content is not None
|
||||
tool_calls = TypeAdapter(
|
||||
list[FunctionDefinition]).validate_json(content)
|
||||
tool_call_ids = []
|
||||
for tool_call in tool_calls:
|
||||
tool_call_ids.append(
|
||||
make_tool_call_id(id_type=self.tool_call_id_type,
|
||||
func_name=tool_call.name,
|
||||
idx=history_tool_call_cnt))
|
||||
history_tool_call_cnt += 1
|
||||
message = ChatMessage(
|
||||
role=role,
|
||||
content="",
|
||||
reasoning_content=reasoning_content,
|
||||
tool_calls=[
|
||||
tool_call_class(function=FunctionCall(
|
||||
name=tool_call.name,
|
||||
arguments=json.dumps(tool_call.parameters,
|
||||
ensure_ascii=False)))
|
||||
for tool_call in tool_calls
|
||||
])
|
||||
tool_call_class(id=tool_call_ids[i],
|
||||
function=FunctionCall(
|
||||
name=tool_call.name,
|
||||
arguments=json.dumps(
|
||||
tool_call.parameters,
|
||||
ensure_ascii=False)))
|
||||
for i, tool_call in enumerate(tool_calls)
|
||||
],
|
||||
reasoning_content=reasoning_content)
|
||||
|
||||
# if the request doesn't use tool choice
|
||||
# OR specifies to not use a tool
|
||||
@ -1248,7 +1278,6 @@ class OpenAIServingChat(OpenAIServing):
|
||||
if (tool_call_info.content
|
||||
and len(tool_call_info.content) > 0):
|
||||
ret_content = tool_call_info.content
|
||||
|
||||
message = ChatMessage(role=role,
|
||||
reasoning_content=reasoning_content,
|
||||
content=ret_content)
|
||||
@ -1327,12 +1356,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
elif choice.message.tool_calls:
|
||||
# For tool calls, log the function name and arguments
|
||||
tool_call_descriptions = []
|
||||
for tool_call in choice.message.tool_calls:
|
||||
if hasattr(tool_call.function, "name") and hasattr(
|
||||
tool_call.function, "arguments"):
|
||||
for tc in choice.message.tool_calls:
|
||||
if hasattr(tc.function, "name") and hasattr(
|
||||
tc.function, "arguments"):
|
||||
tool_call_descriptions.append(
|
||||
f"{tool_call.function.name}({tool_call.function.arguments})"
|
||||
)
|
||||
f"{tc.function.name}({tc.function.arguments})")
|
||||
tool_calls_str = ", ".join(tool_call_descriptions)
|
||||
output_text = f"[tool_calls: {tool_calls_str}]"
|
||||
|
||||
|
@ -6,7 +6,7 @@ from typing import Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
@ -267,7 +267,7 @@ class DeepSeekV3ToolParser(ToolParser):
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=random_tool_call_id(),
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True),
|
||||
|
@ -10,7 +10,7 @@ import partial_json_parser
|
||||
import regex as re
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
@ -203,7 +203,7 @@ class Granite20bFCToolParser(ToolParser):
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=random_tool_call_id(),
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
|
@ -8,7 +8,7 @@ from typing import Union
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
@ -185,7 +185,7 @@ class GraniteToolParser(ToolParser):
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=random_tool_call_id(),
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
|
@ -9,7 +9,7 @@ import partial_json_parser
|
||||
import regex as re
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
@ -307,7 +307,7 @@ class Hermes2ProToolParser(ToolParser):
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=random_tool_call_id(),
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
|
@ -8,7 +8,7 @@ from typing import Union
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
@ -107,7 +107,7 @@ class Internlm2ToolParser(ToolParser):
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=random_tool_call_id(),
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
|
@ -9,7 +9,7 @@ import partial_json_parser
|
||||
import regex as re
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
@ -222,7 +222,7 @@ class JambaToolParser(ToolParser):
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=random_tool_call_id(),
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
|
@ -10,7 +10,7 @@ import regex as re
|
||||
from partial_json_parser.core.options import Allow
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
@ -213,7 +213,7 @@ class Llama3JsonToolParser(ToolParser):
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=random_tool_call_id(),
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
|
@ -7,7 +7,7 @@ from typing import Any, Optional, Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
@ -394,7 +394,7 @@ class MinimaxToolParser(ToolParser):
|
||||
sent_tools.append({
|
||||
"sent_name": False,
|
||||
"sent_arguments": "",
|
||||
"id": random_tool_call_id(),
|
||||
"id": make_tool_call_id(),
|
||||
})
|
||||
|
||||
while len(tool_ids) < tool_count:
|
||||
|
@ -8,7 +8,7 @@ from typing import Any, Optional
|
||||
import regex as re
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation,
|
||||
@ -74,7 +74,7 @@ class Phi4MiniJsonToolParser(ToolParser):
|
||||
|
||||
tool_calls: list[ToolCall] = [
|
||||
ToolCall(
|
||||
id=random_tool_call_id(),
|
||||
id=make_tool_call_id(),
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=raw_function_call["name"],
|
||||
|
@ -7,7 +7,7 @@ from typing import Any, Optional, Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.chat_utils import random_tool_call_id
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
@ -226,7 +226,7 @@ class xLAMToolParser(ToolParser):
|
||||
function_name = name_match.group(1)
|
||||
|
||||
# The test expects us to send just the name first
|
||||
tool_id = random_tool_call_id()
|
||||
tool_id = make_tool_call_id()
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=0,
|
||||
|
Reference in New Issue
Block a user