Refactor MistralTokenizer (#26358)

Signed-off-by: Julien Denize <julien.denize@mistral.ai>
This commit is contained in:
Julien Denize
2025-10-10 00:48:58 +02:00
committed by GitHub
parent 8983e0216f
commit c6187f55f7
18 changed files with 2351 additions and 463 deletions

View File

@ -145,7 +145,7 @@ Supported models:
Known issues:
1. Mistral 7B struggles to generate parallel tool calls correctly.
2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is
2. **For Transformers tokenization backend only**: Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is
much shorter than what vLLM generates. Since an exception is thrown when this condition
is not met, the following additional chat templates are provided:
@ -154,7 +154,14 @@ Known issues:
* <gh-file:examples/tool_chat_template_mistral_parallel.jinja> - this is a "better" version that adds a tool-use system prompt
when tools are provided, that results in much better reliability when working with parallel tool calling.
Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
Recommended flags:
1. To use [mistral-common](https://github.com/mistralai/mistral-common) the official Mistral tokenization backend:
`--tokenizer_mode mistral --config_format mistral --load_format mistral --tool-call-parser mistral`
2. To use the default Transformers tokenization backend:
`--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
### Llama Models (`llama3_json`)

View File

@ -45,10 +45,12 @@ class ModelRequestData(NamedTuple):
# Voxtral
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
from mistral_common.audio import Audio
from mistral_common.protocol.instruct.messages import (
from mistral_common.protocol.instruct.chunk import (
AudioChunk,
RawAudio,
TextChunk,
)
from mistral_common.protocol.instruct.messages import (
UserMessage,
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest

View File

@ -32,7 +32,7 @@ pyzmq >= 25.0.0
msgspec
gguf >= 0.13.0
importlib_metadata; python_version < '3.10'
mistral_common[image,audio] >= 1.8.2
mistral_common[image,audio] >= 1.8.5
opencv-python-headless >= 4.11.0 # required for video IO
pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12

View File

@ -23,7 +23,7 @@ jiwer # required for audio tests
timm # required for internvl test
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.8.2 # required for voxtral test
mistral_common[image,audio] >= 1.8.5 # required for voxtral test
num2words # required for smolvlm test
opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test

View File

@ -29,7 +29,7 @@ torchaudio==2.8.0
torchvision==0.23.0
transformers_stream_generator # required for qwen-vl test
matplotlib # required for qwen-vl test
mistral_common[image,audio] >= 1.8.2 # required for voxtral test
mistral_common[image,audio] >= 1.8.5 # required for voxtral test
num2words # required for smolvlm test
open_clip_torch==2.32.0 # Required for nemotron_vl test
opencv-python-headless >= 4.11.0 # required for video test

View File

@ -474,7 +474,7 @@ mbstrdecoder==1.1.3
# typepy
mdurl==0.1.2
# via markdown-it-py
mistral-common==1.8.2
mistral-common==1.8.5
# via -r requirements/test.in
mlflow==2.22.0
# via terratorch
@ -1012,8 +1012,6 @@ sentence-transformers==3.2.1
# via
# -r requirements/test.in
# mteb
sentencepiece==0.2.0
# via mistral-common
setuptools==77.0.3
# via
# lightning-utilities

View File

@ -6,8 +6,7 @@ from collections.abc import Mapping
from typing import Literal, Optional
import pytest
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, SpecialTokens
from mistral_common.tokens.tokenizers.tekken import SpecialTokenInfo, Tekkenizer
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
@ -2119,34 +2118,9 @@ def test_apply_mistral_chat_template_thinking_chunk():
},
{"role": "user", "content": "Thanks, what is 3+3?"},
]
# TODO(Julien): upon model release change to a tokenizer already configured.
# =================================================================
mistral_tokenizer = MistralTokenizer.from_pretrained(
"mistralai/Devstral-Small-2507"
"mistralai/Magistral-Small-2509"
)
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
# Add think special tokens to the tokenizer
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value
)
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
rank=36, is_control=True, token_str=SpecialTokens.end_think.value
)
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
k: v
for k, v in mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
if v not in {35, 36}
}
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.begin_think.value
] = 35
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.end_think.value
] = 36
mistral_tokenizer.instruct.BEGIN_THINK = 35
mistral_tokenizer.instruct.END_THINK = 36
# =================================================================
tokens_ids = apply_mistral_chat_template(
mistral_tokenizer, messages, chat_template=None, tools=None

View File

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Optional
import pytest
from mistral_common.multimodal import download_image
from mistral_common.protocol.instruct.messages import ImageURLChunk
from mistral_common.protocol.instruct.chunk import ImageURLChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk

View File

@ -6,12 +6,8 @@ import json
import pytest
import pytest_asyncio
from mistral_common.audio import Audio
from mistral_common.protocol.instruct.messages import (
AudioChunk,
RawAudio,
TextChunk,
UserMessage,
)
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from vllm.transformers_utils.tokenizer import MistralTokenizer

View File

@ -6,7 +6,8 @@ from typing import Optional, Union
import numpy as np
import pytest
from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage
from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image

View File

@ -9,7 +9,8 @@ from typing import Any, Union
import numpy as np
import pytest
import torch.nn as nn
from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage
from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image

View File

@ -2,8 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from mistral_common.tokens.tokenizers.base import SpecialTokens
from mistral_common.tokens.tokenizers.tekken import SpecialTokenInfo, Tekkenizer
from tests.reasoning.utils import run_reasoning_extraction_mistral
from vllm.reasoning import ReasoningParser, ReasoningParserManager
@ -14,33 +12,9 @@ parser_name = "mistral"
@pytest.fixture(scope="module")
def mistral_tokenizer():
# TODO(Julien): upon model release change to a tokenizer already configured.
# =================================================================
mistral_tokenizer = MistralTokenizer.from_pretrained(
"mistralai/Devstral-Small-2507"
"mistralai/Magistral-Small-2509"
)
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
# Add think special tokens to the tokenizer
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value
)
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
rank=36, is_control=True, token_str=SpecialTokens.end_think.value
)
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
k: v
for k, v in mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
if v not in {35, 36}
}
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.begin_think.value
] = 35
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
SpecialTokens.end_think.value
] = 36
mistral_tokenizer.instruct.BEGIN_THINK = 35
mistral_tokenizer.instruct.END_THINK = 36
# =================================================================
return mistral_tokenizer

File diff suppressed because it is too large Load Diff

View File

@ -403,20 +403,12 @@ def resolve_mistral_chat_template(
chat_template: Optional[str],
**kwargs: Any,
) -> Optional[str]:
if chat_template is not None:
logger.warning_once(
"'chat_template' cannot be overridden for mistral tokenizer."
)
if "add_generation_prompt" in kwargs:
logger.warning_once(
"'add_generation_prompt' is not supported for mistral tokenizer, "
"so it will be ignored."
)
if "continue_final_message" in kwargs:
logger.warning_once(
"'continue_final_message' is not supported for mistral tokenizer, "
"so it will be ignored."
if chat_template is not None or kwargs.get("chat_template_kwargs") is not None:
raise ValueError(
"'chat_template' or 'chat_template_kwargs' cannot be overridden "
"for mistral tokenizer."
)
return None

View File

@ -10,7 +10,8 @@ from typing import Annotated, Literal, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage
from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
from PIL import Image

View File

@ -12,12 +12,8 @@ import regex as re
import torch
import torch.nn as nn
from mistral_common.audio import mel_filter_bank
from mistral_common.protocol.instruct.messages import (
AudioChunk,
RawAudio,
TextChunk,
UserMessage,
)
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.transcription.request import TranscriptionRequest
from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder

View File

@ -1,34 +1,27 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union, cast
import huggingface_hub
import regex as re
from huggingface_hub import HfApi, hf_hub_download
from transformers.tokenization_utils_base import BatchEncoding
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_base import TokenizerBase
from vllm.utils import is_list_of
if TYPE_CHECKING:
# make sure `mistral_common` is lazy imported,
# so that users who only use non-mistral models
# will not be bothered by the dependency.
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer as PublicMistralTokenizer,
from mistral_common.protocol.instruct.request import (
ChatCompletionRequest as MistralChatCompletionRequest,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from transformers.tokenization_mistral_common import (
MistralCommonTokenizer as TransformersMistralTokenizer,
)
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
logger = init_logger(__name__)
def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
def maybe_serialize_tool_calls(request: "MistralChatCompletionRequest"):
# SEE: https://github.com/vllm-project/vllm/pull/9951
# Credits go to: @gcalmettes
# NOTE: There is currently a bug in pydantic where attributes
@ -65,7 +58,7 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
request.messages[i]["tool_calls"] = validated_tool_calls
def truncate_tool_call_ids(request: "ChatCompletionRequest"):
def truncate_tool_call_ids(request: "MistralChatCompletionRequest"):
"""Truncates tool call IDs for Mistral's ID requirements."""
for i, message in enumerate(request.messages):
if message.get("role") == "assistant":
@ -95,84 +88,34 @@ def truncate_tool_call_ids(request: "ChatCompletionRequest"):
request.messages[i]["tool_call_id"] = tool_call_id
def validate_request_params(request: "ChatCompletionRequest"):
if request.skip_special_tokens is not None and not request.skip_special_tokens:
raise ValueError(
"skip_special_tokens=False is not supported for Mistral tokenizers."
)
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> list[str]:
repo_cache = os.path.join(
huggingface_hub.constants.HF_HUB_CACHE,
huggingface_hub.constants.REPO_ID_SEPARATOR.join(
["models", *repo_id.split("/")]
),
)
if revision is None:
revision_file = os.path.join(repo_cache, "refs", "main")
if os.path.isfile(revision_file):
with open(revision_file) as file:
revision = file.read()
if revision:
revision_dir = os.path.join(repo_cache, "snapshots", revision)
if os.path.isdir(revision_dir):
return os.listdir(revision_dir)
return []
def find_tokenizer_file(files: list[str]):
# Accept both versioned (tokenizer.model.v3) and unversioned
# (tokenizer.model) forms, plus tekken.json and tokenizer.mm.model
# variants. Previous pattern only matched the versioned variants.
file_pattern = re.compile(
r"^tokenizer\.model(\.v.*)?|tekken\.json|tokenizer\.mm\.model(\.v.*)?$"
)
matched_files = [file for file in files if file_pattern.match(file)]
if len(matched_files) > 1:
logger.warning(
"Multiple files matched pattern `%s`: %s. Using %s.",
file_pattern.pattern,
matched_files,
matched_files[0],
)
elif len(matched_files) == 0:
raise OSError(
f"Found {len(matched_files)} files matching the "
f"pattern: `{file_pattern.pattern}`. Make sure that a Mistral "
f"tokenizer is present in {files}."
)
return matched_files[0]
def _aggregate_content(content: list) -> list[dict[str, Any]]:
aggregated_content: list[dict[str, Any]] = []
for chunk in content:
if (
chunk.get("type") == "text"
and aggregated_content
and aggregated_content[-1].get("type") == "text"
):
aggregated_content[-1]["text"] += "\n\n" + chunk.get("text")
else:
aggregated_content.append(chunk)
if len(aggregated_content) == 1 and aggregated_content[0].get("type") == "text":
content = aggregated_content[0]["text"]
return content
def make_mistral_chat_completion_request(
def _prepare_apply_chat_template_tools_and_messages(
messages: list["ChatCompletionMessageParam"],
tools: Optional[list[dict[str, Any]]] = None,
) -> "ChatCompletionRequest":
continue_final_message: bool = False,
add_generation_prompt: bool = False,
) -> tuple[list["ChatCompletionMessageParam"], Optional[list[dict[str, Any]]]]:
if add_generation_prompt and continue_final_message:
raise ValueError(
"Cannot set both `add_generation_prompt` and "
"`continue_final_message` to True."
)
last_message = cast(dict[str, Any], messages[-1])
if last_message["role"] == "assistant":
last_message["prefix"] = True
# add_generation_prompt is directly handled by the tokenizer but we
# check if the user is trying to use it with a final assistant message
# which is probably not what they want.
# If add_generation_prompt is False, we don't need to check anything.
if add_generation_prompt and last_message["role"] == "assistant":
raise ValueError(
"Cannot set `add_generation_prompt` to True when "
"the last message is from the assistant. Consider "
"using `continue_final_message` instead."
)
if continue_final_message and last_message["role"] != "assistant":
raise ValueError(
"Cannot set `continue_final_message` to True when "
"the last message is not from the assistant."
)
# mistral-common requires AssistantMessage content to be string [1].
#
@ -181,13 +124,6 @@ def make_mistral_chat_completion_request(
# Remove reasoning_content as unsupported by Mistral
_ = message.pop("reasoning_content", None) # type: ignore
# Convert list text content to string
if message.get("role") in ("assistant", "tool"):
content: Any = message.get("content")
if isinstance(content, list):
content = _aggregate_content(content)
message["content"] = content
# The Mistral client, in comparison to the OpenAI client, requires the
# "parameters" dict and the "description" string to be present
# even if they are empty.
@ -200,108 +136,113 @@ def make_mistral_chat_completion_request(
if function.get("description") is None:
function["description"] = ""
from mistral_common.protocol.instruct.request import ChatCompletionRequest
return messages, tools
return ChatCompletionRequest(messages=messages, tools=tools) # type: ignore[type-var]
def validate_request_params(request: "ChatCompletionRequest"):
if request.chat_template is not None or request.chat_template_kwargs is not None:
raise ValueError("chat_template is not supported for Mistral tokenizers.")
def _tekken_token_to_id(tokenizer: "Tekkenizer", t: Union[str, bytes]) -> int:
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
assert isinstance(tokenizer, Tekkenizer), type(tokenizer)
t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
shift = tokenizer.num_special_tokens
try:
return shift + tokenizer._tekken_token2id_nospecial[t_bytes]
except KeyError:
t_str = t_bytes.decode("utf-8")
if t_str in tokenizer._special_tokens_reverse_vocab:
return tokenizer._special_tokens_reverse_vocab[t_str]
logger.warning(
"Failed to convert token %s to id, replacing with <unk>", t_bytes
)
return tokenizer.unk_id
class MistralTokenizer(TokenizerBase):
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
self.mistral = tokenizer
self.instruct = tokenizer.instruct_tokenizer
_mistral_version_str = self.instruct.tokenizer.version.value
self.version: int = int(_mistral_version_str.split("v")[-1])
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
self.is_tekken = isinstance(tokenizer_, Tekkenizer)
def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None:
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer)
self._special_token_policy = (
SpecialTokenPolicy.IGNORE if self.is_tekken else None
)
self.transformers_tokenizer = tokenizer
self.mistral = tokenizer.tokenizer
self.instruct = self.mistral.instruct_tokenizer
self.tokenizer = self.instruct.tokenizer
_mistral_version_str = str(self.tokenizer.version.value)
self.version: int = int(_mistral_version_str.split("v")[-1])
self.is_tekken = isinstance(self.tokenizer, Tekkenizer)
self.is_spm = isinstance(self.tokenizer, SentencePieceTokenizer)
if not (self.is_tekken or self.is_spm):
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
raise TypeError(f"Unsupported tokenizer: {type(self.tokenizer)}")
self._vocab = tokenizer_.vocab()
# Convert to a dict[str, int] to match protocol, but this is a lossy
# conversion. There may be multiple token ids that decode to the same
# string due to partial UTF-8 byte sequences being converted to <20>
self._vocab_dict = {token: idx for idx, token in enumerate(self._vocab)}
self.tokenizer = tokenizer_
# Reverse order to ensure that the lowest token id is kept.
self._vocab_dict = {
self.convert_ids_to_tokens([i], skip_special_tokens=False)[0]: i
for i in range(self.vocab_size - 1, -1, -1)
}
# Sort the dict for convenience
self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))
# Vocab sorted by token id.
self._vocab = self.tokenizer._vocab
self._max_token_id = self.vocab_size - 1
@classmethod
def from_pretrained(
cls, path_or_repo_id: str, *, revision: Optional[str] = None
) -> "MistralTokenizer":
if not Path(path_or_repo_id).exists():
assert len(path_or_repo_id.split("/")) == 2, (
"You have either provided a non-existent path: "
"{path_or_repo_id} or an invalid HF Hub repo id."
)
tokenizer_file = cls._download_mistral_tokenizer_from_hf(
path_or_repo_id, revision
)
elif Path(path_or_repo_id).is_dir():
tokenizer_file_name = find_tokenizer_file(os.listdir(path_or_repo_id))
tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name)
else:
assert Path(path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
tokenizer_file = str(Path(path_or_repo_id))
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer as PublicMistralTokenizer,
from transformers.tokenization_mistral_common import (
MistralCommonTokenizer as TransformersMistralTokenizer,
)
mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
return cls(mistral_tokenizer)
@staticmethod
def _download_mistral_tokenizer_from_hf(
tokenizer_name: str, revision: Optional[str]
) -> str:
try:
hf_api = HfApi()
files = hf_api.list_repo_files(repo_id=tokenizer_name, revision=revision)
except ConnectionError as exc:
files = list_local_repo_files(repo_id=tokenizer_name, revision=revision)
if len(files) == 0:
raise exc
filename = find_tokenizer_file(files)
tokenizer_file = hf_hub_download(
tokenizer_name, filename=filename, revision=revision
str_revision = "main" if revision is None else revision
return cls(
TransformersMistralTokenizer.from_pretrained(
path_or_repo_id, revision=str_revision
)
)
return tokenizer_file
# the following attributes are set to fit vLLM's design and are used
# by the structured output backends.
@property
def all_special_tokens_extended(self) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokens
# tekken defines its own extended special tokens list
if hasattr(self.tokenizer, "SPECIAL_TOKENS"):
special_tokens = self.tokenizer.SPECIAL_TOKENS
else:
special_tokens = list(SpecialTokens)
return [s.value if isinstance(s, SpecialTokens) else s for s in special_tokens]
return self.all_special_tokens
@property
def all_special_tokens(self) -> list[str]:
return self.all_special_tokens_extended
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
return [
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
for i in self.all_special_ids
]
@property
def all_special_ids(self) -> list[int]:
return [self.all_special_tokens.index(t) for t in self.all_special_tokens]
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
special_ids = {t["rank"] for t in self.tokenizer._all_special_tokens}
elif self.is_spm:
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
self.tokenizer
)
special_ids = self.tokenizer._control_tokens
else:
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
return sorted(special_ids)
@property
def bos_token_id(self) -> int:
@ -317,7 +258,7 @@ class MistralTokenizer(TokenizerBase):
@property
def pad_token(self) -> str:
raise NotImplementedError()
return self.transformers_tokenizer.pad_token
@property
def is_fast(self) -> bool:
@ -325,7 +266,7 @@ class MistralTokenizer(TokenizerBase):
@property
def vocab_size(self) -> int:
return len(self._vocab)
return self.transformers_tokenizer.vocab_size
@property
def max_token_id(self) -> int:
@ -335,6 +276,23 @@ class MistralTokenizer(TokenizerBase):
def truncation_side(self) -> str:
raise NotImplementedError()
def _is_special_token_id(self, token_id: int) -> bool:
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
if self.is_spm:
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
self.tokenizer
)
return token_id in self.tokenizer._control_tokens
if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
return token_id < self.tokenizer.num_special_tokens
else:
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
def __len__(self) -> int:
return self.vocab_size
@ -346,25 +304,19 @@ class MistralTokenizer(TokenizerBase):
truncation: bool = False,
max_length: Optional[int] = None,
):
input_ids: Union[list[int], list[list[int]]]
# For list[str], original prompt text
if is_list_of(text, str):
input_ids_: list[list[int]] = []
for p in text:
each_input_ids = self.encode_one(p, truncation, max_length)
input_ids_.append(each_input_ids)
input_ids = input_ids_
# For list[int], apply chat template output, already tokens.
elif is_list_of(text, int):
input_ids = text
# For str, single prompt text
else:
input_ids = self.encode_one(text, truncation, max_length)
return BatchEncoding({"input_ids": input_ids})
return self.transformers_tokenizer(
text=text,
text_pair=text_pair,
add_special_tokens=add_special_tokens,
truncation=truncation,
max_length=max_length,
)
@property
def vocab(self) -> list[str]:
return self._vocab
def get_vocab(self) -> dict[str, int]:
# NB: the dictionary form of the vocabulary collapses token ids that map
# to the same string but have different bytes
return self._vocab_dict
def get_added_vocab(self) -> dict[str, int]:
@ -378,11 +330,9 @@ class MistralTokenizer(TokenizerBase):
max_length: Optional[int] = None,
) -> list[int]:
# Mistral Tokenizers should not add special tokens
input_ids = self.encode(text)
if truncation:
input_ids = input_ids[:max_length]
return input_ids
return self.transformers_tokenizer.encode(
text, add_special_tokens=False, truncation=truncation, max_length=max_length
)
def encode(
self,
@ -391,15 +341,20 @@ class MistralTokenizer(TokenizerBase):
max_length: Optional[int] = None,
add_special_tokens: Optional[bool] = None,
) -> list[int]:
# `encode` should only be used for prompt completion
# it should never be used for chat_completion.
# For chat completion use `apply_chat_template`
if add_special_tokens is not None:
return self.tokenizer.encode(
text, bos=add_special_tokens, eos=add_special_tokens
return self.transformers_tokenizer.encode(
text,
truncation=truncation,
max_length=max_length,
add_special_tokens=add_special_tokens,
)
else:
return self.tokenizer.encode(text, bos=True, eos=False)
encoded = self.tokenizer.encode(text, bos=True, eos=False)
if truncation is not False and max_length is not None:
return encoded[:max_length]
else:
return encoded
def apply_chat_template(
self,
@ -407,59 +362,79 @@ class MistralTokenizer(TokenizerBase):
tools: Optional[list[dict[str, Any]]] = None,
**kwargs,
) -> list[int]:
request = make_mistral_chat_completion_request(messages, tools)
encoded = self.mistral.encode_chat_completion(request)
add_generation_prompt = kwargs.pop("add_generation_prompt", False)
continue_final_message = kwargs.get("continue_final_message", False)
padding = kwargs.get("padding", False)
truncation = kwargs.get("truncation", False)
max_length = kwargs.get("max_length")
# encode-decode to get clean prompt
return encoded.tokens
messages, tools = _prepare_apply_chat_template_tools_and_messages(
messages, tools, continue_final_message, add_generation_prompt
)
return self.transformers_tokenizer.apply_chat_template(
conversation=messages,
tools=tools,
continue_final_message=continue_final_message,
tokenize=True,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=None,
return_dict=False,
)
def decode(
self, ids: Union[list[int], int], skip_special_tokens: bool = True
) -> str:
return self.transformers_tokenizer.decode(
ids, skip_special_tokens=skip_special_tokens
)
def convert_tokens_to_string(self, tokens: list[str]) -> str:
from mistral_common.tokens.tokenizers.base import SpecialTokens
from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy,
SpecialTokens,
)
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
to_decode_special_tokens = {SpecialTokens.tool_calls}
if self.is_tekken:
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
tokens = [
t
for t in tokens
if (
t is SpecialTokens.tool_calls
or t not in self.tokenizer._all_special_tokens
)
if (t in to_decode_special_tokens or t not in self.all_special_tokens)
]
if any(isinstance(t, bytes) for t in tokens):
# we need to encode and decode all tokens again
shift = self.tokenizer.num_special_tokens
def _token_to_id(t: str):
t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
try:
return (
shift + self.tokenizer._tekken_token2id_nospecial[t_bytes]
)
except KeyError:
logger.warning(
"Failed to convert token %s to id, replacing with <unk>",
t_bytes,
)
return self.tokenizer.unk_id
ids = [_token_to_id(t) for t in tokens]
decoded = self.tokenizer.decode(ids, self._special_token_policy)
ids = [_tekken_token_to_id(self.tokenizer, t) for t in tokens]
# We filtered unwanted special tokens before
# so we can decode the rest.
decoded = self.tokenizer.decode(ids, SpecialTokenPolicy.KEEP)
else:
decoded = "".join(tokens)
else:
# make sure certain special tokens like Tool calls are
# not decoded
special_tokens = {SpecialTokens.tool_calls}
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
self.tokenizer
)
regular_tokens: list[str] = []
decoded_list = []
decoded_list: list[str] = []
decoded = ""
for token in tokens:
if token in special_tokens:
if token in to_decode_special_tokens:
if regular_tokens:
decoded_list.append(
self.tokenizer.decode(
regular_tokens, self._special_token_policy
regular_tokens, SpecialTokenPolicy.IGNORE
)
)
regular_tokens = []
@ -469,66 +444,56 @@ class MistralTokenizer(TokenizerBase):
if regular_tokens:
decoded_list.append(
self.tokenizer.decode(regular_tokens, self._special_token_policy)
self.tokenizer.decode(regular_tokens, SpecialTokenPolicy.IGNORE)
)
decoded = "".join(decoded_list)
return decoded
def decode(
self, ids: Union[list[int], int], skip_special_tokens: bool = True
) -> str:
assert skip_special_tokens, (
"skip_special_tokens=False is not supported for Mistral tokenizers."
)
if isinstance(ids, int):
ids = [ids]
return self.tokenizer.decode(ids, self._special_token_policy)
def convert_ids_to_tokens(
self,
ids: list[int],
skip_special_tokens: bool = True,
) -> list[str]:
from mistral_common.tokens.tokenizers.base import SpecialTokens
from mistral_common.tokens.tokenizers.base import (
SpecialTokenPolicy,
SpecialTokens,
)
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
# TODO(Patrick) - potentially allow special tokens to not be skipped
assert skip_special_tokens, (
"skip_special_tokens=False is not supported for Mistral tokenizers."
)
if not skip_special_tokens:
return [self.tokenizer.id_to_piece(token_id) for token_id in ids]
assert self.is_tekken or self.is_spm, type(self.tokenizer)
if self.is_tekken:
# skip special tokens except tool call and think tokens
non_skip_special_tokens = {
self.tokenizer.get_control_token(SpecialTokens.tool_calls)
non_skip_special_tokens_ids = {
self.tokenizer.get_control_token(SpecialTokens.tool_calls),
}
if isinstance(self.instruct, InstructTokenizerV13):
if self.instruct.BEGIN_THINK:
non_skip_special_tokens.add(self.instruct.BEGIN_THINK)
non_skip_special_tokens_ids.add(self.instruct.BEGIN_THINK)
if self.instruct.END_THINK:
non_skip_special_tokens.add(self.instruct.END_THINK)
ids = [
non_skip_special_tokens_ids.add(self.instruct.END_THINK)
ids_kept = [
i
for i in ids
if i > self.tokenizer.num_special_tokens or i in non_skip_special_tokens
if i in non_skip_special_tokens_ids or not self._is_special_token_id(i)
]
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
# We filtered unwanted special tokens so we can decode the rest.
tokens = [self.tokenizer.id_to_piece(token_id) for token_id in ids_kept]
if any("<EFBFBD>" in t for t in tokens) and self.is_tekken:
# if a decoded token contains the replacement character, then the
# token has an incomplete UTF-8 character so we must use bytes
# See: https://github.com/vllm-project/vllm/pull/8640
# https://github.com/vllm-project/vllm/pull/9625
# if underlying tokenizeir is sentencepiece, we just add "<22>"
# if underlying tokenizer is sentencepiece, we just add "<22>".
# We filtered unwanted special tokens so we can decode the rest.
tokens = [
self.tokenizer.id_to_byte_piece(id, self._special_token_policy)
for id in ids
self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP)
if token_id not in self.all_special_ids
else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP)
for token_id in ids_kept
]
return tokens

View File

@ -43,34 +43,13 @@ class XgrammarBackend(StructuredOutputBackend):
if isinstance(self.tokenizer, MistralTokenizer):
# NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
try:
if self.tokenizer.is_tekken:
encoded_vocab = self.tokenizer._vocab
else:
encoded_vocab = [
token
for token, _ in sorted(
self.tokenizer.get_vocab().items(),
key=lambda x: x[1],
)
]
stop_token_ids = None
if (
hasattr(
self.tokenizer,
"eos_token_id",
)
and self.tokenizer.eos_token_id is not None
):
stop_token_ids = [self.tokenizer.eos_token_id]
except AttributeError as e:
raise ValueError(
f"Cannot get the vocabulary of the tokenizer "
f"{type(self.tokenizer)}. The tokenizer should have a "
"get_vocab method."
) from e
# not self.tokenizer.vocab_size as self.tokenizer.vocab
# collapses all decoded errors into a single token.
self.vocab_size = len(self.tokenizer.vocab)
tokenizer_info = xgr.TokenizerInfo( # type: ignore
encoded_vocab=encoded_vocab,
encoded_vocab=self.tokenizer.vocab,
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type=xgr.VocabType.RAW
if self.tokenizer.is_tekken