[Refactor][Frontend] Keep all logic about reasoning into one class (#14428)

Signed-off-by: Ce Gao <cegao@tensorchord.ai>
This commit is contained in:
Ce Gao
2025-03-28 15:23:30 +08:00
committed by GitHub
parent 2d9045fce8
commit 32b14baf8a
18 changed files with 171 additions and 200 deletions

View File

@ -3,74 +3,92 @@
import pytest
from transformers import AutoTokenizer
from tests.entrypoints.openai.reasoning_parsers.utils import (
run_reasoning_extraction)
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
ReasoningParserManager)
from tests.reasoning.utils import run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
parser_name = "deepseek_r1"
start_token = "<think>"
end_token = "</think>"
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
@pytest.fixture(scope="module")
def deepseek_r1_qwen_tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
SIMPLE_REASONING = {
"output": "This is a reasoning section</think>This is the rest",
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
COMPLETE_REASONING = {
"output": "This is a reasoning section</think>",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
}
NO_CONTENT = {
"output": "This is content",
"reasoning_content": "This is content",
"content": None,
"is_reasoning_end": False,
}
NO_REASONING_STREAMING = {
"output": "This is a reasoning section",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": False,
}
MULTIPLE_LINES = {
"output": "This\nThat</think>This is the rest\nThat",
"reasoning_content": "This\nThat",
"content": "This is the rest\nThat",
"is_reasoning_end": True,
}
SHORTEST_REASONING_NO_STREAMING = {
"output": "</think>This is the rest",
"reasoning_content": "",
"content": "This is the rest",
"is_reasoning_end": True,
}
SHORTEST_REASONING = {
"output": "</think>This is the rest",
"reasoning_content": None,
"content": "This is the rest",
"is_reasoning_end": True,
}
REASONING_WITH_THINK = {
"output": "<think>This is a reasoning section</think>This is the rest",
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
"is_reasoning_end": True,
}
COMPLETE_REASONING_WITH_THINK = {
"output": "<think>This is a reasoning section</think>",
"reasoning_content": "This is a reasoning section",
"content": None,
"is_reasoning_end": True,
}
MULTIPLE_LINES_WITH_THINK = {
"output": "<think>This\nThat</think>This is the rest\nThat",
"reasoning_content": "This\nThat",
"content": "This is the rest\nThat",
"is_reasoning_end": True,
}
SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
"output": "</think>This is the rest",
"reasoning_content": "",
"content": "This is the rest",
"is_reasoning_end": True,
}
SHORTEST_REASONING_WITH_THINK = {
"output": "</think>This is the rest",
"reasoning_content": None,
"content": "This is the rest",
"is_reasoning_end": True,
}
TEST_CASES = [
@ -166,23 +184,21 @@ TEST_CASES = [
),
]
# Global tokenizer initialization to avoid repeated loading
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
tokenizer.add_tokens([start_token, end_token])
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_reasoning(
streaming: bool,
param_dict: dict,
deepseek_r1_qwen_tokenizer,
):
output = tokenizer.tokenize(param_dict["output"])
output = deepseek_r1_qwen_tokenizer.tokenize(param_dict["output"])
# decode everything to tokens
output_tokens: list[str] = [
tokenizer.convert_tokens_to_string([token]) for token in output
deepseek_r1_qwen_tokenizer.convert_tokens_to_string([token])
for token in output
]
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
parser_name)(tokenizer)
parser_name)(deepseek_r1_qwen_tokenizer)
reasoning, content = run_reasoning_extraction(parser,
output_tokens,
@ -190,3 +206,17 @@ def test_reasoning(
assert reasoning == param_dict["reasoning_content"]
assert content == param_dict["content"]
# Test is_reasoning_end
output_ids = deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(output)
is_reasoning_end = parser.is_reasoning_end(output_ids)
assert is_reasoning_end == param_dict["is_reasoning_end"]
# Test extract_content
if param_dict["content"] is not None:
content = parser.extract_content_ids(output_ids)
assert content == deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(
deepseek_r1_qwen_tokenizer.tokenize(param_dict["content"]))
else:
content = parser.extract_content_ids(output)
assert content == []

View File

@ -2,10 +2,8 @@
import pytest
from transformers import AutoTokenizer
from tests.entrypoints.openai.reasoning_parsers.utils import (
DeltaMessage, run_reasoning_extraction)
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
ReasoningParserManager)
from tests.reasoning.utils import DeltaMessage, run_reasoning_extraction
from vllm.reasoning import ReasoningParser, ReasoningParserManager
parser_name = "granite"
START_REASONING = "Here is my thought process:"

View File

@ -4,7 +4,7 @@ from typing import Optional, Union
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParser
from vllm.reasoning import ReasoningParser
class StreamingReasoningReconstructor:

View File

@ -23,6 +23,7 @@ from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.plugins import load_general_plugins
from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
@ -1119,7 +1120,7 @@ class EngineArgs:
parser.add_argument(
"--reasoning-parser",
type=str,
choices=["deepseek_r1", "granite"],
choices=list(ReasoningParserManager.reasoning_parsers),
default=None,
help=
"Select the reasoning parser depending on the model that you're "

View File

@ -2080,8 +2080,9 @@ class LLMEngine:
guided_decoding.backend = guided_decoding.backend or \
self.decoding_config.guided_decoding_backend
logger.debug("Reasoning backend: %s",
self.decoding_config.reasoning_backend)
if self.decoding_config.reasoning_backend is not None:
logger.debug("Building with reasoning backend %s",
self.decoding_config.reasoning_backend)
processor = get_local_guided_decoding_logits_processor(
guided_params=guided_decoding,

View File

@ -68,7 +68,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TranscriptionRequest,
TranscriptionResponse,
UnloadLoRAAdapterRequest)
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
@ -85,6 +84,7 @@ from vllm.entrypoints.openai.serving_transcription import (
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import load_aware_call, with_cancellation
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import MistralTokenizer

View File

@ -23,8 +23,6 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
RequestResponseMetadata, ToolCall, UsageInfo)
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
ReasoningParserManager)
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
clamp_prompt_logprobs)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
@ -33,6 +31,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall)
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.reasoning import ReasoningParser, ReasoningParserManager
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer

