Compare commits

...

16 Commits

Author SHA1 Message Date
73e0b95cab fix tests 2024-10-14 18:32:36 +01:00
b4406a26c1 Update save-load test 2024-10-14 18:08:39 +01:00
fd6f269052 Inverse template save-loading in separate file 2024-10-14 18:08:39 +01:00
7676abd496 Tests are run in the CI! 2024-10-14 18:08:39 +01:00
fc4482558b Check tests are being run in CI 2024-10-14 18:08:39 +01:00
fc51e7a04c Add reverse templating for the tools list as well 2024-10-14 18:08:39 +01:00
d6b7914b77 push todo 2024-10-14 18:08:39 +01:00
b51b330221 make fixup 2024-10-14 18:08:39 +01:00
39b977995a make fixup 2024-10-14 18:08:39 +01:00
550a6f1fad One more try! 2024-10-14 18:08:39 +01:00
6fd52c6f67 Check tests are still being run 2024-10-14 18:08:39 +01:00
eec1c2033b Check tests are still being run 2024-10-14 18:08:39 +01:00
07a0a752c8 Check tests are still being run 2024-10-14 18:08:39 +01:00
481109b2df Refactor chat template tests out so they run once, instead of on every class 2024-10-14 18:08:39 +01:00
d9f478f706 Fix imports 2024-10-14 18:08:39 +01:00
47a3ca55de Initial commit 2024-10-14 18:08:39 +01:00
4 changed files with 537 additions and 345 deletions

View File

