mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Refactor][Frontend] Keep all logic about reasoning into one class (#14428)
Signed-off-by: Ce Gao <cegao@tensorchord.ai>
This commit is contained in:
@ -3,74 +3,92 @@
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.entrypoints.openai.reasoning_parsers.utils import (
|
||||
run_reasoning_extraction)
|
||||
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
|
||||
ReasoningParserManager)
|
||||
from tests.reasoning.utils import run_reasoning_extraction
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
parser_name = "deepseek_r1"
|
||||
start_token = "<think>"
|
||||
end_token = "</think>"
|
||||
|
||||
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def deepseek_r1_qwen_tokenizer():
|
||||
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
||||
|
||||
|
||||
SIMPLE_REASONING = {
|
||||
"output": "This is a reasoning section</think>This is the rest",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
COMPLETE_REASONING = {
|
||||
"output": "This is a reasoning section</think>",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
NO_CONTENT = {
|
||||
"output": "This is content",
|
||||
"reasoning_content": "This is content",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NO_REASONING_STREAMING = {
|
||||
"output": "This is a reasoning section",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
MULTIPLE_LINES = {
|
||||
"output": "This\nThat</think>This is the rest\nThat",
|
||||
"reasoning_content": "This\nThat",
|
||||
"content": "This is the rest\nThat",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
SHORTEST_REASONING_NO_STREAMING = {
|
||||
"output": "</think>This is the rest",
|
||||
"reasoning_content": "",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
SHORTEST_REASONING = {
|
||||
"output": "</think>This is the rest",
|
||||
"reasoning_content": None,
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
REASONING_WITH_THINK = {
|
||||
"output": "<think>This is a reasoning section</think>This is the rest",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
COMPLETE_REASONING_WITH_THINK = {
|
||||
"output": "<think>This is a reasoning section</think>",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
MULTIPLE_LINES_WITH_THINK = {
|
||||
"output": "<think>This\nThat</think>This is the rest\nThat",
|
||||
"reasoning_content": "This\nThat",
|
||||
"content": "This is the rest\nThat",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
|
||||
"output": "</think>This is the rest",
|
||||
"reasoning_content": "",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
SHORTEST_REASONING_WITH_THINK = {
|
||||
"output": "</think>This is the rest",
|
||||
"reasoning_content": None,
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
TEST_CASES = [
|
||||
@ -166,23 +184,21 @@ TEST_CASES = [
|
||||
),
|
||||
]
|
||||
|
||||
# Global tokenizer initialization to avoid repeated loading
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
|
||||
tokenizer.add_tokens([start_token, end_token])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
|
||||
def test_reasoning(
|
||||
streaming: bool,
|
||||
param_dict: dict,
|
||||
deepseek_r1_qwen_tokenizer,
|
||||
):
|
||||
output = tokenizer.tokenize(param_dict["output"])
|
||||
output = deepseek_r1_qwen_tokenizer.tokenize(param_dict["output"])
|
||||
# decode everything to tokens
|
||||
output_tokens: list[str] = [
|
||||
tokenizer.convert_tokens_to_string([token]) for token in output
|
||||
deepseek_r1_qwen_tokenizer.convert_tokens_to_string([token])
|
||||
for token in output
|
||||
]
|
||||
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
|
||||
parser_name)(tokenizer)
|
||||
parser_name)(deepseek_r1_qwen_tokenizer)
|
||||
|
||||
reasoning, content = run_reasoning_extraction(parser,
|
||||
output_tokens,
|
||||
@ -190,3 +206,17 @@ def test_reasoning(
|
||||
|
||||
assert reasoning == param_dict["reasoning_content"]
|
||||
assert content == param_dict["content"]
|
||||
|
||||
# Test is_reasoning_end
|
||||
output_ids = deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(output)
|
||||
is_reasoning_end = parser.is_reasoning_end(output_ids)
|
||||
assert is_reasoning_end == param_dict["is_reasoning_end"]
|
||||
|
||||
# Test extract_content
|
||||
if param_dict["content"] is not None:
|
||||
content = parser.extract_content_ids(output_ids)
|
||||
assert content == deepseek_r1_qwen_tokenizer.convert_tokens_to_ids(
|
||||
deepseek_r1_qwen_tokenizer.tokenize(param_dict["content"]))
|
||||
else:
|
||||
content = parser.extract_content_ids(output)
|
||||
assert content == []
|
@ -2,10 +2,8 @@
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.entrypoints.openai.reasoning_parsers.utils import (
|
||||
DeltaMessage, run_reasoning_extraction)
|
||||
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
|
||||
ReasoningParserManager)
|
||||
from tests.reasoning.utils import DeltaMessage, run_reasoning_extraction
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
parser_name = "granite"
|
||||
START_REASONING = "Here is my thought process:"
|
@ -4,7 +4,7 @@ from typing import Optional, Union
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage)
|
||||
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParser
|
||||
from vllm.reasoning import ReasoningParser
|
||||
|
||||
|
||||
class StreamingReasoningReconstructor:
|
@ -23,6 +23,7 @@ from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.plugins import load_general_plugins
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@ -1119,7 +1120,7 @@ class EngineArgs:
|
||||
parser.add_argument(
|
||||
"--reasoning-parser",
|
||||
type=str,
|
||||
choices=["deepseek_r1", "granite"],
|
||||
choices=list(ReasoningParserManager.reasoning_parsers),
|
||||
default=None,
|
||||
help=
|
||||
"Select the reasoning parser depending on the model that you're "
|
||||
|
@ -2080,8 +2080,9 @@ class LLMEngine:
|
||||
guided_decoding.backend = guided_decoding.backend or \
|
||||
self.decoding_config.guided_decoding_backend
|
||||
|
||||
logger.debug("Reasoning backend: %s",
|
||||
self.decoding_config.reasoning_backend)
|
||||
if self.decoding_config.reasoning_backend is not None:
|
||||
logger.debug("Building with reasoning backend %s",
|
||||
self.decoding_config.reasoning_backend)
|
||||
|
||||
processor = get_local_guided_decoding_logits_processor(
|
||||
guided_params=guided_decoding,
|
||||
|
@ -68,7 +68,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
TranscriptionRequest,
|
||||
TranscriptionResponse,
|
||||
UnloadLoRAAdapterRequest)
|
||||
from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
@ -85,6 +84,7 @@ from vllm.entrypoints.openai.serving_transcription import (
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.entrypoints.utils import load_aware_call, with_cancellation
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
|
@ -23,8 +23,6 @@ from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
|
||||
RequestResponseMetadata, ToolCall, UsageInfo)
|
||||
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
|
||||
ReasoningParserManager)
|
||||
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
clamp_prompt_logprobs)
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
@ -33,6 +31,7 @@ from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
|
||||
MistralToolCall)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
|
@ -5,10 +5,10 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding.reasoner import get_reasoner
|
||||
from vllm.model_executor.guided_decoding.utils import (
|
||||
convert_lark_to_gbnf, grammar_is_likely_lark,
|
||||
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
@ -107,7 +107,11 @@ async def get_guided_decoding_logits_processor(
|
||||
model_config: ModelConfig,
|
||||
reasoning_backend: str | None = None) -> LogitsProcessor | None:
|
||||
|
||||
reasoner = get_reasoner(tokenizer, reasoning_backend)
|
||||
reasoner = None
|
||||
if reasoning_backend is not None:
|
||||
reasoner_class = ReasoningParserManager.get_reasoning_parser(
|
||||
reasoning_backend)
|
||||
reasoner = reasoner_class(tokenizer)
|
||||
|
||||
guided_params = maybe_backend_fallback(guided_params)
|
||||
|
||||
@ -146,8 +150,11 @@ def get_local_guided_decoding_logits_processor(
|
||||
reasoning_backend: str | None = None) -> LogitsProcessor | None:
|
||||
guided_params = maybe_backend_fallback(guided_params)
|
||||
|
||||
# Get the reasoner if needed, it will be None if reasoning_
|
||||
reasoner = get_reasoner(tokenizer, reasoning_backend)
|
||||
reasoner = None
|
||||
if reasoning_backend is not None:
|
||||
reasoner_class = ReasoningParserManager.get_reasoning_parser(
|
||||
reasoning_backend)
|
||||
reasoner = reasoner_class(tokenizer)
|
||||
|
||||
# CFG grammar not supported by LMFE, so we use outlines instead
|
||||
if guided_params.backend_name == 'outlines':
|
||||
|
@ -12,7 +12,7 @@ from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
|
||||
from vllm.model_executor.guided_decoding.reasoner import Reasoner
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
|
||||
@ -61,7 +61,7 @@ _MAX_THREADPOOL_WORKERS = 16
|
||||
async def get_outlines_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[Reasoner],
|
||||
reasoner: Optional[ReasoningParser],
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
||||
None]:
|
||||
"""
|
||||
@ -92,7 +92,7 @@ async def get_outlines_guided_decoding_logits_processor(
|
||||
def get_local_outlines_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[Reasoner],
|
||||
reasoner: Optional[ReasoningParser],
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
|
||||
None]:
|
||||
"""
|
||||
@ -141,7 +141,7 @@ def _get_logits_processor(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
mode: GuidedDecodingMode,
|
||||
whitespace_pattern: Union[str, None],
|
||||
reasoner: Optional[Reasoner],
|
||||
reasoner: Optional[ReasoningParser],
|
||||
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
|
||||
if mode == GuidedDecodingMode.JSON:
|
||||
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
|
||||
|
@ -34,8 +34,8 @@ from transformers import PreTrainedTokenizerBase
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding.reasoner import Reasoner
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.reasoning import ReasoningParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -49,9 +49,9 @@ else:
|
||||
|
||||
class BaseLogitsProcessor:
|
||||
|
||||
def __init__(self, guide: Guide, reasoner: Optional[Reasoner]):
|
||||
def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]):
|
||||
self._guide: Guide = guide
|
||||
self._reasoner: Optional[Reasoner] = reasoner
|
||||
self._reasoner: Optional[ReasoningParser] = reasoner
|
||||
# CFGState is used for the FSM state for CFGGuide
|
||||
self._fsm_state: DefaultDict[int, Union[int,
|
||||
CFGState]] = defaultdict(int)
|
||||
@ -69,7 +69,7 @@ class BaseLogitsProcessor:
|
||||
# Remove the reasoning tokens from the input_ids
|
||||
# We need this because our implementation relies on the
|
||||
# hash of the input_ids to store the FSM state.
|
||||
input_ids = self._reasoner.extract_content(input_ids)
|
||||
input_ids = self._reasoner.extract_content_ids(input_ids)
|
||||
|
||||
seq_id = hash(tuple(input_ids))
|
||||
|
||||
@ -142,7 +142,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
|
||||
self,
|
||||
regex_string: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[Reasoner],
|
||||
reasoner: Optional[ReasoningParser],
|
||||
):
|
||||
"""Compile the FSM that drives the regex-structured generation.
|
||||
|
||||
@ -163,7 +163,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||
def __init__(self, schema: Union[str, Dict, BaseModel],
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
whitespace_pattern: Union[str, None],
|
||||
reasoner: Optional[Reasoner]):
|
||||
reasoner: Optional[ReasoningParser]):
|
||||
"""Compile the FSM that drives the JSON-guided generation.
|
||||
|
||||
Parameters
|
||||
@ -203,7 +203,7 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
|
||||
return CFGGuide(cfg, tokenizer)
|
||||
|
||||
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
|
||||
reasoner: Optional[Reasoner]):
|
||||
reasoner: Optional[ReasoningParser]):
|
||||
"""Compile the FSM that drives the context free grammar generation.
|
||||
|
||||
Parameters
|
||||
|
@ -1,38 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from dataclasses import dataclass
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepSeekReasoner(Reasoner):
|
||||
"""
|
||||
Reasoner for DeepSeek R series models.
|
||||
"""
|
||||
start_token_id: int
|
||||
end_token_id: int
|
||||
|
||||
start_token: str = "<think>"
|
||||
end_token: str = "</think>"
|
||||
|
||||
@classmethod
|
||||
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
|
||||
return cls(start_token_id=tokenizer.encode(
|
||||
"<think>", add_special_tokens=False)[0],
|
||||
end_token_id=tokenizer.encode("</think>",
|
||||
add_special_tokens=False)[0])
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.end_token_id in input_ids
|
||||
|
||||
def extract_content(self, input_ids: list[int]) -> list[int]:
|
||||
"""
|
||||
Extract the content after the end tokens
|
||||
"""
|
||||
if self.end_token_id not in input_ids or \
|
||||
input_ids.index(self.end_token_id) + 1 == len(input_ids):
|
||||
return []
|
||||
else:
|
||||
return input_ids[input_ids.index(self.end_token_id) + 1:]
|
@ -1,23 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
@dataclass
|
||||
class Reasoner(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def extract_content(self, input_ids: list[int]) -> list[int]:
|
||||
pass
|
@ -27,7 +27,7 @@ if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.guided_decoding.reasoner import Reasoner
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -37,7 +37,7 @@ def get_local_xgrammar_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
model_config: ModelConfig,
|
||||
reasoner: Reasoner | None,
|
||||
reasoner: ReasoningParser | None,
|
||||
max_threads: int = 8):
|
||||
config = GrammarConfig.from_guided_params(guided_params=guided_params,
|
||||
model_config=model_config,
|
||||
@ -280,7 +280,7 @@ class GrammarConfig:
|
||||
class XGrammarLogitsProcessor:
|
||||
"""Wrapper class to support pickle protocol"""
|
||||
config: GrammarConfig
|
||||
reasoner: Reasoner | None = None
|
||||
reasoner: ReasoningParser | None = None
|
||||
|
||||
ctx: xgr.CompiledGrammar | None = None
|
||||
tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment]
|
||||
|
@ -17,7 +17,7 @@ logger = init_logger(__name__)
|
||||
|
||||
class ReasoningParser:
|
||||
"""
|
||||
Abstract reasoning parser class that should not be used directly.
|
||||
Abstract reasoning parser class that should not be used directly.
|
||||
Provided and methods should be used in derived classes.
|
||||
|
||||
It is used to extract reasoning content from the model output.
|
||||
@ -32,6 +32,36 @@ class ReasoningParser:
|
||||
# whereas all tokenizers have .get_vocab()
|
||||
return self.model_tokenizer.get_vocab()
|
||||
|
||||
@abstractmethod
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
"""
|
||||
Check if the reasoning content ends in the input_ids.
|
||||
|
||||
It is used in structured engines like `xgrammar` to check if the
|
||||
reasoning content ends in the model output.
|
||||
|
||||
Parameters:
|
||||
input_ids: list[int]
|
||||
The input_ids of the model output.
|
||||
|
||||
Returns:
|
||||
bool
|
||||
True if the reasoning content ends in the input_ids.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
"""
|
||||
Extract content token ids from the input_ids.
|
||||
Parameters:
|
||||
input_ids: list[int]
|
||||
The input_ids of the model output.
|
||||
Returns:
|
||||
list[int]
|
||||
The extracted content from the input_ids.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
@ -53,10 +83,7 @@ class ReasoningParser:
|
||||
A tuple containing the reasoning content and the content.
|
||||
"""
|
||||
|
||||
raise NotImplementedError(
|
||||
"AbstractReasoningParser.extract_reasoning_calls "
|
||||
"has not been implemented!")
|
||||
|
||||
@abstractmethod
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
@ -73,43 +100,6 @@ class ReasoningParser:
|
||||
the current tokens/diffs, but also the information about what has
|
||||
previously been parsed and extracted (see constructor)
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"AbstractReasoningParser.extract_reasoning_content_streaming "
|
||||
"has not been implemented!")
|
||||
|
||||
# TODO: need to rebase by PR #14428
|
||||
@abstractmethod
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
"""
|
||||
Check if the reasoning content ends in the input_ids.
|
||||
Parameters:
|
||||
input_ids: list[int]
|
||||
The input_ids of the model output.
|
||||
Returns:
|
||||
bool
|
||||
True if the reasoning content ends in the input_ids.
|
||||
"""
|
||||
|
||||
raise NotImplementedError(
|
||||
"AbstractReasoningParser.is_reasoning_end has"
|
||||
"not been implemented!")
|
||||
|
||||
# TODO: need to rebase by PR #14428
|
||||
@abstractmethod
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
"""
|
||||
Extract content token ids from the input_ids.
|
||||
Parameters:
|
||||
input_ids: list[int]
|
||||
The input_ids of the model output.
|
||||
Returns:
|
||||
list[int]
|
||||
The extracted content from the input_ids.
|
||||
"""
|
||||
|
||||
raise NotImplementedError(
|
||||
"AbstractReasoningParser.extract_content_ids has"
|
||||
" not been implemented!")
|
||||
|
||||
|
||||
class ReasoningParserManager:
|
||||
@ -125,14 +115,16 @@ class ReasoningParserManager:
|
||||
if name in cls.reasoning_parsers:
|
||||
return cls.reasoning_parsers[name]
|
||||
|
||||
raise KeyError(f"reasoning helper: '{name}' not found in "
|
||||
"reasoning_parsers")
|
||||
raise KeyError(
|
||||
f"reasoning helper: '{name}' not found in reasoning_parsers")
|
||||
|
||||
@classmethod
|
||||
def _register_module(cls,
|
||||
module: type,
|
||||
module_name: Optional[Union[str, list[str]]] = None,
|
||||
force: bool = True) -> None:
|
||||
def _register_module(
|
||||
cls,
|
||||
module: type,
|
||||
module_name: Optional[Union[str, list[str]]] = None,
|
||||
force: bool = True,
|
||||
) -> None:
|
||||
if not issubclass(module, ReasoningParser):
|
||||
raise TypeError("module must be subclass of ReasoningParser, "
|
||||
f"but got {type(module)}")
|
||||
@ -149,13 +141,14 @@ class ReasoningParserManager:
|
||||
|
||||
@classmethod
|
||||
def register_module(
|
||||
cls,
|
||||
name: Optional[Union[str, list[str]]] = None,
|
||||
force: bool = True,
|
||||
module: Union[type, None] = None) -> Union[type, Callable]:
|
||||
cls,
|
||||
name: Optional[Union[str, list[str]]] = None,
|
||||
force: bool = True,
|
||||
module: Union[type, None] = None,
|
||||
) -> Union[type, Callable]:
|
||||
"""
|
||||
Register module with the given name or name list. it can be used as a
|
||||
decoder(with module as None) or normal function(with module as not
|
||||
decoder(with module as None) or normal function(with module as not
|
||||
None).
|
||||
"""
|
||||
if not isinstance(force, bool):
|
||||
@ -183,7 +176,7 @@ class ReasoningParserManager:
|
||||
@classmethod
|
||||
def import_reasoning_parser(cls, plugin_path: str) -> None:
|
||||
"""
|
||||
Import a user-defined reasoning parser by the path
|
||||
Import a user-defined reasoning parser by the path
|
||||
of the reasoning parser define file.
|
||||
"""
|
||||
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
|
@ -8,9 +8,8 @@ from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage)
|
||||
from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
|
||||
ReasoningParser, ReasoningParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -20,43 +19,45 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
|
||||
"""
|
||||
Reasoning parser for DeepSeek R1 model.
|
||||
|
||||
The DeepSeek R1 model uses <think>...</think> tokens to denote reasoning
|
||||
The DeepSeek R1 model uses <think>...</think> tokens to denote reasoning
|
||||
text. This parser extracts the reasoning content from the model output.
|
||||
"""
|
||||
|
||||
start_token_id: int
|
||||
end_token_id: int
|
||||
|
||||
start_token: str = "<think>"
|
||||
end_token: str = "</think>"
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
self.think_start_token = "<think>"
|
||||
self.think_end_token = "</think>"
|
||||
|
||||
self.reasoning_regex = re.compile(
|
||||
rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL)
|
||||
rf"{self.start_token}(.*?){self.end_token}", re.DOTALL)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ReasoningParser "
|
||||
"constructor during construction.")
|
||||
|
||||
self.think_start_token_id = self.vocab.get(self.think_start_token)
|
||||
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
||||
if (self.think_start_token_id is None
|
||||
or self.think_end_token_id is None):
|
||||
self.start_token_id = self.vocab.get(self.start_token)
|
||||
self.end_token_id = self.vocab.get(self.end_token)
|
||||
if self.start_token_id is None or self.end_token_id is None:
|
||||
raise RuntimeError(
|
||||
"DeepSeek R1 reasoning parser could not locate think start/end "
|
||||
"tokens in the tokenizer!")
|
||||
|
||||
# TODO: need to rebase by PR #14428
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.think_end_token_id in input_ids
|
||||
return self.end_token_id in input_ids
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
"""
|
||||
Extract the content after the end tokens
|
||||
"""
|
||||
if self.think_end_token_id not in input_ids[:-1]:
|
||||
if self.end_token_id not in input_ids[:-1]:
|
||||
return []
|
||||
else:
|
||||
return input_ids[input_ids.index(self.think_end_token_id) + 1:]
|
||||
return input_ids[input_ids.index(self.end_token_id) + 1:]
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
@ -77,22 +78,24 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
|
||||
"""
|
||||
# Skip single special tokens
|
||||
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
|
||||
self.think_start_token_id, self.think_end_token_id
|
||||
self.start_token_id, self.end_token_id
|
||||
]):
|
||||
return None
|
||||
|
||||
# Check if <think> is present in previous or delta.
|
||||
# Keep compatibility with models that don't generate <think> tokens.
|
||||
if self.think_start_token_id in previous_token_ids:
|
||||
if self.think_end_token_id in delta_token_ids:
|
||||
if self.start_token_id in previous_token_ids:
|
||||
if self.end_token_id in delta_token_ids:
|
||||
# <think> in previous, </think> in delta,
|
||||
# extract reasoning content
|
||||
end_index = delta_text.find(self.think_end_token)
|
||||
end_index = delta_text.find(self.end_token)
|
||||
reasoning_content = delta_text[:end_index]
|
||||
content = delta_text[end_index + len(self.think_end_token):]
|
||||
return DeltaMessage(reasoning_content=reasoning_content,
|
||||
content=content if content else None)
|
||||
elif self.think_end_token_id in previous_token_ids:
|
||||
content = delta_text[end_index + len(self.end_token):]
|
||||
return DeltaMessage(
|
||||
reasoning_content=reasoning_content,
|
||||
content=content if content else None,
|
||||
)
|
||||
elif self.end_token_id in previous_token_ids:
|
||||
# <think> in previous, </think> in previous,
|
||||
# reasoning content continues
|
||||
return DeltaMessage(content=delta_text)
|
||||
@ -100,17 +103,18 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
|
||||
# <think> in previous, no </think> in previous or delta,
|
||||
# reasoning content continues
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
elif self.think_start_token_id in delta_token_ids:
|
||||
if self.think_end_token_id in delta_token_ids:
|
||||
elif self.start_token_id in delta_token_ids:
|
||||
if self.end_token_id in delta_token_ids:
|
||||
# <think> in delta, </think> in delta, extract reasoning content
|
||||
start_index = delta_text.find(self.think_start_token)
|
||||
end_index = delta_text.find(self.think_end_token)
|
||||
start_index = delta_text.find(self.start_token)
|
||||
end_index = delta_text.find(self.end_token)
|
||||
reasoning_content = delta_text[start_index +
|
||||
len(self.think_start_token
|
||||
):end_index]
|
||||
content = delta_text[end_index + len(self.think_end_token):]
|
||||
return DeltaMessage(reasoning_content=reasoning_content,
|
||||
content=content if content else None)
|
||||
len(self.start_token):end_index]
|
||||
content = delta_text[end_index + len(self.end_token):]
|
||||
return DeltaMessage(
|
||||
reasoning_content=reasoning_content,
|
||||
content=content if content else None,
|
||||
)
|
||||
else:
|
||||
# <think> in delta, no </think> in delta,
|
||||
# reasoning content continues
|
||||
@ -119,15 +123,17 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
|
||||
# No <think> in previous or delta, also need to check for </think>.
|
||||
# Because the model may have generated </think> without <think>
|
||||
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
|
||||
if self.think_end_token_id in delta_token_ids:
|
||||
if self.end_token_id in delta_token_ids:
|
||||
# </think> in delta with more tokens,
|
||||
# extract reasoning content and content
|
||||
end_index = delta_text.find(self.think_end_token)
|
||||
end_index = delta_text.find(self.end_token)
|
||||
reasoning_content = delta_text[:end_index]
|
||||
content = delta_text[end_index + len(self.think_end_token):]
|
||||
return DeltaMessage(reasoning_content=reasoning_content,
|
||||
content=content if content else None)
|
||||
elif self.think_end_token_id in previous_token_ids:
|
||||
content = delta_text[end_index + len(self.end_token):]
|
||||
return DeltaMessage(
|
||||
reasoning_content=reasoning_content,
|
||||
content=content if content else None,
|
||||
)
|
||||
elif self.end_token_id in previous_token_ids:
|
||||
# </think> in previous, thinking content ends
|
||||
return DeltaMessage(content=delta_text)
|
||||
else:
|
||||
@ -137,22 +143,20 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
|
||||
# DeepSeek R1 doesn't generate <think> now.
|
||||
# Thus we assume the reasoning content is always at the start.
|
||||
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
|
||||
if self.think_end_token not in model_output:
|
||||
if self.end_token not in model_output:
|
||||
return model_output, None
|
||||
else:
|
||||
# Add a start token if it's missing to keep compatibility.
|
||||
if self.think_start_token not in model_output:
|
||||
model_output = f"{self.think_start_token}{model_output}"
|
||||
if self.start_token not in model_output:
|
||||
model_output = f"{self.start_token}{model_output}"
|
||||
# Use a regex to find the reasoning content
|
||||
reasoning_content = self.reasoning_regex.findall(model_output)[0]
|
||||
|
||||
end_index = len(
|
||||
f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
|
||||
)
|
||||
f"{self.start_token}{reasoning_content}{self.end_token}")
|
||||
final_output = model_output[end_index:]
|
||||
|
||||
if len(final_output) == 0:
|
@ -8,9 +8,8 @@ from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage)
|
||||
from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
|
||||
ReasoningParser, ReasoningParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
Reference in New Issue
Block a user