View File

@ -5,10 +5,10 @@ from __future__ import annotations
from typing import TYPE_CHECKING
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner import get_reasoner
from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
from vllm.reasoning import ReasoningParserManager
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
@ -107,7 +107,11 @@ async def get_guided_decoding_logits_processor(
model_config: ModelConfig,
reasoning_backend: str | None = None) -> LogitsProcessor | None:
reasoner = get_reasoner(tokenizer, reasoning_backend)
reasoner = None
if reasoning_backend is not None:
reasoner_class = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoner = reasoner_class(tokenizer)
guided_params = maybe_backend_fallback(guided_params)
@ -146,8 +150,11 @@ def get_local_guided_decoding_logits_processor(
reasoning_backend: str | None = None) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
# Get the reasoner if needed, it will be None if reasoning_
reasoner = get_reasoner(tokenizer, reasoning_backend)
reasoner = None
if reasoning_backend is not None:
reasoner_class = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoner = reasoner_class(tokenizer)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend_name == 'outlines':

View File

@ -12,7 +12,7 @@ from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams
@ -61,7 +61,7 @@ _MAX_THREADPOOL_WORKERS = 16
async def get_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner],
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
@ -92,7 +92,7 @@ async def get_outlines_guided_decoding_logits_processor(
def get_local_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner],
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]:
"""
@ -141,7 +141,7 @@ def _get_logits_processor(
tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None],
reasoner: Optional[Reasoner],
reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,

View File

@ -34,8 +34,8 @@ from transformers import PreTrainedTokenizerBase
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.platforms import current_platform
from vllm.reasoning import ReasoningParser
logger = init_logger(__name__)
@ -49,9 +49,9 @@ else:
class BaseLogitsProcessor:
def __init__(self, guide: Guide, reasoner: Optional[Reasoner]):
def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]):
self._guide: Guide = guide
self._reasoner: Optional[Reasoner] = reasoner
self._reasoner: Optional[ReasoningParser] = reasoner
# CFGState is used for the FSM state for CFGGuide
self._fsm_state: DefaultDict[int, Union[int,
CFGState]] = defaultdict(int)
@ -69,7 +69,7 @@ class BaseLogitsProcessor:
# Remove the reasoning tokens from the input_ids
# We need this because our implementation relies on the
# hash of the input_ids to store the FSM state.
input_ids = self._reasoner.extract_content(input_ids)
input_ids = self._reasoner.extract_content_ids(input_ids)
seq_id = hash(tuple(input_ids))
@ -142,7 +142,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
self,
regex_string: str,
tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner],
reasoner: Optional[ReasoningParser],
):
"""Compile the FSM that drives the regex-structured generation.
@ -163,7 +163,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Union[str, None],
reasoner: Optional[Reasoner]):
reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the JSON-guided generation.
Parameters
@ -203,7 +203,7 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
return CFGGuide(cfg, tokenizer)
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner]):
reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the context free grammar generation.
Parameters

