mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
16 Commits
v4.56.0
...
reverse_te
Author | SHA1 | Date | |
---|---|---|---|
73e0b95cab | |||
b4406a26c1 | |||
fd6f269052 | |||
7676abd496 | |||
fc4482558b | |||
fc51e7a04c | |||
d6b7914b77 | |||
b51b330221 | |||
39b977995a | |||
550a6f1fad | |||
6fd52c6f67 | |||
eec1c2033b | |||
07a0a752c8 | |||
481109b2df | |||
d9f478f706 | |||
47a3ca55de |
@ -65,7 +65,11 @@ from .utils import (
|
||||
requires_backends,
|
||||
to_py_obj,
|
||||
)
|
||||
from .utils.chat_template_utils import _compile_jinja_template, _render_with_assistant_indices
|
||||
from .utils.chat_template_utils import (
|
||||
_compile_inverse_template,
|
||||
_compile_jinja_template,
|
||||
_render_with_assistant_indices,
|
||||
)
|
||||
from .utils.import_utils import PROTOBUF_IMPORT_ERROR
|
||||
|
||||
|
||||
@ -145,6 +149,7 @@ AudioInput = Union["np.ndarray", "torch.Tensor", List["np.ndarray"], List["torch
|
||||
SPECIAL_TOKENS_MAP_FILE = "special_tokens_map.json"
|
||||
ADDED_TOKENS_FILE = "added_tokens.json"
|
||||
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
|
||||
INVERSE_TEMPLATE_FILE = "inverse_template.jinja"
|
||||
|
||||
# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file
|
||||
FULL_TOKENIZER_FILE = "tokenizer.json"
|
||||
@ -1631,6 +1636,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
# we reconstruct that into a single dict while loading them.
|
||||
self.chat_template = {template["name"]: template["template"] for template in self.chat_template}
|
||||
|
||||
self.inverse_template = kwargs.pop("inverse_template", None)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
@ -1916,6 +1923,24 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
else:
|
||||
return rendered
|
||||
|
||||
def apply_inverse_template(self, chat: str, inverse_template: Optional[str] = None, skip_json_load: bool = False):
|
||||
if inverse_template is None:
|
||||
if self.inverse_template is not None:
|
||||
inverse_template = self.inverse_template
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot use apply_inverse_template() because tokenizer.inverse_template is not set! Please set "
|
||||
"the tokenizer.inverse_template attribute to a valid Jinja template string."
|
||||
)
|
||||
# Compilation function uses a cache to avoid recompiling the same template
|
||||
compiled_template = _compile_inverse_template(inverse_template)
|
||||
|
||||
template_out = compiled_template.render(chat=chat)
|
||||
if skip_json_load:
|
||||
return template_out
|
||||
else:
|
||||
return json.loads(template_out)
|
||||
|
||||
def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[List[Dict]] = None) -> str:
|
||||
"""
|
||||
Retrieve the chat template string used for tokenizing chat messages. This template is used
|
||||
@ -2121,6 +2146,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
"tokenizer_config_file": TOKENIZER_CONFIG_FILE,
|
||||
# tokenizer_file used to initialize a slow from a fast. Properly copy the `addedTokens` instead of adding in random orders
|
||||
"tokenizer_file": FULL_TOKENIZER_FILE,
|
||||
"inverse_template": INVERSE_TEMPLATE_FILE,
|
||||
}
|
||||
vocab_files = {**cls.vocab_files_names, **additional_files_names}
|
||||
if "tokenizer_file" in vocab_files:
|
||||
@ -2241,6 +2267,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
from_slow = kwargs.get("from_slow", False)
|
||||
gguf_file = kwargs.get("gguf_file", None)
|
||||
has_tokenizer_file = resolved_vocab_files.get("tokenizer_file", None) is not None
|
||||
inverse_template_file = resolved_vocab_files.pop("inverse_template", None)
|
||||
|
||||
# If one passes a GGUF file path to `gguf_file` there is no need for this check as the tokenizer will be
|
||||
# loaded directly from the GGUF file.
|
||||
@ -2342,6 +2369,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
f" from is '{cls.__name__}'."
|
||||
)
|
||||
|
||||
if inverse_template_file is not None:
|
||||
with open(inverse_template_file) as chat_template_handle:
|
||||
init_kwargs["inverse_template"] = chat_template_handle.read()
|
||||
|
||||
# Update with newly provided kwargs
|
||||
init_kwargs.update(kwargs)
|
||||
|
||||
@ -2577,6 +2608,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + TOKENIZER_CONFIG_FILE
|
||||
)
|
||||
|
||||
inverse_chat_template_file = os.path.join(
|
||||
save_directory, (filename_prefix + "-" if filename_prefix else "") + INVERSE_TEMPLATE_FILE
|
||||
)
|
||||
|
||||
tokenizer_config = copy.deepcopy(self.init_kwargs)
|
||||
|
||||
# Let's save the init kwargs
|
||||
@ -2599,6 +2634,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
else:
|
||||
tokenizer_config["chat_template"] = self.chat_template
|
||||
|
||||
if self.inverse_template is not None:
|
||||
with open(inverse_chat_template_file, "w", encoding="utf-8") as f:
|
||||
f.write(self.inverse_template)
|
||||
|
||||
if len(self.init_inputs) > 0:
|
||||
tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)
|
||||
for file_id in self.vocab_files_names.keys():
|
||||
|
@ -16,6 +16,7 @@ import inspect
|
||||
import json
|
||||
import re
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin, get_type_hints
|
||||
@ -28,7 +29,7 @@ from .import_utils import is_jinja_available, is_torch_available, is_vision_avai
|
||||
if is_jinja_available():
|
||||
import jinja2
|
||||
from jinja2.ext import Extension
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment, SandboxedEnvironment
|
||||
else:
|
||||
jinja2 = None
|
||||
|
||||
@ -406,17 +407,6 @@ def _compile_jinja_template(chat_template):
|
||||
"apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}."
|
||||
)
|
||||
|
||||
def raise_exception(message):
|
||||
raise jinja2.exceptions.TemplateError(message)
|
||||
|
||||
def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
|
||||
# We override the built-in tojson filter because Jinja's default filter escapes HTML characters
|
||||
# We also expose some options like custom indents and separators
|
||||
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
|
||||
|
||||
def strftime_now(format):
|
||||
return datetime.now().strftime(format)
|
||||
|
||||
jinja_env = ImmutableSandboxedEnvironment(
|
||||
trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols]
|
||||
)
|
||||
@ -424,3 +414,76 @@ def _compile_jinja_template(chat_template):
|
||||
jinja_env.globals["raise_exception"] = raise_exception
|
||||
jinja_env.globals["strftime_now"] = strftime_now
|
||||
return jinja_env.from_string(chat_template)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _compile_inverse_template(inverse_template):
|
||||
if version.parse(jinja2.__version__) < version.parse("3.1.0"):
|
||||
raise ImportError(
|
||||
"apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}."
|
||||
)
|
||||
|
||||
jinja_env = SandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[jinja2.ext.loopcontrols])
|
||||
jinja_env.globals["raise_exception"] = raise_exception
|
||||
jinja_env.globals["finditer"] = finditer
|
||||
jinja_env.globals["sort_by_group_start"] = sort_by_group_start
|
||||
jinja_env.filters["tojson"] = tojson
|
||||
jinja_env.globals["json_loads"] = json_loads
|
||||
jinja_env.globals["IGNORECASE"] = re.IGNORECASE
|
||||
jinja_env.globals["MULTILINE"] = re.MULTILINE
|
||||
jinja_env.globals["DOTALL"] = re.DOTALL
|
||||
return jinja_env.from_string(inverse_template)
|
||||
|
||||
|
||||
# Functions for the Jinja environments below this line
|
||||
|
||||
|
||||
def finditer(pattern, string, flags=0, add_tag=None, add_tag_from_group=None):
|
||||
@dataclass
|
||||
class NewMatchObject:
|
||||
group: List[str]
|
||||
group_starts: List[int]
|
||||
tag: Optional[str]
|
||||
|
||||
if add_tag is not None and add_tag_from_group is not None:
|
||||
raise jinja2.exceptions.TemplateError("Cannot use add_tag and add_tag_from_group at the same time!")
|
||||
out = []
|
||||
for match in re.finditer(pattern, string, flags=flags):
|
||||
# groups() by default does not include group(0), the whole string
|
||||
# so we add it in manually to make things match up
|
||||
groups = [match.group(0)] + list(match.groups())
|
||||
group_starts = [match.start(i) for i in range(len(groups))]
|
||||
if add_tag_from_group is not None:
|
||||
add_tag = groups[add_tag_from_group]
|
||||
out.append(NewMatchObject(group=groups, group_starts=group_starts, tag=add_tag))
|
||||
return out
|
||||
|
||||
|
||||
def sort_by_group_start(matches, group_idx=0, group_idx_by_tag=None):
|
||||
if group_idx_by_tag is None:
|
||||
group_idx_by_tag = {}
|
||||
|
||||
def sort_key(match):
|
||||
# Use the idx specific to this tag if present, or the global group_idx if not
|
||||
idx = group_idx_by_tag.get(match.tag, group_idx)
|
||||
return match.group_starts[idx]
|
||||
|
||||
return sorted(matches, key=sort_key)
|
||||
|
||||
|
||||
def json_loads(string):
|
||||
return json.loads(string)
|
||||
|
||||
|
||||
def raise_exception(message):
|
||||
raise jinja2.exceptions.TemplateError(message)
|
||||
|
||||
|
||||
def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
|
||||
# We override the built-in tojson filter because Jinja's default filter escapes HTML characters
|
||||
# We also expose some options like custom indents and separators
|
||||
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
|
||||
|
||||
|
||||
def strftime_now(format_str):
|
||||
return datetime.now().strftime(format_str)
|
||||
|
@ -49,7 +49,6 @@ from transformers.testing_utils import (
|
||||
check_json_file_has_correct_format,
|
||||
get_tests_dir,
|
||||
is_pt_tf_cross_test,
|
||||
require_jinja,
|
||||
require_read_token,
|
||||
require_tf,
|
||||
require_tokenizers,
|
||||
@ -1072,337 +1071,6 @@ class TokenizerTesterMixin:
|
||||
if tokenizer.num_special_tokens_to_add(pair=True):
|
||||
self.assertIn(None, output.sequence_ids())
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template(self):
|
||||
dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}"
|
||||
dummy_conversation = [
|
||||
{"role": "system", "content": "system message"},
|
||||
{"role": "user", "content": "user message"},
|
||||
{"role": "assistant", "content": "assistant message"},
|
||||
]
|
||||
expected_output = "systemsystem messageuseruser messageassistantassistant message"
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template, tokenize=False, return_dict=False
|
||||
)
|
||||
self.assertEqual(output, expected_output) # Test we can pass chat_template arg
|
||||
|
||||
# Check that no error raised when tokenize=True
|
||||
output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=False
|
||||
)
|
||||
dict_output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=True
|
||||
)
|
||||
self.assertEqual(dict_output["input_ids"], output) # Test return_dict behaviour matches
|
||||
|
||||
tokenizer.chat_template = dummy_template
|
||||
self.assertEqual(tokenizer.chat_template, dummy_template) # Test property setter
|
||||
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
|
||||
self.assertEqual(output, expected_output) # Test chat_template attribute is used if no arg is passed
|
||||
# Check that no error raised
|
||||
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
tokenizer.save_pretrained(tmp_dir_name)
|
||||
tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
|
||||
self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted
|
||||
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
|
||||
self.assertEqual(output, expected_output) # Test output is the same after reloading
|
||||
# Check that no error raised
|
||||
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template_batched(self):
|
||||
dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}"
|
||||
dummy_conversations = [
|
||||
[
|
||||
{"role": "system", "content": "system message"},
|
||||
{"role": "user", "content": "user message"},
|
||||
{"role": "assistant", "content": "assistant message"},
|
||||
],
|
||||
[
|
||||
{"role": "system", "content": "system message 2"},
|
||||
{"role": "user", "content": "user message 2"},
|
||||
{"role": "assistant", "content": "assistant message 2"},
|
||||
],
|
||||
]
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
output = tokenizer.apply_chat_template(
|
||||
dummy_conversations, chat_template=dummy_template, tokenize=False
|
||||
)
|
||||
self.assertEqual(
|
||||
output,
|
||||
[
|
||||
"systemsystem messageuseruser messageassistantassistant message",
|
||||
"systemsystem message 2useruser message 2assistantassistant message 2",
|
||||
],
|
||||
)
|
||||
one_element_output = tokenizer.apply_chat_template(
|
||||
dummy_conversations[:1], chat_template=dummy_template, tokenize=False
|
||||
)
|
||||
self.assertEqual(
|
||||
one_element_output, ["systemsystem messageuseruser messageassistantassistant message"]
|
||||
) # Assert that list structure is retained even with one element
|
||||
tokenizer.apply_chat_template(
|
||||
dummy_conversations, chat_template=dummy_template, tokenize=True
|
||||
) # Check that no error raised
|
||||
|
||||
@require_jinja
|
||||
def test_jinja_loopcontrols(self):
|
||||
break_template = """
|
||||
{%- for message in messages %}
|
||||
{{- message.role + " " + message.content }}
|
||||
{%- if loop.first %}
|
||||
{%- break %}
|
||||
{%- endif %}
|
||||
{%- endfor %}""".strip()
|
||||
|
||||
dummy_conversation = [
|
||||
{"role": "system", "content": "1"},
|
||||
{"role": "user", "content": "2"},
|
||||
{"role": "assistant", "content": "3"},
|
||||
]
|
||||
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
break_output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=break_template, tokenize=False
|
||||
)
|
||||
self.assertEqual(break_output, "system 1") # Loop should break after first iter
|
||||
|
||||
@require_jinja
|
||||
def test_jinja_strftime(self):
|
||||
strftime_template = """{{- strftime_now("%Y-%m-%d") }}""".strip()
|
||||
|
||||
dummy_conversation = [
|
||||
{"role": "system", "content": "1"},
|
||||
{"role": "user", "content": "2"},
|
||||
{"role": "assistant", "content": "3"},
|
||||
]
|
||||
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
strftime_output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=strftime_template, tokenize=False
|
||||
)
|
||||
|
||||
# Assert that we get a date formatted as expected
|
||||
self.assertEqual(len(strftime_output), 10)
|
||||
self.assertEqual(len(strftime_output.split("-")), 3)
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template_return_assistant_tokens_mask(self):
|
||||
dummy_template = (
|
||||
"{% for message in messages %}"
|
||||
"{% if (message['role'] != 'assistant') %}"
|
||||
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
|
||||
"{% elif (message['role'] == 'assistant')%}"
|
||||
"{{'<|im_start|>' + message['role'] + '\n'}}"
|
||||
"{% generation %}"
|
||||
"{{message['content'] + '<|im_end|>'}}"
|
||||
"{% endgeneration %}"
|
||||
"{{'\n'}}"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
)
|
||||
conversations = [
|
||||
[
|
||||
{"role": "system", "content": "system message"},
|
||||
{"role": "user", "content": "user message"},
|
||||
{"role": "assistant", "content": "start turn 1 assistant message. end turn 1"},
|
||||
{"role": "user", "content": "user message 2"},
|
||||
{"role": "assistant", "content": "start turn 2 assistant message. end turn 2"},
|
||||
],
|
||||
[
|
||||
{"role": "system", "content": "system message 3"},
|
||||
{"role": "user", "content": "user message 3"},
|
||||
{"role": "assistant", "content": "start turn 3 assistant message. end turn 3"},
|
||||
{"role": "user", "content": "user message 4"},
|
||||
{"role": "assistant", "content": "start turn 4 assistant message. end turn 4"},
|
||||
],
|
||||
]
|
||||
|
||||
# These are the prefix and suffix strings of all the assistant messages. Used to find the assistant substring
|
||||
# in the entire chat string, and then find the corresponding tokens in the tokenized output.
|
||||
assistant_prefix_suffix = [
|
||||
[("start turn 1", "end turn 1<|im_end|>"), ("start turn 2", "end turn 2<|im_end|>")],
|
||||
[("start turn 3", "end turn 3<|im_end|>"), ("start turn 4", "end turn 4<|im_end|>")],
|
||||
]
|
||||
for tokenizer, pretrained_name, _ in self.tokenizers_list:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
if not self.test_rust_tokenizer:
|
||||
self.skipTest(reason="No fast tokenizer defined")
|
||||
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name)
|
||||
|
||||
# check batched
|
||||
output = tokenizer_r.apply_chat_template(
|
||||
conversations,
|
||||
chat_template=dummy_template,
|
||||
tokenize=True,
|
||||
return_assistant_tokens_mask=True,
|
||||
return_dict=True,
|
||||
)
|
||||
for i, conv in enumerate(conversations):
|
||||
chat_string = tokenizer_r.apply_chat_template(
|
||||
conversations[i], tokenize=False, chat_template=dummy_template
|
||||
)
|
||||
assistant_start = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][0][0]))
|
||||
assistant_end = output.char_to_token(
|
||||
i,
|
||||
chat_string.index(assistant_prefix_suffix[i][0][1])
|
||||
+ len(assistant_prefix_suffix[i][0][1])
|
||||
- 1,
|
||||
)
|
||||
|
||||
assistant_start2 = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][1][0]))
|
||||
assistant_end2 = output.char_to_token(
|
||||
i,
|
||||
chat_string.index(assistant_prefix_suffix[i][1][1])
|
||||
+ len(assistant_prefix_suffix[i][1][1])
|
||||
- 1,
|
||||
)
|
||||
|
||||
# assert 1 in first assistant message
|
||||
self.assertEqual(
|
||||
output["assistant_masks"][i][assistant_start : assistant_end + 1],
|
||||
[1] * (assistant_end - assistant_start + 1),
|
||||
)
|
||||
# assert 1 second assistant message
|
||||
self.assertEqual(
|
||||
output["assistant_masks"][i][assistant_start2 : assistant_end2 + 1],
|
||||
[1] * (assistant_end2 - assistant_start2 + 1),
|
||||
)
|
||||
|
||||
# assert 0 in user/system indices
|
||||
self.assertEqual(output["assistant_masks"][i][:assistant_start], [0] * assistant_start)
|
||||
self.assertEqual(
|
||||
output["assistant_masks"][i][assistant_end + 1 : assistant_start2],
|
||||
[0] * (assistant_start2 - assistant_end - 1),
|
||||
)
|
||||
|
||||
# check not batched
|
||||
output = tokenizer_r.apply_chat_template(
|
||||
conversations[0],
|
||||
chat_template=dummy_template,
|
||||
tokenize=True,
|
||||
return_assistant_tokens_mask=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
chat_string = tokenizer_r.apply_chat_template(
|
||||
conversations[0], tokenize=False, chat_template=dummy_template
|
||||
)
|
||||
assistant_start = output.char_to_token(0, chat_string.index(assistant_prefix_suffix[0][0][0]))
|
||||
assistant_end = output.char_to_token(
|
||||
0, chat_string.index(assistant_prefix_suffix[0][0][1]) + len(assistant_prefix_suffix[0][0][1]) - 1
|
||||
)
|
||||
assistant_start2 = output.char_to_token(0, chat_string.index(assistant_prefix_suffix[0][1][0]))
|
||||
assistant_end2 = output.char_to_token(
|
||||
0, chat_string.index(assistant_prefix_suffix[0][1][1]) + len(assistant_prefix_suffix[0][1][1]) - 1
|
||||
)
|
||||
|
||||
# assert 1 in assistant indices
|
||||
self.assertEqual(
|
||||
output["assistant_masks"][assistant_start : assistant_end + 1],
|
||||
[1] * (assistant_end - assistant_start + 1),
|
||||
)
|
||||
self.assertEqual(
|
||||
output["assistant_masks"][assistant_start2 : assistant_end2 + 1],
|
||||
[1] * (assistant_end2 - assistant_start2 + 1),
|
||||
)
|
||||
|
||||
# assert 0 in user/system indices
|
||||
self.assertEqual(output["assistant_masks"][:assistant_start], [0] * assistant_start)
|
||||
self.assertEqual(
|
||||
output["assistant_masks"][assistant_end + 1 : assistant_start2],
|
||||
[0] * (assistant_start2 - assistant_end - 1),
|
||||
)
|
||||
|
||||
@require_jinja
|
||||
def test_continue_final_message(self):
|
||||
dummy_template = """
|
||||
{%- for message in messages %}
|
||||
{{- "<|im_start|>" + message['role'] + "\n" + message['content'] + "<|im_end|>" + "\n"}}
|
||||
{%- endfor %}"""
|
||||
dummy_conversation = [
|
||||
{"role": "system", "content": "system message"},
|
||||
{"role": "user", "content": "user message"},
|
||||
{"role": "assistant", "content": "assistant message"},
|
||||
]
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=False
|
||||
)
|
||||
self.assertEqual(
|
||||
output,
|
||||
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message<|im_end|>\n",
|
||||
)
|
||||
prefill_output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=True
|
||||
)
|
||||
# Assert that the final message is unterminated
|
||||
self.assertEqual(
|
||||
prefill_output,
|
||||
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message",
|
||||
)
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template_dict(self):
|
||||
dummy_template_1 = "{{'a'}}"
|
||||
dummy_template_2 = "{{'b'}}"
|
||||
dummy_conversation = [
|
||||
{"role": "user", "content": "user message"},
|
||||
]
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
|
||||
output1 = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template_1, tokenize=False
|
||||
)
|
||||
output1_via_dict = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template="template1", tokenize=False
|
||||
)
|
||||
self.assertEqual(output1, output1_via_dict)
|
||||
output2 = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template_2, tokenize=False
|
||||
)
|
||||
output2_via_dict = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template="template2", tokenize=False
|
||||
)
|
||||
self.assertEqual(output2, output2_via_dict)
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template_dict_saving(self):
|
||||
dummy_template_1 = "{{'a'}}"
|
||||
dummy_template_2 = "{{'b'}}"
|
||||
tokenizers = self.get_tokenizers()
|
||||
for tokenizer in tokenizers:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__}"):
|
||||
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
tokenizer.save_pretrained(tmp_dir_name)
|
||||
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
|
||||
# Assert that chat templates are correctly serialized as lists of dictionaries
|
||||
self.assertEqual(
|
||||
config_dict["chat_template"],
|
||||
[{"name": "template1", "template": "{{'a'}}"}, {"name": "template2", "template": "{{'b'}}"}],
|
||||
)
|
||||
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
# Assert that the serialized list is correctly reconstructed as a single dict
|
||||
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)
|
||||
|
||||
def test_number_of_added_tokens(self):
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
for tokenizer in tokenizers:
|
||||
|
@ -12,9 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.testing_utils import require_jinja
|
||||
from transformers.utils import DocstringParsingException, TypeHintParsingException, get_json_schema
|
||||
|
||||
|
||||
@ -474,3 +480,419 @@ class JsonSchemaGeneratorTest(unittest.TestCase):
|
||||
},
|
||||
}
|
||||
self.assertEqual(schema["function"], expected_schema)
|
||||
|
||||
|
||||
class ChatTemplateTest(unittest.TestCase):
|
||||
def _get_tokenizer(self):
|
||||
return AutoTokenizer.from_pretrained("hf-internal-testing/tiny-gpt2-with-chatml-template")
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template(self):
|
||||
dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}"
|
||||
dummy_conversation = [
|
||||
{"role": "system", "content": "system message"},
|
||||
{"role": "user", "content": "user message"},
|
||||
{"role": "assistant", "content": "assistant message"},
|
||||
]
|
||||
expected_output = "systemsystem messageuseruser messageassistantassistant message"
|
||||
tokenizer = self._get_tokenizer()
|
||||
output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template, tokenize=False, return_dict=False
|
||||
)
|
||||
self.assertEqual(output, expected_output) # Test we can pass chat_template arg
|
||||
|
||||
# Check that no error raised when tokenize=True
|
||||
output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=False
|
||||
)
|
||||
dict_output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template, tokenize=True, return_dict=True
|
||||
)
|
||||
self.assertEqual(dict_output["input_ids"], output) # Test return_dict behaviour matches
|
||||
|
||||
tokenizer.chat_template = dummy_template
|
||||
self.assertEqual(tokenizer.chat_template, dummy_template) # Test property setter
|
||||
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
|
||||
self.assertEqual(output, expected_output) # Test chat_template attribute is used if no arg is passed
|
||||
# Check that no error raised
|
||||
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
tokenizer.save_pretrained(tmp_dir_name)
|
||||
tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
|
||||
self.assertEqual(tokenizer.chat_template, dummy_template) # Test template has persisted
|
||||
output = tokenizer.apply_chat_template(dummy_conversation, tokenize=False, return_dict=False)
|
||||
self.assertEqual(output, expected_output) # Test output is the same after reloading
|
||||
# Check that no error raised
|
||||
tokenizer.apply_chat_template(dummy_conversation, tokenize=True, return_dict=False)
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template_batched(self):
|
||||
dummy_template = "{% for message in messages %}{{message['role'] + message['content']}}{% endfor %}"
|
||||
dummy_conversations = [
|
||||
[
|
||||
{"role": "system", "content": "system message"},
|
||||
{"role": "user", "content": "user message"},
|
||||
{"role": "assistant", "content": "assistant message"},
|
||||
],
|
||||
[
|
||||
{"role": "system", "content": "system message 2"},
|
||||
{"role": "user", "content": "user message 2"},
|
||||
{"role": "assistant", "content": "assistant message 2"},
|
||||
],
|
||||
]
|
||||
tokenizer = self._get_tokenizer()
|
||||
output = tokenizer.apply_chat_template(dummy_conversations, chat_template=dummy_template, tokenize=False)
|
||||
self.assertEqual(
|
||||
output,
|
||||
[
|
||||
"systemsystem messageuseruser messageassistantassistant message",
|
||||
"systemsystem message 2useruser message 2assistantassistant message 2",
|
||||
],
|
||||
)
|
||||
one_element_output = tokenizer.apply_chat_template(
|
||||
dummy_conversations[:1], chat_template=dummy_template, tokenize=False
|
||||
)
|
||||
self.assertEqual(
|
||||
one_element_output, ["systemsystem messageuseruser messageassistantassistant message"]
|
||||
) # Assert that list structure is retained even with one element
|
||||
tokenizer.apply_chat_template(
|
||||
dummy_conversations, chat_template=dummy_template, tokenize=True
|
||||
) # Check that no error raised
|
||||
|
||||
@require_jinja
|
||||
def test_jinja_loopcontrols(self):
|
||||
break_template = """
|
||||
{%- for message in messages %}
|
||||
{{- message.role + " " + message.content }}
|
||||
{%- if loop.first %}
|
||||
{%- break %}
|
||||
{%- endif %}
|
||||
{%- endfor %}""".strip()
|
||||
|
||||
dummy_conversation = [
|
||||
{"role": "system", "content": "1"},
|
||||
{"role": "user", "content": "2"},
|
||||
{"role": "assistant", "content": "3"},
|
||||
]
|
||||
|
||||
tokenizer = self._get_tokenizer()
|
||||
break_output = tokenizer.apply_chat_template(dummy_conversation, chat_template=break_template, tokenize=False)
|
||||
self.assertEqual(break_output, "system 1") # Loop should break after first iter
|
||||
|
||||
@require_jinja
|
||||
def test_jinja_strftime(self):
|
||||
strftime_template = """{{- strftime_now("%Y-%m-%d") }}""".strip()
|
||||
|
||||
dummy_conversation = [
|
||||
{"role": "system", "content": "1"},
|
||||
{"role": "user", "content": "2"},
|
||||
{"role": "assistant", "content": "3"},
|
||||
]
|
||||
|
||||
tokenizer = self._get_tokenizer()
|
||||
strftime_output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=strftime_template, tokenize=False
|
||||
)
|
||||
|
||||
# Assert that we get a date formatted as expected
|
||||
self.assertEqual(len(strftime_output), 10)
|
||||
self.assertEqual(len(strftime_output.split("-")), 3)
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template_return_assistant_tokens_mask(self):
|
||||
dummy_template = (
|
||||
"{% for message in messages %}"
|
||||
"{% if (message['role'] != 'assistant') %}"
|
||||
"{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
|
||||
"{% elif (message['role'] == 'assistant')%}"
|
||||
"{{'<|im_start|>' + message['role'] + '\n'}}"
|
||||
"{% generation %}"
|
||||
"{{message['content'] + '<|im_end|>'}}"
|
||||
"{% endgeneration %}"
|
||||
"{{'\n'}}"
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
)
|
||||
conversations = [
|
||||
[
|
||||
{"role": "system", "content": "system message"},
|
||||
{"role": "user", "content": "user message"},
|
||||
{"role": "assistant", "content": "start turn 1 assistant message. end turn 1"},
|
||||
{"role": "user", "content": "user message 2"},
|
||||
{"role": "assistant", "content": "start turn 2 assistant message. end turn 2"},
|
||||
],
|
||||
[
|
||||
{"role": "system", "content": "system message 3"},
|
||||
{"role": "user", "content": "user message 3"},
|
||||
{"role": "assistant", "content": "start turn 3 assistant message. end turn 3"},
|
||||
{"role": "user", "content": "user message 4"},
|
||||
{"role": "assistant", "content": "start turn 4 assistant message. end turn 4"},
|
||||
],
|
||||
]
|
||||
|
||||
# These are the prefix and suffix strings of all the assistant messages. Used to find the assistant substring
|
||||
# in the entire chat string, and then find the corresponding tokens in the tokenized output.
|
||||
assistant_prefix_suffix = [
|
||||
[("start turn 1", "end turn 1<|im_end|>"), ("start turn 2", "end turn 2<|im_end|>")],
|
||||
[("start turn 3", "end turn 3<|im_end|>"), ("start turn 4", "end turn 4<|im_end|>")],
|
||||
]
|
||||
tokenizer = self._get_tokenizer()
|
||||
|
||||
# check batched
|
||||
output = tokenizer.apply_chat_template(
|
||||
conversations,
|
||||
chat_template=dummy_template,
|
||||
tokenize=True,
|
||||
return_assistant_tokens_mask=True,
|
||||
return_dict=True,
|
||||
)
|
||||
for i, conv in enumerate(conversations):
|
||||
chat_string = tokenizer.apply_chat_template(conversations[i], tokenize=False, chat_template=dummy_template)
|
||||
assistant_start = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][0][0]))
|
||||
assistant_end = output.char_to_token(
|
||||
i,
|
||||
chat_string.index(assistant_prefix_suffix[i][0][1]) + len(assistant_prefix_suffix[i][0][1]) - 1,
|
||||
)
|
||||
|
||||
assistant_start2 = output.char_to_token(i, chat_string.index(assistant_prefix_suffix[i][1][0]))
|
||||
assistant_end2 = output.char_to_token(
|
||||
i,
|
||||
chat_string.index(assistant_prefix_suffix[i][1][1]) + len(assistant_prefix_suffix[i][1][1]) - 1,
|
||||
)
|
||||
|
||||
# assert 1 in first assistant message
|
||||
self.assertEqual(
|
||||
output["assistant_masks"][i][assistant_start : assistant_end + 1],
|
||||
[1] * (assistant_end - assistant_start + 1),
|
||||
)
|
||||
# assert 1 second assistant message
|
||||
self.assertEqual(
|
||||
output["assistant_masks"][i][assistant_start2 : assistant_end2 + 1],
|
||||
[1] * (assistant_end2 - assistant_start2 + 1),
|
||||
)
|
||||
|
||||
# assert 0 in user/system indices
|
||||
self.assertEqual(output["assistant_masks"][i][:assistant_start], [0] * assistant_start)
|
||||
self.assertEqual(
|
||||
output["assistant_masks"][i][assistant_end + 1 : assistant_start2],
|
||||
[0] * (assistant_start2 - assistant_end - 1),
|
||||
)
|
||||
|
||||
# check not batched
|
||||
output = tokenizer.apply_chat_template(
|
||||
conversations[0],
|
||||
chat_template=dummy_template,
|
||||
tokenize=True,
|
||||
return_assistant_tokens_mask=True,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
chat_string = tokenizer.apply_chat_template(conversations[0], tokenize=False, chat_template=dummy_template)
|
||||
assistant_start = output.char_to_token(0, chat_string.index(assistant_prefix_suffix[0][0][0]))
|
||||
assistant_end = output.char_to_token(
|
||||
0, chat_string.index(assistant_prefix_suffix[0][0][1]) + len(assistant_prefix_suffix[0][0][1]) - 1
|
||||
)
|
||||
assistant_start2 = output.char_to_token(0, chat_string.index(assistant_prefix_suffix[0][1][0]))
|
||||
assistant_end2 = output.char_to_token(
|
||||
0, chat_string.index(assistant_prefix_suffix[0][1][1]) + len(assistant_prefix_suffix[0][1][1]) - 1
|
||||
)
|
||||
|
||||
# assert 1 in assistant indices
|
||||
self.assertEqual(
|
||||
output["assistant_masks"][assistant_start : assistant_end + 1],
|
||||
[1] * (assistant_end - assistant_start + 1),
|
||||
)
|
||||
self.assertEqual(
|
||||
output["assistant_masks"][assistant_start2 : assistant_end2 + 1],
|
||||
[1] * (assistant_end2 - assistant_start2 + 1),
|
||||
)
|
||||
|
||||
# assert 0 in user/system indices
|
||||
self.assertEqual(output["assistant_masks"][:assistant_start], [0] * assistant_start)
|
||||
self.assertEqual(
|
||||
output["assistant_masks"][assistant_end + 1 : assistant_start2],
|
||||
[0] * (assistant_start2 - assistant_end - 1),
|
||||
)
|
||||
|
||||
@require_jinja
|
||||
def test_continue_final_message(self):
|
||||
dummy_template = """
|
||||
{%- for message in messages %}
|
||||
{{- "<|im_start|>" + message['role'] + "\n" + message['content'] + "<|im_end|>" + "\n"}}
|
||||
{%- endfor %}"""
|
||||
dummy_conversation = [
|
||||
{"role": "system", "content": "system message"},
|
||||
{"role": "user", "content": "user message"},
|
||||
{"role": "assistant", "content": "assistant message"},
|
||||
]
|
||||
tokenizer = self._get_tokenizer()
|
||||
output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=False
|
||||
)
|
||||
self.assertEqual(
|
||||
output,
|
||||
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message<|im_end|>\n",
|
||||
)
|
||||
prefill_output = tokenizer.apply_chat_template(
|
||||
dummy_conversation, chat_template=dummy_template, tokenize=False, continue_final_message=True
|
||||
)
|
||||
# Assert that the final message is unterminated
|
||||
self.assertEqual(
|
||||
prefill_output,
|
||||
"<|im_start|>system\nsystem message<|im_end|>\n<|im_start|>user\nuser message<|im_end|>\n<|im_start|>assistant\nassistant message",
|
||||
)
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template_dict(self):
|
||||
dummy_template_1 = "{{'a'}}"
|
||||
dummy_template_2 = "{{'b'}}"
|
||||
dummy_conversation = [
|
||||
{"role": "user", "content": "user message"},
|
||||
]
|
||||
tokenizer = self._get_tokenizer()
|
||||
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
|
||||
output1 = tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template_1, tokenize=False)
|
||||
output1_via_dict = tokenizer.apply_chat_template(dummy_conversation, chat_template="template1", tokenize=False)
|
||||
self.assertEqual(output1, output1_via_dict)
|
||||
output2 = tokenizer.apply_chat_template(dummy_conversation, chat_template=dummy_template_2, tokenize=False)
|
||||
output2_via_dict = tokenizer.apply_chat_template(dummy_conversation, chat_template="template2", tokenize=False)
|
||||
self.assertEqual(output2, output2_via_dict)
|
||||
|
||||
@require_jinja
|
||||
def test_chat_template_dict_saving(self):
|
||||
dummy_template_1 = "{{'a'}}"
|
||||
dummy_template_2 = "{{'b'}}"
|
||||
tokenizer = self._get_tokenizer()
|
||||
tokenizer.chat_template = {"template1": dummy_template_1, "template2": dummy_template_2}
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
tokenizer.save_pretrained(tmp_dir_name)
|
||||
config_dict = json.load(open(os.path.join(tmp_dir_name, "tokenizer_config.json")))
|
||||
# Assert that chat templates are correctly serialized as lists of dictionaries
|
||||
self.assertEqual(
|
||||
config_dict["chat_template"],
|
||||
[{"name": "template1", "template": "{{'a'}}"}, {"name": "template2", "template": "{{'b'}}"}],
|
||||
)
|
||||
new_tokenizer = tokenizer.from_pretrained(tmp_dir_name)
|
||||
# Assert that the serialized list is correctly reconstructed as a single dict
|
||||
self.assertEqual(new_tokenizer.chat_template, tokenizer.chat_template)
|
||||
|
||||
|
||||
class InverseChatTemplateTest(unittest.TestCase):
|
||||
def _get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Rocketknight1/tiny-gpt2-with-mistral-tool-template")
|
||||
tokenizer.inverse_template = r"""
|
||||
{%- set tools = finditer("\[AVAILABLE_TOOLS\] (.*?)\[\/AVAILABLE_TOOLS\]", chat, flags=16) %}
|
||||
{%- set user_messages = finditer('(?:\[INST\] )(.+?)\[\/INST\]', chat, flags=16, add_tag="user") %}
|
||||
{%- set asst_messages = finditer('(?:\[\/INST\]|\[\/TOOL_RESULTS\]) (.+?)<\/s>', chat, flags=16, add_tag="assistant") %}
|
||||
{%- set available_tools = finditer('\[AVAILABLE_TOOLS\] (.*?)\[\/AVAILABLE_TOOLS\]', chat, flags=16, add_tag="available_tools") %}
|
||||
{%- set tool_calls = finditer('\[TOOL_CALLS\] (.+?\])<\/s>', chat, flags=16, add_tag="tool_calls") %}
|
||||
{%- set tool_results = finditer('\[TOOL_RESULTS\] (.+?)\[\/TOOL_RESULTS\]', chat, flags=16, add_tag="tool") %}
|
||||
{%- set combined = sort_by_group_start(user_messages + asst_messages + tool_calls + tool_results, group_idx=1) %}
|
||||
{{- '{' }}
|
||||
{%- if tools | length > 0 %}
|
||||
{%- set tools = json_loads(tools[0].group[1]) %}
|
||||
{{- '"tools": ' }}
|
||||
{{- tools | tojson }}
|
||||
{{- ', ' }}
|
||||
{%- endif %}
|
||||
{{- '"messages": [' }}
|
||||
{%- for match in combined %}
|
||||
{%- if match.tag == 'assistant' or match.tag == 'user' %}
|
||||
{%- set message_dict = dict(role=match.tag, content=match.group[1]) %}
|
||||
{%- elif match.tag == "tool_calls" %}
|
||||
{%- set tool_call = json_loads(match.group[1])[0] %}
|
||||
{%- set tool_call = dict(type="function", id=tool_call["id"]|string, function=dict(name=tool_call.name, arguments=tool_call.arguments)) %}
|
||||
{%- set message_dict = dict(role="assistant", tool_calls=[tool_call]) %}
|
||||
{%- elif match.tag == "tool" %}
|
||||
{%- set base_dict = json_loads(match.group[1]) %}
|
||||
{%- set message_dict = dict(role=match.tag, content=base_dict.content|string, tool_call_id=base_dict.call_id|string) %}
|
||||
{%- endif %}
|
||||
{{- message_dict | tojson }}
|
||||
{%- if not loop.last %}
|
||||
{{- ", " }}
|
||||
{%- else %}
|
||||
{{- "]" }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{{- "}" }}
|
||||
""".strip()
|
||||
return tokenizer
|
||||
|
||||
@require_jinja
|
||||
def test_inverse_chat_template_save_load(self):
|
||||
# Pick a tokenizer with no chat template or reverse template
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
self.assertIsNone(tokenizer.inverse_template)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_dir = Path(tmp_dir)
|
||||
tokenizer.save_pretrained(tmp_dir / "tokenizer")
|
||||
tokenizer_config = json.load(open(tmp_dir / "tokenizer/tokenizer_config.json"))
|
||||
self.assertNotIn("inverse_template", tokenizer_config)
|
||||
self.assertFalse(Path(tmp_dir, "tokenizer", "inverse_template.jinja").is_file())
|
||||
tokenizer.inverse_template = "aaaa"
|
||||
tokenizer.save_pretrained(tmp_dir / "tokenizer_with_inverse_template")
|
||||
tokenizer_config = json.load(open(tmp_dir / "tokenizer_with_inverse_template/tokenizer_config.json"))
|
||||
self.assertNotIn("inverse_template", tokenizer_config) # Make sure it's separate
|
||||
self.assertTrue(Path(tmp_dir, "tokenizer", "inverse_template.jinja").is_file())
|
||||
reloaded_tokenizer = AutoTokenizer.from_pretrained(str(tmp_dir / "tokenizer_with_inverse_template"))
|
||||
self.assertEqual(reloaded_tokenizer.inverse_template, "aaaa")
|
||||
|
||||
@require_jinja
|
||||
def test_simple_chat_inversion(self):
|
||||
tokenizer = self._get_tokenizer()
|
||||
chat = [
|
||||
{"role": "user", "content": "user message"},
|
||||
{"role": "assistant", "content": "assistant message"},
|
||||
]
|
||||
chat_str = tokenizer.apply_chat_template(chat, tokenize=False)
|
||||
inverted_chat = tokenizer.apply_inverse_template(chat_str)
|
||||
self.assertEqual(chat, inverted_chat["messages"])
|
||||
|
||||
@require_jinja
|
||||
def test_chat_inversion_with_tool_calls(self):
|
||||
tokenizer = self._get_tokenizer()
|
||||
chat = [
|
||||
{"role": "user", "content": "user message"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
"id": "9Ae3bDc2F",
|
||||
"function": {
|
||||
"name": "get_current_temperature",
|
||||
"arguments": {"location": "Paris, France", "unit": "celsius"},
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "content": "22.0", "tool_call_id": "9Ae3bDc2F"},
|
||||
{"role": "assistant", "content": "assistant message"},
|
||||
]
|
||||
chat_str = tokenizer.apply_chat_template(chat, tokenize=False)
|
||||
inverted_chat = tokenizer.apply_inverse_template(chat_str)
|
||||
self.assertEqual(chat, inverted_chat["messages"])
|
||||
|
||||
@require_jinja
|
||||
def test_tool_extraction(self):
|
||||
tokenizer = self._get_tokenizer()
|
||||
chat = [
|
||||
{"role": "user", "content": "user message"},
|
||||
]
|
||||
|
||||
def tool_fn(location: str, unit: str):
|
||||
"""
|
||||
Get the current temperature
|
||||
|
||||
Args:
|
||||
location: The location to get the temperature from
|
||||
unit: The unit to return the temperature in
|
||||
"""
|
||||
return 22.0
|
||||
|
||||
chat_str = tokenizer.apply_chat_template(chat, tools=[tool_fn], tokenize=False)
|
||||
inverted_chat = tokenizer.apply_inverse_template(chat_str)
|
||||
self.assertEqual(inverted_chat["messages"], chat)
|
||||
self.assertEqual(inverted_chat["tools"], [get_json_schema(tool_fn)])
|
||||
|
Reference in New Issue
Block a user