Add think chunk (#21333)
Signed-off-by: Julien Denize <julien.denize@mistral.ai>
This commit is contained in:
@ -33,7 +33,7 @@ pyzmq >= 25.0.0
|
||||
msgspec
|
||||
gguf >= 0.13.0
|
||||
importlib_metadata; python_version < '3.10'
|
||||
mistral_common[opencv] >= 1.8.0
|
||||
mistral_common[image,audio] >= 1.8.2
|
||||
opencv-python-headless >= 4.11.0 # required for video IO
|
||||
pyyaml
|
||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||
|
@ -23,7 +23,7 @@ jiwer # required for audio tests
|
||||
timm # required for internvl test
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[opencv] >= 1.8.0 # required for voxtral test
|
||||
mistral_common[image,audio] >= 1.8.2 # required for voxtral test
|
||||
num2words # required for smolvlm test
|
||||
opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
|
@ -28,7 +28,7 @@ torchvision==0.22.1
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
mamba_ssm # required for plamo2 test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[opencv] >= 1.8.0 # required for voxtral test
|
||||
mistral_common[image,audio] >= 1.8.2 # required for voxtral test
|
||||
num2words # required for smolvlm test
|
||||
open_clip_torch==2.32.0 # Required for nemotron_vl test
|
||||
opencv-python-headless >= 4.11.0 # required for video test
|
||||
|
@ -447,7 +447,7 @@ mbstrdecoder==1.1.3
|
||||
# typepy
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistral-common==1.8.0
|
||||
mistral-common==1.8.2
|
||||
# via -r requirements/test.in
|
||||
mlflow==2.22.0
|
||||
# via terratorch
|
||||
@ -999,8 +999,11 @@ soundfile==0.12.1
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# librosa
|
||||
# mistral-common
|
||||
soxr==0.5.0.post1
|
||||
# via librosa
|
||||
# via
|
||||
# librosa
|
||||
# mistral-common
|
||||
sqlalchemy==2.0.41
|
||||
# via
|
||||
# alembic
|
||||
|
@ -6,6 +6,10 @@ from collections.abc import Mapping
|
||||
from typing import Literal, Optional
|
||||
|
||||
import pytest
|
||||
from mistral_common.tokens.tokenizers.base import (SpecialTokenPolicy,
|
||||
SpecialTokens)
|
||||
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenInfo,
|
||||
Tekkenizer)
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.assets.image import ImageAsset
|
||||
@ -21,6 +25,7 @@ from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.multimodal.utils import (encode_audio_base64, encode_image_base64,
|
||||
encode_video_base64)
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
from ..models.registry import HF_EXAMPLE_MODELS
|
||||
from ..utils import VLLM_PATH
|
||||
@ -1374,3 +1379,165 @@ def test_resolve_content_format_examples(template_path, expected_format):
|
||||
)
|
||||
|
||||
assert resolved_format == expected_format
|
||||
|
||||
|
||||
def test_parse_chat_messages_include_thinking_chunk(mistral_model_config,
|
||||
mistral_tokenizer):
|
||||
messages = [{
|
||||
"role":
|
||||
"system",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "You are a helpful assistant."
|
||||
}, {
|
||||
"type":
|
||||
"thinking",
|
||||
"closed":
|
||||
True,
|
||||
"thinking":
|
||||
"Only return the answer when you are confident."
|
||||
}]
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "What is 2+2?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "Let me think about it."
|
||||
}, {
|
||||
"type": "thinking",
|
||||
"closed": True,
|
||||
"thinking": "2+2 = 4"
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "The answer is 4.",
|
||||
}],
|
||||
}]
|
||||
|
||||
conversation_with_thinking, _ = parse_chat_messages(
|
||||
messages,
|
||||
mistral_model_config,
|
||||
mistral_tokenizer,
|
||||
content_format="openai",
|
||||
)
|
||||
|
||||
expected_conversation = [{
|
||||
"role":
|
||||
"system",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "You are a helpful assistant."
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "Only return the answer when you are confident."
|
||||
}],
|
||||
}, {
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "What is 2+2?"
|
||||
}],
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Let me think about it."
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "2+2 = 4"
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "The answer is 4."
|
||||
},
|
||||
]
|
||||
}]
|
||||
|
||||
assert conversation_with_thinking == expected_conversation
|
||||
|
||||
|
||||
def test_apply_mistral_chat_template_thinking_chunk():
|
||||
# Moved import here to avoid yapf and isort conflicts
|
||||
from vllm.entrypoints.chat_utils import apply_mistral_chat_template
|
||||
messages = [{
|
||||
"role":
|
||||
"system",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "You are a helpful assistant."
|
||||
}, {
|
||||
"type":
|
||||
"thinking",
|
||||
"closed":
|
||||
True,
|
||||
"thinking":
|
||||
"Only return the answer when you are confident."
|
||||
}]
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "What is 2+2?"
|
||||
}, {
|
||||
"role":
|
||||
"assistant",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": "Let me think about it."
|
||||
}, {
|
||||
"type": "thinking",
|
||||
"closed": True,
|
||||
"thinking": "2+2 = 4"
|
||||
}, {
|
||||
"type": "text",
|
||||
"text": "The answer is 4.",
|
||||
}],
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Thanks, what is 3+3?"
|
||||
}]
|
||||
|
||||
# TODO(Julien): upon model release change to a tokenizer already configured.
|
||||
# =================================================================
|
||||
mistral_tokenizer = MistralTokenizer.from_pretrained(
|
||||
"mistralai/Devstral-Small-2507")
|
||||
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
|
||||
# Add think special tokens to the tokenizer
|
||||
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
|
||||
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value)
|
||||
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
|
||||
rank=36, is_control=True, token_str=SpecialTokens.end_think.value)
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
|
||||
k: v
|
||||
for k, v in
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
|
||||
if v not in {35, 36}
|
||||
}
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
|
||||
SpecialTokens.begin_think.value] = 35
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
|
||||
SpecialTokens.end_think.value] = 36
|
||||
mistral_tokenizer.instruct.BEGIN_THINK = 35
|
||||
mistral_tokenizer.instruct.END_THINK = 36
|
||||
# =================================================================
|
||||
|
||||
tokens_ids = apply_mistral_chat_template(mistral_tokenizer,
|
||||
messages,
|
||||
chat_template=None,
|
||||
tools=None)
|
||||
|
||||
string_tokens = mistral_tokenizer.mistral.decode(
|
||||
tokens_ids, special_token_policy=SpecialTokenPolicy.KEEP)
|
||||
|
||||
expected_tokens = (
|
||||
r"<s>[SYSTEM_PROMPT]You are a helpful assistant.[THINK]Only return the"
|
||||
r" answer when you are confident.[/THINK][/SYSTEM_PROMPT]"
|
||||
r"[INST]What is 2+2?[/INST]"
|
||||
r"Let me think about it.[THINK]2+2 = 4[/THINK]The answer is 4.</s>"
|
||||
r"[INST]Thanks, what is 3+3?[/INST]")
|
||||
|
||||
assert string_tokens == expected_tokens
|
||||
|
341
tests/reasoning/test_mistral_reasoning_parser.py
Normal file
341
tests/reasoning/test_mistral_reasoning_parser.py
Normal file
@ -0,0 +1,341 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
from mistral_common.tokens.tokenizers.tekken import (SpecialTokenInfo,
|
||||
Tekkenizer)
|
||||
|
||||
from tests.reasoning.utils import run_reasoning_extraction_mistral
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
parser_name = "mistral"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mistral_tokenizer():
|
||||
# TODO(Julien): upon model release change to a tokenizer already configured.
|
||||
# =================================================================
|
||||
mistral_tokenizer = MistralTokenizer.from_pretrained(
|
||||
"mistralai/Devstral-Small-2507")
|
||||
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
|
||||
# Add think special tokens to the tokenizer
|
||||
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
|
||||
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value)
|
||||
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
|
||||
rank=36, is_control=True, token_str=SpecialTokens.end_think.value)
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
|
||||
k: v
|
||||
for k, v in
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
|
||||
if v not in {35, 36}
|
||||
}
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
|
||||
SpecialTokens.begin_think.value] = 35
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
|
||||
SpecialTokens.end_think.value] = 36
|
||||
mistral_tokenizer.instruct.BEGIN_THINK = 35
|
||||
mistral_tokenizer.instruct.END_THINK = 36
|
||||
# =================================================================
|
||||
return mistral_tokenizer
|
||||
|
||||
|
||||
SIMPLE_REASONING = {
|
||||
"output": "This is a reasoning section[/THINK]This is the rest",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
COMPLETE_REASONING = {
|
||||
"output": "This is a reasoning section[/THINK]",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
NO_CONTENT = {
|
||||
"output": "This is content",
|
||||
"reasoning_content": "This is content",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NO_REASONING_STREAMING = {
|
||||
"output": "This is a reasoning section",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
MULTIPLE_LINES = {
|
||||
"output": "This\nThat[/THINK]This is the rest\nThat",
|
||||
"reasoning_content": "This\nThat",
|
||||
"content": "This is the rest\nThat",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
SHORTEST_REASONING_NO_STREAMING = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning_content": "",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
SHORTEST_REASONING = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning_content": None,
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
REASONING_WITH_THINK = {
|
||||
"output": "[THINK]This is a reasoning section[/THINK]This is the rest",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
COMPLETE_REASONING_WITH_THINK = {
|
||||
"output": "[THINK]This is a reasoning section[/THINK]",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
MULTIPLE_LINES_WITH_THINK = {
|
||||
"output": "[THINK]This\nThat[/THINK]This is the rest\nThat",
|
||||
"reasoning_content": "This\nThat",
|
||||
"content": "This is the rest\nThat",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
SHORTEST_REASONING_NO_STREAMING_WITH_THINK = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning_content": "",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
SHORTEST_REASONING_WITH_THINK = {
|
||||
"output": "[/THINK]This is the rest",
|
||||
"reasoning_content": None,
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
THINK_NO_END = {
|
||||
"output": "[THINK]This is a reasoning section",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
EMPTY = {
|
||||
"output": "",
|
||||
"reasoning_content": "",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
EMPTY_STREAMING = {
|
||||
"output": "",
|
||||
"reasoning_content": None,
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NEW_LINE = {
|
||||
"output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": "\nThis is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
# Streaming cannot handle new lines at the beginning of the output
|
||||
# because we need to support [THINK]...[/THINK] and [/THINK]...
|
||||
# We cannot know if the text before [THINK] is reasoning content
|
||||
# or not.
|
||||
NEW_LINE_STREAMING = {
|
||||
"output": "\n[THINK]This is a reasoning section[/THINK]\nThis is the rest",
|
||||
"reasoning_content": "\nThis is a reasoning section",
|
||||
"content": "\nThis is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
False,
|
||||
SIMPLE_REASONING,
|
||||
id="simple_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SIMPLE_REASONING,
|
||||
id="simple_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
NO_CONTENT,
|
||||
id="no_content_token",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
NO_REASONING_STREAMING,
|
||||
id="no_reasoning_token_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MULTIPLE_LINES,
|
||||
id="multiple_lines",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MULTIPLE_LINES,
|
||||
id="multiple_lines_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SHORTEST_REASONING,
|
||||
id="shortest",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
SHORTEST_REASONING_NO_STREAMING,
|
||||
id="shortest_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
REASONING_WITH_THINK,
|
||||
id="reasoning_with_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
REASONING_WITH_THINK,
|
||||
id="reasoning_with_think_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLETE_REASONING_WITH_THINK,
|
||||
id="complete_reasoning_with_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLETE_REASONING_WITH_THINK,
|
||||
id="complete_reasoning_with_think_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MULTIPLE_LINES_WITH_THINK,
|
||||
id="multiple_lines_with_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MULTIPLE_LINES_WITH_THINK,
|
||||
id="multiple_lines_with_think_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
SHORTEST_REASONING_NO_STREAMING_WITH_THINK,
|
||||
id="shortest_with_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SHORTEST_REASONING_WITH_THINK,
|
||||
id="shortest_with_think_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
THINK_NO_END,
|
||||
id="think_no_end",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
THINK_NO_END,
|
||||
id="think_no_end_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
EMPTY,
|
||||
id="empty",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
EMPTY_STREAMING,
|
||||
id="empty_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
NEW_LINE,
|
||||
id="new_line",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
NEW_LINE_STREAMING,
|
||||
id="new_line_streaming",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
|
||||
def test_mistral_reasoning(
|
||||
streaming: bool,
|
||||
param_dict: dict,
|
||||
mistral_tokenizer: MistralTokenizer,
|
||||
):
|
||||
output = param_dict["output"]
|
||||
|
||||
index_think = output.find("[THINK]")
|
||||
len_think = len("[THINK]")
|
||||
index_end_think = output.find("[/THINK]")
|
||||
len_end_think = len("[/THINK]")
|
||||
|
||||
# encode everything to tokens ids
|
||||
output_tokens = []
|
||||
if index_think != -1:
|
||||
output_before_think = output[:index_think]
|
||||
output_tokens += mistral_tokenizer.tokenizer.encode(
|
||||
output_before_think, False, False)
|
||||
output_tokens += [mistral_tokenizer.instruct.BEGIN_THINK]
|
||||
|
||||
if index_end_think != -1:
|
||||
output_middle = output[index_think + len_think:index_end_think]
|
||||
output_after_think = output[index_end_think + len_end_think:]
|
||||
output_tokens += mistral_tokenizer.tokenizer.encode(
|
||||
output_middle, False, False)
|
||||
output_tokens += [mistral_tokenizer.instruct.END_THINK]
|
||||
output_tokens += mistral_tokenizer.tokenizer.encode(
|
||||
output_after_think, False, False)
|
||||
else:
|
||||
output_middle = output[index_think + len_think:]
|
||||
output_tokens += mistral_tokenizer.tokenizer.encode(
|
||||
output_middle, False, False)
|
||||
elif index_end_think != -1:
|
||||
output_before_think = output[:index_end_think]
|
||||
output_after_think = output[index_end_think + len_end_think:]
|
||||
output_tokens += mistral_tokenizer.tokenizer.encode(
|
||||
output_before_think, False, False)
|
||||
output_tokens += [mistral_tokenizer.instruct.END_THINK]
|
||||
output_tokens += mistral_tokenizer.tokenizer.encode(
|
||||
output_after_think, False, False)
|
||||
else:
|
||||
output_tokens += mistral_tokenizer.tokenizer.encode(
|
||||
output, False, False)
|
||||
|
||||
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
|
||||
parser_name)(mistral_tokenizer)
|
||||
|
||||
reasoning, content = run_reasoning_extraction_mistral(parser,
|
||||
output_tokens,
|
||||
streaming=streaming)
|
||||
|
||||
assert reasoning == param_dict["reasoning_content"]
|
||||
assert content == param_dict["content"]
|
||||
|
||||
# Test is_reasoning_end
|
||||
is_reasoning_end = parser.is_reasoning_end(output_tokens)
|
||||
assert is_reasoning_end == param_dict["is_reasoning_end"]
|
||||
|
||||
# Test extract_content
|
||||
if param_dict["content"] is not None:
|
||||
content = parser.extract_content_ids(output_tokens)
|
||||
assert content == mistral_tokenizer.tokenizer.encode(
|
||||
param_dict["content"], bos=False, eos=False)
|
||||
else:
|
||||
content = parser.extract_content_ids(output_tokens)
|
||||
assert content == []
|
@ -6,6 +6,7 @@ from typing import Optional, Union
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage)
|
||||
from vllm.reasoning import ReasoningParser
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
|
||||
class StreamingReasoningReconstructor:
|
||||
@ -54,6 +55,32 @@ def run_reasoning_extraction(
|
||||
return reasoning, content
|
||||
|
||||
|
||||
def run_reasoning_extraction_mistral(
|
||||
reasoning_parser: ReasoningParser,
|
||||
model_output: list[int],
|
||||
request: Union[ChatCompletionRequest, None] = None,
|
||||
streaming: bool = False,
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
assert isinstance(reasoning_parser.model_tokenizer,
|
||||
MistralTokenizer), type(reasoning_parser.model_tokenizer)
|
||||
if streaming:
|
||||
reconstructor = run_reasoning_extraction_streaming_mistral(
|
||||
reasoning_parser,
|
||||
model_output,
|
||||
request,
|
||||
)
|
||||
return (
|
||||
reconstructor.reasoning_content,
|
||||
reconstructor.other_content or None,
|
||||
)
|
||||
else:
|
||||
str_output = reasoning_parser.model_tokenizer.convert_ids_to_tokens(
|
||||
model_output)
|
||||
reasoning, content = run_reasoning_extraction_nonstreaming(
|
||||
reasoning_parser, str_output, request)
|
||||
return reasoning, content
|
||||
|
||||
|
||||
def run_reasoning_extraction_nonstreaming(
|
||||
reasoning_parser: ReasoningParser,
|
||||
model_output: list[str],
|
||||
@ -94,3 +121,35 @@ def run_reasoning_extraction_streaming(
|
||||
previous_text = current_text
|
||||
previous_tokens = current_tokens
|
||||
return reconstructor
|
||||
|
||||
|
||||
def run_reasoning_extraction_streaming_mistral(
|
||||
reasoning_parser: ReasoningParser,
|
||||
model_deltas: list[int],
|
||||
request: Union[ChatCompletionRequest, None] = None,
|
||||
) -> StreamingReasoningReconstructor:
|
||||
assert isinstance(reasoning_parser.model_tokenizer,
|
||||
MistralTokenizer), type(reasoning_parser.model_tokenizer)
|
||||
request = request or ChatCompletionRequest(messages=[], model="test-model")
|
||||
reconstructor = StreamingReasoningReconstructor()
|
||||
previous_text = ""
|
||||
previous_tokens: list[int] = []
|
||||
for model_delta in model_deltas:
|
||||
token_delta = [model_delta]
|
||||
delta = reasoning_parser.model_tokenizer.convert_ids_to_tokens(
|
||||
[model_delta])[0]
|
||||
current_text = previous_text + delta
|
||||
current_tokens = previous_tokens + token_delta
|
||||
delta_message = reasoning_parser.extract_reasoning_content_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta,
|
||||
previous_tokens,
|
||||
current_tokens,
|
||||
token_delta,
|
||||
)
|
||||
if delta_message is not None:
|
||||
reconstructor.append_delta(delta_message)
|
||||
previous_text = current_text
|
||||
previous_tokens = current_tokens
|
||||
return reconstructor
|
||||
|
@ -151,6 +151,27 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
|
||||
video_url: Required[str]
|
||||
|
||||
|
||||
class CustomThinkCompletionContentParam(TypedDict, total=False):
|
||||
"""A Think Completion Content Param that accepts a plain text and a boolean.
|
||||
|
||||
Example:
|
||||
{
|
||||
"thinking": "I am thinking about the answer",
|
||||
"closed": True,
|
||||
"type": "thinking"
|
||||
}
|
||||
"""
|
||||
|
||||
thinking: Required[str]
|
||||
"""The thinking content."""
|
||||
|
||||
closed: bool
|
||||
"""Whether the thinking is closed."""
|
||||
|
||||
type: Required[Literal["thinking"]]
|
||||
"""The thinking type."""
|
||||
|
||||
|
||||
ChatCompletionContentPartParam: TypeAlias = Union[
|
||||
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
|
||||
ChatCompletionContentPartInputAudioParam,
|
||||
@ -159,7 +180,8 @@ ChatCompletionContentPartParam: TypeAlias = Union[
|
||||
CustomChatCompletionContentSimpleImageParam,
|
||||
ChatCompletionContentPartImageEmbedsParam,
|
||||
CustomChatCompletionContentSimpleAudioParam,
|
||||
CustomChatCompletionContentSimpleVideoParam, str]
|
||||
CustomChatCompletionContentSimpleVideoParam, str,
|
||||
CustomThinkCompletionContentParam]
|
||||
|
||||
|
||||
class CustomChatCompletionMessageParam(TypedDict, total=False):
|
||||
@ -938,6 +960,7 @@ _ImageEmbedsParser = partial(cast, ChatCompletionContentPartImageEmbedsParam)
|
||||
_InputAudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
|
||||
_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam)
|
||||
_PILImageParser = partial(cast, CustomChatCompletionContentPILImageParam)
|
||||
_ThinkParser = partial(cast, CustomThinkCompletionContentParam)
|
||||
# Need to validate url objects
|
||||
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
|
||||
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
|
||||
@ -954,6 +977,8 @@ MM_PARSER_MAP: dict[
|
||||
] = {
|
||||
"text":
|
||||
lambda part: _TextParser(part).get("text", None),
|
||||
"thinking":
|
||||
lambda part: _ThinkParser(part).get("thinking", None),
|
||||
"input_text":
|
||||
lambda part: _TextParser(part).get("text", None),
|
||||
"input_image":
|
||||
@ -1100,7 +1125,7 @@ def _parse_chat_message_content_part(
|
||||
"with empty / unparsable content.", part, part_type)
|
||||
return None
|
||||
|
||||
if part_type in ("text", "input_text", "refusal"):
|
||||
if part_type in ("text", "input_text", "refusal", "thinking"):
|
||||
str_content = cast(str, content)
|
||||
if wrap_dicts:
|
||||
return {'type': 'text', 'text': str_content}
|
||||
|
@ -6,6 +6,7 @@ from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
||||
from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser
|
||||
from .granite_reasoning_parser import GraniteReasoningParser
|
||||
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
|
||||
from .mistral_reasoning_parser import MistralReasoningParser
|
||||
from .qwen3_reasoning_parser import Qwen3ReasoningParser
|
||||
|
||||
__all__ = [
|
||||
@ -16,4 +17,5 @@ __all__ = [
|
||||
"HunyuanA13BReasoningParser",
|
||||
"Qwen3ReasoningParser",
|
||||
"Glm4MoeModelReasoningParser",
|
||||
"MistralReasoningParser",
|
||||
]
|
||||
|
47
vllm/reasoning/mistral_reasoning_parser.py
Normal file
47
vllm/reasoning/mistral_reasoning_parser.py
Normal file
@ -0,0 +1,47 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
from vllm.reasoning.deepseek_r1_reasoning_parser import (
|
||||
DeepSeekR1ReasoningParser)
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ReasoningParserManager.register_module("mistral")
|
||||
class MistralReasoningParser(DeepSeekR1ReasoningParser):
|
||||
"""
|
||||
Reasoning parser for Mistral models.
|
||||
|
||||
The Mistral models uses [THINK]...[/THINK] tokens to denote reasoning
|
||||
text. This parser extracts the reasoning content from the model output.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: MistralTokenizer):
|
||||
if not isinstance(tokenizer, MistralTokenizer):
|
||||
raise ValueError(
|
||||
"The tokenizer must be an instance of MistralTokenizer.")
|
||||
|
||||
ReasoningParser.__init__(self, tokenizer)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ReasoningParser "
|
||||
"constructor during construction.")
|
||||
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
|
||||
self.start_token = SpecialTokens.begin_think
|
||||
self.end_token = SpecialTokens.end_think
|
||||
|
||||
self.start_token_id = tokenizer.tokenizer.get_control_token(
|
||||
self.start_token)
|
||||
self.end_token_id = tokenizer.tokenizer.get_control_token(
|
||||
self.end_token)
|
||||
|
||||
if self.start_token_id is None or self.end_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Mistral reasoning parser could not locate think start/end "
|
||||
"tokens in the tokenizer!")
|
@ -145,6 +145,21 @@ def find_tokenizer_file(files: list[str]):
|
||||
return matched_files[0]
|
||||
|
||||
|
||||
def _aggregate_content(content: list) -> list[dict[str, Any]]:
|
||||
aggregated_content: list[dict[str, Any]] = []
|
||||
for chunk in content:
|
||||
if chunk.get("type"
|
||||
) == "text" and aggregated_content and aggregated_content[
|
||||
-1].get("type") == "text":
|
||||
aggregated_content[-1]["text"] += "\n\n" + chunk.get("text")
|
||||
else:
|
||||
aggregated_content.append(chunk)
|
||||
if len(aggregated_content) == 1 and aggregated_content[0].get(
|
||||
"type") == "text":
|
||||
content = aggregated_content[0]["text"]
|
||||
return content
|
||||
|
||||
|
||||
def make_mistral_chat_completion_request(
|
||||
messages: list["ChatCompletionMessageParam"],
|
||||
tools: Optional[list[dict[str,
|
||||
@ -162,10 +177,10 @@ def make_mistral_chat_completion_request(
|
||||
|
||||
# Convert list text content to string
|
||||
if message.get("role") in ("assistant", "tool"):
|
||||
content = message.get("content")
|
||||
content: Any = message.get("content")
|
||||
if isinstance(content, list):
|
||||
content = "\n".join(chunk.get("text") for chunk in content)
|
||||
message["content"] = content
|
||||
content = _aggregate_content(content)
|
||||
message["content"] = content
|
||||
|
||||
# The Mistral client, in comparison to the OpenAI client, requires the
|
||||
# "parameters" dict to be present, even if it's empty.
|
||||
@ -465,6 +480,8 @@ class MistralTokenizer(TokenizerBase):
|
||||
skip_special_tokens: bool = True,
|
||||
) -> list[str]:
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
from mistral_common.tokens.tokenizers.instruct import (
|
||||
InstructTokenizerV13)
|
||||
|
||||
# TODO(Patrick) - potentially allow special tokens to not be skipped
|
||||
assert (
|
||||
@ -474,10 +491,18 @@ class MistralTokenizer(TokenizerBase):
|
||||
assert self.is_tekken or self.is_spm, type(self.tokenizer)
|
||||
|
||||
if self.is_tekken:
|
||||
# skip special tokens except tool call
|
||||
ids = [
|
||||
i for i in ids if i > self.tokenizer.num_special_tokens or i ==
|
||||
# skip special tokens except tool call and think tokens
|
||||
non_skip_special_tokens = {
|
||||
self.tokenizer.get_control_token(SpecialTokens.tool_calls)
|
||||
}
|
||||
if isinstance(self.instruct, InstructTokenizerV13):
|
||||
if self.instruct.BEGIN_THINK:
|
||||
non_skip_special_tokens.add(self.instruct.BEGIN_THINK)
|
||||
if self.instruct.END_THINK:
|
||||
non_skip_special_tokens.add(self.instruct.END_THINK)
|
||||
ids = [
|
||||
i for i in ids if i > self.tokenizer.num_special_tokens
|
||||
or i in non_skip_special_tokens
|
||||
]
|
||||
|
||||
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
|
||||
|
Reference in New Issue
Block a user