mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 10:03:51 +08:00
💬 Fix setup_chat_format
and add clone_chat_template
(#3404)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Edward Beeching <edbeeching@users.noreply.github.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
committed by
GitHub
parent
91b3f5ee9a
commit
4126803875
@ -1,5 +1,9 @@
|
||||
# Model Utilities
|
||||
|
||||
## clone_chat_template
|
||||
|
||||
[[autodoc]] clone_chat_template
|
||||
|
||||
## get_act_offloading_ctx_manager
|
||||
|
||||
[[autodoc]] models.get_act_offloading_ctx_manager
|
||||
|
@ -63,26 +63,26 @@ If you’d like to compute loss on both the prompt **and** the completion while
|
||||
### Add Special Tokens for Chat Format
|
||||
|
||||
Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system, and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment.
|
||||
The [`setup_chat_format`] function in `trl` easily sets up a model and tokenizer for conversational AI tasks. This function:
|
||||
The [`clone_chat_template`] function is a useful utility to prepare a model and tokenizer for conversational AI tasks. This function:
|
||||
- Adds special tokens to the tokenizer, e.g., `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
|
||||
- Resizes the model’s embedding layer to accommodate the new tokens.
|
||||
- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI.
|
||||
- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format.
|
||||
- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g., `64`. If you want to see more formats being supported in the future, please open a GitHub issue on [trl](https://github.com/huggingface/trl)
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import setup_chat_format
|
||||
from trl import clone_chat_template
|
||||
|
||||
# Load model and tokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
|
||||
# Set up the chat format with the default 'chatml' format
|
||||
model, tokenizer = setup_chat_format(model, tokenizer)
|
||||
# Set up the chat format
|
||||
model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Some base models, like those from Qwen, have a predefined chat template in the model's tokenizer. In these cases, it is not necessary to apply `setup_chat_format()`, as the tokenizer already handles the formatting. However, it is necessary to align the EOS token with the chat template to ensure the model's responses terminate correctly. In these cases, specify `eos_token` in `SFTConfig`; for example, for `Qwen/Qwen2.5-1.5B`, one should set `eos_token="<|im_end|>"`.
|
||||
> Some base models, like those from Qwen, have a predefined chat template in the model's tokenizer. In these cases, it is not necessary to apply [`clone_chat_template()`], as the tokenizer already handles the formatting. However, it is necessary to align the EOS token with the chat template to ensure the model's responses terminate correctly. In these cases, specify `eos_token` in [`SFTConfig`]; for example, for `Qwen/Qwen2.5-1.5B`, one should set `eos_token="<|im_end|>"`.
|
||||
|
||||
With our model and tokenizer set up, we can now fine-tune our model on a conversational dataset. Below is an example of how a dataset can be formatted for fine-tuning.
|
||||
|
||||
|
@ -19,7 +19,7 @@ from datasets import Dataset, load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from trl.extras.dataset_formatting import get_formatting_func_from_dataset
|
||||
from trl.models.utils import ChatMlSpecialTokens, setup_chat_format
|
||||
from trl.models.utils import ChatMlSpecialTokens, clone_chat_template, setup_chat_format
|
||||
|
||||
|
||||
class DatasetFormattingTestCase(unittest.TestCase):
|
||||
@ -124,7 +124,7 @@ class SetupChatFormatTestCase(unittest.TestCase):
|
||||
|
||||
def test_setup_chat_format(self):
|
||||
modified_model, modified_tokenizer = setup_chat_format(
|
||||
self.model, self.tokenizer, format="chatml", resize_to_multiple_of=64
|
||||
self.model, self.tokenizer, format="chatml", resize_to_multiple_of=123
|
||||
)
|
||||
|
||||
_chatml = ChatMlSpecialTokens()
|
||||
@ -135,7 +135,7 @@ class SetupChatFormatTestCase(unittest.TestCase):
|
||||
self.assertEqual(modified_tokenizer.eos_token, _chatml.eos_token)
|
||||
self.assertEqual(modified_tokenizer.pad_token, _chatml.pad_token)
|
||||
self.assertEqual(modified_tokenizer.bos_token, _chatml.bos_token)
|
||||
self.assertEqual((self.model.get_input_embeddings().weight.shape[0] % 64), 0)
|
||||
self.assertEqual((modified_model.get_input_embeddings().num_embeddings % 123), 0)
|
||||
|
||||
def test_example_with_setup_model(self):
|
||||
modified_model, modified_tokenizer = setup_chat_format(
|
||||
@ -153,3 +153,38 @@ class SetupChatFormatTestCase(unittest.TestCase):
|
||||
prompt,
|
||||
"<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n",
|
||||
)
|
||||
|
||||
|
||||
class CloneChatTemplateTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# This tokenizer doesn't have a chat_template by default
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
|
||||
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
|
||||
# This one has a chat_template by default
|
||||
self.source = "trl-internal-testing/tiny-Qwen3ForCausalLM"
|
||||
|
||||
def test_clone(self):
|
||||
_, modified_tokenizer = clone_chat_template(self.model, self.tokenizer, self.source)
|
||||
|
||||
# Check if special tokens are correctly set
|
||||
self.assertEqual(modified_tokenizer.eos_token, "<|im_end|>")
|
||||
|
||||
def test_clone_with_resize(self):
|
||||
modified_model, _ = clone_chat_template(self.model, self.tokenizer, self.source, resize_to_multiple_of=123)
|
||||
|
||||
# Check that the input embeddings have been resized to a multiple of 123
|
||||
self.assertEqual((modified_model.get_input_embeddings().num_embeddings % 123), 0)
|
||||
|
||||
def test_apply_new_chat_template(self):
|
||||
_, modified_tokenizer = clone_chat_template(self.model, self.tokenizer, self.source)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi, how can I help you?"},
|
||||
]
|
||||
prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
self.assertEqual(
|
||||
prompt,
|
||||
"<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\nHi, how can I help you?<|im_end|>\n",
|
||||
)
|
||||
|
@ -41,6 +41,7 @@ _import_structure = {
|
||||
"AutoModelForCausalLMWithValueHead",
|
||||
"AutoModelForSeq2SeqLMWithValueHead",
|
||||
"PreTrainedModelWrapper",
|
||||
"clone_chat_template",
|
||||
"create_reference_model",
|
||||
"setup_chat_format",
|
||||
],
|
||||
@ -136,6 +137,7 @@ if TYPE_CHECKING:
|
||||
AutoModelForCausalLMWithValueHead,
|
||||
AutoModelForSeq2SeqLMWithValueHead,
|
||||
PreTrainedModelWrapper,
|
||||
clone_chat_template,
|
||||
create_reference_model,
|
||||
setup_chat_format,
|
||||
)
|
||||
|
@ -23,6 +23,7 @@ _import_structure = {
|
||||
"modeling_value_head": ["AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead"],
|
||||
"utils": [
|
||||
"SUPPORTED_ARCHITECTURES",
|
||||
"clone_chat_template",
|
||||
"prepare_deepspeed",
|
||||
"prepare_fsdp",
|
||||
"setup_chat_format",
|
||||
@ -49,6 +50,7 @@ if TYPE_CHECKING:
|
||||
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
|
||||
from .utils import (
|
||||
SUPPORTED_ARCHITECTURES,
|
||||
clone_chat_template,
|
||||
prepare_deepspeed,
|
||||
prepare_fsdp,
|
||||
setup_chat_format,
|
||||
|
@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
|
||||
|
||||
@ -82,6 +82,10 @@ def setup_chat_format(
|
||||
"""
|
||||
Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens.
|
||||
|
||||
<Tip warning="true">
|
||||
We recommend using [`clone_chat_template`] instead of this function.
|
||||
</Tip>
|
||||
|
||||
If the model already has a chat template, this will throw an error. If you want to overwrite it, please set `tokenizer.chat_template` to `None`.
|
||||
|
||||
Args:
|
||||
@ -116,7 +120,11 @@ def setup_chat_format(
|
||||
|
||||
# resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377
|
||||
model.resize_token_embeddings(
|
||||
len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None
|
||||
# After studying many tokenizers, we found that len(tokenizer.vocab) is the most reliable way to get the vocab
|
||||
# size. Avoid using tokenizer.vocab_size or tokenizer.vocab_size + len(tokenizer.added_tokens_encoder),
|
||||
# as handling of special and added tokens varies across tokenizers.
|
||||
new_num_tokens=len(tokenizer.vocab),
|
||||
pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None,
|
||||
)
|
||||
# Update the model config to use the new eos & bos tokens
|
||||
if getattr(model, "config", None) is not None:
|
||||
@ -132,6 +140,75 @@ def setup_chat_format(
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def clone_chat_template(
|
||||
model: PreTrainedModel,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
source_tokenizer_path: str,
|
||||
resize_to_multiple_of: Optional[int] = 64,
|
||||
) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
|
||||
"""
|
||||
Clones a chat template from a source tokenizer to the target tokenizer and updates the model accordingly.
|
||||
|
||||
This function:
|
||||
- Copies the chat template from a source tokenizer to the target tokenizer.
|
||||
- Adds any new tokens from the source tokenizer to the target tokenizer.
|
||||
- Sets and synchronizes the EOS token across the tokenizer and model.
|
||||
- Resizes the model's token embeddings to match the new vocabulary size, optionally rounding it up to a multiple of
|
||||
a specified value.
|
||||
|
||||
Args:
|
||||
model (`PreTrainedModel`):
|
||||
Model to update.
|
||||
tokenizer (`PreTrainedTokenizer`):
|
||||
Tokenizer to update.
|
||||
source_tokenizer_path (`str`):
|
||||
Path or identifier of the pretrained tokenizer to clone from.
|
||||
resize_to_multiple_of (`int` or `None`, *optional*, defaults to `64`):
|
||||
The embedding layer will be resized to the new vocabulary size. If this is not `None`, it will round up the
|
||||
new vocabulary size to the nearest multiple of this value.
|
||||
|
||||
Returns:
|
||||
model (`PreTrainedModel`):
|
||||
Updated model with resized token embeddings and EOS token configured.
|
||||
tokenizer (`~transformers.PreTrainedTokenizer`):
|
||||
Updated tokenizer with the chat template and special tokens applied.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import clone_chat_template
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
|
||||
model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")
|
||||
```
|
||||
"""
|
||||
# Load the source tokenizer containing the desired chat template
|
||||
tokenizer_source = AutoTokenizer.from_pretrained(source_tokenizer_path)
|
||||
|
||||
# Copy the chat template from the source tokenizer
|
||||
tokenizer.chat_template = tokenizer_source.get_chat_template()
|
||||
|
||||
# Ensure all added tokens from the source are available in the target tokenizer
|
||||
tokenizer.add_tokens(list(tokenizer_source.added_tokens_decoder.values()))
|
||||
|
||||
# Set the EOS token from the source tokenizer (important for generation)
|
||||
tokenizer.eos_token = tokenizer_source.eos_token
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
model.generation_config.eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
# Resize model embeddings to include any new tokens, optionally rounding up to a multiple
|
||||
model.resize_token_embeddings(
|
||||
# After studying many tokenizers, we found that len(tokenizer.vocab) is the most reliable way to get the vocab
|
||||
# size. Avoid using tokenizer.vocab_size or tokenizer.vocab_size + len(tokenizer.added_tokens_encoder),
|
||||
# as handling of special and added tokens varies across tokenizers.
|
||||
new_num_tokens=len(tokenizer.vocab),
|
||||
pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None,
|
||||
)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def remove_hooks(model: "DeepSpeedEngine") -> None:
|
||||
"""Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
|
||||
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
|
||||
|
@ -63,10 +63,10 @@ from trl import (
|
||||
SFTConfig,
|
||||
SFTTrainer,
|
||||
TrlParser,
|
||||
clone_chat_template,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
setup_chat_format,
|
||||
)
|
||||
|
||||
|
||||
@ -104,7 +104,8 @@ def main(script_args, training_args, model_args):
|
||||
|
||||
# Set default chat template if needed
|
||||
if tokenizer.chat_template is None:
|
||||
model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
|
||||
# TODO: source should be passed as an argument
|
||||
model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")
|
||||
|
||||
################
|
||||
# Dataset
|
||||
|
Reference in New Issue
Block a user