mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
1 Commits
Author | SHA1 | Date | |
---|---|---|---|
3f8d7f8a05 |
@ -763,6 +763,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
||||
return_offsets_mapping: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
split_special_tokens: bool = False,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
def get_input_ids(text):
|
||||
@ -840,6 +841,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
split_special_tokens: bool = False,
|
||||
) -> BatchEncoding:
|
||||
"""
|
||||
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model. It
|
||||
|
@ -1748,10 +1748,18 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
|
||||
using_default_template = False
|
||||
|
||||
using_default_template = False
|
||||
|
||||
# First, handle the cases when the model has a dict of multiple templates
|
||||
if isinstance(self.chat_template, dict) or (
|
||||
self.chat_template is None and isinstance(self.default_chat_template, dict)
|
||||
):
|
||||
if self.chat_template is not None:
|
||||
template_dict = self.chat_template
|
||||
using_default_dict = False
|
||||
else:
|
||||
template_dict = self.default_chat_template
|
||||
using_default_dict = True
|
||||
if self.chat_template is not None:
|
||||
template_dict = self.chat_template
|
||||
using_default_dict = False
|
||||
@ -1763,10 +1771,14 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
chat_template = template_dict[chat_template]
|
||||
if using_default_dict:
|
||||
using_default_template = True
|
||||
if using_default_dict:
|
||||
using_default_template = True
|
||||
elif chat_template is None and "default" in template_dict:
|
||||
chat_template = template_dict["default"]
|
||||
if using_default_dict:
|
||||
using_default_template = True
|
||||
if using_default_dict:
|
||||
using_default_template = True
|
||||
elif chat_template is None:
|
||||
raise ValueError(
|
||||
"This model has multiple chat templates with no default specified! Please either pass a chat "
|
||||
@ -1907,6 +1919,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible.
|
||||
Will be removed in v5 of Transformers.
|
||||
resume_download:
|
||||
Deprecated and ignored. All downloads are now resumed by default when possible.
|
||||
Will be removed in v5 of Transformers.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
@ -1962,6 +1977,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
assert tokenizer.unk_token == "<unk>"
|
||||
```"""
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
resume_download = kwargs.pop("resume_download", None)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
|
@ -486,6 +486,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
return_offsets_mapping: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
split_special_tokens: bool = False,
|
||||
) -> BatchEncoding:
|
||||
if not isinstance(batch_text_or_text_pairs, (tuple, list)):
|
||||
raise TypeError(
|
||||
@ -570,6 +571,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
return_offsets_mapping: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
split_special_tokens: bool = False,
|
||||
**kwargs,
|
||||
) -> BatchEncoding:
|
||||
batched_input = [(text, text_pair)] if text_pair else [text]
|
||||
|
@ -4172,30 +4172,69 @@ class TokenizerTesterMixin:
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
special_token = "[SPECIAL_TOKEN]"
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
||||
pretrained_name, additional_special_tokens=[special_token], split_special_tokens=True, **kwargs
|
||||
)
|
||||
tokenizer_cr = self.rust_tokenizer_class.from_pretrained(
|
||||
pretrained_name,
|
||||
additional_special_tokens=[special_token],
|
||||
split_special_tokens=True,
|
||||
**kwargs,
|
||||
from_slow=True,
|
||||
)
|
||||
tokenizer_p = self.tokenizer_class.from_pretrained(
|
||||
pretrained_name, additional_special_tokens=[special_token], split_special_tokens=True, **kwargs
|
||||
)
|
||||
|
||||
if not tokenizer.is_fast:
|
||||
# bloom, gptneox etc only have a fast
|
||||
tokenizer.add_special_tokens(
|
||||
{
|
||||
"additional_special_tokens": [
|
||||
AddedToken(special_token, rstrip=True, lstrip=True, normalized=True, special=True)
|
||||
]
|
||||
}
|
||||
)
|
||||
encoded_special_token = tokenizer.encode(special_token, add_special_tokens=False)
|
||||
self.assertEqual(len(encoded_special_token), 1)
|
||||
encoded_special_token = tokenizer_p.encode(
|
||||
special_token, add_special_tokens=False, split_special_tokens=False
|
||||
)
|
||||
self.assertEqual(len(encoded_special_token), 1)
|
||||
|
||||
encoded_split_special_token = tokenizer.encode(
|
||||
special_token, add_special_tokens=False, split_special_tokens=True
|
||||
)
|
||||
if len(encoded_split_special_token) == 1:
|
||||
# if we have subword tokenization or special vocab
|
||||
self.assertTrue(
|
||||
encoded_split_special_token[0] != tokenizer.convert_tokens_to_ids(special_token)
|
||||
)
|
||||
else:
|
||||
self.assertTrue(len(encoded_split_special_token) > 1)
|
||||
encoded_split_special_token = tokenizer_p.encode(special_token, add_special_tokens=False)
|
||||
self.assertTrue(len(encoded_split_special_token) > 1)
|
||||
p_output = tokenizer_p.tokenize(f"Hey this is a {special_token} token")
|
||||
r_output = tokenizer_r.tokenize(f"Hey this is a {special_token} token")
|
||||
cr_output = tokenizer_cr.tokenize(f"Hey this is a {special_token} token")
|
||||
|
||||
self.assertEqual(p_output, r_output)
|
||||
self.assertEqual(cr_output, r_output)
|
||||
self.assertTrue(special_token not in p_output)
|
||||
|
||||
p_output_explicit = tokenizer_p.tokenize(
|
||||
f"Hey this is a {special_token} token", split_special_tokens=False
|
||||
)
|
||||
r_output_explicit = tokenizer_r.tokenize(
|
||||
f"Hey this is a {special_token} token", split_special_tokens=False
|
||||
)
|
||||
cr_output_explicit = tokenizer_cr.tokenize(
|
||||
f"Hey this is a {special_token} token", split_special_tokens=False
|
||||
)
|
||||
|
||||
self.assertTrue(special_token in p_output_explicit)
|
||||
self.assertEqual(p_output_explicit, r_output_explicit)
|
||||
self.assertEqual(cr_output_explicit, r_output_explicit)
|
||||
|
||||
p_special_token_id = tokenizer_p.encode(special_token, add_special_tokens=False)[0]
|
||||
p_output = tokenizer_p(f"Hey this is a {special_token} token")
|
||||
r_output = tokenizer_r(f"Hey this is a {special_token} token")
|
||||
cr_output = tokenizer_cr(f"Hey this is a {special_token} token")
|
||||
|
||||
self.assertTrue(p_special_token_id not in p_output)
|
||||
self.assertEqual(p_output, r_output)
|
||||
self.assertEqual(cr_output, r_output)
|
||||
|
||||
tmpdirname = tempfile.mkdtemp()
|
||||
tokenizer_p.save_pretrained(tmpdirname)
|
||||
fast_from_saved = self.tokenizer_class.from_pretrained(tmpdirname)
|
||||
|
||||
output_reloaded = fast_from_saved.tokenize(f"Hey this is a {special_token} token")
|
||||
self.assertTrue(special_token not in output_reloaded)
|
||||
|
||||
output_explicit_reloaded = fast_from_saved.tokenize(
|
||||
f"Hey this is a {special_token} token", split_special_tokens=False
|
||||
)
|
||||
self.assertTrue(special_token in output_explicit_reloaded)
|
||||
|
||||
def test_added_tokens_serialization(self):
|
||||
# Utility to test the added vocab
|
||||
|
Reference in New Issue
Block a user