mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
142 lines
4.4 KiB
Python
142 lines
4.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import openai # use the official client for correctness check
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from ...utils import RemoteOpenAIServer
|
|
|
|
# a reasoning and tool calling model
|
|
MODEL_NAME = "Qwen/QwQ-32B"
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def server(): # noqa: F811
|
|
args = [
|
|
"--max-model-len",
|
|
"8192",
|
|
"--enforce-eager",
|
|
"--reasoning-parser",
|
|
"deepseek_r1",
|
|
"--enable-auto-tool-choice",
|
|
"--tool-call-parser",
|
|
"hermes",
|
|
]
|
|
|
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
|
yield remote_server
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def client(server):
|
|
async with server.get_async_client() as async_client:
|
|
yield async_client
|
|
|
|
|
|
TOOLS = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "get_current_weather",
|
|
"description": "Get the current weather in a given location",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"city": {
|
|
"type": "string",
|
|
"description": "The city to find the weather for, e.g. "
|
|
"'San Francisco'",
|
|
},
|
|
"state": {
|
|
"type": "string",
|
|
"description": "the two-letter abbreviation for the state that "
|
|
"the city is in, e.g. 'CA' which would mean 'California'",
|
|
},
|
|
"unit": {
|
|
"type": "string",
|
|
"description": "The unit to fetch the temperature in",
|
|
"enum": ["celsius", "fahrenheit"],
|
|
},
|
|
},
|
|
"required": ["city", "state", "unit"],
|
|
},
|
|
},
|
|
}
|
|
]
|
|
|
|
MESSAGES = [
|
|
{"role": "user", "content": "Hi! How are you doing today?"},
|
|
{"role": "assistant", "content": "I'm doing well! How can I help you?"},
|
|
{
|
|
"role": "user",
|
|
"content": "Can you tell me what the temperate will be in Dallas, "
|
|
"in fahrenheit?",
|
|
},
|
|
]
|
|
|
|
FUNC_NAME = "get_current_weather"
|
|
FUNC_ARGS = """{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}"""
|
|
|
|
|
|
def extract_reasoning_and_calls(chunks: list):
|
|
reasoning_content = ""
|
|
tool_call_idx = -1
|
|
arguments = []
|
|
function_names = []
|
|
for chunk in chunks:
|
|
if chunk.choices[0].delta.tool_calls:
|
|
tool_call = chunk.choices[0].delta.tool_calls[0]
|
|
if tool_call.index != tool_call_idx:
|
|
tool_call_idx = chunk.choices[0].delta.tool_calls[0].index
|
|
arguments.append("")
|
|
function_names.append("")
|
|
|
|
if tool_call.function:
|
|
if tool_call.function.name:
|
|
function_names[tool_call_idx] = tool_call.function.name
|
|
|
|
if tool_call.function.arguments:
|
|
arguments[tool_call_idx] += tool_call.function.arguments
|
|
else:
|
|
if hasattr(chunk.choices[0].delta, "reasoning_content"):
|
|
reasoning_content += chunk.choices[0].delta.reasoning_content
|
|
return reasoning_content, arguments, function_names
|
|
|
|
|
|
# test streaming
|
|
@pytest.mark.asyncio
|
|
async def test_chat_streaming_of_tool_and_reasoning(client: openai.AsyncOpenAI):
|
|
stream = await client.chat.completions.create(
|
|
model=MODEL_NAME,
|
|
messages=MESSAGES,
|
|
tools=TOOLS,
|
|
temperature=0.0,
|
|
stream=True,
|
|
)
|
|
|
|
chunks = []
|
|
async for chunk in stream:
|
|
chunks.append(chunk)
|
|
|
|
reasoning_content, arguments, function_names = extract_reasoning_and_calls(chunks)
|
|
assert len(reasoning_content) > 0
|
|
assert len(function_names) > 0 and function_names[0] == FUNC_NAME
|
|
assert len(arguments) > 0 and arguments[0] == FUNC_ARGS
|
|
|
|
|
|
# test full generate
|
|
@pytest.mark.asyncio
|
|
async def test_chat_full_of_tool_and_reasoning(client: openai.AsyncOpenAI):
|
|
tool_calls = await client.chat.completions.create(
|
|
model=MODEL_NAME,
|
|
messages=MESSAGES,
|
|
tools=TOOLS,
|
|
temperature=0.0,
|
|
stream=False,
|
|
)
|
|
|
|
assert len(tool_calls.choices[0].message.reasoning_content) > 0
|
|
assert tool_calls.choices[0].message.tool_calls[0].function.name == FUNC_NAME
|
|
assert tool_calls.choices[0].message.tool_calls[0].function.arguments == FUNC_ARGS
|