[Bugfix] fixed top_logprobs: -1 does not appear to work as intended (#26470)

Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
Chauncey
2025-10-11 00:41:17 +08:00
committed by GitHub
parent cddce79fda
commit 910abdbd08
2 changed files with 16 additions and 1 deletions

View File

@ -7,12 +7,23 @@ import openai # use the official client for correctness check
import pytest
import pytest_asyncio
from vllm.config import ModelConfig
from ...utils import RemoteOpenAIServer
# # any model with a chat template should work here
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
def get_vocab_size(model_name):
config = ModelConfig(
model=model_name,
seed=0,
dtype="float16",
)
return config.get_vocab_size()
@pytest.fixture(scope="module")
def server():
args = [
@ -107,6 +118,7 @@ async def test_top_logprobs(client: openai.AsyncOpenAI):
completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=1,
extra_body={
"top_logprobs": -1,
"logprobs": "true",
@ -115,3 +127,6 @@ async def test_top_logprobs(client: openai.AsyncOpenAI):
assert completion.choices[0].logprobs is not None
assert completion.choices[0].logprobs.content is not None
assert len(completion.choices[0].logprobs.content) > 0
assert len(
completion.choices[0].logprobs.content[0].top_logprobs
) == get_vocab_size(MODEL_NAME)

View File

@ -1643,7 +1643,7 @@ class OpenAIServingChat(OpenAIServing):
bytes=list(token.encode("utf-8", errors="replace")),
)
for i, p in enumerate(logprobs.items())
if top_logprobs and i < top_logprobs
if (top_logprobs and i < top_logprobs or top_logprobs == -1)
]
def _create_chat_logprobs(