mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Frontend] Skip stop
in reasoning content (#14550)
Signed-off-by: Ce Gao <cegao@tensorchord.ai> Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
228
tests/engine/test_stop_checker.py
Normal file
228
tests/engine/test_stop_checker.py
Normal file
@ -0,0 +1,228 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Sequence, SequenceStatus
|
||||
|
||||
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
|
||||
|
||||
class MockReasoningParser(ReasoningParser):
|
||||
"""Mock reasoning parser for testing purposes."""
|
||||
|
||||
def __init__(self,
|
||||
tokenizer: AutoTokenizer,
|
||||
reasoning_active: bool = False):
|
||||
super().__init__(tokenizer)
|
||||
self.reasoning_active = reasoning_active
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return not self.reasoning_active
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
return input_ids
|
||||
|
||||
|
||||
class MockSequence(Sequence):
|
||||
"""Mock sequence for testing purposes."""
|
||||
|
||||
def __init__(self, token_ids, output_text="test_output", eos_token_id=0):
|
||||
self.token_ids = token_ids
|
||||
self.output_text = output_text
|
||||
self.eos_token_id = eos_token_id
|
||||
self.status = SequenceStatus.RUNNING
|
||||
self.stop_reason = None
|
||||
|
||||
def get_token_ids(self):
|
||||
return self.token_ids
|
||||
|
||||
def get_last_token_id(self):
|
||||
return self.token_ids[-1] if self.token_ids else None
|
||||
|
||||
def get_len(self):
|
||||
return len(self.token_ids)
|
||||
|
||||
def get_output_len(self):
|
||||
return len(self.token_ids) - 1 # Simulating prompt + outputs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def deepseek_r1_qwen_tokenizer():
|
||||
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stop_checker():
|
||||
return StopChecker(max_model_len=10,
|
||||
get_tokenizer_for_seq=deepseek_r1_qwen_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def stop_checker_with_reasoner():
|
||||
reasoner = MockReasoningParser(deepseek_r1_qwen_tokenizer)
|
||||
return StopChecker(max_model_len=10,
|
||||
get_tokenizer_for_seq=deepseek_r1_qwen_tokenizer,
|
||||
reasoner=reasoner)
|
||||
|
||||
|
||||
def test_eos_token_stopping(stop_checker):
|
||||
"""Test sequence stopping when EOS token is encountered."""
|
||||
seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0)
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=1,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
|
||||
|
||||
def test_ignore_eos(stop_checker):
|
||||
"""Test sequence continuing when EOS token is ignored."""
|
||||
seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0)
|
||||
sampling_params = SamplingParams(ignore_eos=True)
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=1,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.RUNNING
|
||||
|
||||
|
||||
def test_min_tokens(stop_checker):
|
||||
"""Test min_tokens prevents early stopping."""
|
||||
seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0)
|
||||
sampling_params = SamplingParams(min_tokens=3)
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=1,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.RUNNING
|
||||
|
||||
|
||||
def test_stop_token_ids(stop_checker):
|
||||
"""Test sequence stopping with custom stop token IDs."""
|
||||
seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0)
|
||||
sampling_params = SamplingParams(stop_token_ids=[3])
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=1,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
assert seq.stop_reason == 3
|
||||
|
||||
|
||||
def test_stop_strings(stop_checker):
|
||||
"""Test sequence stopping with stop strings."""
|
||||
seq = MockSequence(token_ids=[1, 2, 3],
|
||||
output_text="test output with STOP",
|
||||
eos_token_id=0)
|
||||
sampling_params = SamplingParams(stop=["STOP"])
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=1,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
assert seq.stop_reason == "STOP"
|
||||
assert "STOP" not in seq.output_text # Default behavior removes stop string
|
||||
|
||||
|
||||
def test_include_stop_str_in_output(stop_checker):
|
||||
"""Test keeping stop strings in output."""
|
||||
seq = MockSequence(token_ids=[1, 2, 3],
|
||||
output_text="test output with STOP",
|
||||
eos_token_id=0)
|
||||
sampling_params = SamplingParams(stop=["STOP"],
|
||||
include_stop_str_in_output=True)
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=5,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
assert "STOP" in seq.output_text
|
||||
|
||||
|
||||
def test_max_tokens(stop_checker):
|
||||
"""Test sequence stopping at max_tokens."""
|
||||
seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0)
|
||||
sampling_params = SamplingParams(max_tokens=2)
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=1,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
|
||||
|
||||
def test_max_model_len(stop_checker):
|
||||
"""Test sequence stopping at max_model_len."""
|
||||
seq = MockSequence(token_ids=list(range(11)),
|
||||
eos_token_id=0) # 11 tokens, max is 10
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
stop_checker.maybe_stop_sequence(seq,
|
||||
new_char_count=1,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
assert seq.status == SequenceStatus.FINISHED_LENGTH_CAPPED
|
||||
|
||||
|
||||
def test_reasoning_skip_stops(stop_checker_with_reasoner):
|
||||
"""Test that stop tokens and strings are ignored during reasoning."""
|
||||
# Set reasoning_active to True to simulate being in reasoning mode
|
||||
stop_checker_with_reasoner.reasoner.reasoning_active = True
|
||||
|
||||
# Test with stop token
|
||||
seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0)
|
||||
sampling_params = SamplingParams(stop_token_ids=[3])
|
||||
|
||||
stop_checker_with_reasoner.maybe_stop_sequence(
|
||||
seq, new_char_count=1, sampling_params=sampling_params)
|
||||
assert seq.status == SequenceStatus.RUNNING
|
||||
|
||||
# Test with stop string
|
||||
seq = MockSequence(token_ids=[1, 2, 3], output_text="test STOP")
|
||||
sampling_params = SamplingParams(stop=["STOP"])
|
||||
|
||||
stop_checker_with_reasoner.maybe_stop_sequence(
|
||||
seq, new_char_count=4, sampling_params=sampling_params)
|
||||
assert seq.status == SequenceStatus.RUNNING
|
||||
|
||||
# But EOS token still stops the sequence
|
||||
seq = MockSequence(token_ids=[1, 2, 0], eos_token_id=0)
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
stop_checker_with_reasoner.maybe_stop_sequence(
|
||||
seq, new_char_count=1, sampling_params=sampling_params)
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
|
||||
|
||||
def test_reasoning_end_enables_stops(stop_checker_with_reasoner):
|
||||
"""Test that stop tokens work after reasoning ends."""
|
||||
# Set reasoning_active to False to simulate being out of reasoning mode
|
||||
stop_checker_with_reasoner.reasoner.reasoning_active = False
|
||||
|
||||
# Test with stop token
|
||||
seq = MockSequence(token_ids=[1, 2, 3], eos_token_id=0)
|
||||
sampling_params = SamplingParams(stop_token_ids=[3])
|
||||
|
||||
stop_checker_with_reasoner.maybe_stop_sequence(
|
||||
seq, new_char_count=1, sampling_params=sampling_params)
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
||||
|
||||
# Test with stop string
|
||||
seq = MockSequence(token_ids=[1, 2, 3], output_text="test STOP")
|
||||
sampling_params = SamplingParams(stop=["STOP"])
|
||||
|
||||
stop_checker_with_reasoner.maybe_stop_sequence(
|
||||
seq, new_char_count=4, sampling_params=sampling_params)
|
||||
assert seq.status == SequenceStatus.FINISHED_STOPPED
|
@ -40,6 +40,7 @@ from vllm.multimodal.cache import processor_only_cache_from_config
|
||||
from vllm.multimodal.processing import EncDecMultiModalProcessor
|
||||
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
|
||||
RequestOutputFactory)
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.sequence import (ExecuteModelRequest, ParallelSampleSequenceGroup,
|
||||
Sequence, SequenceGroup, SequenceGroupBase,
|
||||
@ -372,6 +373,14 @@ class LLMEngine:
|
||||
"vllm.llm_engine",
|
||||
self.observability_config.otlp_traces_endpoint)
|
||||
|
||||
# Initialize reasoning parser if reasoning backend is set.
|
||||
if self.decoding_config.reasoning_backend and \
|
||||
self.tokenizer:
|
||||
reasoner_class = ReasoningParserManager.get_reasoning_parser(
|
||||
self.decoding_config.reasoning_backend)
|
||||
self.reasoner: ReasoningParser = reasoner_class(
|
||||
self.tokenizer.get_lora_tokenizer())
|
||||
|
||||
# Create sequence output processor, e.g. for beam search or
|
||||
# speculative decoding.
|
||||
self.output_processor = (
|
||||
@ -381,8 +390,12 @@ class LLMEngine:
|
||||
self.scheduler,
|
||||
self.seq_counter,
|
||||
get_tokenizer_for_seq,
|
||||
stop_checker=StopChecker(self.scheduler_config.max_model_len,
|
||||
get_tokenizer_for_seq),
|
||||
stop_checker=StopChecker(
|
||||
self.scheduler_config.max_model_len,
|
||||
get_tokenizer_for_seq,
|
||||
self.reasoner if self.decoding_config.reasoning_backend
|
||||
and self.tokenizer else None,
|
||||
),
|
||||
))
|
||||
|
||||
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
|
||||
|
@ -4,6 +4,7 @@
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Sequence, SequenceStatus
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
@ -16,11 +17,16 @@ class StopChecker:
|
||||
emitted, or if we have exceeded the max model len.
|
||||
"""
|
||||
|
||||
def __init__(self, max_model_len: int,
|
||||
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]):
|
||||
def __init__(
|
||||
self,
|
||||
max_model_len: int,
|
||||
get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer],
|
||||
reasoner: Optional[ReasoningParser] = None,
|
||||
):
|
||||
# Do not use it directly, but use `self._get_max_model_len`.
|
||||
self._max_model_len = max_model_len
|
||||
self.get_tokenizer_for_seq = get_tokenizer_for_seq
|
||||
self.reasoner = reasoner
|
||||
|
||||
def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
|
||||
if lora_req and lora_req.long_lora_max_len:
|
||||
@ -57,6 +63,11 @@ class StopChecker:
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
return
|
||||
|
||||
# Skip stop string/token checks if in reasoning content generation
|
||||
if self.reasoner is not None and \
|
||||
not self.reasoner.is_reasoning_end(seq.get_token_ids()):
|
||||
return
|
||||
|
||||
# Check if a stop token was encountered.
|
||||
# This assumes a single token produced per step.
|
||||
last_token_id = seq.get_last_token_id()
|
||||
|
Reference in New Issue
Block a user