Compare commits

...

1 Commits

Author SHA1 Message Date
3f8d7f8a05 attempt 2024-05-14 14:57:29 +02:00
4 changed files with 81 additions and 22 deletions

View File

@ -763,6 +763,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
split_special_tokens: bool = False,
**kwargs,
) -> BatchEncoding:
def get_input_ids(text):
@ -840,6 +841,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
return_special_tokens_mask: bool = False,
return_length: bool = False,
verbose: bool = True,
split_special_tokens: bool = False,
) -> BatchEncoding:
"""
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It

View File

@ -1748,10 +1748,18 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
using_default_template = False
using_default_template = False
# First, handle the cases when the model has a dict of multiple templates
if isinstance(self.chat_template, dict) or (
self.chat_template is None and isinstance(self.default_chat_template, dict)
):
if self.chat_template is not None:
template_dict = self.chat_template
using_default_dict = False
else:
template_dict = self.default_chat_template
using_default_dict = True
if self.chat_template is not None:
template_dict = self.chat_template
using_default_dict = False
@ -1763,10 +1771,14 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
chat_template = template_dict[chat_template]
if using_default_dict:
using_default_template = True
if using_default_dict:
using_default_template = True
elif chat_template is None and "default" in template_dict:
chat_template = template_dict["default"]
if using_default_dict:
using_default_template = True
if using_default_dict:
using_default_template = True
elif chat_template is None:
raise ValueError(
"This model has multiple chat templates with no default specified! Please either pass a chat "
@ -1907,6 +1919,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
resume_download:
Deprecated and ignored. All downloads are now resumed by default when possible.
Will be removed in v5 of Transformers.
resume_download:
Deprecated and ignored. All downloads are now resumed by default when possible.
Will be removed in v5 of Transformers.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@ -1962,6 +1977,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
assert tokenizer.unk_token == "<unk>"
```"""
resume_download = kwargs.pop("resume_download", None)
resume_download = kwargs.pop("resume_download", None)
proxies = kwargs.pop("proxies", None)
use_auth_token = kwargs.pop("use_auth_token", None)
subfolder = kwargs.pop("subfolder", None)

View File

@ -486,6 +486,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
split_special_tokens: bool = False,
) -> BatchEncoding:
if not isinstance(batch_text_or_text_pairs, (tuple, list)):
raise TypeError(
@ -570,6 +571,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
return_offsets_mapping: bool = False,
return_length: bool = False,
verbose: bool = True,
split_special_tokens: bool = False,
**kwargs,
) -> BatchEncoding:
batched_input = [(text, text_pair)] if text_pair else [text]

View File

@ -4172,30 +4172,69 @@ class TokenizerTesterMixin:
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
special_token = "[SPECIAL_TOKEN]"
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=[special_token], split_special_tokens=True, **kwargs
)
tokenizer_cr = self.rust_tokenizer_class.from_pretrained(
pretrained_name,
additional_special_tokens=[special_token],
split_special_tokens=True,
**kwargs,
from_slow=True,
)
tokenizer_p = self.tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=[special_token], split_special_tokens=True, **kwargs
)
if not tokenizer.is_fast:
# bloom, gptneox etc only have a fast
tokenizer.add_special_tokens(
{
"additional_special_tokens": [
AddedToken(special_token, rstrip=True, lstrip=True, normalized=True, special=True)
]
}
)
encoded_special_token = tokenizer.encode(special_token, add_special_tokens=False)
self.assertEqual(len(encoded_special_token), 1)
encoded_special_token = tokenizer_p.encode(
special_token, add_special_tokens=False, split_special_tokens=False
)
self.assertEqual(len(encoded_special_token), 1)
encoded_split_special_token = tokenizer.encode(
special_token, add_special_tokens=False, split_special_tokens=True
)
if len(encoded_split_special_token) == 1:
# if we have subword tokenization or special vocab
self.assertTrue(
encoded_split_special_token[0] != tokenizer.convert_tokens_to_ids(special_token)
)
else:
self.assertTrue(len(encoded_split_special_token) > 1)
encoded_split_special_token = tokenizer_p.encode(special_token, add_special_tokens=False)
self.assertTrue(len(encoded_split_special_token) > 1)
p_output = tokenizer_p.tokenize(f"Hey this is a {special_token} token")
r_output = tokenizer_r.tokenize(f"Hey this is a {special_token} token")
cr_output = tokenizer_cr.tokenize(f"Hey this is a {special_token} token")
self.assertEqual(p_output, r_output)
self.assertEqual(cr_output, r_output)
self.assertTrue(special_token not in p_output)
p_output_explicit = tokenizer_p.tokenize(
f"Hey this is a {special_token} token", split_special_tokens=False
)
r_output_explicit = tokenizer_r.tokenize(
f"Hey this is a {special_token} token", split_special_tokens=False
)
cr_output_explicit = tokenizer_cr.tokenize(
f"Hey this is a {special_token} token", split_special_tokens=False
)
self.assertTrue(special_token in p_output_explicit)
self.assertEqual(p_output_explicit, r_output_explicit)
self.assertEqual(cr_output_explicit, r_output_explicit)
p_special_token_id = tokenizer_p.encode(special_token, add_special_tokens=False)[0]
p_output = tokenizer_p(f"Hey this is a {special_token} token")
r_output = tokenizer_r(f"Hey this is a {special_token} token")
cr_output = tokenizer_cr(f"Hey this is a {special_token} token")
self.assertTrue(p_special_token_id not in p_output)
self.assertEqual(p_output, r_output)
self.assertEqual(cr_output, r_output)
tmpdirname = tempfile.mkdtemp()
tokenizer_p.save_pretrained(tmpdirname)
fast_from_saved = self.tokenizer_class.from_pretrained(tmpdirname)
output_reloaded = fast_from_saved.tokenize(f"Hey this is a {special_token} token")
self.assertTrue(special_token not in output_reloaded)
output_explicit_reloaded = fast_from_saved.tokenize(
f"Hey this is a {special_token} token", split_special_tokens=False
)
self.assertTrue(special_token in output_explicit_reloaded)
def test_added_tokens_serialization(self):
# Utility to test the added vocab