mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Signed-off-by: taohui <taohui3@gmail.com> Signed-off-by: Tao Hui <taohui3@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
@ -11,6 +11,7 @@ vLLM currently supports the following reasoning models:
|
||||
| Model Series | Parser Name | Structured Output Support | Tool Calling |
|
||||
|--------------|-------------|------------------|-------------|
|
||||
| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `json`, `regex` | ❌ |
|
||||
| [DeepSeek-V3.1](https://huggingface.co/collections/deepseek-ai/deepseek-v31-68a491bed32bd77e7fca048f) | `deepseek_v3` | `json`, `regex` | ❌ |
|
||||
| [ERNIE-4.5-VL series](https://huggingface.co/baidu/ERNIE-4.5-VL-28B-A3B-PT) | `ernie45` | `json`, `regex` | ❌ |
|
||||
| [ERNIE-4.5-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking) | `ernie45` | `json`, `regex` | ✅ |
|
||||
| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `json`, `regex` | ✅ |
|
||||
@ -20,8 +21,9 @@ vLLM currently supports the following reasoning models:
|
||||
| [GLM-4.5 series](https://huggingface.co/collections/zai-org/glm-45-687c621d34bda8c9e4bf503b) | `glm45` | `json`, `regex` | ✅ |
|
||||
|
||||
!!! note
|
||||
IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`.
|
||||
IBM Granite 3.2 and DeepSeek-V3.1 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`.
|
||||
The reasoning feature for the Qwen3 series is enabled by default. To disable it, you must pass `enable_thinking=False` in your `chat_template_kwargs`.
|
||||
DeepSeek-V3.1 tool calling is supported in non-thinking mode.
|
||||
|
||||
## Quickstart
|
||||
|
||||
|
76
tests/reasoning/test_deepseekv3_reasoning_parser.py
Normal file
76
tests/reasoning/test_deepseekv3_reasoning_parser.py
Normal file
@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||
from vllm.reasoning import (
|
||||
DeepSeekR1ReasoningParser,
|
||||
DeepSeekV3ReasoningParser,
|
||||
IdentityReasoningParser,
|
||||
)
|
||||
|
||||
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-V3.1"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def tokenizer():
|
||||
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"thinking,expected_parser_type",
|
||||
[
|
||||
(True, DeepSeekR1ReasoningParser),
|
||||
(False, IdentityReasoningParser),
|
||||
],
|
||||
)
|
||||
def test_parser_selection(tokenizer, thinking, expected_parser_type):
|
||||
parser = DeepSeekV3ReasoningParser(
|
||||
tokenizer, chat_template_kwargs={"thinking": thinking}
|
||||
)
|
||||
|
||||
assert isinstance(parser._parser, expected_parser_type)
|
||||
|
||||
|
||||
def test_identity_reasoning_parser_basic(tokenizer):
|
||||
parser = IdentityReasoningParser(tokenizer)
|
||||
|
||||
# Test is_reasoning_end always returns True
|
||||
input_text = "This is some output"
|
||||
input_tokens = tokenizer.tokenize(input_text)
|
||||
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
|
||||
assert parser.is_reasoning_end(input_ids) is True
|
||||
|
||||
# Test extract_content_ids returns all input_ids
|
||||
assert parser.extract_content_ids(input_ids) == input_ids
|
||||
|
||||
# Test extract_reasoning_content returns (None, model_output)
|
||||
request = ChatCompletionRequest(model="test-model", messages=[], temperature=1.0)
|
||||
reasoning, content = parser.extract_reasoning_content(input_text, request)
|
||||
assert reasoning is None
|
||||
assert content == input_text
|
||||
|
||||
# Test extract_reasoning_content_streaming returns DeltaMessage or None
|
||||
result = parser.extract_reasoning_content_streaming(
|
||||
previous_text="",
|
||||
current_text="Hello world",
|
||||
delta_text="Hello world",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=input_ids,
|
||||
delta_token_ids=input_ids,
|
||||
)
|
||||
assert isinstance(result, DeltaMessage)
|
||||
assert result.content == "Hello world"
|
||||
|
||||
# If delta_text is empty, should return None
|
||||
result_none = parser.extract_reasoning_content_streaming(
|
||||
previous_text="Hello world",
|
||||
current_text="Hello world",
|
||||
delta_text="",
|
||||
previous_token_ids=input_ids,
|
||||
current_token_ids=input_ids,
|
||||
delta_token_ids=[],
|
||||
)
|
||||
assert result_none is None
|
@ -570,7 +570,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
try:
|
||||
if self.reasoning_parser:
|
||||
reasoning_parser = self.reasoning_parser(tokenizer)
|
||||
reasoning_parser = self.reasoning_parser(
|
||||
tokenizer,
|
||||
chat_template_kwargs=request.chat_template_kwargs, # type: ignore
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.exception("Error in reasoning parser creation.")
|
||||
data = self.create_streaming_error_response(str(e))
|
||||
@ -1335,7 +1338,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
if self.reasoning_parser:
|
||||
try:
|
||||
reasoning_parser = self.reasoning_parser(tokenizer)
|
||||
reasoning_parser = self.reasoning_parser(
|
||||
tokenizer,
|
||||
chat_template_kwargs=request.chat_template_kwargs, # type: ignore
|
||||
)
|
||||
except RuntimeError as e:
|
||||
logger.exception("Error in reasoning parser creation.")
|
||||
return self.create_error_response(str(e))
|
||||
|
@ -4,11 +4,13 @@
|
||||
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
|
||||
from .basic_parsers import BaseThinkingReasoningParser
|
||||
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
||||
from .deepseek_v3_reasoning_parser import DeepSeekV3ReasoningParser
|
||||
from .ernie45_reasoning_parser import Ernie45ReasoningParser
|
||||
from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser
|
||||
from .gptoss_reasoning_parser import GptOssReasoningParser
|
||||
from .granite_reasoning_parser import GraniteReasoningParser
|
||||
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
|
||||
from .identity_reasoning_parser import IdentityReasoningParser
|
||||
from .mistral_reasoning_parser import MistralReasoningParser
|
||||
from .olmo3_reasoning_parser import Olmo3ReasoningParser
|
||||
from .qwen3_reasoning_parser import Qwen3ReasoningParser
|
||||
@ -20,6 +22,8 @@ __all__ = [
|
||||
"BaseThinkingReasoningParser",
|
||||
"ReasoningParserManager",
|
||||
"DeepSeekR1ReasoningParser",
|
||||
"IdentityReasoningParser",
|
||||
"DeepSeekV3ReasoningParser",
|
||||
"Ernie45ReasoningParser",
|
||||
"GraniteReasoningParser",
|
||||
"HunyuanA13BReasoningParser",
|
||||
|
66
vllm/reasoning/deepseek_v3_reasoning_parser.py
Normal file
66
vllm/reasoning/deepseek_v3_reasoning_parser.py
Normal file
@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import (
|
||||
DeepSeekR1ReasoningParser,
|
||||
ReasoningParser,
|
||||
ReasoningParserManager,
|
||||
)
|
||||
|
||||
from .identity_reasoning_parser import IdentityReasoningParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ReasoningParserManager.register_module("deepseek_v3")
|
||||
class DeepSeekV3ReasoningParser(ReasoningParser):
|
||||
"""
|
||||
V3 parser that delegates to either DeepSeekR1ReasoningParser or
|
||||
IdentityReasoningParser based on `thinking` and `separate_reasoning`.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
|
||||
chat_kwargs = kwargs.pop("chat_template_kwargs", {}) or {}
|
||||
thinking = bool(chat_kwargs.pop("thinking", False))
|
||||
|
||||
if thinking:
|
||||
self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs)
|
||||
else:
|
||||
self._parser = IdentityReasoningParser(tokenizer, *args, **kwargs)
|
||||
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
return self._parser.is_reasoning_end(input_ids)
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
return self._parser.extract_content_ids(input_ids)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> tuple[str | None, str | None]:
|
||||
return self._parser.extract_reasoning_content(model_output, request)
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> DeltaMessage | None:
|
||||
return self._parser.extract_reasoning_content_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
)
|
58
vllm/reasoning/identity_reasoning_parser.py
Normal file
58
vllm/reasoning/identity_reasoning_parser.py
Normal file
@ -0,0 +1,58 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class IdentityReasoningParser(ReasoningParser):
|
||||
"""
|
||||
Identity reasoning parser.
|
||||
|
||||
This parser does not attempt to parse or strip out reasoning tokens.
|
||||
It treats the entire model output as content and ignores reasoning.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ReasoningParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
# Always return True, since we never treat reasoning specially
|
||||
return True
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
# Identity: return all tokens as content
|
||||
return input_ids
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> DeltaMessage | None:
|
||||
# Just wrap delta_text as content, ignore reasoning
|
||||
if delta_text:
|
||||
return DeltaMessage(content=delta_text)
|
||||
return None
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> tuple[str | None, str | None]:
|
||||
# No reasoning separation: return None for reasoning_content,
|
||||
# and full model_output as content
|
||||
return None, model_output
|
Reference in New Issue
Block a user