mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-21 01:23:56 +08:00
Compare commits
2 Commits
v4.37.0
...
chat_templ
Author | SHA1 | Date | |
---|---|---|---|
e8ce410515 | |||
2268abf4a1 |
@ -459,9 +459,9 @@ class CodeLlamaTokenizer(PreTrainedTokenizer):
|
||||
"{% if messages[0]['role'] == 'system' %}"
|
||||
"{% set loop_messages = messages[1:] %}" # Extract system message if it's present
|
||||
"{% set system_message = messages[0]['content'] %}"
|
||||
"{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
|
||||
"{% elif use_default_prompt == true and not '<<SYS>>' in messages[0]['content'] %}"
|
||||
"{% set loop_messages = messages %}" # Or use the default system message if the flag is set
|
||||
"{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
|
||||
"{% set system_message = default_system_message %}"
|
||||
"{% else %}"
|
||||
"{% set loop_messages = messages %}"
|
||||
"{% set system_message = false %}"
|
||||
@ -484,12 +484,16 @@ class CodeLlamaTokenizer(PreTrainedTokenizer):
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
)
|
||||
template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
|
||||
default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
|
||||
template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
|
||||
|
||||
return template
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template_kwargs
|
||||
def default_chat_template_kwargs(self):
|
||||
return {
|
||||
"use_default_prompt": self.use_default_system_prompt,
|
||||
"default_system_message": DEFAULT_SYSTEM_PROMPT,
|
||||
}
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["sp_model"] = None
|
||||
|
@ -362,9 +362,9 @@ class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
|
||||
"{% if messages[0]['role'] == 'system' %}"
|
||||
"{% set loop_messages = messages[1:] %}" # Extract system message if it's present
|
||||
"{% set system_message = messages[0]['content'] %}"
|
||||
"{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
|
||||
"{% elif use_default_prompt == true and not '<<SYS>>' in messages[0]['content'] %}"
|
||||
"{% set loop_messages = messages %}" # Or use the default system message if the flag is set
|
||||
"{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
|
||||
"{% set system_message = default_system_message %}"
|
||||
"{% else %}"
|
||||
"{% set loop_messages = messages %}"
|
||||
"{% set system_message = false %}"
|
||||
@ -387,12 +387,16 @@ class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
)
|
||||
template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
|
||||
default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
|
||||
template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
|
||||
|
||||
return template
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template_kwargs
|
||||
def default_chat_template_kwargs(self):
|
||||
return {
|
||||
"use_default_prompt": self.use_default_system_prompt,
|
||||
"default_system_message": DEFAULT_SYSTEM_PROMPT,
|
||||
}
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
|
@ -393,9 +393,9 @@ class LlamaTokenizer(PreTrainedTokenizer):
|
||||
"{% if messages[0]['role'] == 'system' %}"
|
||||
"{% set loop_messages = messages[1:] %}" # Extract system message if it's present
|
||||
"{% set system_message = messages[0]['content'] %}"
|
||||
"{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
|
||||
"{% elif use_default_prompt == true and not '<<SYS>>' in messages[0]['content'] %}"
|
||||
"{% set loop_messages = messages %}" # Or use the default system message if the flag is set
|
||||
"{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
|
||||
"{% set system_message = default_system_message %}"
|
||||
"{% else %}"
|
||||
"{% set loop_messages = messages %}"
|
||||
"{% set system_message = false %}"
|
||||
@ -418,8 +418,11 @@ class LlamaTokenizer(PreTrainedTokenizer):
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
)
|
||||
template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
|
||||
default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
|
||||
template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
|
||||
|
||||
return template
|
||||
|
||||
@property
|
||||
def default_chat_template_kwargs(self):
|
||||
return {
|
||||
"use_default_prompt": self.use_default_system_prompt,
|
||||
"default_system_message": DEFAULT_SYSTEM_PROMPT,
|
||||
}
|
||||
|
@ -210,9 +210,9 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
|
||||
"{% if messages[0]['role'] == 'system' %}"
|
||||
"{% set loop_messages = messages[1:] %}" # Extract system message if it's present
|
||||
"{% set system_message = messages[0]['content'] %}"
|
||||
"{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
|
||||
"{% elif use_default_prompt == true and not '<<SYS>>' in messages[0]['content'] %}"
|
||||
"{% set loop_messages = messages %}" # Or use the default system message if the flag is set
|
||||
"{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
|
||||
"{% set system_message = default_system_message %}"
|
||||
"{% else %}"
|
||||
"{% set loop_messages = messages %}"
|
||||
"{% set system_message = false %}"
|
||||
@ -235,8 +235,12 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
|
||||
"{% endif %}"
|
||||
"{% endfor %}"
|
||||
)
|
||||
template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
|
||||
default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
|
||||
template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
|
||||
|
||||
return template
|
||||
|
||||
@property
|
||||
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template_kwargs
|
||||
def default_chat_template_kwargs(self):
|
||||
return {
|
||||
"use_default_prompt": self.use_default_system_prompt,
|
||||
"default_system_message": DEFAULT_SYSTEM_PROMPT,
|
||||
}
|
||||
|
@ -1567,6 +1567,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
|
||||
# Stores a Jinja template that formats chat histories into tokenizable strings
|
||||
self.chat_template = kwargs.pop("chat_template", None)
|
||||
self.chat_template_kwargs = kwargs.pop("chat_template_kwargs", None)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@ -1641,6 +1642,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
self,
|
||||
conversation: Union[List[Dict[str, str]], "Conversation"],
|
||||
chat_template: Optional[str] = None,
|
||||
template_kwargs: Optional[Dict] = None,
|
||||
tokenize: bool = True,
|
||||
padding: bool = False,
|
||||
truncation: bool = False,
|
||||
@ -1659,6 +1661,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
with "role" and "content" keys, representing the chat history so far.
|
||||
chat_template (str, *optional*): A Jinja template to use for this conversion. If
|
||||
this is not passed, the model's default chat template will be used instead.
|
||||
template_kwargs (Dict, *optional*): Additional kwargs that will be made available to the template.
|
||||
tokenize (`bool`, defaults to `True`):
|
||||
Whether to tokenize the output. If `False`, the output will be a string.
|
||||
padding (`bool`, defaults to `False`):
|
||||
@ -1693,10 +1696,20 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
else:
|
||||
chat_template = self.default_chat_template
|
||||
|
||||
if template_kwargs is None:
|
||||
if self.chat_template_kwargs is not None:
|
||||
template_kwargs = self.chat_template_kwargs
|
||||
else:
|
||||
template_kwargs = getattr(self, "default_chat_template_kwargs", {})
|
||||
|
||||
# Compilation function uses a cache to avoid recompiling the same template
|
||||
compiled_template = self._compile_jinja_template(chat_template)
|
||||
|
||||
rendered = compiled_template.render(messages=conversation, **self.special_tokens_map)
|
||||
# Add special tokens as template kwargs as well, but actual kwargs take priority over them
|
||||
token_kwargs = {key: val for key, val in self.special_tokens_map.items() if key not in template_kwargs}
|
||||
template_kwargs.update(token_kwargs)
|
||||
|
||||
rendered = compiled_template.render(messages=conversation, **template_kwargs)
|
||||
|
||||
if padding is True:
|
||||
padding = "max_length" # There's only one sequence here, so "longest" makes no sense
|
||||
@ -2302,6 +2315,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
|
||||
if self.chat_template is not None:
|
||||
tokenizer_config["chat_template"] = self.chat_template
|
||||
if self.chat_template_kwargs is not None:
|
||||
tokenizer_config["chat_template_kwargs"] = self.chat_template_kwargs
|
||||
|
||||
if len(self.init_inputs) > 0:
|
||||
tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)
|
||||
|
Reference in New Issue
Block a user