[Bugfix] Respect min_tokens in scheduler stop check (#26317)

Signed-off-by: Elaine Zhao <elaineyz@amazon.com>
This commit is contained in:
Elaine Zhao
2025-10-08 14:08:24 -07:00
committed by GitHub
parent 93f2c0aa08
commit f08919b7d1
2 changed files with 95 additions and 0 deletions

View File

@ -497,6 +497,96 @@ def test_stop_via_update_from_output():
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
def test_check_stop_min_tokens():
"""Test that requests don't stop when min_tokens requirement isn't met."""
from vllm.v1.core.sched.utils import check_stop
# Test case 1: num_output_tokens < min_tokens
# Should return False (don't stop)
sampling_params = SamplingParams(
ignore_eos=False,
max_tokens=20,
min_tokens=5,
)
request = Request(
request_id="0",
prompt_token_ids=[0, 1, 2],
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
)
# Simulate having generated 3 output tokens (less than min_tokens=5)
request.append_output_token_ids([10, 11, EOS_TOKEN_ID]) # EOS token present
result = check_stop(request, max_model_len=100)
assert result is False, "Should not stop when num_output_tokens<min_tokens"
# Test case 2: num_output_tokens >= min_tokens
# Should follow normal stopping logic (stop on EOS)
request.append_output_token_ids(
[
10,
11,
12,
13,
14,
EOS_TOKEN_ID,
]
) # 6 tokens > min_tokens
result = check_stop(request, max_model_len=100)
assert result is True, "Should stop on EOS when min_tokens met"
assert request.status == RequestStatus.FINISHED_STOPPED
# Test case 3: min_tokens = 0, should follow normal stopping logic
sampling_params_no_min = SamplingParams(
ignore_eos=False,
max_tokens=20,
min_tokens=0,
)
request_no_min = Request(
request_id="1",
prompt_token_ids=[0, 1, 2],
sampling_params=sampling_params_no_min,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
)
request_no_min.append_output_token_ids([10, EOS_TOKEN_ID])
result = check_stop(request_no_min, max_model_len=100)
assert result is True, "Should stop on EOS when min_tokens=0"
assert request_no_min.status == RequestStatus.FINISHED_STOPPED
# Test case 4: min_tokens > 0 with stop token (not EOS)
sampling_params_stop = SamplingParams(
ignore_eos=False,
max_tokens=20,
min_tokens=5,
stop_token_ids=[42],
)
request_stop = Request(
request_id="2",
prompt_token_ids=[0, 1, 2],
sampling_params=sampling_params_stop,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
)
# Only 3 output tokens, less than min_tokens=5, but has stop token
request_stop.append_output_token_ids([10, 11, 42])
result = check_stop(request_stop, max_model_len=100)
assert result is False, "Should not stop when num_output_tokens<min_tokens"
# Test case 5: min_tokens met, should stop on stop token
request_stop.append_output_token_ids(
[10, 11, 12, 13, 14, 42]
) # 6 tokens >= min_tokens=5
result = check_stop(request_stop, max_model_len=100)
assert result is True, "Should stop on stop token when min_tokens met"
assert request_stop.status == RequestStatus.FINISHED_STOPPED
assert request_stop.stop_reason == 42
@pytest.mark.parametrize(
"enable_prefix_caching, prompt_logprobs",
[

View File

@ -58,6 +58,11 @@ def check_stop(
sampling_params = request.sampling_params
assert sampling_params is not None
min_tokens = sampling_params.min_tokens
if request.num_output_tokens < min_tokens:
return False
last_token_id = request.output_token_ids[-1]
if not sampling_params.ignore_eos and last_token_id == request.eos_token_id:
request.status = RequestStatus.FINISHED_STOPPED