@ -65,7 +65,11 @@ from .utils import (
requires_backends,
to_py_obj,
)
from .utils.chat_template_utils import _compile_jinja_template, _render_with_assistant_indices
from .utils.chat_template_utils import (
_compile_inverse_template,
_compile_jinja_template,
_render_with_assistant_indices,
)
from .utils.import_utils import PROTOBUF_IMPORT_ERROR
@ -145,6 +149,7 @@ AudioInput = Union["np.ndarray", "torch.Tensor", List["np.ndarray"], List["torch
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
ADDED_TOKENS_FILE = "added_tokens.json"
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
INVERSE_TEMPLATE_FILE = "inverse_template.jinja"
# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
FULL_TOKENIZER_FILE = "tokenizer.json"
@ -1631,6 +1636,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# we reconstruct that into a single dict while loading them.
self.chat_template = {template["name"]: template["template"] for template in self.chat_template}
self.inverse_template = kwargs.pop("inverse_template", None)
super().__init__(**kwargs)
@property
@ -1916,6 +1923,24 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
else:
return rendered
def apply_inverse_template(self, chat: str, inverse_template: Optional[str] = None, skip_json_load: bool = False):
if inverse_template is None:
if self.inverse_template is not None:
inverse_template = self.inverse_template
else:
raise ValueError(
"Cannot use apply_inverse_template() because tokenizer.inverse_template is not set! Please set "
"the tokenizer.inverse_template attribute to a valid Jinja template string."
)
# Compilation function uses a cache to avoid recompiling the same template
compiled_template = _compile_inverse_template(inverse_template)
template_out = compiled_template.render(chat=chat)
if skip_json_load:
return template_out
else:
return json.loads(template_out)
def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[List[Dict]] = None) -> str:
"""
Retrieve the chat template string used for tokenizing chat messages. This template is used
@ -2121,6 +2146,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
# tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders
"tokenizer_file": FULL_TOKENIZER_FILE,
"inverse_template": INVERSE_TEMPLATE_FILE,
}
vocab_files = {**cls.vocab_files_names, **additional_files_names}
if "tokenizer_file" in vocab_files:
@ -2241,6 +2267,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
from_slow = kwargs.get("from_slow", False)
gguf_file = kwargs.get("gguf_file", None)
has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None
inverse_template_file = resolved_vocab_files.pop("inverse_template", None)
# If one passes a GGUF file path to `gguf_file` there is no need for this check as the tokenizer will be
# loaded directly from the GGUF file.
@ -2342,6 +2369,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
f" from is '{cls.__name__}'."
)
if inverse_template_file is not None:
with open(inverse_template_file) as chat_template_handle:
init_kwargs["inverse_template"] = chat_template_handle.read()
# Update with newly provided kwargs
init_kwargs.update(kwargs)
@ -2577,6 +2608,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_CONFIG_FILE
)
inverse_chat_template_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + INVERSE_TEMPLATE_FILE
)
tokenizer_config = copy.deepcopy(self.init_kwargs)
# Let's save the init kwargs
@ -2599,6 +2634,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
else:
tokenizer_config["chat_template"] = self.chat_template
if self.inverse_template is not None:
with open(inverse_chat_template_file, "w", encoding="utf-8") as f:
f.write(self.inverse_template)
if len(self.init_inputs) > 0:
tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)
for file_id in self.vocab_files_names.keys():

View File

@ -16,6 +16,7 @@ import inspect
import json
import re
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin, get_type_hints
@ -28,7 +29,7 @@ from .import_utils import is_jinja_available, is_torch_available, is_vision_avai
if is_jinja_available():
import jinja2
from jinja2.ext import Extension
from jinja2.sandbox import ImmutableSandboxedEnvironment
from jinja2.sandbox import ImmutableSandboxedEnvironment, SandboxedEnvironment
else:
jinja2 = None
@ -406,17 +407,6 @@ def _compile_jinja_template(chat_template):
"apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}."
)
def raise_exception(message):
raise jinja2.exceptions.TemplateError(message)
def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
# We override the built-in tojson filter because Jinja's default filter escapes HTML characters
# We also expose some options like custom indents and separators
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
def strftime_now(format):
return datetime.now().strftime(format)
jinja_env = ImmutableSandboxedEnvironment(
trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols]
)
@ -424,3 +414,76 @@ def _compile_jinja_template(chat_template):
jinja_env.globals["raise_exception"] = raise_exception
jinja_env.globals["strftime_now"] = strftime_now
return jinja_env.from_string(chat_template)
@lru_cache
def _compile_inverse_template(inverse_template):
if version.parse(jinja2.__version__) < version.parse("3.1.0"):
raise ImportError(
"apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}."
)
jinja_env = SandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[jinja2.ext.loopcontrols])
jinja_env.globals["raise_exception"] = raise_exception
jinja_env.globals["finditer"] = finditer
jinja_env.globals["sort_by_group_start"] = sort_by_group_start
jinja_env.filters["tojson"] = tojson
jinja_env.globals["json_loads"] = json_loads
jinja_env.globals["IGNORECASE"] = re.IGNORECASE
jinja_env.globals["MULTILINE"] = re.MULTILINE
jinja_env.globals["DOTALL"] = re.DOTALL
return jinja_env.from_string(inverse_template)
# Functions for the Jinja environments below this line
def finditer(pattern, string, flags=0, add_tag=None, add_tag_from_group=None):
@dataclass
class NewMatchObject:
group: List[str]
group_starts: List[int]
tag: Optional[str]
if add_tag is not None and add_tag_from_group is not None:
raise jinja2.exceptions.TemplateError("Cannot use add_tag and add_tag_from_group at the same time!")
out = []
for match in re.finditer(pattern, string, flags=flags):
# groups() by default does not include group(0), the whole string
# so we add it in manually to make things match up
groups = [match.group(0)] + list(match.groups())
group_starts = [match.start(i) for i in range(len(groups))]
if add_tag_from_group is not None:
add_tag = groups[add_tag_from_group]
out.append(NewMatchObject(group=groups, group_starts=group_starts, tag=add_tag))
return out
def sort_by_group_start(matches, group_idx=0, group_idx_by_tag=None):
if group_idx_by_tag is None:
group_idx_by_tag = {}
def sort_key(match):
# Use the idx specific to this tag if present, or the global group_idx if not
idx = group_idx_by_tag.get(match.tag, group_idx)
return match.group_starts[idx]
return sorted(matches, key=sort_key)
def json_loads(string):
return json.loads(string)
def raise_exception(message):
raise jinja2.exceptions.TemplateError(message)
def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
# We override the built-in tojson filter because Jinja's default filter escapes HTML characters
# We also expose some options like custom indents and separators
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
def strftime_now(format_str):
return datetime.now().strftime(format_str)

View File

@ -49,7 +49,6 @@ from transformers.testing_utils import (
check_json_file_has_correct_format,
get_tests_dir,
is_pt_tf_cross_test,
require_jinja,
require_read_token,
require_tf,
require_tokenizers,
@ -1072,337 +1071,6 @@ class TokenizerTesterMixin:
if tokenizer.num_special_tokens_to_add(pair=True):
self.assertIn(None, output.sequence_ids())
@require_jinja
def test_chat_template(self):
dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}"
dummy_conversation = [
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "assistant message"},
]
expected_output = "systemsystem messageuseruser messageassistantassistant message"
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=False, return_dict=False
)
self.assertEqual(output, expected_output) # Test we can pass chat_template arg
# Check that no error raised when tokenize=True
output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=False
)
dict_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=True
)
self.assertEqual(dict_output["input_ids"], output) # Test return_dict behaviour matches
tokenizer.chat_template = dummy_template
self.assertEqual(tokenizer.chat_template, dummy_template) # Test property setter
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
self.assertEqual(output, expected_output) # Test chat_template attribute is used if no arg is passed
# Check that no error raised
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name)
tokenizer = tokenizer.from_pretrained(tmp_dir_name)
self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
self.assertEqual(output, expected_output) # Test output is the same after reloading
# Check that no error raised
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
@require_jinja
def test_chat_template_batched(self):
dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}"
dummy_conversations = [
[
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "assistant message"},
],
[
{"role": "system", "content": "system message 2"},
{"role": "user", "content": "user message 2"},
{"role": "assistant", "content": "assistant message 2"},
],
]
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
output = tokenizer.apply_chat_template(
dummy_conversations, chat_template=dummy_template, tokenize=False
)
self.assertEqual(
output,
[
"systemsystem messageuseruser messageassistantassistant message",
"systemsystem message 2useruser message 2assistantassistant message 2",
],
)
one_element_output = tokenizer.apply_chat_template(
dummy_conversations[:1], chat_template=dummy_template, tokenize=False
)
self.assertEqual(
one_element_output, ["systemsystem messageuseruser messageassistantassistant message"]
) # Assert that list structure is retained even with one element
tokenizer.apply_chat_template(
dummy_conversations, chat_template=dummy_template, tokenize=True
) # Check that no error raised
@require_jinja
def test_jinja_loopcontrols(self):
break_template = """
{%- for message in messages %}
{{- message.role + " " + message.content }}
{%- if loop.first %}
{%- break %}
{%- endif %}
{%- endfor %}""".strip()
dummy_conversation = [
{"role": "system", "content": "1"},
{"role": "user", "content": "2"},
{"role": "assistant", "content": "3"},
]
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
break_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=break_template, tokenize=False
)
self.assertEqual(break_output, "system 1") # Loop should break after first iter
@require_jinja
def test_jinja_strftime(self):
strftime_template = """{{- strftime_now("%Y-%m-%d") }}""".strip()
dummy_conversation = [
{"role": "system", "content": "1"},
{"role": "user", "content": "2"},
{"role": "assistant", "content": "3"},
]
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
strftime_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=strftime_template, tokenize=False
)
# Assert that we get a date formatted as expected
self.assertEqual(len(strftime_output), 10)
self.assertEqual(len(strftime_output.split("-")), 3)
@require_jinja
def test_chat_template_return_assistant_tokens_mask(self):
dummy_template = (
"{% for message in messages %}"
"{% if (message['role'] != 'assistant') %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% elif (message['role'] == 'assistant')%}"
"{{'<|im_start|>' + message['role'] + '\n'}}"
"{% generation %}"
"{{message['content'] + '<|im_end|>'}}"
"{% endgeneration %}"
"{{'\n'}}"
"{% endif %}"
"{% endfor %}"
)
conversations = [
[
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "start turn 1 assistant message. end turn 1"},
{"role": "user", "content": "user message 2"},
{"role": "assistant", "content": "start turn 2 assistant message. end turn 2"},
],
[
{"role": "system", "content": "system message 3"},
{"role": "user", "content": "user message 3"},
{"role": "assistant", "content": "start turn 3 assistant message. end turn 3"},
{"role": "user", "content": "user message 4"},
{"role": "assistant", "content": "start turn 4 assistant message. end turn 4"},
],
]
# These are the prefix and suffix strings of all the assistant messages. Used to find the assistant substring
# in the entire chat string, and then find the corresponding tokens in the tokenized output.
assistant_prefix_suffix = [
[("start turn 1", "end turn 1<|im_end|>"), ("start turn 2", "end turn 2<|im_end|>")],
[("start turn 3", "end turn 3<|im_end|>"), ("start turn 4", "end turn 4<|im_end|>")],
]
for tokenizer, pretrained_name, _ in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
if not self.test_rust_tokenizer:
self.skipTest(reason="No fast tokenizer defined")
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name)
# check batched
output = tokenizer_r.apply_chat_template(
conversations,
chat_template=dummy_template,
tokenize=True,
return_assistant_tokens_mask=True,
return_dict=True,
)
for i, conv in enumerate(conversations):
chat_string = tokenizer_r.apply_chat_template(
conversations[i], tokenize=False, chat_template=dummy_template
)
assistant_start = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][0][0]))
assistant_end = output.char_to_token(
i,
chat_string.index(assistant_prefix_suffix[i][0][1])
+ len(assistant_prefix_suffix[i][0][1])
- 1,
)
assistant_start2 = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][1][0]))
assistant_end2 = output.char_to_token(
i,
chat_string.index(assistant_prefix_suffix[i][1][1])
+ len(assistant_prefix_suffix[i][1][1])
- 1,
)
# assert 1 in first assistant message
self.assertEqual(
output["assistant_masks"][i][assistant_start : assistant_end + 1],
[1] * (assistant_end - assistant_start + 1),
)
# assert 1 second assistant message
self.assertEqual(
output["assistant_masks"][i][assistant_start2 : assistant_end2 + 1],
[1] * (assistant_end2 - assistant_start2 + 1),
)
# assert 0 in user/system indices
self.assertEqual(output["assistant_masks"][i][:assistant_start], [0] * assistant_start)
self.assertEqual(
output["assistant_masks"][i][assistant_end + 1 : assistant_start2],
[0] * (assistant_start2 - assistant_end - 1),
)
# check not batched
output = tokenizer_r.apply_chat_template(
conversations[0],
chat_template=dummy_template,
tokenize=True,
return_assistant_tokens_mask=True,
return_dict=True,
)
chat_string = tokenizer_r.apply_chat_template(
conversations[0], tokenize=False, chat_template=dummy_template
)
assistant_start = output.char_to_token(0, chat_string.index(assistant_prefix_suffix[0][0][0]))
assistant_end = output.char_to_token(
0, chat_string.index(assistant_prefix_suffix[0][0][1]) + len(assistant_prefix_suffix[0][0][1]) - 1
)
assistant_start2 = output.char_to_token(0, chat_string.index(assistant_prefix_suffix[0][1][0]))
assistant_end2 = output.char_to_token(
0, chat_string.index(assistant_prefix_suffix[0][1][1]) + len(assistant_prefix_suffix[0][1][1]) - 1
)
# assert 1 in assistant indices
self.assertEqual(
output["assistant_masks"][assistant_start : assistant_end + 1],
[1] * (assistant_end - assistant_start + 1),
)
self.assertEqual(
output["assistant_masks"][assistant_start2 : assistant_end2 + 1],
[1] * (assistant_end2 - assistant_start2 + 1),
)
# assert 0 in user/system indices
self.assertEqual(output["assistant_masks"][:assistant_start], [0] * assistant_start)
self.assertEqual(
output["assistant_masks"][assistant_end + 1 : assistant_start2],
[0] * (assistant_start2 - assistant_end - 1),
)
@require_jinja
def test_continue_final_message(self):
dummy_template = """
{%- for message in messages %}
{{- "<|im_start|>" + message['role'] + "\n" + message['content'] + "<|im_end|>" + "\n"}}
{%- endfor %}"""
dummy_conversation = [
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "assistant message"},
]
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=False
)
self.assertEqual(
output,
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message<|im_end|>\n",
)
prefill_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=True
)
# Assert that the final message is unterminated
self.assertEqual(
prefill_output,
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message",
)
@require_jinja
def test_chat_template_dict(self):
dummy_template_1 = "{{'a'}}"
dummy_template_2 = "{{'b'}}"
dummy_conversation = [
{"role": "user", "content": "user message"},
]
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
output1 = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template_1, tokenize=False
)
output1_via_dict = tokenizer.apply_chat_template(
dummy_conversation, chat_template="template1", tokenize=False
)
self.assertEqual(output1, output1_via_dict)
output2 = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template_2, tokenize=False
)
output2_via_dict = tokenizer.apply_chat_template(
dummy_conversation, chat_template="template2", tokenize=False
)
self.assertEqual(output2, output2_via_dict)
@require_jinja
def test_chat_template_dict_saving(self):
dummy_template_1 = "{{'a'}}"
dummy_template_2 = "{{'b'}}"
tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name)
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
# Assert that chat templates are correctly serialized as lists of dictionaries
self.assertEqual(
config_dict["chat_template"],
[{"name": "template1", "template": "{{'a'}}"}, {"name": "template2", "template": "{{'b'}}"}],
)
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
# Assert that the serialized list is correctly reconstructed as a single dict
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)
def test_number_of_added_tokens(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:

View File

@ -12,9 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import tempfile
import unittest
from pathlib import Path
from typing import List, Optional, Tuple, Union
from transformers import AutoTokenizer
from transformers.testing_utils import require_jinja
from transformers.utils import DocstringParsingException, TypeHintParsingException, get_json_schema
@ -474,3 +480,419 @@ class JsonSchemaGeneratorTest(unittest.TestCase):
},
}
self.assertEqual(schema["function"], expected_schema)
class ChatTemplateTest(unittest.TestCase):
def _get_tokenizer(self):
return AutoTokenizer.from_pretrained("hf-internal-testing/tiny-gpt2-with-chatml-template")
@require_jinja
def test_chat_template(self):
dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}"
dummy_conversation = [
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "assistant message"},
]
expected_output = "systemsystem messageuseruser messageassistantassistant message"
tokenizer = self._get_tokenizer()
output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=False, return_dict=False
)
self.assertEqual(output, expected_output) # Test we can pass chat_template arg
# Check that no error raised when tokenize=True
output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=False
)
dict_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=True
)
self.assertEqual(dict_output["input_ids"], output) # Test return_dict behaviour matches
tokenizer.chat_template = dummy_template
self.assertEqual(tokenizer.chat_template, dummy_template) # Test property setter
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
self.assertEqual(output, expected_output) # Test chat_template attribute is used if no arg is passed
# Check that no error raised
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name)
tokenizer = tokenizer.from_pretrained(tmp_dir_name)
self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
self.assertEqual(output, expected_output) # Test output is the same after reloading
# Check that no error raised
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
@require_jinja
def test_chat_template_batched(self):
dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}"
dummy_conversations = [
[
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "assistant message"},
],
[
{"role": "system", "content": "system message 2"},
{"role": "user", "content": "user message 2"},
{"role": "assistant", "content": "assistant message 2"},
],
]
tokenizer = self._get_tokenizer()
output = tokenizer.apply_chat_template(dummy_conversations, chat_template=dummy_template, tokenize=False)
self.assertEqual(
output,
[
"systemsystem messageuseruser messageassistantassistant message",
"systemsystem message 2useruser message 2assistantassistant message 2",
],
)
one_element_output = tokenizer.apply_chat_template(
dummy_conversations[:1], chat_template=dummy_template, tokenize=False
)
self.assertEqual(
one_element_output, ["systemsystem messageuseruser messageassistantassistant message"]
) # Assert that list structure is retained even with one element
tokenizer.apply_chat_template(
dummy_conversations, chat_template=dummy_template, tokenize=True
) # Check that no error raised
@require_jinja
def test_jinja_loopcontrols(self):
break_template = """
{%- for message in messages %}
{{- message.role + " " + message.content }}
{%- if loop.first %}
{%- break %}
{%- endif %}
{%- endfor %}""".strip()
dummy_conversation = [
{"role": "system", "content": "1"},
{"role": "user", "content": "2"},
{"role": "assistant", "content": "3"},
]
tokenizer = self._get_tokenizer()
break_output = tokenizer.apply_chat_template(dummy_conversation, chat_template=break_template, tokenize=False)
self.assertEqual(break_output, "system 1") # Loop should break after first iter
@require_jinja
def test_jinja_strftime(self):
strftime_template = """{{- strftime_now("%Y-%m-%d") }}""".strip()
dummy_conversation = [
{"role": "system", "content": "1"},
{"role": "user", "content": "2"},
{"role": "assistant", "content": "3"},
]
tokenizer = self._get_tokenizer()
strftime_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=strftime_template, tokenize=False
)
# Assert that we get a date formatted as expected
self.assertEqual(len(strftime_output), 10)
self.assertEqual(len(strftime_output.split("-")), 3)
@require_jinja
def test_chat_template_return_assistant_tokens_mask(self):
dummy_template = (
"{% for message in messages %}"
"{% if (message['role'] != 'assistant') %}"
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
"{% elif (message['role'] == 'assistant')%}"
"{{'<|im_start|>' + message['role'] + '\n'}}"
"{% generation %}"
"{{message['content'] + '<|im_end|>'}}"
"{% endgeneration %}"
"{{'\n'}}"
"{% endif %}"
"{% endfor %}"
)
conversations = [
[
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "start turn 1 assistant message. end turn 1"},
{"role": "user", "content": "user message 2"},
{"role": "assistant", "content": "start turn 2 assistant message. end turn 2"},
],
[
{"role": "system", "content": "system message 3"},
{"role": "user", "content": "user message 3"},
{"role": "assistant", "content": "start turn 3 assistant message. end turn 3"},
{"role": "user", "content": "user message 4"},
{"role": "assistant", "content": "start turn 4 assistant message. end turn 4"},
],
]
# These are the prefix and suffix strings of all the assistant messages. Used to find the assistant substring
# in the entire chat string, and then find the corresponding tokens in the tokenized output.
assistant_prefix_suffix = [
[("start turn 1", "end turn 1<|im_end|>"), ("start turn 2", "end turn 2<|im_end|>")],
[("start turn 3", "end turn 3<|im_end|>"), ("start turn 4", "end turn 4<|im_end|>")],
]
tokenizer = self._get_tokenizer()
# check batched
output = tokenizer.apply_chat_template(
conversations,
chat_template=dummy_template,
tokenize=True,
return_assistant_tokens_mask=True,
return_dict=True,
)
for i, conv in enumerate(conversations):
chat_string = tokenizer.apply_chat_template(conversations[i], tokenize=False, chat_template=dummy_template)
assistant_start = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][0][0]))
assistant_end = output.char_to_token(
i,
chat_string.index(assistant_prefix_suffix[i][0][1]) + len(assistant_prefix_suffix[i][0][1]) - 1,
)
assistant_start2 = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][1][0]))
assistant_end2 = output.char_to_token(
i,
chat_string.index(assistant_prefix_suffix[i][1][1]) + len(assistant_prefix_suffix[i][1][1]) - 1,
)
# assert 1 in first assistant message
self.assertEqual(
output["assistant_masks"][i][assistant_start : assistant_end + 1],
[1] * (assistant_end - assistant_start + 1),
)
# assert 1 second assistant message
self.assertEqual(
output["assistant_masks"][i][assistant_start2 : assistant_end2 + 1],
[1] * (assistant_end2 - assistant_start2 + 1),
)
# assert 0 in user/system indices
self.assertEqual(output["assistant_masks"][i][:assistant_start], [0] * assistant_start)
self.assertEqual(
output["assistant_masks"][i][assistant_end + 1 : assistant_start2],
[0] * (assistant_start2 - assistant_end - 1),
)
# check not batched
output = tokenizer.apply_chat_template(
conversations[0],
chat_template=dummy_template,
tokenize=True,
return_assistant_tokens_mask=True,
return_dict=True,
)
chat_string = tokenizer.apply_chat_template(conversations[0], tokenize=False, chat_template=dummy_template)
assistant_start = output.char_to_token(0, chat_string.index(assistant_prefix_suffix[0][0][0]))
assistant_end = output.char_to_token(
0, chat_string.index(assistant_prefix_suffix[0][0][1]) + len(assistant_prefix_suffix[0][0][1]) - 1
)
assistant_start2 = output.char_to_token(0, chat_string.index(assistant_prefix_suffix[0][1][0]))
assistant_end2 = output.char_to_token(
0, chat_string.index(assistant_prefix_suffix[0][1][1]) + len(assistant_prefix_suffix[0][1][1]) - 1
)
# assert 1 in assistant indices
self.assertEqual(
output["assistant_masks"][assistant_start : assistant_end + 1],
[1] * (assistant_end - assistant_start + 1),
)
self.assertEqual(
output["assistant_masks"][assistant_start2 : assistant_end2 + 1],
[1] * (assistant_end2 - assistant_start2 + 1),
)
# assert 0 in user/system indices
self.assertEqual(output["assistant_masks"][:assistant_start], [0] * assistant_start)
self.assertEqual(
output["assistant_masks"][assistant_end + 1 : assistant_start2],
[0] * (assistant_start2 - assistant_end - 1),
)
@require_jinja
def test_continue_final_message(self):
dummy_template = """
{%- for message in messages %}
{{- "<|im_start|>" + message['role'] + "\n" + message['content'] + "<|im_end|>" + "\n"}}
{%- endfor %}"""
dummy_conversation = [
{"role": "system", "content": "system message"},
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "assistant message"},
]
tokenizer = self._get_tokenizer()
output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=False
)
self.assertEqual(
output,
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message<|im_end|>\n",
)
prefill_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=True
)
# Assert that the final message is unterminated
self.assertEqual(
prefill_output,
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message",
)
@require_jinja
def test_chat_template_dict(self):
dummy_template_1 = "{{'a'}}"
dummy_template_2 = "{{'b'}}"
dummy_conversation = [
{"role": "user", "content": "user message"},
]
tokenizer = self._get_tokenizer()
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
output1 = tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template_1, tokenize=False)
output1_via_dict = tokenizer.apply_chat_template(dummy_conversation, chat_template="template1", tokenize=False)
self.assertEqual(output1, output1_via_dict)
output2 = tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template_2, tokenize=False)
output2_via_dict = tokenizer.apply_chat_template(dummy_conversation, chat_template="template2", tokenize=False)
self.assertEqual(output2, output2_via_dict)
@require_jinja
def test_chat_template_dict_saving(self):
dummy_template_1 = "{{'a'}}"
dummy_template_2 = "{{'b'}}"
tokenizer = self._get_tokenizer()
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
with tempfile.TemporaryDirectory() as tmp_dir_name:
tokenizer.save_pretrained(tmp_dir_name)
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
# Assert that chat templates are correctly serialized as lists of dictionaries
self.assertEqual(
config_dict["chat_template"],
[{"name": "template1", "template": "{{'a'}}"}, {"name": "template2", "template": "{{'b'}}"}],
)
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
# Assert that the serialized list is correctly reconstructed as a single dict
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)
class InverseChatTemplateTest(unittest.TestCase):
def _get_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained("Rocketknight1/tiny-gpt2-with-mistral-tool-template")
tokenizer.inverse_template = r"""
{%- set tools = finditer("\[AVAILABLE_TOOLS\] (.*?)\[\/AVAILABLE_TOOLS\]", chat, flags=16) %}
{%- set user_messages = finditer('(?:\[INST\] )(.+?)\[\/INST\]', chat, flags=16, add_tag="user") %}
{%- set asst_messages = finditer('(?:\[\/INST\]|\[\/TOOL_RESULTS\]) (.+?)<\/s>', chat, flags=16, add_tag="assistant") %}
{%- set available_tools = finditer('\[AVAILABLE_TOOLS\] (.*?)\[\/AVAILABLE_TOOLS\]', chat, flags=16, add_tag="available_tools") %}
{%- set tool_calls = finditer('\[TOOL_CALLS\] (.+?\])<\/s>', chat, flags=16, add_tag="tool_calls") %}
{%- set tool_results = finditer('\[TOOL_RESULTS\] (.+?)\[\/TOOL_RESULTS\]', chat, flags=16, add_tag="tool") %}
{%- set combined = sort_by_group_start(user_messages + asst_messages + tool_calls + tool_results, group_idx=1) %}
{{- '{' }}
{%- if tools | length > 0 %}
{%- set tools = json_loads(tools[0].group[1]) %}
{{- '"tools": ' }}
{{- tools | tojson }}
{{- ', ' }}
{%- endif %}
{{- '"messages": [' }}
{%- for match in combined %}
{%- if match.tag == 'assistant' or match.tag == 'user' %}
{%- set message_dict = dict(role=match.tag, content=match.group[1]) %}
{%- elif match.tag == "tool_calls" %}
{%- set tool_call = json_loads(match.group[1])[0] %}
{%- set tool_call = dict(type="function", id=tool_call["id"]|string, function=dict(name=tool_call.name, arguments=tool_call.arguments)) %}
{%- set message_dict = dict(role="assistant", tool_calls=[tool_call]) %}
{%- elif match.tag == "tool" %}
{%- set base_dict = json_loads(match.group[1]) %}
{%- set message_dict = dict(role=match.tag, content=base_dict.content|string, tool_call_id=base_dict.call_id|string) %}
{%- endif %}
{{- message_dict | tojson }}
{%- if not loop.last %}
{{- ", " }}
{%- else %}
{{- "]" }}
{%- endif %}
{%- endfor %}
{{- "}" }}
""".strip()
return tokenizer
@require_jinja
def test_inverse_chat_template_save_load(self):
# Pick a tokenizer with no chat template or reverse template
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
self.assertIsNone(tokenizer.inverse_template)
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir = Path(tmp_dir)
tokenizer.save_pretrained(tmp_dir / "tokenizer")
tokenizer_config = json.load(open(tmp_dir / "tokenizer/tokenizer_config.json"))
self.assertNotIn("inverse_template", tokenizer_config)
self.assertFalse(Path(tmp_dir, "tokenizer", "inverse_template.jinja").is_file())
tokenizer.inverse_template = "aaaa"
tokenizer.save_pretrained(tmp_dir / "tokenizer_with_inverse_template")
tokenizer_config = json.load(open(tmp_dir / "tokenizer_with_inverse_template/tokenizer_config.json"))
self.assertNotIn("inverse_template", tokenizer_config) # Make sure it's separate
self.assertTrue(Path(tmp_dir, "tokenizer", "inverse_template.jinja").is_file())
reloaded_tokenizer = AutoTokenizer.from_pretrained(str(tmp_dir / "tokenizer_with_inverse_template"))
self.assertEqual(reloaded_tokenizer.inverse_template, "aaaa")
@require_jinja
def test_simple_chat_inversion(self):
tokenizer = self._get_tokenizer()
chat = [
{"role": "user", "content": "user message"},
{"role": "assistant", "content": "assistant message"},
]
chat_str = tokenizer.apply_chat_template(chat, tokenize=False)
inverted_chat = tokenizer.apply_inverse_template(chat_str)
self.assertEqual(chat, inverted_chat["messages"])
@require_jinja
def test_chat_inversion_with_tool_calls(self):
tokenizer = self._get_tokenizer()
chat = [
{"role": "user", "content": "user message"},
{
"role": "assistant",
"tool_calls": [
{
"type": "function",
"id": "9Ae3bDc2F",
"function": {
"name": "get_current_temperature",
"arguments": {"location": "Paris, France", "unit": "celsius"},
},
}
],
},
{"role": "tool", "content": "22.0", "tool_call_id": "9Ae3bDc2F"},
{"role": "assistant", "content": "assistant message"},
]
chat_str = tokenizer.apply_chat_template(chat, tokenize=False)
inverted_chat = tokenizer.apply_inverse_template(chat_str)
self.assertEqual(chat, inverted_chat["messages"])
@require_jinja
def test_tool_extraction(self):
tokenizer = self._get_tokenizer()
chat = [
{"role": "user", "content": "user message"},
]
def tool_fn(location: str, unit: str):
"""
Get the current temperature
Args:
location: The location to get the temperature from
unit: The unit to return the temperature in
"""
return 22.0
chat_str = tokenizer.apply_chat_template(chat, tools=[tool_fn], tokenize=False)
inverted_chat = tokenizer.apply_inverse_template(chat_str)
self.assertEqual(inverted_chat["messages"], chat)
self.assertEqual(inverted_chat["tools"], [get_json_schema(tool_fn)])