mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
[chat template] update when "push_to_hub" (#39815)
* update templates push to hub * rvert jinja suffix and move it to processor file
This commit is contained in:
committed by
GitHub
parent
7bba4d1202
commit
313afcc468
@ -776,6 +776,7 @@ class ProcessorMixin(PushToHubMixin):
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
save_jinja_files = kwargs.pop("save_jinja_files", True)
|
||||
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
@ -803,8 +804,6 @@ class ProcessorMixin(PushToHubMixin):
|
||||
configs.append(self)
|
||||
custom_object_save(self, save_directory, config=configs)
|
||||
|
||||
save_jinja_files = kwargs.get("save_jinja_files", True)
|
||||
|
||||
for attribute_name in self.attributes:
|
||||
# Save the tokenizer in its own vocab file. The other attributes are saved as part of `processor_config.json`
|
||||
if attribute_name == "tokenizer":
|
||||
@ -840,7 +839,6 @@ class ProcessorMixin(PushToHubMixin):
|
||||
# Save `chat_template` in its own file. We can't get it from `processor_dict` as we popped it in `to_dict`
|
||||
# to avoid serializing chat template in json config file. So let's get it from `self` directly
|
||||
if self.chat_template is not None:
|
||||
save_jinja_files = kwargs.get("save_jinja_files", True)
|
||||
is_single_template = isinstance(self.chat_template, str)
|
||||
if save_jinja_files and is_single_template:
|
||||
# New format for single templates is to save them as chat_template.jinja
|
||||
@ -999,6 +997,7 @@ class ProcessorMixin(PushToHubMixin):
|
||||
cache_dir=cache_dir,
|
||||
token=token,
|
||||
):
|
||||
template = template.removesuffix(".jinja")
|
||||
additional_chat_template_files[template] = f"{CHAT_TEMPLATE_DIR}/{template}.jinja"
|
||||
except EntryNotFoundError:
|
||||
pass # No template dir means no template files
|
||||
|
@ -2512,6 +2512,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
A tuple of `str`: The files saved.
|
||||
"""
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
save_jinja_files = kwargs.pop("save_jinja_files", True)
|
||||
|
||||
if use_auth_token is not None:
|
||||
warnings.warn(
|
||||
@ -2560,7 +2561,6 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
tokenizer_config["extra_special_tokens"] = self.extra_special_tokens
|
||||
tokenizer_config.update(self.extra_special_tokens)
|
||||
|
||||
save_jinja_files = kwargs.get("save_jinja_files", True)
|
||||
tokenizer_config, saved_raw_chat_template_files = self.save_chat_templates(
|
||||
save_directory, tokenizer_config, filename_prefix, save_jinja_files
|
||||
)
|
||||
|
@ -163,7 +163,7 @@ def list_repo_templates(
|
||||
local_files_only: bool,
|
||||
revision: Optional[str] = None,
|
||||
cache_dir: Optional[str] = None,
|
||||
token: Union[bool, str, None] = None,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
) -> list[str]:
|
||||
"""List template files from a repo.
|
||||
|
||||
|
@ -34,7 +34,10 @@ from transformers import (
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
BertTokenizer,
|
||||
LlamaTokenizer,
|
||||
LlavaProcessor,
|
||||
ProcessorMixin,
|
||||
SiglipImageProcessor,
|
||||
Wav2Vec2Config,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
Wav2Vec2Processor,
|
||||
@ -57,6 +60,7 @@ from test_module.custom_tokenization import CustomTokenizer # noqa E402
|
||||
|
||||
|
||||
SAMPLE_PROCESSOR_CONFIG = get_tests_dir("fixtures/dummy_feature_extractor_config.json")
|
||||
SAMPLE_VOCAB_LLAMA = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
SAMPLE_VOCAB = get_tests_dir("fixtures/vocab.json")
|
||||
SAMPLE_PROCESSOR_CONFIG_DIR = get_tests_dir("fixtures")
|
||||
|
||||
@ -503,3 +507,43 @@ class ProcessorPushToHubTester(unittest.TestCase):
|
||||
new_processor = AutoProcessor.from_pretrained(tmp_repo.repo_id, trust_remote_code=True)
|
||||
# Can't make an isinstance check because the new_processor is from the CustomProcessor class of a dynamic module
|
||||
self.assertEqual(new_processor.__class__.__name__, "CustomProcessor")
|
||||
|
||||
def test_push_to_hub_with_chat_templates(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tokenizer = LlamaTokenizer(SAMPLE_VOCAB_LLAMA, keep_accents=True)
|
||||
image_processor = SiglipImageProcessor()
|
||||
chat_template = "default dummy template for testing purposes only"
|
||||
processor = LlavaProcessor(
|
||||
tokenizer=tokenizer, image_processor=image_processor, chat_template=chat_template
|
||||
)
|
||||
self.assertEqual(processor.chat_template, chat_template)
|
||||
|
||||
existing_tokenizer_template = getattr(processor.tokenizer, "chat_template", None)
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
processor.save_pretrained(
|
||||
tmp_dir, repo_id=tmp_repo.repo_id, token=self._token, push_to_hub=True, save_jinja_files=False
|
||||
)
|
||||
reloaded_processor = LlavaProcessor.from_pretrained(tmp_repo.repo_id)
|
||||
self.assertEqual(processor.chat_template, reloaded_processor.chat_template)
|
||||
# When we don't use single-file chat template saving, processor and tokenizer chat templates
|
||||
# should remain separate
|
||||
self.assertEqual(
|
||||
getattr(reloaded_processor.tokenizer, "chat_template", None), existing_tokenizer_template
|
||||
)
|
||||
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
processor.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, token=self._token, push_to_hub=True)
|
||||
reloaded_processor = LlavaProcessor.from_pretrained(tmp_repo.repo_id)
|
||||
self.assertEqual(processor.chat_template, reloaded_processor.chat_template)
|
||||
# When we save as single files, tokenizers and processors share a chat template, which means
|
||||
# the reloaded tokenizer should get the chat template as well
|
||||
self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template)
|
||||
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
processor.chat_template = {"default": "a", "secondary": "b"}
|
||||
processor.save_pretrained(tmp_dir, repo_id=tmp_repo.repo_id, token=self._token, push_to_hub=True)
|
||||
reloaded_processor = LlavaProcessor.from_pretrained(tmp_repo.repo_id)
|
||||
self.assertEqual(processor.chat_template, reloaded_processor.chat_template)
|
||||
# When we save as single files, tokenizers and processors share a chat template, which means
|
||||
# the reloaded tokenizer should get the chat template as well
|
||||
self.assertEqual(reloaded_processor.chat_template, reloaded_processor.tokenizer.chat_template)
|
||||
|
@ -131,6 +131,32 @@ class TokenizerPushToHubTester(unittest.TestCase):
|
||||
new_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
|
||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||
|
||||
def test_push_to_hub_chat_templates(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||
tokenizer = BertTokenizer(vocab_file)
|
||||
tokenizer.chat_template = "test template"
|
||||
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
tokenizer.save_pretrained(
|
||||
tmp_repo.repo_id, token=self._token, push_to_hub=True, save_jinja_files=False
|
||||
)
|
||||
reloaded_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
|
||||
self.assertEqual(tokenizer.chat_template, reloaded_tokenizer.chat_template)
|
||||
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
tokenizer.save_pretrained(tmp_repo.repo_id, token=self._token, push_to_hub=True)
|
||||
reloaded_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
|
||||
self.assertEqual(tokenizer.chat_template, reloaded_tokenizer.chat_template)
|
||||
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
tokenizer.chat_template = {"default": "a", "secondary": "b"}
|
||||
tokenizer.save_pretrained(tmp_repo.repo_id, token=self._token, push_to_hub=True)
|
||||
reloaded_tokenizer = BertTokenizer.from_pretrained(tmp_repo.repo_id)
|
||||
self.assertEqual(tokenizer.chat_template, reloaded_tokenizer.chat_template)
|
||||
|
||||
def test_push_to_hub_via_save_pretrained(self):
|
||||
with TemporaryHubRepo(token=self._token) as tmp_repo:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
|
Reference in New Issue
Block a user