💬 Fix setup_chat_format and add clone_chat_template (#3404)

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Edward Beeching <edbeeching@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
Quentin Gallouédec
2025-06-15 15:59:42 +02:00
committed by GitHub
parent 91b3f5ee9a
commit 4126803875
7 changed files with 134 additions and 13 deletions

View File

@ -1,5 +1,9 @@
# Model Utilities
## clone_chat_template
[[autodoc]] clone_chat_template
## get_act_offloading_ctx_manager
[[autodoc]] models.get_act_offloading_ctx_manager

View File

@ -63,26 +63,26 @@ If youd like to compute loss on both the prompt **and** the completion while
### Add Special Tokens for Chat Format
Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system, and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment.
The [`setup_chat_format`] function in `trl` easily sets up a model and tokenizer for conversational AI tasks. This function:
The [`clone_chat_template`] function is a useful utility to prepare a model and tokenizer for conversational AI tasks. This function:
- Adds special tokens to the tokenizer, e.g., `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
- Resizes the models embedding layer to accommodate the new tokens.
- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI.
- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format.
- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g., `64`. If you want to see more formats being supported in the future, please open a GitHub issue on [trl](https://github.com/huggingface/trl)
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import setup_chat_format
from trl import clone_chat_template
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
# Set up the chat format with the default 'chatml' format
model, tokenizer = setup_chat_format(model, tokenizer)
# Set up the chat format
model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")
```
> [!WARNING]
> Some base models, like those from Qwen, have a predefined chat template in the model's tokenizer. In these cases, it is not necessary to apply `setup_chat_format()`, as the tokenizer already handles the formatting. However, it is necessary to align the EOS token with the chat template to ensure the model's responses terminate correctly. In these cases, specify `eos_token` in `SFTConfig`; for example, for `Qwen/Qwen2.5-1.5B`, one should set `eos_token="<|im_end|>"`.
> Some base models, like those from Qwen, have a predefined chat template in the model's tokenizer. In these cases, it is not necessary to apply [`clone_chat_template()`], as the tokenizer already handles the formatting. However, it is necessary to align the EOS token with the chat template to ensure the model's responses terminate correctly. In these cases, specify `eos_token` in [`SFTConfig`]; for example, for `Qwen/Qwen2.5-1.5B`, one should set `eos_token="<|im_end|>"`.
With our model and tokenizer set up, we can now fine-tune our model on a conversational dataset. Below is an example of how a dataset can be formatted for fine-tuning.

View File

@ -19,7 +19,7 @@ from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl.extras.dataset_formatting import get_formatting_func_from_dataset
from trl.models.utils import ChatMlSpecialTokens, setup_chat_format
from trl.models.utils import ChatMlSpecialTokens, clone_chat_template, setup_chat_format
class DatasetFormattingTestCase(unittest.TestCase):
@ -124,7 +124,7 @@ class SetupChatFormatTestCase(unittest.TestCase):
def test_setup_chat_format(self):
modified_model, modified_tokenizer = setup_chat_format(
self.model, self.tokenizer, format="chatml", resize_to_multiple_of=64
self.model, self.tokenizer, format="chatml", resize_to_multiple_of=123
)
_chatml = ChatMlSpecialTokens()
@ -135,7 +135,7 @@ class SetupChatFormatTestCase(unittest.TestCase):
self.assertEqual(modified_tokenizer.eos_token, _chatml.eos_token)
self.assertEqual(modified_tokenizer.pad_token, _chatml.pad_token)
self.assertEqual(modified_tokenizer.bos_token, _chatml.bos_token)
self.assertEqual((self.model.get_input_embeddings().weight.shape[0] % 64), 0)
self.assertEqual((modified_model.get_input_embeddings().num_embeddings % 123), 0)
def test_example_with_setup_model(self):
modified_model, modified_tokenizer = setup_chat_format(
@ -153,3 +153,38 @@ class SetupChatFormatTestCase(unittest.TestCase):
prompt,
"<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n",
)
class CloneChatTemplateTestCase(unittest.TestCase):
def setUp(self):
# This tokenizer doesn't have a chat_template by default
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
# This one has a chat_template by default
self.source = "trl-internal-testing/tiny-Qwen3ForCausalLM"
def test_clone(self):
_, modified_tokenizer = clone_chat_template(self.model, self.tokenizer, self.source)
# Check if special tokens are correctly set
self.assertEqual(modified_tokenizer.eos_token, "<|im_end|>")
def test_clone_with_resize(self):
modified_model, _ = clone_chat_template(self.model, self.tokenizer, self.source, resize_to_multiple_of=123)
# Check that the input embeddings have been resized to a multiple of 123
self.assertEqual((modified_model.get_input_embeddings().num_embeddings % 123), 0)
def test_apply_new_chat_template(self):
_, modified_tokenizer = clone_chat_template(self.model, self.tokenizer, self.source)
messages = [
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi, how can I help you?"},
]
prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False)
self.assertEqual(
prompt,
"<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\nHi, how can I help you?<|im_end|>\n",
)

