📜 Add chat_template_path parameter to SFTConfig (#3599)

This commit is contained in:
Quentin Gallouédec
2025-06-20 14:15:03 +02:00
committed by GitHub
parent 37a71e82bf
commit 67f17f7ea4
4 changed files with 168 additions and 1 deletions

89
tests/data/template.jinja Normal file
View File

@ -0,0 +1,89 @@
{%- if tools %}
{{- '<|im_start|>system\n' }}
{%- if messages[0].role == 'system' %}
{{- messages[0].content + '\n\n' }}
{%- endif %}
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
{%- for tool in tools %}
{{- "\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
{%- else %}
{%- if messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
{%- for message in messages[::-1] %}
{%- set index = (messages|length - 1) - loop.index0 %}
{%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
{%- set ns.multi_step_tool = false %}
{%- set ns.last_query_index = index %}
{%- endif %}
{%- endfor %}
{%- for message in messages %}
{%- if message.content is string %}
{%- set content = message.content %}
{%- else %}
{%- set content = '' %}
{%- endif %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{%- set reasoning_content = '' %}
{%- if message.reasoning_content is string %}
{%- set reasoning_content = message.reasoning_content %}
{%- else %}
{%- if '</think>' in content %}
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
{%- endif %}
{%- endif %}
{%- if loop.index0 > ns.last_query_index %}
{%- if loop.last or (not loop.last and reasoning_content) %}
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- else %}
{{- '<|im_start|>' + message.role + '\n' + content }}
{%- endif %}
{%- if message.tool_calls %}
{%- for tool_call in message.tool_calls %}
{%- if (loop.first and content) or (not loop.first) %}
{{- '\n' }}
{%- endif %}
{%- if tool_call.function %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '<tool_call>\n{"name": "' }}
{{- tool_call.name }}
{{- '", "arguments": ' }}
{%- if tool_call.arguments is string %}
{{- tool_call.arguments }}
{%- else %}
{{- tool_call.arguments | tojson }}
{%- endif %}
{{- '}\n</tool_call>' }}
{%- endfor %}
{%- endif %}
{{- '<|im_end|>\n' }}
{%- elif message.role == "tool" %}
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\n<tool_response>\n' }}
{{- content }}
{{- '\n</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|im_end|>\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- if enable_thinking is defined and enable_thinking is false %}
{{- '<think>\n\n</think>\n\n' }}
{%- endif %}
{%- endif %}

View File

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import copy import copy
import pathlib
import tempfile import tempfile
import unittest import unittest
@ -1431,3 +1432,59 @@ class SFTTrainerTester2(unittest.TestCase):
for n, param in previous_trainable_params.items(): for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n) new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_train_with_set_chat_template_from_model(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(output_dir=tmp_dir, chat_template_path="Qwen/Qwen3-4B", report_to="none")
# trl-internal-testing/tiny-GPTNeoXForCausalLM doesn't have a chat template set by default
trainer = SFTTrainer(
model="trl-internal-testing/tiny-GPTNeoXForCausalLM", args=training_args, train_dataset=dataset
)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_train_with_set_chat_template_from_path(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(
output_dir=tmp_dir,
chat_template_path=str(pathlib.Path(__file__).parent / "data" / "template.jinja"),
report_to="none",
)
# trl-internal-testing/tiny-GPTNeoXForCausalLM doesn't have a chat template set by default
trainer = SFTTrainer(
model="trl-internal-testing/tiny-GPTNeoXForCausalLM", args=training_args, train_dataset=dataset
)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")

View File

@ -38,6 +38,11 @@ class SFTConfig(TrainingArguments):
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
argument of the [`SFTTrainer`] is provided as a string. argument of the [`SFTTrainer`] is provided as a string.
chat_template_path (`str` or `None`, *optional*, defaults to `None`):
If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory
or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must
ensure that any special tokens referenced in the template are added to the tokenizer and that the model's
embedding layer is resized accordingly.
> Parameters that control the data preprocessing > Parameters that control the data preprocessing
@ -130,6 +135,15 @@ class SFTConfig(TrainingArguments):
"the `SFTTrainer` is provided as a string." "the `SFTTrainer` is provided as a string."
}, },
) )
chat_template_path: Optional[str] = field(
default=None,
metadata={
"help": "If specified, sets the model's chat template. This can either be the path to a tokenizer (local "
"directory or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, "
"you must ensure that any special tokens referenced in the template are added to the tokenizer and "
"that the model's embedding layer is resized accordingly."
},
)
# Parameters that control the data preprocessing # Parameters that control the data preprocessing
dataset_text_field: str = field( dataset_text_field: str = field(

View File

@ -51,7 +51,7 @@ from ..data_utils import (
pack_dataset, pack_dataset,
truncate_dataset, truncate_dataset,
) )
from ..models import get_act_offloading_ctx_manager from ..models import clone_chat_template, get_act_offloading_ctx_manager
from .sft_config import SFTConfig from .sft_config import SFTConfig
from .utils import ( from .utils import (
ConstantLengthDataset, ConstantLengthDataset,
@ -342,6 +342,13 @@ class SFTTrainer(Trainer):
if isinstance(model, str): if isinstance(model, str):
model = self._create_model_from_path(model, args) model = self._create_model_from_path(model, args)
if args.chat_template_path is not None:
if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")):
with open(args.chat_template_path, encoding="utf-8") as chat_template_file:
processing_class.chat_template = chat_template_file.read()
else:
model, processing_class = clone_chat_template(model, processing_class, args.chat_template_path)
# PEFT configuration and model wrapping # PEFT configuration and model wrapping
if peft_config is not None: if peft_config is not None:
model = self._prepare_peft_model(model, peft_config, args) model = self._prepare_peft_model(model, peft_config, args)