[Model] Add DeepSeek-V3.1 reasoning parser (split from PR #24972) (#25589)

Signed-off-by: taohui <taohui3@gmail.com>
Signed-off-by: Tao Hui <taohui3@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
Tao Hui
2025-10-15 11:09:52 +08:00
committed by GitHub
parent a2986b3e33
commit 85a65e7f51
6 changed files with 215 additions and 3 deletions

View File

@ -11,6 +11,7 @@ vLLM currently supports the following reasoning models:
| Model Series | Parser Name | Structured Output Support | Tool Calling |
|--------------|-------------|------------------|-------------|
| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `json`, `regex` | ❌ |
| [DeepSeek-V3.1](https://huggingface.co/collections/deepseek-ai/deepseek-v31-68a491bed32bd77e7fca048f) | `deepseek_v3` | `json`, `regex` | ❌ |
| [ERNIE-4.5-VL series](https://huggingface.co/baidu/ERNIE-4.5-VL-28B-A3B-PT) | `ernie45` | `json`, `regex` | ❌ |
| [ERNIE-4.5-21B-A3B-Thinking](https://huggingface.co/baidu/ERNIE-4.5-21B-A3B-Thinking) | `ernie45` | `json`, `regex` | ✅ |
| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `json`, `regex` | ✅ |
@ -20,8 +21,9 @@ vLLM currently supports the following reasoning models:
| [GLM-4.5 series](https://huggingface.co/collections/zai-org/glm-45-687c621d34bda8c9e4bf503b) | `glm45` | `json`, `regex` | ✅ |
!!! note
IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`.
IBM Granite 3.2 and DeepSeek-V3.1 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`.
The reasoning feature for the Qwen3 series is enabled by default. To disable it, you must pass `enable_thinking=False` in your `chat_template_kwargs`.
DeepSeek-V3.1 tool calling is supported in non-thinking mode.
## Quickstart

View File

@ -0,0 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from transformers import AutoTokenizer
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from vllm.reasoning import (
DeepSeekR1ReasoningParser,
DeepSeekV3ReasoningParser,
IdentityReasoningParser,
)
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-V3.1"
@pytest.fixture(scope="module")
def tokenizer():
return AutoTokenizer.from_pretrained(REASONING_MODEL_NAME)
@pytest.mark.parametrize(
"thinking,expected_parser_type",
[
(True, DeepSeekR1ReasoningParser),
(False, IdentityReasoningParser),
],
)
def test_parser_selection(tokenizer, thinking, expected_parser_type):
parser = DeepSeekV3ReasoningParser(
tokenizer, chat_template_kwargs={"thinking": thinking}
)
assert isinstance(parser._parser, expected_parser_type)
def test_identity_reasoning_parser_basic(tokenizer):
parser = IdentityReasoningParser(tokenizer)
# Test is_reasoning_end always returns True
input_text = "This is some output"
input_tokens = tokenizer.tokenize(input_text)
input_ids = tokenizer.convert_tokens_to_ids(input_tokens)
assert parser.is_reasoning_end(input_ids) is True
# Test extract_content_ids returns all input_ids
assert parser.extract_content_ids(input_ids) == input_ids
# Test extract_reasoning_content returns (None, model_output)
request = ChatCompletionRequest(model="test-model", messages=[], temperature=1.0)
reasoning, content = parser.extract_reasoning_content(input_text, request)
assert reasoning is None
assert content == input_text
# Test extract_reasoning_content_streaming returns DeltaMessage or None
result = parser.extract_reasoning_content_streaming(
previous_text="",
current_text="Hello world",
delta_text="Hello world",
previous_token_ids=[],
current_token_ids=input_ids,
delta_token_ids=input_ids,
)
assert isinstance(result, DeltaMessage)
assert result.content == "Hello world"
# If delta_text is empty, should return None
result_none = parser.extract_reasoning_content_streaming(
previous_text="Hello world",
current_text="Hello world",
delta_text="",
previous_token_ids=input_ids,
current_token_ids=input_ids,
delta_token_ids=[],
)
assert result_none is None

View File

@ -570,7 +570,10 @@ class OpenAIServingChat(OpenAIServing):
try:
if self.reasoning_parser:
reasoning_parser = self.reasoning_parser(tokenizer)
reasoning_parser = self.reasoning_parser(
tokenizer,
chat_template_kwargs=request.chat_template_kwargs, # type: ignore
)
except RuntimeError as e:
logger.exception("Error in reasoning parser creation.")
data = self.create_streaming_error_response(str(e))
@ -1335,7 +1338,10 @@ class OpenAIServingChat(OpenAIServing):
if self.reasoning_parser:
try:
reasoning_parser = self.reasoning_parser(tokenizer)
reasoning_parser = self.reasoning_parser(
tokenizer,
chat_template_kwargs=request.chat_template_kwargs, # type: ignore
)
except RuntimeError as e:
logger.exception("Error in reasoning parser creation.")
return self.create_error_response(str(e))

View File

@ -4,11 +4,13 @@
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
from .basic_parsers import BaseThinkingReasoningParser
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from .deepseek_v3_reasoning_parser import DeepSeekV3ReasoningParser
from .ernie45_reasoning_parser import Ernie45ReasoningParser
from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser
from .gptoss_reasoning_parser import GptOssReasoningParser
from .granite_reasoning_parser import GraniteReasoningParser
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
from .identity_reasoning_parser import IdentityReasoningParser
from .mistral_reasoning_parser import MistralReasoningParser
from .olmo3_reasoning_parser import Olmo3ReasoningParser
from .qwen3_reasoning_parser import Qwen3ReasoningParser
@ -20,6 +22,8 @@ __all__ = [
"BaseThinkingReasoningParser",
"ReasoningParserManager",
"DeepSeekR1ReasoningParser",
"IdentityReasoningParser",
"DeepSeekV3ReasoningParser",
"Ernie45ReasoningParser",
"GraniteReasoningParser",
"HunyuanA13BReasoningParser",

View File

@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from vllm.logger import init_logger
from vllm.reasoning import (
DeepSeekR1ReasoningParser,
ReasoningParser,
ReasoningParserManager,
)
from .identity_reasoning_parser import IdentityReasoningParser
logger = init_logger(__name__)
@ReasoningParserManager.register_module("deepseek_v3")
class DeepSeekV3ReasoningParser(ReasoningParser):
"""
V3 parser that delegates to either DeepSeekR1ReasoningParser or
IdentityReasoningParser based on `thinking` and `separate_reasoning`.
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
chat_kwargs = kwargs.pop("chat_template_kwargs", {}) or {}
thinking = bool(chat_kwargs.pop("thinking", False))
if thinking:
self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs)
else:
self._parser = IdentityReasoningParser(tokenizer, *args, **kwargs)
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
return self._parser.is_reasoning_end(input_ids)
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
return self._parser.extract_content_ids(input_ids)
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[str | None, str | None]:
return self._parser.extract_reasoning_content(model_output, request)
def extract_reasoning_content_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
return self._parser.extract_reasoning_content_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
)

View File

@ -0,0 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, DeltaMessage
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser
logger = init_logger(__name__)
class IdentityReasoningParser(ReasoningParser):
"""
Identity reasoning parser.
This parser does not attempt to parse or strip out reasoning tokens.
It treats the entire model output as content and ignores reasoning.
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
if not self.model_tokenizer:
raise ValueError(
"The model tokenizer must be passed to the ReasoningParser "
"constructor during construction."
)
def is_reasoning_end(self, input_ids: list[int]) -> bool:
# Always return True, since we never treat reasoning specially
return True
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
# Identity: return all tokens as content
return input_ids
def extract_reasoning_content_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
# Just wrap delta_text as content, ignore reasoning
if delta_text:
return DeltaMessage(content=delta_text)
return None
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[str | None, str | None]:
# No reasoning separation: return None for reasoning_content,
# and full model_output as content
return None, model_output