mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Refactor MistralTokenizer (#26358)
Signed-off-by: Julien Denize <julien.denize@mistral.ai>
This commit is contained in:
@ -145,7 +145,7 @@ Supported models:
|
||||
Known issues:
|
||||
|
||||
1. Mistral 7B struggles to generate parallel tool calls correctly.
|
||||
2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is
|
||||
2. **For Transformers tokenization backend only**: Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is
|
||||
much shorter than what vLLM generates. Since an exception is thrown when this condition
|
||||
is not met, the following additional chat templates are provided:
|
||||
|
||||
@ -154,7 +154,14 @@ Known issues:
|
||||
* <gh-file:examples/tool_chat_template_mistral_parallel.jinja> - this is a "better" version that adds a tool-use system prompt
|
||||
when tools are provided, that results in much better reliability when working with parallel tool calling.
|
||||
|
||||
Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
|
||||
Recommended flags:
|
||||
|
||||
1. To use [mistral-common](https://github.com/mistralai/mistral-common) the official Mistral tokenization backend:
|
||||
|
||||
`--tokenizer_mode mistral --config_format mistral --load_format mistral --tool-call-parser mistral`
|
||||
|
||||
2. To use the default Transformers tokenization backend:
|
||||
`--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja`
|
||||
|
||||
### Llama Models (`llama3_json`)
|
||||
|
||||
|
@ -45,10 +45,12 @@ class ModelRequestData(NamedTuple):
|
||||
# Voxtral
|
||||
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
|
||||
from mistral_common.audio import Audio
|
||||
from mistral_common.protocol.instruct.messages import (
|
||||
from mistral_common.protocol.instruct.chunk import (
|
||||
AudioChunk,
|
||||
RawAudio,
|
||||
TextChunk,
|
||||
)
|
||||
from mistral_common.protocol.instruct.messages import (
|
||||
UserMessage,
|
||||
)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
|
@ -32,7 +32,7 @@ pyzmq >= 25.0.0
|
||||
msgspec
|
||||
gguf >= 0.13.0
|
||||
importlib_metadata; python_version < '3.10'
|
||||
mistral_common[image,audio] >= 1.8.2
|
||||
mistral_common[image,audio] >= 1.8.5
|
||||
opencv-python-headless >= 4.11.0 # required for video IO
|
||||
pyyaml
|
||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||
|
@ -23,7 +23,7 @@ jiwer # required for audio tests
|
||||
timm # required for internvl test
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[image,audio] >= 1.8.2 # required for voxtral test
|
||||
mistral_common[image,audio] >= 1.8.5 # required for voxtral test
|
||||
num2words # required for smolvlm test
|
||||
opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
|
@ -29,7 +29,7 @@ torchaudio==2.8.0
|
||||
torchvision==0.23.0
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[image,audio] >= 1.8.2 # required for voxtral test
|
||||
mistral_common[image,audio] >= 1.8.5 # required for voxtral test
|
||||
num2words # required for smolvlm test
|
||||
open_clip_torch==2.32.0 # Required for nemotron_vl test
|
||||
opencv-python-headless >= 4.11.0 # required for video test
|
||||
|
@ -474,7 +474,7 @@ mbstrdecoder==1.1.3
|
||||
# typepy
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistral-common==1.8.2
|
||||
mistral-common==1.8.5
|
||||
# via -r requirements/test.in
|
||||
mlflow==2.22.0
|
||||
# via terratorch
|
||||
@ -1012,8 +1012,6 @@ sentence-transformers==3.2.1
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# mteb
|
||||
sentencepiece==0.2.0
|
||||
# via mistral-common
|
||||
setuptools==77.0.3
|
||||
# via
|
||||
# lightning-utilities
|
||||
|
@ -6,8 +6,7 @@ from collections.abc import Mapping
|
||||
from typing import Literal, Optional
|
||||
|
||||
import pytest
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy, SpecialTokens
|
||||
from mistral_common.tokens.tokenizers.tekken import SpecialTokenInfo, Tekkenizer
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
||||
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.assets.image import ImageAsset
|
||||
@ -2119,34 +2118,9 @@ def test_apply_mistral_chat_template_thinking_chunk():
|
||||
},
|
||||
{"role": "user", "content": "Thanks, what is 3+3?"},
|
||||
]
|
||||
|
||||
# TODO(Julien): upon model release change to a tokenizer already configured.
|
||||
# =================================================================
|
||||
mistral_tokenizer = MistralTokenizer.from_pretrained(
|
||||
"mistralai/Devstral-Small-2507"
|
||||
"mistralai/Magistral-Small-2509"
|
||||
)
|
||||
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
|
||||
# Add think special tokens to the tokenizer
|
||||
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
|
||||
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value
|
||||
)
|
||||
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
|
||||
rank=36, is_control=True, token_str=SpecialTokens.end_think.value
|
||||
)
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
|
||||
k: v
|
||||
for k, v in mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
|
||||
if v not in {35, 36}
|
||||
}
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
|
||||
SpecialTokens.begin_think.value
|
||||
] = 35
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
|
||||
SpecialTokens.end_think.value
|
||||
] = 36
|
||||
mistral_tokenizer.instruct.BEGIN_THINK = 35
|
||||
mistral_tokenizer.instruct.END_THINK = 36
|
||||
# =================================================================
|
||||
|
||||
tokens_ids = apply_mistral_chat_template(
|
||||
mistral_tokenizer, messages, chat_template=None, tools=None
|
||||
|
@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import pytest
|
||||
from mistral_common.multimodal import download_image
|
||||
from mistral_common.protocol.instruct.messages import ImageURLChunk
|
||||
from mistral_common.protocol.instruct.chunk import ImageURLChunk
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
|
||||
|
@ -6,12 +6,8 @@ import json
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from mistral_common.audio import Audio
|
||||
from mistral_common.protocol.instruct.messages import (
|
||||
AudioChunk,
|
||||
RawAudio,
|
||||
TextChunk,
|
||||
UserMessage,
|
||||
)
|
||||
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
|
||||
from mistral_common.protocol.instruct.messages import UserMessage
|
||||
|
||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||
|
||||
|
@ -6,7 +6,8 @@ from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage
|
||||
from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
|
||||
from mistral_common.protocol.instruct.messages import UserMessage
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from PIL import Image
|
||||
|
||||
|
@ -9,7 +9,8 @@ from typing import Any, Union
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch.nn as nn
|
||||
from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage
|
||||
from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
|
||||
from mistral_common.protocol.instruct.messages import UserMessage
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from PIL import Image
|
||||
|
||||
|
@ -2,8 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
from mistral_common.tokens.tokenizers.tekken import SpecialTokenInfo, Tekkenizer
|
||||
|
||||
from tests.reasoning.utils import run_reasoning_extraction_mistral
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
@ -14,33 +12,9 @@ parser_name = "mistral"
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mistral_tokenizer():
|
||||
# TODO(Julien): upon model release change to a tokenizer already configured.
|
||||
# =================================================================
|
||||
mistral_tokenizer = MistralTokenizer.from_pretrained(
|
||||
"mistralai/Devstral-Small-2507"
|
||||
"mistralai/Magistral-Small-2509"
|
||||
)
|
||||
assert isinstance(mistral_tokenizer.tokenizer, Tekkenizer)
|
||||
# Add think special tokens to the tokenizer
|
||||
mistral_tokenizer.tokenizer._all_special_tokens[35] = SpecialTokenInfo(
|
||||
rank=35, is_control=True, token_str=SpecialTokens.begin_think.value
|
||||
)
|
||||
mistral_tokenizer.tokenizer._all_special_tokens[36] = SpecialTokenInfo(
|
||||
rank=36, is_control=True, token_str=SpecialTokens.end_think.value
|
||||
)
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab = {
|
||||
k: v
|
||||
for k, v in mistral_tokenizer.tokenizer._special_tokens_reverse_vocab.items()
|
||||
if v not in {35, 36}
|
||||
}
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
|
||||
SpecialTokens.begin_think.value
|
||||
] = 35
|
||||
mistral_tokenizer.tokenizer._special_tokens_reverse_vocab[
|
||||
SpecialTokens.end_think.value
|
||||
] = 36
|
||||
mistral_tokenizer.instruct.BEGIN_THINK = 35
|
||||
mistral_tokenizer.instruct.END_THINK = 36
|
||||
# =================================================================
|
||||
return mistral_tokenizer
|
||||
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -403,20 +403,12 @@ def resolve_mistral_chat_template(
|
||||
chat_template: Optional[str],
|
||||
**kwargs: Any,
|
||||
) -> Optional[str]:
|
||||
if chat_template is not None:
|
||||
logger.warning_once(
|
||||
"'chat_template' cannot be overridden for mistral tokenizer."
|
||||
)
|
||||
if "add_generation_prompt" in kwargs:
|
||||
logger.warning_once(
|
||||
"'add_generation_prompt' is not supported for mistral tokenizer, "
|
||||
"so it will be ignored."
|
||||
)
|
||||
if "continue_final_message" in kwargs:
|
||||
logger.warning_once(
|
||||
"'continue_final_message' is not supported for mistral tokenizer, "
|
||||
"so it will be ignored."
|
||||
if chat_template is not None or kwargs.get("chat_template_kwargs") is not None:
|
||||
raise ValueError(
|
||||
"'chat_template' or 'chat_template_kwargs' cannot be overridden "
|
||||
"for mistral tokenizer."
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
@ -10,7 +10,8 @@ from typing import Annotated, Literal, Optional, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage
|
||||
from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
|
||||
from mistral_common.protocol.instruct.messages import UserMessage
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
|
||||
from PIL import Image
|
||||
|
@ -12,12 +12,8 @@ import regex as re
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mistral_common.audio import mel_filter_bank
|
||||
from mistral_common.protocol.instruct.messages import (
|
||||
AudioChunk,
|
||||
RawAudio,
|
||||
TextChunk,
|
||||
UserMessage,
|
||||
)
|
||||
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
|
||||
from mistral_common.protocol.instruct.messages import UserMessage
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.protocol.transcription.request import TranscriptionRequest
|
||||
from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder
|
||||
|
@ -1,34 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
import huggingface_hub
|
||||
import regex as re
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from transformers.tokenization_utils_base import BatchEncoding
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer_base import TokenizerBase
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# make sure `mistral_common` is lazy imported,
|
||||
# so that users who only use non-mistral models
|
||||
# will not be bothered by the dependency.
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import (
|
||||
MistralTokenizer as PublicMistralTokenizer,
|
||||
from mistral_common.protocol.instruct.request import (
|
||||
ChatCompletionRequest as MistralChatCompletionRequest,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
from transformers.tokenization_mistral_common import (
|
||||
MistralCommonTokenizer as TransformersMistralTokenizer,
|
||||
)
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
|
||||
def maybe_serialize_tool_calls(request: "MistralChatCompletionRequest"):
|
||||
# SEE: https://github.com/vllm-project/vllm/pull/9951
|
||||
# Credits go to: @gcalmettes
|
||||
# NOTE: There is currently a bug in pydantic where attributes
|
||||
@ -65,7 +58,7 @@ def maybe_serialize_tool_calls(request: "ChatCompletionRequest"):
|
||||
request.messages[i]["tool_calls"] = validated_tool_calls
|
||||
|
||||
|
||||
def truncate_tool_call_ids(request: "ChatCompletionRequest"):
|
||||
def truncate_tool_call_ids(request: "MistralChatCompletionRequest"):
|
||||
"""Truncates tool call IDs for Mistral's ID requirements."""
|
||||
for i, message in enumerate(request.messages):
|
||||
if message.get("role") == "assistant":
|
||||
@ -95,84 +88,34 @@ def truncate_tool_call_ids(request: "ChatCompletionRequest"):
|
||||
request.messages[i]["tool_call_id"] = tool_call_id
|
||||
|
||||
|
||||
def validate_request_params(request: "ChatCompletionRequest"):
|
||||
if request.skip_special_tokens is not None and not request.skip_special_tokens:
|
||||
raise ValueError(
|
||||
"skip_special_tokens=False is not supported for Mistral tokenizers."
|
||||
)
|
||||
|
||||
|
||||
def list_local_repo_files(repo_id: str, revision: Optional[str]) -> list[str]:
|
||||
repo_cache = os.path.join(
|
||||
huggingface_hub.constants.HF_HUB_CACHE,
|
||||
huggingface_hub.constants.REPO_ID_SEPARATOR.join(
|
||||
["models", *repo_id.split("/")]
|
||||
),
|
||||
)
|
||||
|
||||
if revision is None:
|
||||
revision_file = os.path.join(repo_cache, "refs", "main")
|
||||
if os.path.isfile(revision_file):
|
||||
with open(revision_file) as file:
|
||||
revision = file.read()
|
||||
|
||||
if revision:
|
||||
revision_dir = os.path.join(repo_cache, "snapshots", revision)
|
||||
if os.path.isdir(revision_dir):
|
||||
return os.listdir(revision_dir)
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def find_tokenizer_file(files: list[str]):
|
||||
# Accept both versioned (tokenizer.model.v3) and unversioned
|
||||
# (tokenizer.model) forms, plus tekken.json and tokenizer.mm.model
|
||||
# variants. Previous pattern only matched the versioned variants.
|
||||
file_pattern = re.compile(
|
||||
r"^tokenizer\.model(\.v.*)?|tekken\.json|tokenizer\.mm\.model(\.v.*)?$"
|
||||
)
|
||||
|
||||
matched_files = [file for file in files if file_pattern.match(file)]
|
||||
if len(matched_files) > 1:
|
||||
logger.warning(
|
||||
"Multiple files matched pattern `%s`: %s. Using %s.",
|
||||
file_pattern.pattern,
|
||||
matched_files,
|
||||
matched_files[0],
|
||||
)
|
||||
elif len(matched_files) == 0:
|
||||
raise OSError(
|
||||
f"Found {len(matched_files)} files matching the "
|
||||
f"pattern: `{file_pattern.pattern}`. Make sure that a Mistral "
|
||||
f"tokenizer is present in {files}."
|
||||
)
|
||||
|
||||
return matched_files[0]
|
||||
|
||||
|
||||
def _aggregate_content(content: list) -> list[dict[str, Any]]:
|
||||
aggregated_content: list[dict[str, Any]] = []
|
||||
for chunk in content:
|
||||
if (
|
||||
chunk.get("type") == "text"
|
||||
and aggregated_content
|
||||
and aggregated_content[-1].get("type") == "text"
|
||||
):
|
||||
aggregated_content[-1]["text"] += "\n\n" + chunk.get("text")
|
||||
else:
|
||||
aggregated_content.append(chunk)
|
||||
if len(aggregated_content) == 1 and aggregated_content[0].get("type") == "text":
|
||||
content = aggregated_content[0]["text"]
|
||||
return content
|
||||
|
||||
|
||||
def make_mistral_chat_completion_request(
|
||||
def _prepare_apply_chat_template_tools_and_messages(
|
||||
messages: list["ChatCompletionMessageParam"],
|
||||
tools: Optional[list[dict[str, Any]]] = None,
|
||||
) -> "ChatCompletionRequest":
|
||||
continue_final_message: bool = False,
|
||||
add_generation_prompt: bool = False,
|
||||
) -> tuple[list["ChatCompletionMessageParam"], Optional[list[dict[str, Any]]]]:
|
||||
if add_generation_prompt and continue_final_message:
|
||||
raise ValueError(
|
||||
"Cannot set both `add_generation_prompt` and "
|
||||
"`continue_final_message` to True."
|
||||
)
|
||||
|
||||
last_message = cast(dict[str, Any], messages[-1])
|
||||
if last_message["role"] == "assistant":
|
||||
last_message["prefix"] = True
|
||||
# add_generation_prompt is directly handled by the tokenizer but we
|
||||
# check if the user is trying to use it with a final assistant message
|
||||
# which is probably not what they want.
|
||||
# If add_generation_prompt is False, we don't need to check anything.
|
||||
if add_generation_prompt and last_message["role"] == "assistant":
|
||||
raise ValueError(
|
||||
"Cannot set `add_generation_prompt` to True when "
|
||||
"the last message is from the assistant. Consider "
|
||||
"using `continue_final_message` instead."
|
||||
)
|
||||
if continue_final_message and last_message["role"] != "assistant":
|
||||
raise ValueError(
|
||||
"Cannot set `continue_final_message` to True when "
|
||||
"the last message is not from the assistant."
|
||||
)
|
||||
|
||||
# mistral-common requires AssistantMessage content to be string [1].
|
||||
#
|
||||
@ -181,13 +124,6 @@ def make_mistral_chat_completion_request(
|
||||
# Remove reasoning_content as unsupported by Mistral
|
||||
_ = message.pop("reasoning_content", None) # type: ignore
|
||||
|
||||
# Convert list text content to string
|
||||
if message.get("role") in ("assistant", "tool"):
|
||||
content: Any = message.get("content")
|
||||
if isinstance(content, list):
|
||||
content = _aggregate_content(content)
|
||||
message["content"] = content
|
||||
|
||||
# The Mistral client, in comparison to the OpenAI client, requires the
|
||||
# "parameters" dict and the "description" string to be present
|
||||
# even if they are empty.
|
||||
@ -200,108 +136,113 @@ def make_mistral_chat_completion_request(
|
||||
if function.get("description") is None:
|
||||
function["description"] = ""
|
||||
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
return messages, tools
|
||||
|
||||
return ChatCompletionRequest(messages=messages, tools=tools) # type: ignore[type-var]
|
||||
|
||||
def validate_request_params(request: "ChatCompletionRequest"):
|
||||
if request.chat_template is not None or request.chat_template_kwargs is not None:
|
||||
raise ValueError("chat_template is not supported for Mistral tokenizers.")
|
||||
|
||||
|
||||
def _tekken_token_to_id(tokenizer: "Tekkenizer", t: Union[str, bytes]) -> int:
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
|
||||
assert isinstance(tokenizer, Tekkenizer), type(tokenizer)
|
||||
|
||||
t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
|
||||
shift = tokenizer.num_special_tokens
|
||||
try:
|
||||
return shift + tokenizer._tekken_token2id_nospecial[t_bytes]
|
||||
except KeyError:
|
||||
t_str = t_bytes.decode("utf-8")
|
||||
if t_str in tokenizer._special_tokens_reverse_vocab:
|
||||
return tokenizer._special_tokens_reverse_vocab[t_str]
|
||||
logger.warning(
|
||||
"Failed to convert token %s to id, replacing with <unk>", t_bytes
|
||||
)
|
||||
return tokenizer.unk_id
|
||||
|
||||
|
||||
class MistralTokenizer(TokenizerBase):
|
||||
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
|
||||
self.mistral = tokenizer
|
||||
self.instruct = tokenizer.instruct_tokenizer
|
||||
_mistral_version_str = self.instruct.tokenizer.version.value
|
||||
self.version: int = int(_mistral_version_str.split("v")[-1])
|
||||
|
||||
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
|
||||
self.is_tekken = isinstance(tokenizer_, Tekkenizer)
|
||||
def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None:
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
|
||||
self.is_spm = isinstance(tokenizer_, SentencePieceTokenizer)
|
||||
self._special_token_policy = (
|
||||
SpecialTokenPolicy.IGNORE if self.is_tekken else None
|
||||
)
|
||||
self.transformers_tokenizer = tokenizer
|
||||
self.mistral = tokenizer.tokenizer
|
||||
self.instruct = self.mistral.instruct_tokenizer
|
||||
self.tokenizer = self.instruct.tokenizer
|
||||
|
||||
_mistral_version_str = str(self.tokenizer.version.value)
|
||||
self.version: int = int(_mistral_version_str.split("v")[-1])
|
||||
|
||||
self.is_tekken = isinstance(self.tokenizer, Tekkenizer)
|
||||
self.is_spm = isinstance(self.tokenizer, SentencePieceTokenizer)
|
||||
if not (self.is_tekken or self.is_spm):
|
||||
raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}")
|
||||
raise TypeError(f"Unsupported tokenizer: {type(self.tokenizer)}")
|
||||
|
||||
self._vocab = tokenizer_.vocab()
|
||||
# Convert to a dict[str, int] to match protocol, but this is a lossy
|
||||
# conversion. There may be multiple token ids that decode to the same
|
||||
# string due to partial UTF-8 byte sequences being converted to <20>
|
||||
self._vocab_dict = {token: idx for idx, token in enumerate(self._vocab)}
|
||||
self.tokenizer = tokenizer_
|
||||
# Reverse order to ensure that the lowest token id is kept.
|
||||
self._vocab_dict = {
|
||||
self.convert_ids_to_tokens([i], skip_special_tokens=False)[0]: i
|
||||
for i in range(self.vocab_size - 1, -1, -1)
|
||||
}
|
||||
# Sort the dict for convenience
|
||||
self._vocab_dict = dict(sorted(self._vocab_dict.items(), key=lambda x: x[1]))
|
||||
|
||||
# Vocab sorted by token id.
|
||||
self._vocab = self.tokenizer._vocab
|
||||
self._max_token_id = self.vocab_size - 1
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, path_or_repo_id: str, *, revision: Optional[str] = None
|
||||
) -> "MistralTokenizer":
|
||||
if not Path(path_or_repo_id).exists():
|
||||
assert len(path_or_repo_id.split("/")) == 2, (
|
||||
"You have either provided a non-existent path: "
|
||||
"{path_or_repo_id} or an invalid HF Hub repo id."
|
||||
)
|
||||
tokenizer_file = cls._download_mistral_tokenizer_from_hf(
|
||||
path_or_repo_id, revision
|
||||
)
|
||||
elif Path(path_or_repo_id).is_dir():
|
||||
tokenizer_file_name = find_tokenizer_file(os.listdir(path_or_repo_id))
|
||||
tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name)
|
||||
else:
|
||||
assert Path(path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}"
|
||||
tokenizer_file = str(Path(path_or_repo_id))
|
||||
|
||||
from mistral_common.tokens.tokenizers.mistral import (
|
||||
MistralTokenizer as PublicMistralTokenizer,
|
||||
from transformers.tokenization_mistral_common import (
|
||||
MistralCommonTokenizer as TransformersMistralTokenizer,
|
||||
)
|
||||
|
||||
mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file)
|
||||
return cls(mistral_tokenizer)
|
||||
|
||||
@staticmethod
|
||||
def _download_mistral_tokenizer_from_hf(
|
||||
tokenizer_name: str, revision: Optional[str]
|
||||
) -> str:
|
||||
try:
|
||||
hf_api = HfApi()
|
||||
files = hf_api.list_repo_files(repo_id=tokenizer_name, revision=revision)
|
||||
except ConnectionError as exc:
|
||||
files = list_local_repo_files(repo_id=tokenizer_name, revision=revision)
|
||||
|
||||
if len(files) == 0:
|
||||
raise exc
|
||||
|
||||
filename = find_tokenizer_file(files)
|
||||
|
||||
tokenizer_file = hf_hub_download(
|
||||
tokenizer_name, filename=filename, revision=revision
|
||||
str_revision = "main" if revision is None else revision
|
||||
return cls(
|
||||
TransformersMistralTokenizer.from_pretrained(
|
||||
path_or_repo_id, revision=str_revision
|
||||
)
|
||||
)
|
||||
return tokenizer_file
|
||||
|
||||
# the following attributes are set to fit vLLM's design and are used
|
||||
# by the structured output backends.
|
||||
@property
|
||||
def all_special_tokens_extended(self) -> list[str]:
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
|
||||
# tekken defines its own extended special tokens list
|
||||
if hasattr(self.tokenizer, "SPECIAL_TOKENS"):
|
||||
special_tokens = self.tokenizer.SPECIAL_TOKENS
|
||||
else:
|
||||
special_tokens = list(SpecialTokens)
|
||||
return [s.value if isinstance(s, SpecialTokens) else s for s in special_tokens]
|
||||
return self.all_special_tokens
|
||||
|
||||
@property
|
||||
def all_special_tokens(self) -> list[str]:
|
||||
return self.all_special_tokens_extended
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
|
||||
|
||||
return [
|
||||
self.tokenizer.decode([i], special_token_policy=SpecialTokenPolicy.KEEP)
|
||||
for i in self.all_special_ids
|
||||
]
|
||||
|
||||
@property
|
||||
def all_special_ids(self) -> list[int]:
|
||||
return [self.all_special_tokens.index(t) for t in self.all_special_tokens]
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
|
||||
if self.is_tekken:
|
||||
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
|
||||
special_ids = {t["rank"] for t in self.tokenizer._all_special_tokens}
|
||||
elif self.is_spm:
|
||||
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
|
||||
self.tokenizer
|
||||
)
|
||||
special_ids = self.tokenizer._control_tokens
|
||||
else:
|
||||
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
|
||||
return sorted(special_ids)
|
||||
|
||||
@property
|
||||
def bos_token_id(self) -> int:
|
||||
@ -317,7 +258,7 @@ class MistralTokenizer(TokenizerBase):
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
raise NotImplementedError()
|
||||
return self.transformers_tokenizer.pad_token
|
||||
|
||||
@property
|
||||
def is_fast(self) -> bool:
|
||||
@ -325,7 +266,7 @@ class MistralTokenizer(TokenizerBase):
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return len(self._vocab)
|
||||
return self.transformers_tokenizer.vocab_size
|
||||
|
||||
@property
|
||||
def max_token_id(self) -> int:
|
||||
@ -335,6 +276,23 @@ class MistralTokenizer(TokenizerBase):
|
||||
def truncation_side(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def _is_special_token_id(self, token_id: int) -> bool:
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
|
||||
if self.is_spm:
|
||||
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
|
||||
self.tokenizer
|
||||
)
|
||||
return token_id in self.tokenizer._control_tokens
|
||||
if self.is_tekken:
|
||||
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
|
||||
return token_id < self.tokenizer.num_special_tokens
|
||||
else:
|
||||
raise ValueError(f"Unknown tokenizer type: {type(self.tokenizer)}")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.vocab_size
|
||||
|
||||
@ -346,25 +304,19 @@ class MistralTokenizer(TokenizerBase):
|
||||
truncation: bool = False,
|
||||
max_length: Optional[int] = None,
|
||||
):
|
||||
input_ids: Union[list[int], list[list[int]]]
|
||||
# For list[str], original prompt text
|
||||
if is_list_of(text, str):
|
||||
input_ids_: list[list[int]] = []
|
||||
for p in text:
|
||||
each_input_ids = self.encode_one(p, truncation, max_length)
|
||||
input_ids_.append(each_input_ids)
|
||||
input_ids = input_ids_
|
||||
# For list[int], apply chat template output, already tokens.
|
||||
elif is_list_of(text, int):
|
||||
input_ids = text
|
||||
# For str, single prompt text
|
||||
else:
|
||||
input_ids = self.encode_one(text, truncation, max_length)
|
||||
return BatchEncoding({"input_ids": input_ids})
|
||||
return self.transformers_tokenizer(
|
||||
text=text,
|
||||
text_pair=text_pair,
|
||||
add_special_tokens=add_special_tokens,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
@property
|
||||
def vocab(self) -> list[str]:
|
||||
return self._vocab
|
||||
|
||||
def get_vocab(self) -> dict[str, int]:
|
||||
# NB: the dictionary form of the vocabulary collapses token ids that map
|
||||
# to the same string but have different bytes
|
||||
return self._vocab_dict
|
||||
|
||||
def get_added_vocab(self) -> dict[str, int]:
|
||||
@ -378,11 +330,9 @@ class MistralTokenizer(TokenizerBase):
|
||||
max_length: Optional[int] = None,
|
||||
) -> list[int]:
|
||||
# Mistral Tokenizers should not add special tokens
|
||||
input_ids = self.encode(text)
|
||||
|
||||
if truncation:
|
||||
input_ids = input_ids[:max_length]
|
||||
return input_ids
|
||||
return self.transformers_tokenizer.encode(
|
||||
text, add_special_tokens=False, truncation=truncation, max_length=max_length
|
||||
)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
@ -391,15 +341,20 @@ class MistralTokenizer(TokenizerBase):
|
||||
max_length: Optional[int] = None,
|
||||
add_special_tokens: Optional[bool] = None,
|
||||
) -> list[int]:
|
||||
# `encode` should only be used for prompt completion
|
||||
# it should never be used for chat_completion.
|
||||
# For chat completion use `apply_chat_template`
|
||||
if add_special_tokens is not None:
|
||||
return self.tokenizer.encode(
|
||||
text, bos=add_special_tokens, eos=add_special_tokens
|
||||
return self.transformers_tokenizer.encode(
|
||||
text,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
return self.tokenizer.encode(text, bos=True, eos=False)
|
||||
encoded = self.tokenizer.encode(text, bos=True, eos=False)
|
||||
|
||||
if truncation is not False and max_length is not None:
|
||||
return encoded[:max_length]
|
||||
else:
|
||||
return encoded
|
||||
|
||||
def apply_chat_template(
|
||||
self,
|
||||
@ -407,59 +362,79 @@ class MistralTokenizer(TokenizerBase):
|
||||
tools: Optional[list[dict[str, Any]]] = None,
|
||||
**kwargs,
|
||||
) -> list[int]:
|
||||
request = make_mistral_chat_completion_request(messages, tools)
|
||||
encoded = self.mistral.encode_chat_completion(request)
|
||||
add_generation_prompt = kwargs.pop("add_generation_prompt", False)
|
||||
continue_final_message = kwargs.get("continue_final_message", False)
|
||||
padding = kwargs.get("padding", False)
|
||||
truncation = kwargs.get("truncation", False)
|
||||
max_length = kwargs.get("max_length")
|
||||
|
||||
# encode-decode to get clean prompt
|
||||
return encoded.tokens
|
||||
messages, tools = _prepare_apply_chat_template_tools_and_messages(
|
||||
messages, tools, continue_final_message, add_generation_prompt
|
||||
)
|
||||
|
||||
return self.transformers_tokenizer.apply_chat_template(
|
||||
conversation=messages,
|
||||
tools=tools,
|
||||
continue_final_message=continue_final_message,
|
||||
tokenize=True,
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
return_tensors=None,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
def decode(
|
||||
self, ids: Union[list[int], int], skip_special_tokens: bool = True
|
||||
) -> str:
|
||||
return self.transformers_tokenizer.decode(
|
||||
ids, skip_special_tokens=skip_special_tokens
|
||||
)
|
||||
|
||||
def convert_tokens_to_string(self, tokens: list[str]) -> str:
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
from mistral_common.tokens.tokenizers.base import (
|
||||
SpecialTokenPolicy,
|
||||
SpecialTokens,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.sentencepiece import (
|
||||
SentencePieceTokenizer,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
||||
|
||||
to_decode_special_tokens = {SpecialTokens.tool_calls}
|
||||
if self.is_tekken:
|
||||
assert isinstance(self.tokenizer, Tekkenizer), type(self.tokenizer)
|
||||
tokens = [
|
||||
t
|
||||
for t in tokens
|
||||
if (
|
||||
t is SpecialTokens.tool_calls
|
||||
or t not in self.tokenizer._all_special_tokens
|
||||
)
|
||||
if (t in to_decode_special_tokens or t not in self.all_special_tokens)
|
||||
]
|
||||
|
||||
if any(isinstance(t, bytes) for t in tokens):
|
||||
# we need to encode and decode all tokens again
|
||||
shift = self.tokenizer.num_special_tokens
|
||||
|
||||
def _token_to_id(t: str):
|
||||
t_bytes = t.encode("utf-8") if not isinstance(t, bytes) else t
|
||||
try:
|
||||
return (
|
||||
shift + self.tokenizer._tekken_token2id_nospecial[t_bytes]
|
||||
)
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
"Failed to convert token %s to id, replacing with <unk>",
|
||||
t_bytes,
|
||||
)
|
||||
return self.tokenizer.unk_id
|
||||
|
||||
ids = [_token_to_id(t) for t in tokens]
|
||||
decoded = self.tokenizer.decode(ids, self._special_token_policy)
|
||||
ids = [_tekken_token_to_id(self.tokenizer, t) for t in tokens]
|
||||
# We filtered unwanted special tokens before
|
||||
# so we can decode the rest.
|
||||
decoded = self.tokenizer.decode(ids, SpecialTokenPolicy.KEEP)
|
||||
else:
|
||||
decoded = "".join(tokens)
|
||||
else:
|
||||
# make sure certain special tokens like Tool calls are
|
||||
# not decoded
|
||||
special_tokens = {SpecialTokens.tool_calls}
|
||||
assert isinstance(self.tokenizer, SentencePieceTokenizer), type(
|
||||
self.tokenizer
|
||||
)
|
||||
|
||||
regular_tokens: list[str] = []
|
||||
decoded_list = []
|
||||
decoded_list: list[str] = []
|
||||
decoded = ""
|
||||
|
||||
for token in tokens:
|
||||
if token in special_tokens:
|
||||
if token in to_decode_special_tokens:
|
||||
if regular_tokens:
|
||||
decoded_list.append(
|
||||
self.tokenizer.decode(
|
||||
regular_tokens, self._special_token_policy
|
||||
regular_tokens, SpecialTokenPolicy.IGNORE
|
||||
)
|
||||
)
|
||||
regular_tokens = []
|
||||
@ -469,66 +444,56 @@ class MistralTokenizer(TokenizerBase):
|
||||
|
||||
if regular_tokens:
|
||||
decoded_list.append(
|
||||
self.tokenizer.decode(regular_tokens, self._special_token_policy)
|
||||
self.tokenizer.decode(regular_tokens, SpecialTokenPolicy.IGNORE)
|
||||
)
|
||||
|
||||
decoded = "".join(decoded_list)
|
||||
|
||||
return decoded
|
||||
|
||||
def decode(
|
||||
self, ids: Union[list[int], int], skip_special_tokens: bool = True
|
||||
) -> str:
|
||||
assert skip_special_tokens, (
|
||||
"skip_special_tokens=False is not supported for Mistral tokenizers."
|
||||
)
|
||||
|
||||
if isinstance(ids, int):
|
||||
ids = [ids]
|
||||
return self.tokenizer.decode(ids, self._special_token_policy)
|
||||
|
||||
def convert_ids_to_tokens(
|
||||
self,
|
||||
ids: list[int],
|
||||
skip_special_tokens: bool = True,
|
||||
) -> list[str]:
|
||||
from mistral_common.tokens.tokenizers.base import SpecialTokens
|
||||
from mistral_common.tokens.tokenizers.base import (
|
||||
SpecialTokenPolicy,
|
||||
SpecialTokens,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.instruct import InstructTokenizerV13
|
||||
|
||||
# TODO(Patrick) - potentially allow special tokens to not be skipped
|
||||
assert skip_special_tokens, (
|
||||
"skip_special_tokens=False is not supported for Mistral tokenizers."
|
||||
)
|
||||
if not skip_special_tokens:
|
||||
return [self.tokenizer.id_to_piece(token_id) for token_id in ids]
|
||||
|
||||
assert self.is_tekken or self.is_spm, type(self.tokenizer)
|
||||
|
||||
if self.is_tekken:
|
||||
# skip special tokens except tool call and think tokens
|
||||
non_skip_special_tokens = {
|
||||
self.tokenizer.get_control_token(SpecialTokens.tool_calls)
|
||||
non_skip_special_tokens_ids = {
|
||||
self.tokenizer.get_control_token(SpecialTokens.tool_calls),
|
||||
}
|
||||
if isinstance(self.instruct, InstructTokenizerV13):
|
||||
if self.instruct.BEGIN_THINK:
|
||||
non_skip_special_tokens.add(self.instruct.BEGIN_THINK)
|
||||
non_skip_special_tokens_ids.add(self.instruct.BEGIN_THINK)
|
||||
if self.instruct.END_THINK:
|
||||
non_skip_special_tokens.add(self.instruct.END_THINK)
|
||||
ids = [
|
||||
non_skip_special_tokens_ids.add(self.instruct.END_THINK)
|
||||
|
||||
ids_kept = [
|
||||
i
|
||||
for i in ids
|
||||
if i > self.tokenizer.num_special_tokens or i in non_skip_special_tokens
|
||||
if i in non_skip_special_tokens_ids or not self._is_special_token_id(i)
|
||||
]
|
||||
|
||||
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
|
||||
# We filtered unwanted special tokens so we can decode the rest.
|
||||
tokens = [self.tokenizer.id_to_piece(token_id) for token_id in ids_kept]
|
||||
|
||||
if any("<EFBFBD>" in t for t in tokens) and self.is_tekken:
|
||||
# if a decoded token contains the replacement character, then the
|
||||
# token has an incomplete UTF-8 character so we must use bytes
|
||||
# See: https://github.com/vllm-project/vllm/pull/8640
|
||||
# https://github.com/vllm-project/vllm/pull/9625
|
||||
# if underlying tokenizeir is sentencepiece, we just add "<22>"
|
||||
# if underlying tokenizer is sentencepiece, we just add "<22>".
|
||||
# We filtered unwanted special tokens so we can decode the rest.
|
||||
tokens = [
|
||||
self.tokenizer.id_to_byte_piece(id, self._special_token_policy)
|
||||
for id in ids
|
||||
self.tokenizer.id_to_byte_piece(token_id, SpecialTokenPolicy.KEEP)
|
||||
if token_id not in self.all_special_ids
|
||||
else self.tokenizer.decode([token_id], SpecialTokenPolicy.KEEP)
|
||||
for token_id in ids_kept
|
||||
]
|
||||
|
||||
return tokens
|
||||
|
@ -43,34 +43,13 @@ class XgrammarBackend(StructuredOutputBackend):
|
||||
if isinstance(self.tokenizer, MistralTokenizer):
|
||||
# NOTE: ideally, xgrammar should handle this accordingly.
|
||||
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
|
||||
try:
|
||||
if self.tokenizer.is_tekken:
|
||||
encoded_vocab = self.tokenizer._vocab
|
||||
else:
|
||||
encoded_vocab = [
|
||||
token
|
||||
for token, _ in sorted(
|
||||
self.tokenizer.get_vocab().items(),
|
||||
key=lambda x: x[1],
|
||||
)
|
||||
]
|
||||
stop_token_ids = None
|
||||
if (
|
||||
hasattr(
|
||||
self.tokenizer,
|
||||
"eos_token_id",
|
||||
)
|
||||
and self.tokenizer.eos_token_id is not None
|
||||
):
|
||||
stop_token_ids = [self.tokenizer.eos_token_id]
|
||||
except AttributeError as e:
|
||||
raise ValueError(
|
||||
f"Cannot get the vocabulary of the tokenizer "
|
||||
f"{type(self.tokenizer)}. The tokenizer should have a "
|
||||
"get_vocab method."
|
||||
) from e
|
||||
|
||||
# not self.tokenizer.vocab_size as self.tokenizer.vocab
|
||||
# collapses all decoded errors into a single token.
|
||||
self.vocab_size = len(self.tokenizer.vocab)
|
||||
tokenizer_info = xgr.TokenizerInfo( # type: ignore
|
||||
encoded_vocab=encoded_vocab,
|
||||
encoded_vocab=self.tokenizer.vocab,
|
||||
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
|
||||
vocab_type=xgr.VocabType.RAW
|
||||
if self.tokenizer.is_tekken
|
||||
|
Reference in New Issue
Block a user