View File

@ -41,6 +41,7 @@ _import_structure = {
"AutoModelForCausalLMWithValueHead",
"AutoModelForSeq2SeqLMWithValueHead",
"PreTrainedModelWrapper",
"clone_chat_template",
"create_reference_model",
"setup_chat_format",
],
@ -136,6 +137,7 @@ if TYPE_CHECKING:
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
PreTrainedModelWrapper,
clone_chat_template,
create_reference_model,
setup_chat_format,
)

View File

@ -23,6 +23,7 @@ _import_structure = {
"modeling_value_head": ["AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead"],
"utils": [
"SUPPORTED_ARCHITECTURES",
"clone_chat_template",
"prepare_deepspeed",
"prepare_fsdp",
"setup_chat_format",
@ -49,6 +50,7 @@ if TYPE_CHECKING:
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
from .utils import (
SUPPORTED_ARCHITECTURES,
clone_chat_template,
prepare_deepspeed,
prepare_fsdp,
setup_chat_format,

View File

@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union
import torch.nn as nn
from packaging import version
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
@ -82,6 +82,10 @@ def setup_chat_format(
"""
Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens.
<Tip warning="true">
We recommend using [`clone_chat_template`] instead of this function.
</Tip>
If the model already has a chat template, this will throw an error. If you want to overwrite it, please set `tokenizer.chat_template` to `None`.
Args:
@ -116,7 +120,11 @@ def setup_chat_format(
# resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377
model.resize_token_embeddings(
len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None
# After studying many tokenizers, we found that len(tokenizer.vocab) is the most reliable way to get the vocab
# size. Avoid using tokenizer.vocab_size or tokenizer.vocab_size + len(tokenizer.added_tokens_encoder),
# as handling of special and added tokens varies across tokenizers.
new_num_tokens=len(tokenizer.vocab),
pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None,
)
# Update the model config to use the new eos & bos tokens
if getattr(model, "config", None) is not None:
@ -132,6 +140,75 @@ def setup_chat_format(
return model, tokenizer
def clone_chat_template(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
source_tokenizer_path: str,
resize_to_multiple_of: Optional[int] = 64,
) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
"""
Clones a chat template from a source tokenizer to the target tokenizer and updates the model accordingly.
This function:
- Copies the chat template from a source tokenizer to the target tokenizer.
- Adds any new tokens from the source tokenizer to the target tokenizer.
- Sets and synchronizes the EOS token across the tokenizer and model.
- Resizes the model's token embeddings to match the new vocabulary size, optionally rounding it up to a multiple of
a specified value.
Args:
model (`PreTrainedModel`):
Model to update.
tokenizer (`PreTrainedTokenizer`):
Tokenizer to update.
source_tokenizer_path (`str`):
Path or identifier of the pretrained tokenizer to clone from.
resize_to_multiple_of (`int` or `None`, *optional*, defaults to `64`):
The embedding layer will be resized to the new vocabulary size. If this is not `None`, it will round up the
new vocabulary size to the nearest multiple of this value.
Returns:
model (`PreTrainedModel`):
Updated model with resized token embeddings and EOS token configured.
tokenizer (`~transformers.PreTrainedTokenizer`):
Updated tokenizer with the chat template and special tokens applied.
Example:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import clone_chat_template
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")
```
"""
# Load the source tokenizer containing the desired chat template
tokenizer_source = AutoTokenizer.from_pretrained(source_tokenizer_path)
# Copy the chat template from the source tokenizer
tokenizer.chat_template = tokenizer_source.get_chat_template()
# Ensure all added tokens from the source are available in the target tokenizer
tokenizer.add_tokens(list(tokenizer_source.added_tokens_decoder.values()))
# Set the EOS token from the source tokenizer (important for generation)
tokenizer.eos_token = tokenizer_source.eos_token
model.config.eos_token_id = tokenizer.eos_token_id
model.generation_config.eos_token_id = tokenizer.eos_token_id
# Resize model embeddings to include any new tokens, optionally rounding up to a multiple
model.resize_token_embeddings(
# After studying many tokenizers, we found that len(tokenizer.vocab) is the most reliable way to get the vocab
# size. Avoid using tokenizer.vocab_size or tokenizer.vocab_size + len(tokenizer.added_tokens_encoder),
# as handling of special and added tokens varies across tokenizers.
new_num_tokens=len(tokenizer.vocab),
pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None,
)
return model, tokenizer
def remove_hooks(model: "DeepSpeedEngine") -> None:
"""Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer

View File

@ -63,10 +63,10 @@ from trl import (
SFTConfig,
SFTTrainer,
TrlParser,
clone_chat_template,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
setup_chat_format,
)
@ -104,7 +104,8 @@ def main(script_args, training_args, model_args):
# Set default chat template if needed
if tokenizer.chat_template is None:
model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
# TODO: source should be passed as an argument
model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")
################
# Dataset