mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] fixed top_logprobs: -1 does not appear to work as intended (#26470)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
This commit is contained in:
@ -7,12 +7,23 @@ import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
# # any model with a chat template should work here
|
||||
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
|
||||
|
||||
|
||||
def get_vocab_size(model_name):
|
||||
config = ModelConfig(
|
||||
model=model_name,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
)
|
||||
return config.get_vocab_size()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
@ -107,6 +118,7 @@ async def test_top_logprobs(client: openai.AsyncOpenAI):
|
||||
completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=1,
|
||||
extra_body={
|
||||
"top_logprobs": -1,
|
||||
"logprobs": "true",
|
||||
@ -115,3 +127,6 @@ async def test_top_logprobs(client: openai.AsyncOpenAI):
|
||||
assert completion.choices[0].logprobs is not None
|
||||
assert completion.choices[0].logprobs.content is not None
|
||||
assert len(completion.choices[0].logprobs.content) > 0
|
||||
assert len(
|
||||
completion.choices[0].logprobs.content[0].top_logprobs
|
||||
) == get_vocab_size(MODEL_NAME)
|
||||
|
@ -1643,7 +1643,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
bytes=list(token.encode("utf-8", errors="replace")),
|
||||
)
|
||||
for i, p in enumerate(logprobs.items())
|
||||
if top_logprobs and i < top_logprobs
|
||||
if (top_logprobs and i < top_logprobs or top_logprobs == -1)
|
||||
]
|
||||
|
||||
def _create_chat_logprobs(
|
||||
|
Reference in New Issue
Block a user