mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
📜 Add chat_template_path
parameter to SFTConfig
(#3599)
This commit is contained in:
committed by
GitHub
parent
37a71e82bf
commit
67f17f7ea4
89
tests/data/template.jinja
Normal file
89
tests/data/template.jinja
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
{%- if tools %}
|
||||||
|
{{- '<|im_start|>system\n' }}
|
||||||
|
{%- if messages[0].role == 'system' %}
|
||||||
|
{{- messages[0].content + '\n\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
||||||
|
{%- for tool in tools %}
|
||||||
|
{{- "\n" }}
|
||||||
|
{{- tool | tojson }}
|
||||||
|
{%- endfor %}
|
||||||
|
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
||||||
|
{%- else %}
|
||||||
|
{%- if messages[0].role == 'system' %}
|
||||||
|
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
||||||
|
{%- for message in messages[::-1] %}
|
||||||
|
{%- set index = (messages|length - 1) - loop.index0 %}
|
||||||
|
{%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
||||||
|
{%- set ns.multi_step_tool = false %}
|
||||||
|
{%- set ns.last_query_index = index %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- for message in messages %}
|
||||||
|
{%- if message.content is string %}
|
||||||
|
{%- set content = message.content %}
|
||||||
|
{%- else %}
|
||||||
|
{%- set content = '' %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
||||||
|
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
||||||
|
{%- elif message.role == "assistant" %}
|
||||||
|
{%- set reasoning_content = '' %}
|
||||||
|
{%- if message.reasoning_content is string %}
|
||||||
|
{%- set reasoning_content = message.reasoning_content %}
|
||||||
|
{%- else %}
|
||||||
|
{%- if '</think>' in content %}
|
||||||
|
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
||||||
|
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if loop.index0 > ns.last_query_index %}
|
||||||
|
{%- if loop.last or (not loop.last and reasoning_content) %}
|
||||||
|
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
||||||
|
{%- else %}
|
||||||
|
{{- '<|im_start|>' + message.role + '\n' + content }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- else %}
|
||||||
|
{{- '<|im_start|>' + message.role + '\n' + content }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if message.tool_calls %}
|
||||||
|
{%- for tool_call in message.tool_calls %}
|
||||||
|
{%- if (loop.first and content) or (not loop.first) %}
|
||||||
|
{{- '\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- if tool_call.function %}
|
||||||
|
{%- set tool_call = tool_call.function %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '<tool_call>\n{"name": "' }}
|
||||||
|
{{- tool_call.name }}
|
||||||
|
{{- '", "arguments": ' }}
|
||||||
|
{%- if tool_call.arguments is string %}
|
||||||
|
{{- tool_call.arguments }}
|
||||||
|
{%- else %}
|
||||||
|
{{- tool_call.arguments | tojson }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '}\n</tool_call>' }}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '<|im_end|>\n' }}
|
||||||
|
{%- elif message.role == "tool" %}
|
||||||
|
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
||||||
|
{{- '<|im_start|>user' }}
|
||||||
|
{%- endif %}
|
||||||
|
{{- '\n<tool_response>\n' }}
|
||||||
|
{{- content }}
|
||||||
|
{{- '\n</tool_response>' }}
|
||||||
|
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
||||||
|
{{- '<|im_end|>\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
{%- if add_generation_prompt %}
|
||||||
|
{{- '<|im_start|>assistant\n' }}
|
||||||
|
{%- if enable_thinking is defined and enable_thinking is false %}
|
||||||
|
{{- '<think>\n\n</think>\n\n' }}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
@ -13,6 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import pathlib
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
@ -1431,3 +1432,59 @@ class SFTTrainerTester2(unittest.TestCase):
|
|||||||
for n, param in previous_trainable_params.items():
|
for n, param in previous_trainable_params.items():
|
||||||
new_param = trainer.model.get_parameter(n)
|
new_param = trainer.model.get_parameter(n)
|
||||||
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
|
|
||||||
|
def test_train_with_set_chat_template_from_model(self):
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# Initialize the trainer
|
||||||
|
training_args = SFTConfig(output_dir=tmp_dir, chat_template_path="Qwen/Qwen3-4B", report_to="none")
|
||||||
|
# trl-internal-testing/tiny-GPTNeoXForCausalLM doesn't have a chat template set by default
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
model="trl-internal-testing/tiny-GPTNeoXForCausalLM", args=training_args, train_dataset=dataset
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the initial parameters to compare them later
|
||||||
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the training loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
|
|
||||||
|
# Check the params have changed
|
||||||
|
for n, param in previous_trainable_params.items():
|
||||||
|
new_param = trainer.model.get_parameter(n)
|
||||||
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
|
|
||||||
|
def test_train_with_set_chat_template_from_path(self):
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# Initialize the trainer
|
||||||
|
training_args = SFTConfig(
|
||||||
|
output_dir=tmp_dir,
|
||||||
|
chat_template_path=str(pathlib.Path(__file__).parent / "data" / "template.jinja"),
|
||||||
|
report_to="none",
|
||||||
|
)
|
||||||
|
# trl-internal-testing/tiny-GPTNeoXForCausalLM doesn't have a chat template set by default
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
model="trl-internal-testing/tiny-GPTNeoXForCausalLM", args=training_args, train_dataset=dataset
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the initial parameters to compare them later
|
||||||
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the training loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
|
|
||||||
|
# Check the params have changed
|
||||||
|
for n, param in previous_trainable_params.items():
|
||||||
|
new_param = trainer.model.get_parameter(n)
|
||||||
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
|
@ -38,6 +38,11 @@ class SFTConfig(TrainingArguments):
|
|||||||
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
||||||
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
||||||
argument of the [`SFTTrainer`] is provided as a string.
|
argument of the [`SFTTrainer`] is provided as a string.
|
||||||
|
chat_template_path (`str` or `None`, *optional*, defaults to `None`):
|
||||||
|
If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory
|
||||||
|
or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must
|
||||||
|
ensure that any special tokens referenced in the template are added to the tokenizer and that the model's
|
||||||
|
embedding layer is resized accordingly.
|
||||||
|
|
||||||
> Parameters that control the data preprocessing
|
> Parameters that control the data preprocessing
|
||||||
|
|
||||||
@ -130,6 +135,15 @@ class SFTConfig(TrainingArguments):
|
|||||||
"the `SFTTrainer` is provided as a string."
|
"the `SFTTrainer` is provided as a string."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
chat_template_path: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "If specified, sets the model's chat template. This can either be the path to a tokenizer (local "
|
||||||
|
"directory or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, "
|
||||||
|
"you must ensure that any special tokens referenced in the template are added to the tokenizer and "
|
||||||
|
"that the model's embedding layer is resized accordingly."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Parameters that control the data preprocessing
|
# Parameters that control the data preprocessing
|
||||||
dataset_text_field: str = field(
|
dataset_text_field: str = field(
|
||||||
|
@ -51,7 +51,7 @@ from ..data_utils import (
|
|||||||
pack_dataset,
|
pack_dataset,
|
||||||
truncate_dataset,
|
truncate_dataset,
|
||||||
)
|
)
|
||||||
from ..models import get_act_offloading_ctx_manager
|
from ..models import clone_chat_template, get_act_offloading_ctx_manager
|
||||||
from .sft_config import SFTConfig
|
from .sft_config import SFTConfig
|
||||||
from .utils import (
|
from .utils import (
|
||||||
ConstantLengthDataset,
|
ConstantLengthDataset,
|
||||||
@ -342,6 +342,13 @@ class SFTTrainer(Trainer):
|
|||||||
if isinstance(model, str):
|
if isinstance(model, str):
|
||||||
model = self._create_model_from_path(model, args)
|
model = self._create_model_from_path(model, args)
|
||||||
|
|
||||||
|
if args.chat_template_path is not None:
|
||||||
|
if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")):
|
||||||
|
with open(args.chat_template_path, encoding="utf-8") as chat_template_file:
|
||||||
|
processing_class.chat_template = chat_template_file.read()
|
||||||
|
else:
|
||||||
|
model, processing_class = clone_chat_template(model, processing_class, args.chat_template_path)
|
||||||
|
|
||||||
# PEFT configuration and model wrapping
|
# PEFT configuration and model wrapping
|
||||||
if peft_config is not None:
|
if peft_config is not None:
|
||||||
model = self._prepare_peft_model(model, peft_config, args)
|
model = self._prepare_peft_model(model, peft_config, args)
|
||||||
|
Reference in New Issue
Block a user