[V1] support min_tokens for detokener (#22014)

Signed-off-by: calvin chen <wen.chen@dynamia.ai>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Calvin Chen
2025-08-16 10:28:10 +08:00
committed by GitHub
parent f6b5040590
commit e4e37ded56
2 changed files with 58 additions and 3 deletions

View File

@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import AutoTokenizer
from vllm import SamplingParams
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.detokenizer import FastIncrementalDetokenizer
PROMPT = "Hello, my name is Lee, and I'm a student in the " + \
"college of engineering"
@pytest.mark.parametrize("min_tokens,stop,truth", [
(0, None, " is Lee, and I'm a student in the college of engineering"),
(0, "e", " is L"),
(5, "e", " is Lee, and I'm a stud"),
])
def test_min_tokens_with_stop(min_tokens: int, stop: str, truth: str):
"""Test for a specific min_tokens and stop.
See https://github.com/vllm-project/vllm/pull/22014
"""
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
all_prompt_ids = tokenizer(PROMPT, add_special_tokens=False).input_ids
# The prompt is "Hello, my name is"
prompt_token_ids = all_prompt_ids[:4]
params = SamplingParams(
stop=stop,
min_tokens=min_tokens,
)
request = EngineCoreRequest("",
prompt_token_ids,
None,
None,
None,
params,
None,
None,
0.0,
None,
cache_salt=None,
data_parallel_rank=None)
detokenizer = FastIncrementalDetokenizer(tokenizer, request)
detokenizer.update(all_prompt_ids[4:], False)
assert detokenizer.output_text == truth

View File

@ -74,6 +74,7 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
params = request.sampling_params
assert params is not None
self.stop = stop = params.stop
self.min_tokens = params.min_tokens
self.include_stop_str_in_output = params.include_stop_str_in_output
# Number of chars to hold back when stop strings are to be excluded
@ -111,10 +112,14 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
# 1) Detokenize the new token ids incrementally.
# TODO(woosuk): This method becomes very inefficient when the number of
# new_token_ids is more than 1. We need to optimize this.
offset_before = len(self.output_text)
stop_check_offset = len(self.output_text)
for new_token_id in new_token_ids:
self.token_ids.append(new_token_id)
self.output_text += self.decode_next(new_token_id)
# Support min_tokens, see https://github.com/vllm-project/vllm/pull/22014
if self.min_tokens and len(
self.output_token_ids) <= self.min_tokens:
stop_check_offset = len(self.output_text)
if stop_terminated:
if skipped_stop_token_id is not None:
@ -125,10 +130,10 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
# 2) Evaluate stop strings.
stop_string = None
if self.stop:
if self.stop and len(self.output_token_ids) > self.min_tokens:
stop = StopChecker.check_stop_strings(
output_text=self.output_text,
new_char_count=len(self.output_text) - offset_before,
new_char_count=len(self.output_text) - stop_check_offset,
stop=self.stop,
include_in_output=self.include_stop_str_in_output,
)