mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Add return_token_ids parameter to OpenAI API endpoints (#22587)
Signed-off-by: Yuge Zhang <scottyugochang@gmail.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
@ -74,31 +74,44 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy):
|
||||
-d '{"messages": [{"role": "assistant", "tool_calls": [{"custom": {"input": "", "name": ""}, "id": "", "type": "custom"}]}]}' \
|
||||
http://localhost:8000/v1/chat/completions
|
||||
""" # noqa: E501
|
||||
if (hasattr(case, "body") and isinstance(case.body, dict)
|
||||
and "messages" in case.body
|
||||
and isinstance(case.body["messages"], list)
|
||||
and len(case.body["messages"]) > 0):
|
||||
if hasattr(case, "body") and isinstance(case.body, dict):
|
||||
if ("messages" in case.body
|
||||
and isinstance(case.body["messages"], list)
|
||||
and len(case.body["messages"]) > 0):
|
||||
|
||||
for message in case.body["messages"]:
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
for message in case.body["messages"]:
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
|
||||
# Check for invalid file type in tokenize endpoint
|
||||
if op.method.lower() == "post" and op.path == "/tokenize":
|
||||
content = message.get("content", [])
|
||||
if (isinstance(content, list) and len(content) > 0 and any(
|
||||
item.get("type") == "file" for item in content)):
|
||||
return False
|
||||
# Check for invalid file type in tokenize endpoint
|
||||
if op.method.lower() == "post" and op.path == "/tokenize":
|
||||
content = message.get("content", [])
|
||||
if (isinstance(content, list) and len(content) > 0
|
||||
and any(
|
||||
item.get("type") == "file"
|
||||
for item in content)):
|
||||
return False
|
||||
|
||||
# Check for invalid tool_calls with non-function types
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
if isinstance(tool_calls, list):
|
||||
for tool_call in tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
if tool_call.get("type") != "function":
|
||||
return False
|
||||
if "custom" in tool_call:
|
||||
return False
|
||||
|
||||
# Sometimes guided_grammar is generated to be empty
|
||||
# Causing a server error in EBNF grammar parsing
|
||||
# https://github.com/vllm-project/vllm/pull/22587#issuecomment-3195253421
|
||||
guided_grammar = case.body.get("guided_grammar")
|
||||
|
||||
if guided_grammar == '':
|
||||
# Allow None (will be handled as no grammar)
|
||||
# But skip empty strings
|
||||
return False
|
||||
|
||||
# Check for invalid tool_calls with non-function types
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
if isinstance(tool_calls, list):
|
||||
for tool_call in tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
if tool_call.get("type") != "function":
|
||||
return False
|
||||
if "custom" in tool_call:
|
||||
return False
|
||||
return True
|
||||
|
||||
return strategy.filter(no_invalid_types)
|
||||
|
374
tests/entrypoints/openai/test_return_token_ids.py
Normal file
374
tests/entrypoints/openai/test_return_token_ids.py
Normal file
@ -0,0 +1,374 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--enable-auto-tool-choice",
|
||||
"--tool-call-parser",
|
||||
"hermes",
|
||||
"--enforce-eager",
|
||||
]
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_completion_with_emoji(server):
|
||||
"""Test basic completion with emoji to verify token_ids field."""
|
||||
async with server.get_async_client() as client:
|
||||
# Test with return_token_ids enabled
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="Complete this sentence with emojis: I love coding 🚀",
|
||||
max_tokens=10,
|
||||
temperature=0,
|
||||
logprobs=1,
|
||||
extra_body={"return_token_ids": True},
|
||||
)
|
||||
|
||||
# Check the raw response to see the structure
|
||||
completion_dict = completion.model_dump()
|
||||
|
||||
# Verify prompt_token_ids field is present in the completion response
|
||||
assert "prompt_token_ids" in completion_dict["choices"][0]
|
||||
assert isinstance(completion.choices[0].prompt_token_ids, list)
|
||||
|
||||
# Check against the expected prompt token IDs
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
encoded_tokens = tokenizer.encode(
|
||||
"Complete this sentence with emojis: I love coding 🚀")
|
||||
# Check that encoded_tokens is a subsequence of prompt_token_ids
|
||||
assert any(completion.choices[0].prompt_token_ids[i:i +
|
||||
len(encoded_tokens)]
|
||||
== encoded_tokens for i in range(
|
||||
len(completion.choices[0].prompt_token_ids) -
|
||||
len(encoded_tokens) + 1))
|
||||
|
||||
# Verify token_ids field is present in the choice
|
||||
assert completion.choices[0].token_ids is not None
|
||||
assert isinstance(completion.choices[0].token_ids, list)
|
||||
assert len(completion.choices[0].token_ids) > 0
|
||||
|
||||
# Verify decoding works correctly
|
||||
decoded_text = tokenizer.decode(completion.choices[0].token_ids)
|
||||
# The decoded text should contain a <|im_end|> at the end
|
||||
assert decoded_text.startswith(completion.choices[0].text)
|
||||
|
||||
# Test without return_token_ids (should be None)
|
||||
completion_without = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="Complete this sentence with emojis: I love coding 🚀",
|
||||
max_tokens=10,
|
||||
temperature=0,
|
||||
logprobs=1,
|
||||
extra_body={"return_token_ids": False},
|
||||
)
|
||||
|
||||
completion_without_dict = completion_without.model_dump()
|
||||
assert completion_without_dict["choices"][0].get("token_ids") is None
|
||||
assert completion_without_dict.get("prompt_token_ids") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_with_tool_use(server):
|
||||
"""Test chat completion with tool use (get_weather function)."""
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type":
|
||||
"string",
|
||||
"description":
|
||||
"The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The unit of temperature",
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}]
|
||||
|
||||
async with server.get_async_client() as client:
|
||||
# Test with return_token_ids enabled
|
||||
response = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Paris?"
|
||||
},
|
||||
],
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
max_tokens=100,
|
||||
temperature=0,
|
||||
logprobs=True,
|
||||
extra_body={"return_token_ids": True},
|
||||
)
|
||||
|
||||
# Verify token_ids field is present in choices
|
||||
assert response.choices[0].token_ids is not None
|
||||
assert isinstance(response.choices[0].token_ids, list)
|
||||
|
||||
# Verify prompt_token_ids field is present
|
||||
assert response.prompt_token_ids is not None
|
||||
assert isinstance(response.prompt_token_ids, list)
|
||||
|
||||
# Verify the prompt texts and response texts
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
prompt_text = tokenizer.decode(response.prompt_token_ids)
|
||||
assert prompt_text.startswith(
|
||||
"<|im_start|>system\nYou are a helpful assistant.")
|
||||
assert prompt_text.endswith(
|
||||
"What's the weather like in Paris?<|im_end|>\n"
|
||||
"<|im_start|>assistant\n")
|
||||
|
||||
response_text = tokenizer.decode(response.choices[0].token_ids)
|
||||
assert response_text.startswith('<tool_call>\n{"name": "get_weather"')
|
||||
assert response_text.endswith("</tool_call><|im_end|>")
|
||||
|
||||
# If tool call was made, verify the response structure
|
||||
if response.choices[0].message.tool_calls:
|
||||
assert len(response.choices[0].message.tool_calls) > 0
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
assert tool_call.function.name == "get_weather"
|
||||
|
||||
# Test without return_token_ids
|
||||
response_without = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Paris?"
|
||||
},
|
||||
],
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
max_tokens=100,
|
||||
temperature=0,
|
||||
logprobs=True,
|
||||
extra_body={"return_token_ids": False},
|
||||
)
|
||||
|
||||
assert response_without.choices[0].token_ids is None
|
||||
assert response_without.prompt_token_ids is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_comparison_with_prompt_logprobs_and_logprobs(server):
|
||||
"""
|
||||
Test that token_ids align with prompt_logprobs and
|
||||
logprobs when return_tokens_as_token_ids is enabled.
|
||||
"""
|
||||
async with server.get_async_client() as client:
|
||||
# Test with both return_token_ids and return_tokens_as_token_ids enabled
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="Hello, world! How are you today?",
|
||||
max_tokens=20,
|
||||
temperature=0,
|
||||
echo=True,
|
||||
logprobs=1,
|
||||
extra_body={
|
||||
"return_token_ids": True,
|
||||
"return_tokens_as_token_ids": True,
|
||||
"prompt_logprobs": 1
|
||||
},
|
||||
)
|
||||
|
||||
# Verify all fields are present
|
||||
assert completion.choices[0].token_ids is not None
|
||||
assert completion.choices[0].prompt_token_ids is not None
|
||||
assert completion.choices[0].prompt_logprobs is not None
|
||||
assert completion.choices[0].logprobs is not None
|
||||
|
||||
# Extract token IDs from logprobs
|
||||
# (when return_tokens_as_token_ids is True)
|
||||
logprobs_token_ids = []
|
||||
for token_str in completion.choices[0].logprobs.tokens:
|
||||
# Token format is "token_id:12345" when
|
||||
# return_tokens_as_token_ids is True
|
||||
if token_str.startswith("token_id:"):
|
||||
token_id = int(token_str.removeprefix("token_id:"))
|
||||
logprobs_token_ids.append(token_id)
|
||||
|
||||
# When echo=True, the logprobs include both prompt and response tokens
|
||||
# The token_ids field should match the the suffix of response portion
|
||||
# The prompt_token_ids should match the prompt portion
|
||||
assert len(completion.choices[0].token_ids) < len(logprobs_token_ids)
|
||||
response_token_ids_length = len(completion.choices[0].token_ids)
|
||||
assert logprobs_token_ids[-response_token_ids_length:] == \
|
||||
completion.choices[0].token_ids
|
||||
|
||||
# Verify tokenizer consistency
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
|
||||
# Decode prompt tokens
|
||||
if completion.choices[0].prompt_token_ids:
|
||||
prompt_text = tokenizer.decode(
|
||||
completion.choices[0].prompt_token_ids)
|
||||
# The decoded prompt should match or close to original prompt
|
||||
assert "Hello, world" in prompt_text
|
||||
|
||||
# Decode response tokens
|
||||
if completion.choices[0].token_ids:
|
||||
response_text = tokenizer.decode(completion.choices[0].token_ids)
|
||||
assert completion.choices[0].text.endswith(response_text)
|
||||
|
||||
# Test streaming mode
|
||||
stream = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="Tell me a short fact about Python:",
|
||||
max_tokens=30,
|
||||
temperature=0,
|
||||
stream=True,
|
||||
echo=False,
|
||||
logprobs=1,
|
||||
extra_body={
|
||||
"return_token_ids": True,
|
||||
"return_tokens_as_token_ids": True
|
||||
},
|
||||
)
|
||||
|
||||
# Collect streamed tokens
|
||||
streamed_prompt_token_ids = []
|
||||
streamed_token_ids = []
|
||||
streamed_logprob_token_ids = []
|
||||
first_chunk = True
|
||||
async for chunk in stream:
|
||||
for token_str in chunk.choices[0].logprobs.tokens:
|
||||
# Token format is "token_id:12345" when
|
||||
# return_tokens_as_token_ids is True
|
||||
if token_str.startswith("token_id:"):
|
||||
token_id = int(token_str.removeprefix("token_id:"))
|
||||
streamed_logprob_token_ids.append(token_id)
|
||||
if first_chunk:
|
||||
streamed_prompt_token_ids = chunk.choices[0].prompt_token_ids
|
||||
first_chunk = False
|
||||
streamed_token_ids += chunk.choices[0].token_ids
|
||||
|
||||
# Verify we collected some tokens and first chunk had prompt_token_ids
|
||||
assert len(streamed_prompt_token_ids) > 0
|
||||
assert streamed_token_ids == streamed_logprob_token_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_with_emoji_and_token_ids(server):
|
||||
"""Test chat completion with emojis to verify token_ids handling."""
|
||||
chat_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You like to use emojis in your responses."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Repeat after me: I love cats 🐱"
|
||||
},
|
||||
]
|
||||
async with server.get_async_client() as client:
|
||||
response = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=chat_messages,
|
||||
max_tokens=50,
|
||||
temperature=0,
|
||||
logprobs=True,
|
||||
extra_body={"return_token_ids": True},
|
||||
)
|
||||
|
||||
# Verify token_ids are present
|
||||
response_dict = response.model_dump()
|
||||
assert response.choices[0].token_ids is not None
|
||||
assert "prompt_token_ids" in response_dict
|
||||
|
||||
# Verify the response contains the expected fields
|
||||
assert response.choices[0].message.content is not None
|
||||
|
||||
# Decode token_ids and verify consistency
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
|
||||
decoded_prompt = tokenizer.decode(response.prompt_token_ids)
|
||||
assert decoded_prompt.startswith(
|
||||
"<|im_start|>system\nYou like to use emojis in your responses.")
|
||||
assert decoded_prompt.endswith(
|
||||
"I love cats 🐱<|im_end|>\n<|im_start|>assistant\n")
|
||||
|
||||
decoded_response = tokenizer.decode(response.choices[0].token_ids)
|
||||
# The content should match the response text
|
||||
# except the ending <|im_end|>
|
||||
assert decoded_response == response.choices[
|
||||
0].message.content + "<|im_end|>"
|
||||
|
||||
# Test with streaming
|
||||
stream = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=chat_messages,
|
||||
max_tokens=50,
|
||||
temperature=0,
|
||||
stream=True,
|
||||
extra_body={"return_token_ids": True},
|
||||
)
|
||||
|
||||
collected_content = ""
|
||||
collected_token_ids = []
|
||||
first_chunk = True
|
||||
|
||||
async for chunk in stream:
|
||||
if first_chunk:
|
||||
assert chunk.prompt_token_ids is not None
|
||||
assert isinstance(chunk.prompt_token_ids, list)
|
||||
# Check the prompt_token_ids match the initial prompt
|
||||
decoded_prompt_stream = tokenizer.decode(
|
||||
chunk.prompt_token_ids)
|
||||
assert decoded_prompt_stream == decoded_prompt
|
||||
first_chunk = False
|
||||
else:
|
||||
chunk_dump = chunk.model_dump()
|
||||
assert "prompt_token_ids" not in chunk_dump, \
|
||||
"Subsequent chunks should not have prompt_token_ids"
|
||||
|
||||
if chunk.choices:
|
||||
if chunk.choices[0].delta.content:
|
||||
collected_content += chunk.choices[0].delta.content
|
||||
# token_ids may not present in all chunks
|
||||
choice_dump = chunk.choices[0].model_dump()
|
||||
if "token_ids" in choice_dump:
|
||||
collected_token_ids.extend(chunk.choices[0].token_ids)
|
||||
|
||||
# Verify we got response and token_ids
|
||||
assert len(collected_content) > 0
|
||||
assert len(collected_token_ids) > 0
|
||||
|
||||
# Verify token_ids decode properly
|
||||
decoded_response = tokenizer.decode(collected_token_ids)
|
||||
assert decoded_response == collected_content + "<|im_end|>"
|
@ -576,6 +576,14 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
"If specified with 'logprobs', tokens are represented "
|
||||
" as strings of the form 'token_id:{token_id}' so that tokens "
|
||||
"that are not JSON-encodable can be identified."))
|
||||
return_token_ids: Optional[bool] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the result will include token IDs alongside the "
|
||||
"generated text. In streaming mode, prompt_token_ids is included "
|
||||
"only in the first chunk, and token_ids contains the delta tokens "
|
||||
"for each chunk. This is useful for debugging or when you "
|
||||
"need to map generated text back to input tokens."))
|
||||
cache_salt: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
@ -1062,6 +1070,14 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
"If specified with 'logprobs', tokens are represented "
|
||||
" as strings of the form 'token_id:{token_id}' so that tokens "
|
||||
"that are not JSON-encodable can be identified."))
|
||||
return_token_ids: Optional[bool] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, the result will include token IDs alongside the "
|
||||
"generated text. In streaming mode, prompt_token_ids is included "
|
||||
"only in the first chunk, and token_ids contains the delta tokens "
|
||||
"for each chunk. This is useful for debugging or when you "
|
||||
"need to map generated text back to input tokens."))
|
||||
|
||||
cache_salt: Optional[str] = Field(
|
||||
default=None,
|
||||
@ -1480,7 +1496,9 @@ class CompletionResponseChoice(OpenAIBaseModel):
|
||||
"to stop, None if the completion finished for some other reason "
|
||||
"including encountering the EOS token"),
|
||||
)
|
||||
token_ids: Optional[list[int]] = None # For response
|
||||
prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
|
||||
prompt_token_ids: Optional[list[int]] = None # For prompt
|
||||
|
||||
|
||||
class CompletionResponse(OpenAIBaseModel):
|
||||
@ -1511,6 +1529,10 @@ class CompletionResponseStreamChoice(OpenAIBaseModel):
|
||||
"to stop, None if the completion finished for some other reason "
|
||||
"including encountering the EOS token"),
|
||||
)
|
||||
# not part of the OpenAI spec but for tracing the tokens
|
||||
# prompt tokens is put into choice to align with CompletionResponseChoice
|
||||
prompt_token_ids: Optional[list[int]] = None
|
||||
token_ids: Optional[list[int]] = None
|
||||
|
||||
|
||||
class CompletionStreamResponse(OpenAIBaseModel):
|
||||
@ -1680,6 +1702,9 @@ class ChatCompletionResponseChoice(OpenAIBaseModel):
|
||||
finish_reason: Optional[str] = "stop"
|
||||
# not part of the OpenAI spec but included in vLLM for legacy reasons
|
||||
stop_reason: Optional[Union[int, str]] = None
|
||||
# not part of the OpenAI spec but is useful for tracing the tokens
|
||||
# in agent scenarios
|
||||
token_ids: Optional[list[int]] = None
|
||||
|
||||
|
||||
class ChatCompletionResponse(OpenAIBaseModel):
|
||||
@ -1695,6 +1720,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
|
||||
|
||||
# vLLM-specific fields that are not in OpenAI spec
|
||||
prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
|
||||
prompt_token_ids: Optional[list[int]] = None
|
||||
kv_transfer_params: Optional[dict[str, Any]] = Field(
|
||||
default=None, description="KVTransfer parameters.")
|
||||
|
||||
@ -1712,6 +1738,8 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
|
||||
logprobs: Optional[ChatCompletionLogProbs] = None
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Optional[Union[int, str]] = None
|
||||
# not part of the OpenAI spec but for tracing the tokens
|
||||
token_ids: Optional[list[int]] = None
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(OpenAIBaseModel):
|
||||
@ -1721,6 +1749,8 @@ class ChatCompletionStreamResponse(OpenAIBaseModel):
|
||||
model: str
|
||||
choices: list[ChatCompletionResponseStreamChoice]
|
||||
usage: Optional[UsageInfo] = Field(default=None)
|
||||
# not part of the OpenAI spec but for tracing the tokens
|
||||
prompt_token_ids: Optional[list[int]] = None
|
||||
|
||||
|
||||
class TranscriptionResponseStreamChoice(OpenAIBaseModel):
|
||||
|
@ -568,12 +568,17 @@ class OpenAIServingChat(OpenAIServing):
|
||||
),
|
||||
logprobs=None,
|
||||
finish_reason=None)
|
||||
|
||||
# return prompt_token_ids at the first chunk ever
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
model=model_name,
|
||||
prompt_token_ids=(res.prompt_token_ids
|
||||
if request.return_token_ids else
|
||||
None))
|
||||
|
||||
# if continuous usage stats are requested, add it
|
||||
if include_continuous_usage:
|
||||
@ -912,7 +917,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
index=i,
|
||||
delta=delta_message,
|
||||
logprobs=logprobs,
|
||||
finish_reason=None)
|
||||
finish_reason=None,
|
||||
token_ids=(as_list(output.token_ids)
|
||||
if request.return_token_ids else None))
|
||||
|
||||
# if the model is finished generating
|
||||
else:
|
||||
@ -973,7 +980,9 @@ class OpenAIServingChat(OpenAIServing):
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason
|
||||
if not auto_tools_called else "tool_calls",
|
||||
stop_reason=output.stop_reason)
|
||||
stop_reason=output.stop_reason,
|
||||
token_ids=(as_list(output.token_ids)
|
||||
if request.return_token_ids else None))
|
||||
|
||||
finish_reason_sent[i] = True
|
||||
|
||||
@ -1260,7 +1269,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
logprobs=logprobs,
|
||||
finish_reason="tool_calls" if auto_tools_called else
|
||||
output.finish_reason if output.finish_reason else "stop",
|
||||
stop_reason=output.stop_reason)
|
||||
stop_reason=output.stop_reason,
|
||||
token_ids=(as_list(output.token_ids)
|
||||
if request.return_token_ids else None),
|
||||
)
|
||||
|
||||
choices.append(choice_data)
|
||||
|
||||
@ -1301,6 +1313,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
|
||||
prompt_token_ids=(final_res.prompt_token_ids
|
||||
if request.return_token_ids else None),
|
||||
kv_transfer_params=final_res.kv_transfer_params,
|
||||
)
|
||||
|
||||
|
@ -42,7 +42,7 @@ from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import merge_async_iterators
|
||||
from vllm.utils import as_list, merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -365,6 +365,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
for output in res.outputs:
|
||||
i = output.index + prompt_idx * num_choices
|
||||
|
||||
# Useful when request.return_token_ids is True
|
||||
# Returning prompt token IDs shares the same logic
|
||||
# with the echo implementation.
|
||||
prompt_token_ids_to_return: Optional[list[int]] = None
|
||||
|
||||
assert request.max_tokens is not None
|
||||
if request.echo and not has_echoed[i]:
|
||||
assert prompt_token_ids is not None
|
||||
@ -385,6 +390,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
*(prompt_logprobs or []),
|
||||
*(output.logprobs or []),
|
||||
]
|
||||
prompt_token_ids_to_return = prompt_token_ids
|
||||
has_echoed[i] = True
|
||||
else:
|
||||
# return just the delta
|
||||
@ -392,6 +398,12 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
delta_token_ids = output.token_ids
|
||||
out_logprobs = output.logprobs
|
||||
|
||||
# has_echoed[i] is reused here to indicate whether
|
||||
# we have already returned the prompt token IDs.
|
||||
if not has_echoed[i]:
|
||||
prompt_token_ids_to_return = prompt_token_ids
|
||||
has_echoed[i] = True
|
||||
|
||||
if (not delta_text and not delta_token_ids
|
||||
and not previous_num_tokens[i]):
|
||||
# Chunked prefill case, don't return empty chunks
|
||||
@ -428,6 +440,9 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason,
|
||||
stop_reason=stop_reason,
|
||||
prompt_token_ids=prompt_token_ids_to_return,
|
||||
token_ids=(as_list(output.token_ids) if
|
||||
request.return_token_ids else None),
|
||||
)
|
||||
],
|
||||
)
|
||||
@ -548,6 +563,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason,
|
||||
prompt_logprobs=final_res.prompt_logprobs,
|
||||
prompt_token_ids=(prompt_token_ids
|
||||
if request.return_token_ids else None),
|
||||
token_ids=(as_list(output.token_ids)
|
||||
if request.return_token_ids else None),
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
|
Reference in New Issue
Block a user