View File

@ -1,38 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
@dataclass
class DeepSeekReasoner(Reasoner):
"""
Reasoner for DeepSeek R series models.
"""
start_token_id: int
end_token_id: int
start_token: str = "<think>"
end_token: str = "</think>"
@classmethod
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
return cls(start_token_id=tokenizer.encode(
"<think>", add_special_tokens=False)[0],
end_token_id=tokenizer.encode("</think>",
add_special_tokens=False)[0])
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.end_token_id in input_ids
def extract_content(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens
"""
if self.end_token_id not in input_ids or \
input_ids.index(self.end_token_id) + 1 == len(input_ids):
return []
else:
return input_ids[input_ids.index(self.end_token_id) + 1:]

View File

@ -1,23 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
@dataclass
class Reasoner(ABC):
@abstractmethod
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
pass
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
pass
@abstractmethod
def extract_content(self, input_ids: list[int]) -> list[int]:
pass

View File

@ -27,7 +27,7 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams
logger = init_logger(__name__)
@ -37,7 +37,7 @@ def get_local_xgrammar_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
reasoner: Reasoner | None,
reasoner: ReasoningParser | None,
max_threads: int = 8):
config = GrammarConfig.from_guided_params(guided_params=guided_params,
model_config=model_config,
@ -280,7 +280,7 @@ class GrammarConfig:
class XGrammarLogitsProcessor:
"""Wrapper class to support pickle protocol"""
config: GrammarConfig
reasoner: Reasoner | None = None
reasoner: ReasoningParser | None = None
ctx: xgr.CompiledGrammar | None = None
tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment]

View File

@ -17,7 +17,7 @@ logger = init_logger(__name__)
class ReasoningParser:
"""
Abstract reasoning parser class that should not be used directly.
Abstract reasoning parser class that should not be used directly.
Provided and methods should be used in derived classes.
It is used to extract reasoning content from the model output.
@ -32,6 +32,36 @@ class ReasoningParser:
# whereas all tokenizers have .get_vocab()
return self.model_tokenizer.get_vocab()
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
"""
Check if the reasoning content ends in the input_ids.
It is used in structured engines like `xgrammar` to check if the
reasoning content ends in the model output.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
bool
True if the reasoning content ends in the input_ids.
"""
@abstractmethod
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract content token ids from the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
list[int]
The extracted content from the input_ids.
"""
@abstractmethod
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:
@ -53,10 +83,7 @@ class ReasoningParser:
A tuple containing the reasoning content and the content.
"""
raise NotImplementedError(
"AbstractReasoningParser.extract_reasoning_calls "
"has not been implemented!")
@abstractmethod
def extract_reasoning_content_streaming(
self,
previous_text: str,
@ -73,43 +100,6 @@ class ReasoningParser:
the current tokens/diffs, but also the information about what has
previously been parsed and extracted (see constructor)
"""
raise NotImplementedError(
"AbstractReasoningParser.extract_reasoning_content_streaming "
"has not been implemented!")
# TODO: need to rebase by PR #14428
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
"""
Check if the reasoning content ends in the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
bool
True if the reasoning content ends in the input_ids.
"""
raise NotImplementedError(
"AbstractReasoningParser.is_reasoning_end has"
"not been implemented!")
# TODO: need to rebase by PR #14428
@abstractmethod
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract content token ids from the input_ids.
Parameters:
input_ids: list[int]
The input_ids of the model output.
Returns:
list[int]
The extracted content from the input_ids.
"""
raise NotImplementedError(
"AbstractReasoningParser.extract_content_ids has"
" not been implemented!")
class ReasoningParserManager:
@ -125,14 +115,16 @@ class ReasoningParserManager:
if name in cls.reasoning_parsers:
return cls.reasoning_parsers[name]
raise KeyError(f"reasoning helper: '{name}' not found in "
"reasoning_parsers")
raise KeyError(
f"reasoning helper: '{name}' not found in reasoning_parsers")
@classmethod
def _register_module(cls,
module: type,
module_name: Optional[Union[str, list[str]]] = None,
force: bool = True) -> None:
def _register_module(
cls,
module: type,
module_name: Optional[Union[str, list[str]]] = None,
force: bool = True,
) -> None:
if not issubclass(module, ReasoningParser):
raise TypeError("module must be subclass of ReasoningParser, "
f"but got {type(module)}")
@ -149,13 +141,14 @@ class ReasoningParserManager:
@classmethod
def register_module(
cls,
name: Optional[Union[str, list[str]]] = None,
force: bool = True,
module: Union[type, None] = None) -> Union[type, Callable]:
cls,
name: Optional[Union[str, list[str]]] = None,
force: bool = True,
module: Union[type, None] = None,
) -> Union[type, Callable]:
"""
Register module with the given name or name list. it can be used as a
decoder(with module as None) or normal function(with module as not
decoder(with module as None) or normal function(with module as not
None).
"""
if not isinstance(force, bool):
@ -183,7 +176,7 @@ class ReasoningParserManager:
@classmethod
def import_reasoning_parser(cls, plugin_path: str) -> None:
"""
Import a user-defined reasoning parser by the path
Import a user-defined reasoning parser by the path
of the reasoning parser define file.
"""
module_name = os.path.splitext(os.path.basename(plugin_path))[0]

View File

@ -8,9 +8,8 @@ from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
ReasoningParser, ReasoningParserManager)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__)
@ -20,43 +19,45 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
"""
Reasoning parser for DeepSeek R1 model.
The DeepSeek R1 model uses <think>...</think> tokens to denote reasoning
The DeepSeek R1 model uses <think>...</think> tokens to denote reasoning
text. This parser extracts the reasoning content from the model output.
"""
start_token_id: int
end_token_id: int
start_token: str = "<think>"
end_token: str = "</think>"
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
self.think_start_token = "<think>"
self.think_end_token = "</think>"
self.reasoning_regex = re.compile(
rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL)
rf"{self.start_token}(.*?){self.end_token}", re.DOTALL)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction.")
self.think_start_token_id = self.vocab.get(self.think_start_token)
self.think_end_token_id = self.vocab.get(self.think_end_token)
if (self.think_start_token_id is None
or self.think_end_token_id is None):
self.start_token_id = self.vocab.get(self.start_token)
self.end_token_id = self.vocab.get(self.end_token)
if self.start_token_id is None or self.end_token_id is None:
raise RuntimeError(
"DeepSeek R1 reasoning parser could not locate think start/end "
"tokens in the tokenizer!")
# TODO: need to rebase by PR #14428
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.think_end_token_id in input_ids
return self.end_token_id in input_ids
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens
"""
if self.think_end_token_id not in input_ids[:-1]:
if self.end_token_id not in input_ids[:-1]:
return []
else:
return input_ids[input_ids.index(self.think_end_token_id) + 1:]
return input_ids[input_ids.index(self.end_token_id) + 1:]
def extract_reasoning_content_streaming(
self,
@ -77,22 +78,24 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
"""
# Skip single special tokens
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
self.think_start_token_id, self.think_end_token_id
self.start_token_id, self.end_token_id
]):
return None
# Check if <think> is present in previous or delta.
# Keep compatibility with models that don't generate <think> tokens.
if self.think_start_token_id in previous_token_ids:
if self.think_end_token_id in delta_token_ids:
if self.start_token_id in previous_token_ids:
if self.end_token_id in delta_token_ids:
# <think> in previous, </think> in delta,
# extract reasoning content
end_index = delta_text.find(self.think_end_token)
end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
elif self.think_end_token_id in previous_token_ids:
content = delta_text[end_index + len(self.end_token):]
return DeltaMessage(
reasoning_content=reasoning_content,
content=content if content else None,
)
elif self.end_token_id in previous_token_ids:
# <think> in previous, </think> in previous,
# reasoning content continues
return DeltaMessage(content=delta_text)
@ -100,17 +103,18 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# <think> in previous, no </think> in previous or delta,
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
elif self.think_start_token_id in delta_token_ids:
if self.think_end_token_id in delta_token_ids:
elif self.start_token_id in delta_token_ids:
if self.end_token_id in delta_token_ids:
# <think> in delta, </think> in delta, extract reasoning content
start_index = delta_text.find(self.think_start_token)
end_index = delta_text.find(self.think_end_token)
start_index = delta_text.find(self.start_token)
end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[start_index +
len(self.think_start_token
):end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
len(self.start_token):end_index]
content = delta_text[end_index + len(self.end_token):]
return DeltaMessage(
reasoning_content=reasoning_content,
content=content if content else None,
)
else:
# <think> in delta, no </think> in delta,
# reasoning content continues
@ -119,15 +123,17 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# No <think> in previous or delta, also need to check for </think>.
# Because the model may have generated </think> without <think>
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.think_end_token_id in delta_token_ids:
if self.end_token_id in delta_token_ids:
# </think> in delta with more tokens,
# extract reasoning content and content
end_index = delta_text.find(self.think_end_token)
end_index = delta_text.find(self.end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
elif self.think_end_token_id in previous_token_ids:
content = delta_text[end_index + len(self.end_token):]
return DeltaMessage(
reasoning_content=reasoning_content,
content=content if content else None,
)
elif self.end_token_id in previous_token_ids:
# </think> in previous, thinking content ends
return DeltaMessage(content=delta_text)
else:
@ -137,22 +143,20 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:
# DeepSeek R1 doesn't generate <think> now.
# Thus we assume the reasoning content is always at the start.
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.think_end_token not in model_output:
if self.end_token not in model_output:
return model_output, None
else:
# Add a start token if it's missing to keep compatibility.
if self.think_start_token not in model_output:
model_output = f"{self.think_start_token}{model_output}"
if self.start_token not in model_output:
model_output = f"{self.start_token}{model_output}"
# Use a regex to find the reasoning content
reasoning_content = self.reasoning_regex.findall(model_output)[0]
end_index = len(
f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
)
f"{self.start_token}{reasoning_content}{self.end_token}")
final_output = model_output[end_index:]
if len(final_output) == 0:

View File

@ -8,9 +8,8 @@ from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
ReasoningParser, ReasoningParserManager)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__)