Add return_token_ids parameter to OpenAI API endpoints (#22587)

Signed-off-by: Yuge Zhang <scottyugochang@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
Yuge Zhang
2025-08-20 00:48:31 +08:00
committed by GitHub
parent 4f510bc2a1
commit 24f4d1a224
5 changed files with 477 additions and 27 deletions

View File

@ -74,31 +74,44 @@ def before_generate_case(context: schemathesis.hooks.HookContext, strategy):
-d '{"messages": [{"role": "assistant", "tool_calls": [{"custom": {"input": "", "name": ""}, "id": "", "type": "custom"}]}]}' \
http://localhost:8000/v1/chat/completions
""" # noqa: E501
if (hasattr(case, "body") and isinstance(case.body, dict)
and "messages" in case.body
and isinstance(case.body["messages"], list)
and len(case.body["messages"]) > 0):
if hasattr(case, "body") and isinstance(case.body, dict):
if ("messages" in case.body
and isinstance(case.body["messages"], list)
and len(case.body["messages"]) > 0):
for message in case.body["messages"]:
if not isinstance(message, dict):
continue
for message in case.body["messages"]:
if not isinstance(message, dict):
continue
# Check for invalid file type in tokenize endpoint
if op.method.lower() == "post" and op.path == "/tokenize":
content = message.get("content", [])
if (isinstance(content, list) and len(content) > 0 and any(
item.get("type") == "file" for item in content)):
return False
# Check for invalid file type in tokenize endpoint
if op.method.lower() == "post" and op.path == "/tokenize":
content = message.get("content", [])
if (isinstance(content, list) and len(content) > 0
and any(
item.get("type") == "file"
for item in content)):
return False
# Check for invalid tool_calls with non-function types
tool_calls = message.get("tool_calls", [])
if isinstance(tool_calls, list):
for tool_call in tool_calls:
if isinstance(tool_call, dict):
if tool_call.get("type") != "function":
return False
if "custom" in tool_call:
return False
# Sometimes guided_grammar is generated to be empty
# Causing a server error in EBNF grammar parsing
# https://github.com/vllm-project/vllm/pull/22587#issuecomment-3195253421
guided_grammar = case.body.get("guided_grammar")
if guided_grammar == '':
# Allow None (will be handled as no grammar)
# But skip empty strings
return False
# Check for invalid tool_calls with non-function types
tool_calls = message.get("tool_calls", [])
if isinstance(tool_calls, list):
for tool_call in tool_calls:
if isinstance(tool_call, dict):
if tool_call.get("type") != "function":
return False
if "custom" in tool_call:
return False
return True
return strategy.filter(no_invalid_types)

View File

@ -0,0 +1,374 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.transformers_utils.tokenizer import get_tokenizer
from ...utils import RemoteOpenAIServer
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
@pytest.fixture(scope="module")
def server():
args = [
"--max-model-len",
"2048",
"--max-num-seqs",
"128",
"--enable-auto-tool-choice",
"--tool-call-parser",
"hermes",
"--enforce-eager",
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.mark.asyncio
async def test_basic_completion_with_emoji(server):
"""Test basic completion with emoji to verify token_ids field."""
async with server.get_async_client() as client:
# Test with return_token_ids enabled
completion = await client.completions.create(
model=MODEL_NAME,
prompt="Complete this sentence with emojis: I love coding 🚀",
max_tokens=10,
temperature=0,
logprobs=1,
extra_body={"return_token_ids": True},
)
# Check the raw response to see the structure
completion_dict = completion.model_dump()
# Verify prompt_token_ids field is present in the completion response
assert "prompt_token_ids" in completion_dict["choices"][0]
assert isinstance(completion.choices[0].prompt_token_ids, list)
# Check against the expected prompt token IDs
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
encoded_tokens = tokenizer.encode(
"Complete this sentence with emojis: I love coding 🚀")
# Check that encoded_tokens is a subsequence of prompt_token_ids
assert any(completion.choices[0].prompt_token_ids[i:i +
len(encoded_tokens)]
== encoded_tokens for i in range(
len(completion.choices[0].prompt_token_ids) -
len(encoded_tokens) + 1))
# Verify token_ids field is present in the choice
assert completion.choices[0].token_ids is not None
assert isinstance(completion.choices[0].token_ids, list)
assert len(completion.choices[0].token_ids) > 0
# Verify decoding works correctly
decoded_text = tokenizer.decode(completion.choices[0].token_ids)
# The decoded text should contain a <|im_end|> at the end
assert decoded_text.startswith(completion.choices[0].text)
# Test without return_token_ids (should be None)
completion_without = await client.completions.create(
model=MODEL_NAME,
prompt="Complete this sentence with emojis: I love coding 🚀",
max_tokens=10,
temperature=0,
logprobs=1,
extra_body={"return_token_ids": False},
)
completion_without_dict = completion_without.model_dump()
assert completion_without_dict["choices"][0].get("token_ids") is None
assert completion_without_dict.get("prompt_token_ids") is None
@pytest.mark.asyncio
async def test_chat_completion_with_tool_use(server):
"""Test chat completion with tool use (get_weather function)."""
tools = [{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type":
"string",
"description":
"The city and state, e.g. San Francisco, CA",
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The unit of temperature",
},
},
"required": ["location"],
},
},
}]
async with server.get_async_client() as client:
# Test with return_token_ids enabled
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=[
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What's the weather like in Paris?"
},
],
tools=tools,
tool_choice="auto",
max_tokens=100,
temperature=0,
logprobs=True,
extra_body={"return_token_ids": True},
)
# Verify token_ids field is present in choices
assert response.choices[0].token_ids is not None
assert isinstance(response.choices[0].token_ids, list)
# Verify prompt_token_ids field is present
assert response.prompt_token_ids is not None
assert isinstance(response.prompt_token_ids, list)
# Verify the prompt texts and response texts
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
prompt_text = tokenizer.decode(response.prompt_token_ids)
assert prompt_text.startswith(
"<|im_start|>system\nYou are a helpful assistant.")
assert prompt_text.endswith(
"What's the weather like in Paris?<|im_end|>\n"
"<|im_start|>assistant\n")
response_text = tokenizer.decode(response.choices[0].token_ids)
assert response_text.startswith('<tool_call>\n{"name": "get_weather"')
assert response_text.endswith("</tool_call><|im_end|>")
# If tool call was made, verify the response structure
if response.choices[0].message.tool_calls:
assert len(response.choices[0].message.tool_calls) > 0
tool_call = response.choices[0].message.tool_calls[0]
assert tool_call.function.name == "get_weather"
# Test without return_token_ids
response_without = await client.chat.completions.create(
model=MODEL_NAME,
messages=[
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "What's the weather like in Paris?"
},
],
tools=tools,
tool_choice="auto",
max_tokens=100,
temperature=0,
logprobs=True,
extra_body={"return_token_ids": False},
)
assert response_without.choices[0].token_ids is None
assert response_without.prompt_token_ids is None
@pytest.mark.asyncio
async def test_comparison_with_prompt_logprobs_and_logprobs(server):
"""
Test that token_ids align with prompt_logprobs and
logprobs when return_tokens_as_token_ids is enabled.
"""
async with server.get_async_client() as client:
# Test with both return_token_ids and return_tokens_as_token_ids enabled
completion = await client.completions.create(
model=MODEL_NAME,
prompt="Hello, world! How are you today?",
max_tokens=20,
temperature=0,
echo=True,
logprobs=1,
extra_body={
"return_token_ids": True,
"return_tokens_as_token_ids": True,
"prompt_logprobs": 1
},
)
# Verify all fields are present
assert completion.choices[0].token_ids is not None
assert completion.choices[0].prompt_token_ids is not None
assert completion.choices[0].prompt_logprobs is not None
assert completion.choices[0].logprobs is not None
# Extract token IDs from logprobs
# (when return_tokens_as_token_ids is True)
logprobs_token_ids = []
for token_str in completion.choices[0].logprobs.tokens:
# Token format is "token_id:12345" when
# return_tokens_as_token_ids is True
if token_str.startswith("token_id:"):
token_id = int(token_str.removeprefix("token_id:"))
logprobs_token_ids.append(token_id)
# When echo=True, the logprobs include both prompt and response tokens
# The token_ids field should match the the suffix of response portion
# The prompt_token_ids should match the prompt portion
assert len(completion.choices[0].token_ids) < len(logprobs_token_ids)
response_token_ids_length = len(completion.choices[0].token_ids)
assert logprobs_token_ids[-response_token_ids_length:] == \
completion.choices[0].token_ids
# Verify tokenizer consistency
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# Decode prompt tokens
if completion.choices[0].prompt_token_ids:
prompt_text = tokenizer.decode(
completion.choices[0].prompt_token_ids)
# The decoded prompt should match or close to original prompt
assert "Hello, world" in prompt_text
# Decode response tokens
if completion.choices[0].token_ids:
response_text = tokenizer.decode(completion.choices[0].token_ids)
assert completion.choices[0].text.endswith(response_text)
# Test streaming mode
stream = await client.completions.create(
model=MODEL_NAME,
prompt="Tell me a short fact about Python:",
max_tokens=30,
temperature=0,
stream=True,
echo=False,
logprobs=1,
extra_body={
"return_token_ids": True,
"return_tokens_as_token_ids": True
},
)
# Collect streamed tokens
streamed_prompt_token_ids = []
streamed_token_ids = []
streamed_logprob_token_ids = []
first_chunk = True
async for chunk in stream:
for token_str in chunk.choices[0].logprobs.tokens:
# Token format is "token_id:12345" when
# return_tokens_as_token_ids is True
if token_str.startswith("token_id:"):
token_id = int(token_str.removeprefix("token_id:"))
streamed_logprob_token_ids.append(token_id)
if first_chunk:
streamed_prompt_token_ids = chunk.choices[0].prompt_token_ids
first_chunk = False
streamed_token_ids += chunk.choices[0].token_ids
# Verify we collected some tokens and first chunk had prompt_token_ids
assert len(streamed_prompt_token_ids) > 0
assert streamed_token_ids == streamed_logprob_token_ids
@pytest.mark.asyncio
async def test_chat_completion_with_emoji_and_token_ids(server):
"""Test chat completion with emojis to verify token_ids handling."""
chat_messages = [
{
"role": "system",
"content": "You like to use emojis in your responses."
},
{
"role": "user",
"content": "Repeat after me: I love cats 🐱"
},
]
async with server.get_async_client() as client:
response = await client.chat.completions.create(
model=MODEL_NAME,
messages=chat_messages,
max_tokens=50,
temperature=0,
logprobs=True,
extra_body={"return_token_ids": True},
)
# Verify token_ids are present
response_dict = response.model_dump()
assert response.choices[0].token_ids is not None
assert "prompt_token_ids" in response_dict
# Verify the response contains the expected fields
assert response.choices[0].message.content is not None
# Decode token_ids and verify consistency
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
decoded_prompt = tokenizer.decode(response.prompt_token_ids)
assert decoded_prompt.startswith(
"<|im_start|>system\nYou like to use emojis in your responses.")
assert decoded_prompt.endswith(
"I love cats 🐱<|im_end|>\n<|im_start|>assistant\n")
decoded_response = tokenizer.decode(response.choices[0].token_ids)
# The content should match the response text
# except the ending <|im_end|>
assert decoded_response == response.choices[
0].message.content + "<|im_end|>"
# Test with streaming
stream = await client.chat.completions.create(
model=MODEL_NAME,
messages=chat_messages,
max_tokens=50,
temperature=0,
stream=True,
extra_body={"return_token_ids": True},
)
collected_content = ""
collected_token_ids = []
first_chunk = True
async for chunk in stream:
if first_chunk:
assert chunk.prompt_token_ids is not None
assert isinstance(chunk.prompt_token_ids, list)
# Check the prompt_token_ids match the initial prompt
decoded_prompt_stream = tokenizer.decode(
chunk.prompt_token_ids)
assert decoded_prompt_stream == decoded_prompt
first_chunk = False
else:
chunk_dump = chunk.model_dump()
assert "prompt_token_ids" not in chunk_dump, \
"Subsequent chunks should not have prompt_token_ids"
if chunk.choices:
if chunk.choices[0].delta.content:
collected_content += chunk.choices[0].delta.content
# token_ids may not present in all chunks
choice_dump = chunk.choices[0].model_dump()
if "token_ids" in choice_dump:
collected_token_ids.extend(chunk.choices[0].token_ids)
# Verify we got response and token_ids
assert len(collected_content) > 0
assert len(collected_token_ids) > 0
# Verify token_ids decode properly
decoded_response = tokenizer.decode(collected_token_ids)
assert decoded_response == collected_content + "<|im_end|>"

View File

@ -576,6 +576,14 @@ class ChatCompletionRequest(OpenAIBaseModel):
"If specified with 'logprobs', tokens are represented "
" as strings of the form 'token_id:{token_id}' so that tokens "
"that are not JSON-encodable can be identified."))
return_token_ids: Optional[bool] = Field(
default=None,
description=(
"If specified, the result will include token IDs alongside the "
"generated text. In streaming mode, prompt_token_ids is included "
"only in the first chunk, and token_ids contains the delta tokens "
"for each chunk. This is useful for debugging or when you "
"need to map generated text back to input tokens."))
cache_salt: Optional[str] = Field(
default=None,
description=(
@ -1062,6 +1070,14 @@ class CompletionRequest(OpenAIBaseModel):
"If specified with 'logprobs', tokens are represented "
" as strings of the form 'token_id:{token_id}' so that tokens "
"that are not JSON-encodable can be identified."))
return_token_ids: Optional[bool] = Field(
default=None,
description=(
"If specified, the result will include token IDs alongside the "
"generated text. In streaming mode, prompt_token_ids is included "
"only in the first chunk, and token_ids contains the delta tokens "
"for each chunk. This is useful for debugging or when you "
"need to map generated text back to input tokens."))
cache_salt: Optional[str] = Field(
default=None,
@ -1480,7 +1496,9 @@ class CompletionResponseChoice(OpenAIBaseModel):
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)
token_ids: Optional[list[int]] = None # For response
prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
prompt_token_ids: Optional[list[int]] = None # For prompt
class CompletionResponse(OpenAIBaseModel):
@ -1511,6 +1529,10 @@ class CompletionResponseStreamChoice(OpenAIBaseModel):
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)
# not part of the OpenAI spec but for tracing the tokens
# prompt tokens is put into choice to align with CompletionResponseChoice
prompt_token_ids: Optional[list[int]] = None
token_ids: Optional[list[int]] = None
class CompletionStreamResponse(OpenAIBaseModel):
@ -1680,6 +1702,9 @@ class ChatCompletionResponseChoice(OpenAIBaseModel):
finish_reason: Optional[str] = "stop"
# not part of the OpenAI spec but included in vLLM for legacy reasons
stop_reason: Optional[Union[int, str]] = None
# not part of the OpenAI spec but is useful for tracing the tokens
# in agent scenarios
token_ids: Optional[list[int]] = None
class ChatCompletionResponse(OpenAIBaseModel):
@ -1695,6 +1720,7 @@ class ChatCompletionResponse(OpenAIBaseModel):
# vLLM-specific fields that are not in OpenAI spec
prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None
prompt_token_ids: Optional[list[int]] = None
kv_transfer_params: Optional[dict[str, Any]] = Field(
default=None, description="KVTransfer parameters.")
@ -1712,6 +1738,8 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
logprobs: Optional[ChatCompletionLogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = None
# not part of the OpenAI spec but for tracing the tokens
token_ids: Optional[list[int]] = None
class ChatCompletionStreamResponse(OpenAIBaseModel):
@ -1721,6 +1749,8 @@ class ChatCompletionStreamResponse(OpenAIBaseModel):
model: str
choices: list[ChatCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
# not part of the OpenAI spec but for tracing the tokens
prompt_token_ids: Optional[list[int]] = None
class TranscriptionResponseStreamChoice(OpenAIBaseModel):

View File

@ -568,12 +568,17 @@ class OpenAIServingChat(OpenAIServing):
),
logprobs=None,
finish_reason=None)
# return prompt_token_ids at the first chunk ever
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
model=model_name,
prompt_token_ids=(res.prompt_token_ids
if request.return_token_ids else
None))
# if continuous usage stats are requested, add it
if include_continuous_usage:
@ -912,7 +917,9 @@ class OpenAIServingChat(OpenAIServing):
index=i,
delta=delta_message,
logprobs=logprobs,
finish_reason=None)
finish_reason=None,
token_ids=(as_list(output.token_ids)
if request.return_token_ids else None))
# if the model is finished generating
else:
@ -973,7 +980,9 @@ class OpenAIServingChat(OpenAIServing):
logprobs=logprobs,
finish_reason=output.finish_reason
if not auto_tools_called else "tool_calls",
stop_reason=output.stop_reason)
stop_reason=output.stop_reason,
token_ids=(as_list(output.token_ids)
if request.return_token_ids else None))
finish_reason_sent[i] = True
@ -1260,7 +1269,10 @@ class OpenAIServingChat(OpenAIServing):
logprobs=logprobs,
finish_reason="tool_calls" if auto_tools_called else
output.finish_reason if output.finish_reason else "stop",
stop_reason=output.stop_reason)
stop_reason=output.stop_reason,
token_ids=(as_list(output.token_ids)
if request.return_token_ids else None),
)
choices.append(choice_data)
@ -1301,6 +1313,8 @@ class OpenAIServingChat(OpenAIServing):
choices=choices,
usage=usage,
prompt_logprobs=clamp_prompt_logprobs(final_res.prompt_logprobs),
prompt_token_ids=(final_res.prompt_token_ids
if request.return_token_ids else None),
kv_transfer_params=final_res.kv_transfer_params,
)

View File

@ -42,7 +42,7 @@ from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import merge_async_iterators
from vllm.utils import as_list, merge_async_iterators
logger = init_logger(__name__)
@ -365,6 +365,11 @@ class OpenAIServingCompletion(OpenAIServing):
for output in res.outputs:
i = output.index + prompt_idx * num_choices
# Useful when request.return_token_ids is True
# Returning prompt token IDs shares the same logic
# with the echo implementation.
prompt_token_ids_to_return: Optional[list[int]] = None
assert request.max_tokens is not None
if request.echo and not has_echoed[i]:
assert prompt_token_ids is not None
@ -385,6 +390,7 @@ class OpenAIServingCompletion(OpenAIServing):
*(prompt_logprobs or []),
*(output.logprobs or []),
]
prompt_token_ids_to_return = prompt_token_ids
has_echoed[i] = True
else:
# return just the delta
@ -392,6 +398,12 @@ class OpenAIServingCompletion(OpenAIServing):
delta_token_ids = output.token_ids
out_logprobs = output.logprobs
# has_echoed[i] is reused here to indicate whether
# we have already returned the prompt token IDs.
if not has_echoed[i]:
prompt_token_ids_to_return = prompt_token_ids
has_echoed[i] = True
if (not delta_text and not delta_token_ids
and not previous_num_tokens[i]):
# Chunked prefill case, don't return empty chunks
@ -428,6 +440,9 @@ class OpenAIServingCompletion(OpenAIServing):
logprobs=logprobs,
finish_reason=finish_reason,
stop_reason=stop_reason,
prompt_token_ids=prompt_token_ids_to_return,
token_ids=(as_list(output.token_ids) if
request.return_token_ids else None),
)
],
)
@ -548,6 +563,10 @@ class OpenAIServingCompletion(OpenAIServing):
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
prompt_logprobs=final_res.prompt_logprobs,
prompt_token_ids=(prompt_token_ids
if request.return_token_ids else None),
token_ids=(as_list(output.token_ids)
if request.return_token_ids else None),
)
choices.append(choice_data)