Compare commits

...

2 Commits

Author SHA1 Message Date
e8ce410515 Save attribute correctly 2023-09-15 14:58:28 +01:00
2268abf4a1 Update LLaMA with template kwargs and add support to the base method 2023-09-15 14:55:50 +01:00
5 changed files with 55 additions and 25 deletions

View File

@ -459,9 +459,9 @@ class CodeLlamaTokenizer(PreTrainedTokenizer):
"{% if messages[0]['role'] == 'system' %}"
"{% set loop_messages = messages[1:] %}" # Extract system message if it's present
"{% set system_message = messages[0]['content'] %}"
"{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
"{% elif use_default_prompt == true and not '<<SYS>>' in messages[0]['content'] %}"
"{% set loop_messages = messages %}" # Or use the default system message if the flag is set
"{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
"{% set system_message = default_system_message %}"
"{% else %}"
"{% set loop_messages = messages %}"
"{% set system_message = false %}"
@ -484,12 +484,16 @@ class CodeLlamaTokenizer(PreTrainedTokenizer):
"{% endif %}"
"{% endfor %}"
)
template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
return template
@property
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template_kwargs
def default_chat_template_kwargs(self):
return {
"use_default_prompt": self.use_default_system_prompt,
"default_system_message": DEFAULT_SYSTEM_PROMPT,
}
def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None

View File

@ -362,9 +362,9 @@ class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
"{% if messages[0]['role'] == 'system' %}"
"{% set loop_messages = messages[1:] %}" # Extract system message if it's present
"{% set system_message = messages[0]['content'] %}"
"{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
"{% elif use_default_prompt == true and not '<<SYS>>' in messages[0]['content'] %}"
"{% set loop_messages = messages %}" # Or use the default system message if the flag is set
"{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
"{% set system_message = default_system_message %}"
"{% else %}"
"{% set loop_messages = messages %}"
"{% set system_message = false %}"
@ -387,12 +387,16 @@ class CodeLlamaTokenizerFast(PreTrainedTokenizerFast):
"{% endif %}"
"{% endfor %}"
)
template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
return template
@property
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template_kwargs
def default_chat_template_kwargs(self):
return {
"use_default_prompt": self.use_default_system_prompt,
"default_system_message": DEFAULT_SYSTEM_PROMPT,
}
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:

View File

@ -393,9 +393,9 @@ class LlamaTokenizer(PreTrainedTokenizer):
"{% if messages[0]['role'] == 'system' %}"
"{% set loop_messages = messages[1:] %}" # Extract system message if it's present
"{% set system_message = messages[0]['content'] %}"
"{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
"{% elif use_default_prompt == true and not '<<SYS>>' in messages[0]['content'] %}"
"{% set loop_messages = messages %}" # Or use the default system message if the flag is set
"{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
"{% set system_message = default_system_message %}"
"{% else %}"
"{% set loop_messages = messages %}"
"{% set system_message = false %}"
@ -418,8 +418,11 @@ class LlamaTokenizer(PreTrainedTokenizer):
"{% endif %}"
"{% endfor %}"
)
template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
return template
@property
def default_chat_template_kwargs(self):
return {
"use_default_prompt": self.use_default_system_prompt,
"default_system_message": DEFAULT_SYSTEM_PROMPT,
}

View File

@ -210,9 +210,9 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
"{% if messages[0]['role'] == 'system' %}"
"{% set loop_messages = messages[1:] %}" # Extract system message if it's present
"{% set system_message = messages[0]['content'] %}"
"{% elif USE_DEFAULT_PROMPT == true and not '<<SYS>>' in messages[0]['content'] %}"
"{% elif use_default_prompt == true and not '<<SYS>>' in messages[0]['content'] %}"
"{% set loop_messages = messages %}" # Or use the default system message if the flag is set
"{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}"
"{% set system_message = default_system_message %}"
"{% else %}"
"{% set loop_messages = messages %}"
"{% set system_message = false %}"
@ -235,8 +235,12 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
"{% endif %}"
"{% endfor %}"
)
template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false")
default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'")
template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message)
return template
@property
# Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template_kwargs
def default_chat_template_kwargs(self):
return {
"use_default_prompt": self.use_default_system_prompt,
"default_system_message": DEFAULT_SYSTEM_PROMPT,
}

View File

@ -1567,6 +1567,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# Stores a Jinja template that formats chat histories into tokenizable strings
self.chat_template = kwargs.pop("chat_template", None)
self.chat_template_kwargs = kwargs.pop("chat_template_kwargs", None)
super().__init__(**kwargs)
@ -1641,6 +1642,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
self,
conversation: Union[List[Dict[str, str]], "Conversation"],
chat_template: Optional[str] = None,
template_kwargs: Optional[Dict] = None,
tokenize: bool = True,
padding: bool = False,
truncation: bool = False,
@ -1659,6 +1661,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
with "role" and "content" keys, representing the chat history so far.
chat_template (str, *optional*): A Jinja template to use for this conversion. If
this is not passed, the model's default chat template will be used instead.
template_kwargs (Dict, *optional*): Additional kwargs that will be made available to the template.
tokenize (`bool`, defaults to `True`):
Whether to tokenize the output. If `False`, the output will be a string.
padding (`bool`, defaults to `False`):
@ -1693,10 +1696,20 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
else:
chat_template = self.default_chat_template
if template_kwargs is None:
if self.chat_template_kwargs is not None:
template_kwargs = self.chat_template_kwargs
else:
template_kwargs = getattr(self, "default_chat_template_kwargs", {})
# Compilation function uses a cache to avoid recompiling the same template
compiled_template = self._compile_jinja_template(chat_template)
rendered = compiled_template.render(messages=conversation, **self.special_tokens_map)
# Add special tokens as template kwargs as well, but actual kwargs take priority over them
token_kwargs = {key: val for key, val in self.special_tokens_map.items() if key not in template_kwargs}
template_kwargs.update(token_kwargs)
rendered = compiled_template.render(messages=conversation, **template_kwargs)
if padding is True:
padding = "max_length" # There's only one sequence here, so "longest" makes no sense
@ -2302,6 +2315,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if self.chat_template is not None:
tokenizer_config["chat_template"] = self.chat_template
if self.chat_template_kwargs is not None:
tokenizer_config["chat_template_kwargs"] = self.chat_template_kwargs
if len(self.init_inputs) > 0:
tokenizer_config["init_inputs"] = copy.deepcopy(self.init_inputs)