mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 10:03:51 +08:00
📜 Add chat_template_path
parameter to SFTConfig
(#3599)
This commit is contained in:
committed by
GitHub
parent
37a71e82bf
commit
67f17f7ea4
89
tests/data/template.jinja
Normal file
89
tests/data/template.jinja
Normal file
@ -0,0 +1,89 @@
|
||||
{%- if tools %}
|
||||
{{- '<|im_start|>system\n' }}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{{- messages[0].content + '\n\n' }}
|
||||
{%- endif %}
|
||||
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
||||
{%- for tool in tools %}
|
||||
{{- "\n" }}
|
||||
{{- tool | tojson }}
|
||||
{%- endfor %}
|
||||
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
||||
{%- else %}
|
||||
{%- if messages[0].role == 'system' %}
|
||||
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
||||
{%- for message in messages[::-1] %}
|
||||
{%- set index = (messages|length - 1) - loop.index0 %}
|
||||
{%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
||||
{%- set ns.multi_step_tool = false %}
|
||||
{%- set ns.last_query_index = index %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- for message in messages %}
|
||||
{%- if message.content is string %}
|
||||
{%- set content = message.content %}
|
||||
{%- else %}
|
||||
{%- set content = '' %}
|
||||
{%- endif %}
|
||||
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
||||
{%- elif message.role == "assistant" %}
|
||||
{%- set reasoning_content = '' %}
|
||||
{%- if message.reasoning_content is string %}
|
||||
{%- set reasoning_content = message.reasoning_content %}
|
||||
{%- else %}
|
||||
{%- if '</think>' in content %}
|
||||
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
||||
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- if loop.index0 > ns.last_query_index %}
|
||||
{%- if loop.last or (not loop.last and reasoning_content) %}
|
||||
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
||||
{%- else %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content }}
|
||||
{%- endif %}
|
||||
{%- else %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content }}
|
||||
{%- endif %}
|
||||
{%- if message.tool_calls %}
|
||||
{%- for tool_call in message.tool_calls %}
|
||||
{%- if (loop.first and content) or (not loop.first) %}
|
||||
{{- '\n' }}
|
||||
{%- endif %}
|
||||
{%- if tool_call.function %}
|
||||
{%- set tool_call = tool_call.function %}
|
||||
{%- endif %}
|
||||
{{- '<tool_call>\n{"name": "' }}
|
||||
{{- tool_call.name }}
|
||||
{{- '", "arguments": ' }}
|
||||
{%- if tool_call.arguments is string %}
|
||||
{{- tool_call.arguments }}
|
||||
{%- else %}
|
||||
{{- tool_call.arguments | tojson }}
|
||||
{%- endif %}
|
||||
{{- '}\n</tool_call>' }}
|
||||
{%- endfor %}
|
||||
{%- endif %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- elif message.role == "tool" %}
|
||||
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
||||
{{- '<|im_start|>user' }}
|
||||
{%- endif %}
|
||||
{{- '\n<tool_response>\n' }}
|
||||
{{- content }}
|
||||
{{- '\n</tool_response>' }}
|
||||
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
||||
{{- '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- if enable_thinking is defined and enable_thinking is false %}
|
||||
{{- '<think>\n\n</think>\n\n' }}
|
||||
{%- endif %}
|
||||
{%- endif %}
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import pathlib
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@ -1431,3 +1432,59 @@ class SFTTrainerTester2(unittest.TestCase):
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||
|
||||
def test_train_with_set_chat_template_from_model(self):
|
||||
# Get the dataset
|
||||
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Initialize the trainer
|
||||
training_args = SFTConfig(output_dir=tmp_dir, chat_template_path="Qwen/Qwen3-4B", report_to="none")
|
||||
# trl-internal-testing/tiny-GPTNeoXForCausalLM doesn't have a chat template set by default
|
||||
trainer = SFTTrainer(
|
||||
model="trl-internal-testing/tiny-GPTNeoXForCausalLM", args=training_args, train_dataset=dataset
|
||||
)
|
||||
|
||||
# Save the initial parameters to compare them later
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
# Train the model
|
||||
trainer.train()
|
||||
|
||||
# Check that the training loss is not None
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# Check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||
|
||||
def test_train_with_set_chat_template_from_path(self):
|
||||
# Get the dataset
|
||||
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Initialize the trainer
|
||||
training_args = SFTConfig(
|
||||
output_dir=tmp_dir,
|
||||
chat_template_path=str(pathlib.Path(__file__).parent / "data" / "template.jinja"),
|
||||
report_to="none",
|
||||
)
|
||||
# trl-internal-testing/tiny-GPTNeoXForCausalLM doesn't have a chat template set by default
|
||||
trainer = SFTTrainer(
|
||||
model="trl-internal-testing/tiny-GPTNeoXForCausalLM", args=training_args, train_dataset=dataset
|
||||
)
|
||||
|
||||
# Save the initial parameters to compare them later
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
# Train the model
|
||||
trainer.train()
|
||||
|
||||
# Check that the training loss is not None
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# Check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||
|
@ -38,6 +38,11 @@ class SFTConfig(TrainingArguments):
|
||||
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
||||
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
||||
argument of the [`SFTTrainer`] is provided as a string.
|
||||
chat_template_path (`str` or `None`, *optional*, defaults to `None`):
|
||||
If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory
|
||||
or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must
|
||||
ensure that any special tokens referenced in the template are added to the tokenizer and that the model's
|
||||
embedding layer is resized accordingly.
|
||||
|
||||
> Parameters that control the data preprocessing
|
||||
|
||||
@ -130,6 +135,15 @@ class SFTConfig(TrainingArguments):
|
||||
"the `SFTTrainer` is provided as a string."
|
||||
},
|
||||
)
|
||||
chat_template_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If specified, sets the model's chat template. This can either be the path to a tokenizer (local "
|
||||
"directory or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, "
|
||||
"you must ensure that any special tokens referenced in the template are added to the tokenizer and "
|
||||
"that the model's embedding layer is resized accordingly."
|
||||
},
|
||||
)
|
||||
|
||||
# Parameters that control the data preprocessing
|
||||
dataset_text_field: str = field(
|
||||
|
@ -51,7 +51,7 @@ from ..data_utils import (
|
||||
pack_dataset,
|
||||
truncate_dataset,
|
||||
)
|
||||
from ..models import get_act_offloading_ctx_manager
|
||||
from ..models import clone_chat_template, get_act_offloading_ctx_manager
|
||||
from .sft_config import SFTConfig
|
||||
from .utils import (
|
||||
ConstantLengthDataset,
|
||||
@ -342,6 +342,13 @@ class SFTTrainer(Trainer):
|
||||
if isinstance(model, str):
|
||||
model = self._create_model_from_path(model, args)
|
||||
|
||||
if args.chat_template_path is not None:
|
||||
if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")):
|
||||
with open(args.chat_template_path, encoding="utf-8") as chat_template_file:
|
||||
processing_class.chat_template = chat_template_file.read()
|
||||
else:
|
||||
model, processing_class = clone_chat_template(model, processing_class, args.chat_template_path)
|
||||
|
||||
# PEFT configuration and model wrapping
|
||||
if peft_config is not None:
|
||||
model = self._prepare_peft_model(model, peft_config, args)
|
||||
|
Reference in New Issue
Block a user