Compare commits

...

17 Commits

Author SHA1 Message Date
013d360b8f 🔹 Fix: Miscalculated mask shape in comments (#2925) 2025-02-21 17:01:53 +01:00
e5ae703d35 🐦🔥 6x faster GRPO with multi-step optimization (#2899)
* Add num_updates and epsilon parameters to GRPOConfig and GRPOTrainer

* test sampler

* update the loss computation

* fix eval sampler

* should work now

* buffer inputs with grad accum

* optimize when num_iterations == 1

* test

* minor comment removal and fix log metric

* beta position

* clarify comment [ci skip]

* clarify sampler doc [ci skip]

* fix collision with eval logging

* clarify
2025-02-20 19:51:45 +01:00
a92e00e810 🪪 Adds profiling decorators for GRPOTrainer (#2889)
* adds profiling decorator

* naming + precommit

* style

* revert inclusion of slider table

* revert 2

* revert3

* revert4

* revert 5 fml

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-20 09:57:42 +01:00
9b3c5bf64f 📍 [GRPO] add gradient_checkpointing (#2848)
* add gradient_checkpointing

* added a helper

* Update trl/trainer/grpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/grpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* minor refactor for better readability

* use acceelrate util

* enable_input_require_grads is in base class

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-02-18 18:09:16 +01:00
15fec312d5 🍃 GRPO - Do not load reference model when beta == 0 (#2806)
* 🔧 Optimize GRPO training by conditionally loading reference model based on beta value

*  Add test for GRPOTrainer with beta=0 to ensure no reference model and KL divergence

* 🔧 Refactor GRPOTrainer code for improved readability and maintainability

* 🔧 Simplify per_token_loss calculation in GRPOTrainer for clarity

* fix test, style, and some struct for clarity

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-18 17:57:15 +01:00
be1e34003c 🩳 max_seq_length to max_length (#2895)
* `max_seq_length` to `max_length`

* remove in 0.20
2025-02-18 16:53:37 +01:00
6aaf379a82 ⚰️ Remove deprecated (#2894) 2025-02-18 16:53:21 +01:00
49adf74833 Add vLLM guided decoding support to GRPO Trainer (#2811)
*  Add vLLM guided decoding support to GRPO Trainer

* 🔧 Update vLLM guided decoding in GRPO to use regex parameter

* style and docstring

* test

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-18 16:53:05 +01:00
6c54f023ae 🪂 Don't gather logits in SFT to avoid hanging (#2890)
* Don't gather logits

* Remove unused function and test
2025-02-18 15:31:08 +01:00
963243a7d1 Optimize vllm num_generations (#2855)
* small optimization of vllm batching

* style

* adds comment

* style
2025-02-18 11:44:15 +01:00
aafd8cbea5 🍟 [SFT] Handles the dataset if it has been preprocessed (#2863)
* return dataset if it's preprocessed

* add is_processed flag variable

* add test

* move test_sft_trainer_directly_with_pretokenized_data to Tester2

* Update sft_trainer.py

* no need for padding and truncation

* minor reorganization

* Update trl/trainer/sft_trainer.py

* let the collator pad

* style

* fix tests

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-02-18 09:56:47 +01:00
822653824b 🧶 [GRPO][vLLM + LoRA] Move unmerge of PEFT model after weight loading (#2873) 2025-02-17 20:34:07 +01:00
ba036576d4 💬 Add maybe_convert_to_chatml map for conversational datasets in SFT (#2862)
* add back get_formatting_func_from_dataset

* maybe_convert_to_chatml

* maybe_convert_to_chatml before maybe_apply_chat_template map

* remove comment

* test

* desc

* style

* Update trl/data_utils.py

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-17 16:47:06 +01:00
293b620950 [GRPO] Fix loss normalization (#2881)
* fix GRPO loss normalization

* fix sum dim

* fix loss= repeated
2025-02-17 13:26:21 +01:00
ae3bd0d07a 🆙 Bump vLLM min version to 0.7.2 (#2860)
Bumps vllm as there were a number of throughput improvements in vllm==0.7.2

Also may resolve issue such as https://github.com/huggingface/trl/issues/2851
2025-02-17 10:54:07 +01:00
6d9fc11fd6 [SFT] fix check for AutoLigerKernelForCausalLM (#2874)
* fix check for AutoLigerKernelForCausalLM

* fix case where AutoLigerKernelForCausalLM is not defined

* update min liger version

* formatting

* fix win CI
2025-02-17 07:50:55 +01:00
ffcb9f4aee ⬆️ Bump dev version 2025-02-13 14:33:44 +00:00
24 changed files with 879 additions and 247 deletions

View File

@ -42,7 +42,7 @@ accelerate launch $EXTRA_ACCELERATE_ARGS \
--output_dir $OUTPUT_DIR \
--max_steps $MAX_STEPS \
--per_device_train_batch_size $BATCH_SIZE \
--max_seq_length $SEQ_LEN \
--max_length $SEQ_LEN \
$EXTRA_TRAINING_ARGS
"""

View File

@ -12,6 +12,10 @@
[[autodoc]] maybe_apply_chat_template
## maybe_convert_to_chatml
[[autodoc]] maybe_convert_to_chatml
## extract_prompt
[[autodoc]] extract_prompt

View File

@ -44,7 +44,7 @@ training_args = DPOConfig(..., max_completion_length=...)
</hfoption>
<hfoption id="SFT">
SFT truncation is applied to the input sequence via the `max_seq_length` parameter.
SFT truncation is applied to the input sequence via the `max_length` parameter.
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/truncation_input_ids.png" alt="Truncation input ids" width="600"/>
@ -55,7 +55,7 @@ To set the truncation parameter, use the following code snippet:
```python
from trl import SFTConfig
training_args = SFTConfig(..., max_seq_length=...)
training_args = SFTConfig(..., max_length=...)
```
</hfoption>
@ -85,7 +85,7 @@ Packing eliminates padding, preserves all sequence information, and allows for f
```python
from trl import SFTConfig
training_args = SFTConfig(..., packing=True, max_seq_length=512)
training_args = SFTConfig(..., packing=True, max_length=512)
```
<Tip warning={true}>

View File

@ -19,7 +19,7 @@ from trl import SFTConfig, SFTTrainer
dataset = load_dataset("stanfordnlp/imdb", split="train")
training_args = SFTConfig(
max_seq_length=512,
max_length=512,
output_dir="/tmp",
)
trainer = SFTTrainer(
@ -29,7 +29,7 @@ trainer = SFTTrainer(
)
trainer.train()
```
Make sure to pass the correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
Make sure to pass the correct value for `max_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
You can also construct a model outside of the trainer and pass it as follows:
@ -550,12 +550,12 @@ import torch
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel
max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number
max_length = 2048 # Supports automatic RoPE Scaling, so choose any number
# Load model
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/mistral-7b",
max_seq_length=max_seq_length,
max_seq_length=max_length,
dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
@ -581,7 +581,7 @@ model = FastLanguageModel.get_peft_model(
random_state=3407,
)
training_args = SFTConfig(output_dir="./output", max_seq_length=max_seq_length)
training_args = SFTConfig(output_dir="./output", max_length=max_length)
trainer = SFTTrainer(
model=model,
@ -624,7 +624,7 @@ To learn more about Liger-Kernel, visit their [official repository](https://gith
Pay attention to the following best practices when training a model with that trainer:
- [`SFTTrainer`] always truncates by default the sequences to the `max_seq_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training.
- [`SFTTrainer`] always truncates by default the sequences to the `max_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training.
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.

View File

@ -185,7 +185,7 @@ trainer = SFTTrainer(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
max_seq_length=None,
max_length=None,
formatting_func=prepare_sample_text,
processing_class=tokenizer,
args=training_args,

View File

@ -71,7 +71,7 @@ To create the package for PyPI.
from setuptools import find_packages, setup
__version__ = "0.15.0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
__version__ = "0.16.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
REQUIRED_PKGS = [
"accelerate>=0.34.0",
@ -85,13 +85,13 @@ EXTRAS = {
"diffusers": ["diffusers>=0.18.0"],
"judges": ["openai>=1.23.2", "llm-blender>=0.0.2"],
# liger-kernel depends on triton, which is only available on Linux https://github.com/triton-lang/triton#compatibility
"liger": ["liger-kernel>=0.4.0; sys_platform != 'win32'"],
"liger": ["liger-kernel>=0.5.3; sys_platform != 'win32'"],
"mergekit": ["mergekit>=0.0.5.1"],
"peft": ["peft>=0.8.0"],
"quantization": ["bitsandbytes"],
"scikit": ["scikit-learn"],
"test": ["parameterized", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "pytest"],
"vllm": ["vllm>=0.7.1; sys_platform != 'win32'"], # vllm is not available on Windows
"vllm": ["vllm>=0.7.2; sys_platform != 'win32'"], # vllm is not available on Windows
"vlm": ["Pillow"],
}
EXTRAS["dev"] = []

View File

@ -46,7 +46,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
def setUp(self):
self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]")
self.eval_dataset = load_dataset("stanfordnlp/imdb", split="test[:10%]")
self.max_seq_length = 128
self.max_length = 128
self.peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
@ -74,7 +74,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
per_device_train_batch_size=2,
max_steps=10,
packing=packing,
max_seq_length=self.max_seq_length,
max_length=self.max_length,
)
trainer = SFTTrainer(
@ -100,7 +100,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
per_device_train_batch_size=2,
max_steps=10,
packing=packing,
max_seq_length=self.max_seq_length,
max_length=self.max_length,
)
model = AutoModelForCausalLM.from_pretrained(model_name)
@ -135,7 +135,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
max_steps=10,
fp16=True,
packing=packing,
max_seq_length=self.max_seq_length,
max_length=self.max_length,
)
model = AutoModelForCausalLM.from_pretrained(model_name)
@ -172,7 +172,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
max_steps=10,
fp16=True, # this is sufficient to enable amp
packing=packing,
max_seq_length=self.max_seq_length,
max_length=self.max_length,
)
model = AutoModelForCausalLM.from_pretrained(model_name)
@ -205,7 +205,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
per_device_train_batch_size=2,
max_steps=10,
packing=packing,
max_seq_length=self.max_seq_length,
max_length=self.max_length,
fp16=True, # this is sufficient to enable amp
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
@ -242,7 +242,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
per_device_train_batch_size=2,
max_steps=10,
packing=packing,
max_seq_length=self.max_seq_length,
max_length=self.max_length,
fp16=True, # this is sufficient to enable amp
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
@ -286,7 +286,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
per_device_train_batch_size=2,
max_steps=10,
packing=packing,
max_seq_length=self.max_seq_length,
max_length=self.max_length,
fp16=True, # this is sufficient to enable amp
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
@ -324,7 +324,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
per_device_train_batch_size=2,
max_steps=10,
packing=packing,
max_seq_length=self.max_seq_length,
max_length=self.max_length,
fp16=True, # this is sufficient to enable amp
gradient_checkpointing=True,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
@ -364,7 +364,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
training_args = SFTConfig(
packing=packing,
max_seq_length=self.max_seq_length,
max_length=self.max_length,
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
@ -411,7 +411,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
per_device_train_batch_size=2,
max_steps=2,
packing=packing,
max_seq_length=self.max_seq_length,
max_length=self.max_length,
use_liger=True,
)

View File

@ -24,6 +24,7 @@ from trl.data_utils import (
extract_prompt,
is_conversational,
maybe_apply_chat_template,
maybe_convert_to_chatml,
maybe_extract_prompt,
maybe_unpair_preference_dataset,
pack_examples,
@ -435,6 +436,51 @@ class TestPackExamples(unittest.TestCase):
self.assertEqual(dataset.to_dict(), expected_output)
class TestMaybeConvertToChatML(unittest.TestCase):
def test_with_conversations_key(self):
# Particular case where the key is "conversations": we rename it to "messages"
example = {
"conversations": [
{"from": "user", "value": "What color is the sky?"},
{"from": "assistant", "value": "It is blue."},
]
}
expected_output = {
"messages": [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
]
}
self.assertEqual(maybe_convert_to_chatml(example), expected_output)
def test_without_conversations_key(self):
# Same as before, but we don't rename the keys
example = {
"prompt": [{"from": "user", "value": "What color is the sky?"}],
"completion": [{"from": "assistant", "value": "It is blue."}],
}
expected_output = {
"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is blue."}],
}
self.assertEqual(maybe_convert_to_chatml(example), expected_output)
def test_not_conversional(self):
# When not needed, the example should remain unchanged
example = {"text": "The sky is blue."}
self.assertEqual(maybe_convert_to_chatml(example), example)
def test_already_chatml(self):
# When the example is already in ChatML format, it should remain unchanged
example = {
"messages": [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
]
}
self.assertEqual(maybe_convert_to_chatml(example), example)
# Run the tests
if __name__ == "__main__":
unittest.main()

View File

@ -25,10 +25,113 @@ from transformers.utils import is_peft_available
from trl import GRPOConfig, GRPOTrainer
from trl.import_utils import is_vllm_available
from trl.trainer.grpo_trainer import RepeatRandomSampler
if is_peft_available():
from peft import LoraConfig
from peft import LoraConfig, PeftModel
class RepeatRandomSamplerTester(unittest.TestCase):
def test_sampler(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
sampler = RepeatRandomSampler(dataset, mini_repeat_count=2)
# Should output something like [4, 4, 3, 3, 0, 0, 1, 1, 2, 2, 6, 6, 5, 5]
sampled = list(sampler)
# Check that the length is doubled
assert len(sampled) == 2 * len(dataset)
# Check that all indexes are present
assert set(sampled) == set(range(len(dataset)))
# Check that each element is repeated twice
assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2))
def test_sampler_no_repeat(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
sampler = RepeatRandomSampler(dataset, mini_repeat_count=1)
# Should output something like [4, 3, 0, 1, 2, 6, 5]
sampled = list(sampler)
# Check that the length is the same
assert len(sampled) == len(dataset)
# Check that all indexes are present
assert set(sampled) == set(range(len(dataset)))
def test_sampler_with_batch_size(self):
dataset = ["a", "b", "c", "d", "e", "f", "g", "h"]
sampler = RepeatRandomSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2)
# Should output something like [4, 3, 4, 3, 0, 1, 0, 1, 2, 6, 2, 6, 5, 7, 5, 7]
sampled = list(sampler)
# Check that the length is doubled
assert len(sampled) == 2 * len(dataset)
# Check that all indexes are present
assert set(sampled) == set(range(len(dataset)))
# Check that each element is repeated as expected
assert all(sampled[i : i + 1] == sampled[i + 2 : i + 3] for i in range(0, len(sampled), 4))
def test_sampler_with_batch_size_and_drop(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
sampler = RepeatRandomSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2)
# Should output something like [4, 3, 4, 3, 0, 1, 0, 1, 2, 6, 2, 6]
sampled = list(sampler)
# Check that the length is doubled
assert len(sampled) == 2 * (
len(dataset) - 1
) # one element is dropped, because it's not enough to form a batch
# Check that the sampled indexes are a subset of the dataset indexes
assert set(sampled).issubset(set(range(len(dataset))))
# Check that each element is repeated as expected
assert all(sampled[i : i + 1] == sampled[i + 2 : i + 3] for i in range(0, len(sampled), 4))
def test_sampler_with_mini_repeat_count_and_batch_size_1(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
sampler = RepeatRandomSampler(dataset, mini_repeat_count=2, batch_size=3, repeat_count=2)
# Should output something like [4, 4, 3, 3, 0, 0, 4, 4, 3, 3, 0, 0,
# 1, 1, 2, 2, 6, 6, 1, 1, 2, 2, 6, 6]
sampled = list(sampler)
# Check that the length is quadrupled
assert len(sampled) == 4 * (len(dataset) - 1) # 1 element is dropped, because it's not enough to form a batch
# Check that the sampled indexes are a subset of the dataset indexes
assert set(sampled).issubset(set(range(len(dataset))))
# Check that each element is repeated as expected
assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2))
# Check that the batch is repeated as expected
assert sampled[0:6] == sampled[6:12]
assert sampled[12:18] == sampled[18:24]
def test_sampler_with_mini_repeat_count_and_batch_size_2(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
sampler = RepeatRandomSampler(dataset, mini_repeat_count=3, batch_size=2, repeat_count=2)
# Should output something like [4, 4, 4, 3, 3, 3, 4, 4, 4, 3, 3, 3,
# 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1,
# 2, 2, 2, 6, 6, 6, 2, 2, 2, 6, 6, 6]
sampled = list(sampler)
# Check that the length is sextupled
assert len(sampled) == 6 * (len(dataset) - 1) # 1 element is dropped, because it's not enough to form a batch
# Check that the sampled indexes are a subset of the dataset indexes
assert set(sampled).issubset(set(range(len(dataset))))
# Check that each element is repeated as expected
assert all(sampled[i] == sampled[i + 1] == sampled[i + 2] for i in range(0, len(sampled), 3))
# Check that the batch is repeated as expected
assert sampled[0:6] == sampled[6:12]
assert sampled[12:18] == sampled[18:24]
assert sampled[24:30] == sampled[30:36]
def test_sampler_with_mini_repeat_count_and_batch_size_3(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
sampler = RepeatRandomSampler(dataset, mini_repeat_count=2, batch_size=2, repeat_count=3)
# Should output something like [4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3,
# 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1,
# 2, 2, 6, 6, 2, 2, 6, 6, 2, 2, 6, 6]
sampled = list(sampler)
# Check that the length is sextupled
assert len(sampled) == 6 * (len(dataset) - 1) # 1 element is dropped, because it's not enough to form a batch
# Check that the sampled indexes are a subset of the dataset indexes
assert set(sampled).issubset(set(range(len(dataset))))
# Check that each element is repeated as expected
assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2))
# Check that the batch is repeated as expected
assert sampled[0:4] == sampled[4:8] == sampled[8:12]
assert sampled[12:16] == sampled[16:20] == sampled[20:24]
assert sampled[24:28] == sampled[28:32] == sampled[32:36]
class GRPOTrainerTester(unittest.TestCase):
@ -96,6 +199,37 @@ class GRPOTrainerTester(unittest.TestCase):
trainer.train()
def test_training_multiple_iterations(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
num_iterations=2,
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
@require_peft
def test_training_peft(self):
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
@ -133,6 +267,57 @@ class GRPOTrainerTester(unittest.TestCase):
elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.")
@require_peft
def test_training_peft_with_gradient_checkpointing(self):
"""Test that training works with PEFT and gradient checkpointing enabled."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
model = AutoModelForCausalLM.from_pretrained(
"trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
torch_dtype=torch.float32, # Use float32 for testing to avoid precision issues
use_cache=False, # Required for gradient checkpointing
)
lora_config = LoraConfig(
r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none"
)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=32,
gradient_checkpointing=True, # Enable gradient checkpointing
report_to="none",
)
trainer = GRPOTrainer(
model=model,
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
peft_config=lora_config,
)
# Verify gradient checkpointing is enabled
self.assertIsInstance(trainer.model, PeftModel)
# Store initial parameters to check which ones change
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check that only LoRA parameters have changed, base model parameters remain unchanged
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if "lora" in n.lower(): # LoRA parameters should change
self.assertFalse(torch.equal(param, new_param), f"LoRA parameter {n} has not changed.")
else: # Base model parameters should not change
self.assertTrue(torch.equal(param, new_param), f"Base parameter {n} has changed.")
def test_training_different_reward_model(self):
# Use a reward model different from the model: different chat template, tokenization, etc.
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train")
@ -500,6 +685,36 @@ class GRPOTrainerTester(unittest.TestCase):
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
def test_beta_zero_no_ref_model_and_no_kl(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
beta=0.0, # set beta to 0 to test the case where the reference model is not used
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
@unittest.skipIf(not is_vllm_available(), "vLLM is not available")
@require_torch_accelerator
@require_peft
@ -548,3 +763,39 @@ class GRPOTrainerTester(unittest.TestCase):
elif "base_layer" not in n and "original_module" not in n:
# We expect the peft params to be different (except for the base layer)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.")
@unittest.skipIf(not is_vllm_available(), "vLLM is not available")
@require_torch_accelerator
def test_training_vllm_guided_decoding(self):
"""Test that training works with vLLM for generation with guided decoding."""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
report_to="none",
use_vllm=True,
vllm_device="cuda:0", # will raise a warning, but allows this test to work with only one GPU
vllm_guided_decoding_regex=r"<reasoning>\n.*\n</reasoning>\n<answer>\n.*\n</answer>",
)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

View File

@ -288,7 +288,7 @@ class SFTTrainerTester(unittest.TestCase):
self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
def test_sft_trainer_with_pretokenzied_data_packing(self):
def test_sft_trainer_with_pretokenized_data_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
@ -326,7 +326,7 @@ class SFTTrainerTester(unittest.TestCase):
eval_steps=1,
save_steps=1,
per_device_train_batch_size=2,
max_seq_length=32, # make sure there is at least 1 packed sequence
max_length=32, # make sure there is at least 1 packed sequence
packing=True,
report_to="none",
)
@ -353,7 +353,7 @@ class SFTTrainerTester(unittest.TestCase):
train_dataset=self.conversational_lm_dataset["train"],
)
# Same, but with packing with `max_seq_length`
# Same, but with packing with `max_length`
training_args = SFTConfig(
output_dir=tmp_dir,
dataloader_drop_last=True,
@ -361,7 +361,7 @@ class SFTTrainerTester(unittest.TestCase):
eval_steps=1,
save_steps=1,
per_device_train_batch_size=2,
max_seq_length=16, # make sure there is at least 1 packed sequence
max_length=16, # make sure there is at least 1 packed sequence
packing=True,
report_to="none",
)
@ -396,7 +396,7 @@ class SFTTrainerTester(unittest.TestCase):
eval_steps=1,
save_steps=1,
per_device_train_batch_size=2,
max_seq_length=32, # make sure there is at least 1 packed sequence
max_length=32, # make sure there is at least 1 packed sequence
packing=True,
report_to="none",
)
@ -461,7 +461,7 @@ class SFTTrainerTester(unittest.TestCase):
save_steps=1,
num_train_epochs=2,
per_device_train_batch_size=2,
max_seq_length=16,
max_length=16,
packing=True,
report_to="none",
)
@ -485,7 +485,7 @@ class SFTTrainerTester(unittest.TestCase):
save_steps=1,
num_train_epochs=2,
per_device_train_batch_size=2,
max_seq_length=16,
max_length=16,
report_to="none",
)
trainer = SFTTrainer(
@ -534,7 +534,7 @@ class SFTTrainerTester(unittest.TestCase):
max_steps=2,
save_steps=1,
per_device_train_batch_size=2,
max_seq_length=16,
max_length=16,
packing=True,
report_to="none",
)
@ -558,7 +558,7 @@ class SFTTrainerTester(unittest.TestCase):
max_steps=2,
save_steps=1,
per_device_train_batch_size=2,
max_seq_length=16,
max_length=16,
packing=True,
report_to="none",
)
@ -583,7 +583,7 @@ class SFTTrainerTester(unittest.TestCase):
max_steps=2,
save_steps=1,
per_device_train_batch_size=2,
max_seq_length=16,
max_length=16,
report_to="none",
)
trainer = SFTTrainer(
@ -606,7 +606,7 @@ class SFTTrainerTester(unittest.TestCase):
max_steps=2,
save_steps=1,
per_device_train_batch_size=2,
max_seq_length=16,
max_length=16,
report_to="none",
)
trainer = SFTTrainer(
@ -755,7 +755,7 @@ class SFTTrainerTester(unittest.TestCase):
save_steps=1,
per_device_train_batch_size=2,
packing=True,
max_seq_length=500,
max_length=500,
report_to="none",
)
trainer = SFTTrainer(
@ -782,7 +782,7 @@ class SFTTrainerTester(unittest.TestCase):
per_device_train_batch_size=2,
save_strategy="epoch",
packing=True,
max_seq_length=500,
max_length=500,
report_to="none",
)
trainer = SFTTrainer(
@ -1088,7 +1088,7 @@ class SFTTrainerTester(unittest.TestCase):
per_device_train_batch_size=2,
gradient_checkpointing=True,
packing=True,
max_seq_length=16, # make sure there is at least 1 packed sequence
max_length=16, # make sure there is at least 1 packed sequence
eval_packing=False,
report_to="none",
)
@ -1114,7 +1114,7 @@ class SFTTrainerTester(unittest.TestCase):
save_steps=2,
per_device_train_batch_size=2,
gradient_checkpointing=True,
max_seq_length=16, # make sure there is at least 1 packed sequence
max_length=16, # make sure there is at least 1 packed sequence
packing=True,
report_to="none",
)
@ -1139,7 +1139,7 @@ class SFTTrainerTester(unittest.TestCase):
save_steps=2,
per_device_train_batch_size=2,
gradient_checkpointing=True,
max_seq_length=16, # make sure there is at least 1 packed sequence
max_length=16, # make sure there is at least 1 packed sequence
packing=False,
report_to="none",
)
@ -1370,3 +1370,65 @@ class SFTTrainerTester2(unittest.TestCase):
"base_layer" not in n
): # We expect the peft parameters to be different (except for the base layer)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_train_with_non_chatml_conversational_data(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id)
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
# Rename role/content to from/value to ensure SFT works with non-chatML conversational data
def rename_fields(example: list[dict]):
return {"conversations": [{"from": m["role"], "value": m["content"]} for m in example["messages"]]}
dataset = dataset.map(rename_fields, remove_columns="messages")
with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(output_dir=tmp_dir, report_to="none")
trainer = SFTTrainer(args=training_args, model=model, train_dataset=dataset)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_sft_trainer_with_pretokenized_data(self):
# Get the model and dataset
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
def tokenize_example(example):
return tokenizer(example["text"])
# Apply tokenization
tokenized_dataset = dataset.map(tokenize_example, remove_columns=["text"])
with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(output_dir=tmp_dir, report_to="none")
trainer = SFTTrainer(args=training_args, model=model, train_dataset=tokenized_dataset)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")

View File

@ -368,7 +368,7 @@ class TrainerArgTester(unittest.TestCase):
tmp_dir,
dataset_text_field="dummy_text_field",
packing=True,
max_seq_length=256,
max_length=256,
dataset_num_proc=4,
dataset_batch_size=512,
neftune_noise_alpha=0.1,
@ -379,7 +379,7 @@ class TrainerArgTester(unittest.TestCase):
trainer = SFTTrainer(model_id, args=training_args, train_dataset=dataset)
self.assertEqual(trainer.args.dataset_text_field, "dummy_text_field")
self.assertEqual(trainer.args.packing, True)
self.assertEqual(trainer.args.max_seq_length, 256)
self.assertEqual(trainer.args.max_length, 256)
self.assertEqual(trainer.args.dataset_num_proc, 4)
self.assertEqual(trainer.args.dataset_batch_size, 512)
self.assertEqual(trainer.args.neftune_noise_alpha, 0.1)

View File

@ -27,7 +27,6 @@ from trl.trainer import compute_accuracy
from trl.trainer.utils import (
DataCollatorForChatML,
batch_generation,
compute_token_accuracy,
decode_and_strip_padding,
flush_left,
generate_model_card,
@ -456,60 +455,6 @@ class TestFlushLeft(unittest.TestCase):
self.assertTrue(torch.equal(new_mask, expected_mask))
class TestComputeTokenAccuracy(unittest.TestCase):
def test_basic_accuracy(self):
# Test basic accuracy computation
logits = torch.tensor([[[0.9, 0.1], [0.8, 0.2]], [[0.3, 0.7], [0.6, 0.4]]]) # Shape: [2, 2, 2]
labels = torch.tensor([[1, 0], [1, 0]]) # Shape: [2, 2]
accuracy = compute_token_accuracy(logits, labels)
self.assertAlmostEqual(accuracy, 0.75) # 3 correct out of 4 tokens
def test_with_ignore_index(self):
# Test accuracy computation with ignored tokens
logits = torch.tensor([[[0.9, 0.1], [0.8, 0.2]], [[0.3, 0.7], [0.6, 0.4]]])
labels = torch.tensor([[1, -100], [1, 0]]) # -100 is ignored
accuracy = compute_token_accuracy(logits, labels, ignore_index=-100)
self.assertAlmostEqual(accuracy, 2 / 3) # 2 correct out of 3 non-ignored tokens
def test_all_ignored(self):
# Test case where all tokens are ignored
logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]])
labels = torch.tensor([[-100, -100]])
accuracy = compute_token_accuracy(logits, labels)
self.assertEqual(accuracy, 0.0) # No valid tokens to compute accuracy
def test_perfect_accuracy(self):
# Test case with 100% accuracy
logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]])
labels = torch.tensor([[1, 0]])
accuracy = compute_token_accuracy(logits, labels)
self.assertEqual(accuracy, 1.0) # All predictions correct
def test_zero_accuracy(self):
# Test case with 0% accuracy
logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]])
labels = torch.tensor([[0, 1]])
accuracy = compute_token_accuracy(logits, labels)
self.assertEqual(accuracy, 0.0) # All predictions wrong
def test_batch_accuracy(self):
# Test accuracy computation across multiple batches
logits = torch.tensor(
[
[[0.9, 0.1], [0.8, 0.2], [0.3, 0.7]], # Batch 1
[[0.2, 0.8], [0.7, 0.3], [0.6, 0.4]], # Batch 2
]
)
labels = torch.tensor(
[
[1, 0, 1], # Batch 1
[1, 0, -100], # Batch 2 (last token ignored)
]
)
accuracy = compute_token_accuracy(logits, labels)
self.assertAlmostEqual(accuracy, 0.8)
class TestSelectiveLogSoftmax(unittest.TestCase):
@parameterized.expand([(torch.float64,), (torch.float32,), (torch.float16,), (torch.bfloat16,)])
def test_selective_log_softmax(self, dtype):

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.15.0"
__version__ = "0.16.0.dev0"
from typing import TYPE_CHECKING
@ -26,6 +26,7 @@ _import_structure = {
"extract_prompt",
"is_conversational",
"maybe_apply_chat_template",
"maybe_convert_to_chatml",
"maybe_extract_prompt",
"maybe_unpair_preference_dataset",
"pack_examples",
@ -101,7 +102,7 @@ _import_structure = {
"XPOTrainer",
],
"trainer.callbacks": ["MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"],
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config", "compute_token_accuracy"],
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"],
}
try:
@ -126,6 +127,7 @@ if TYPE_CHECKING:
extract_prompt,
is_conversational,
maybe_apply_chat_template,
maybe_convert_to_chatml,
maybe_extract_prompt,
maybe_unpair_preference_dataset,
pack_examples,
@ -202,7 +204,7 @@ if TYPE_CHECKING:
XPOTrainer,
)
from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback
from .trainer.utils import compute_token_accuracy, get_kbit_device_map, get_peft_config, get_quantization_config
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
try:
if not is_diffusers_available():

View File

@ -31,7 +31,8 @@ def is_conversational(example: dict[str, Any]) -> bool:
dataset type.
Returns:
`bool`: `True` if the data is in a conversational format, `False` otherwise.
`bool`:
`True` if the data is in a conversational format, `False` otherwise.
Examples:
@ -185,20 +186,21 @@ def maybe_apply_chat_template(
For keys `"messages"`, `"prompt"`, `"chosen"`, `"rejected"`, and `"completion"`, the values are lists of
messages, where each message is a dictionary with keys `"role"` and `"content"`.
tokenizer (`PreTrainedTokenizer`):
The tokenizer to apply the chat template with.
Tokenizer to apply the chat template with.
tools (`list[Union[dict, Callable]]` or `None`, *optional*, defaults to `None`):
A list of tools (callable functions) that will be accessible to the model.
If the template does not support function calling, this argument will have no effect
Returns:
`dict[str, str]`: The formatted example with the chat template applied.
`dict[str, str]`:
Formatted example with the chat template applied.
Notes:
- This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced by
`"text"`.
- This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced
by `"text"`.
- In case of prompt-only data, if the last role is `"user"`, the generation prompt is added to the prompt. Else,
if the last role is `"assistant"`, the final message is continued.
- In case of prompt-only data, if the last role is `"user"`, the generation prompt is added to the prompt.
Else, if the last role is `"assistant"`, the final message is continued.
Example:
@ -462,3 +464,52 @@ def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str,
# Split the values into chunks of size seq_length
examples = {k: [v[i : i + seq_length] for i in range(0, len(v), seq_length)] for k, v in examples.items()}
return examples
def maybe_convert_to_chatml(example: dict[str, list]) -> dict[str, list]:
"""
Convert a conversational dataset with fields `from` and `value` to ChatML format.
This function modifies conversational data to align with OpenAI's ChatML format:
- Replaces the key `"from"` with `"role"` in message dictionaries.
- Replaces the key `"value"` with `"content"` in message dictionaries.
- Renames `"conversations"` to `"messages"` for consistency with ChatML.
Args:
example (`dict[str, list]`):
A single data entry containing a list of messages.
Returns:
`dict[str, list]`:
Example reformatted to ChatML style.
Example:
```python
>>> from trl import maybe_convert_to_chatml
>>> example = {
... "conversations": [
... {"from": "user", "value": "What color is the sky?"},
... {"from": "assistant", "value": "It is blue."}
... ]
... }
>>> maybe_convert_to_chatml(example)
{'messages': [{'role': 'user', 'content': 'What color is the sky?'},
{'role': 'assistant', 'content': 'It is blue.'}]}
```
"""
# List of possible keys containing message lists
for key in ["prompt", "completion", "chosen", "rejected", "messages", "conversations"]:
if key in example and isinstance(example[key], list):
messages = example[key]
for message in messages:
if isinstance(message, dict):
if "from" in message:
message["role"] = message.pop("from")
if "value" in message:
message["content"] = message.pop("value")
# Rename "conversations" to "messages"
if "conversations" in example:
example["messages"] = example.pop("conversations")
return example

41
trl/extras/profiling.py Normal file
View File

@ -0,0 +1,41 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import time
from transformers import is_wandb_available
if is_wandb_available():
import wandb
def profiling_decorator(func):
"""
Decorator to profile a function and log the time taken to execute it.
"""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
start_time = time.perf_counter()
result = func(self, *args, **kwargs)
end_time = time.perf_counter()
duration = end_time - start_time
if "wandb" in self.args.report_to and wandb.run is not None and self.accelerator.is_main_process:
wandb.log({f"profiling/Time taken: {self.__class__.__name__}.{func.__name__}": duration})
return result
return wrapper

View File

@ -20,7 +20,6 @@ from typing import TYPE_CHECKING, Literal, Optional, Union
from accelerate.utils import is_deepspeed_available
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.utils.deprecation import deprecate_kwarg
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
@ -175,7 +174,6 @@ def add_hooks(model: "DeepSpeedEngine") -> None:
@contextmanager
@deprecate_kwarg("is_peft_model", "0.16.0", warn_if_greater_or_equal_version=True)
def unwrap_model_for_generation(
model: Union["DistributedDataParallel", "DeepSpeedEngine"],
accelerator: "Accelerator",

View File

@ -76,7 +76,6 @@ _import_structure = {
"disable_dropout_in_model",
"empty_cache",
"peft_module_casting_to_bf16",
"compute_token_accuracy",
],
"xpo_config": ["XPOConfig"],
"xpo_trainer": ["XPOTrainer"],
@ -145,7 +144,6 @@ if TYPE_CHECKING:
DataCollatorForCompletionOnlyLM,
RunningMoments,
compute_accuracy,
compute_token_accuracy,
disable_dropout_in_model,
empty_cache,
peft_module_casting_to_bf16,

View File

@ -51,7 +51,6 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPIN
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available, is_torch_xpu_available
from transformers.utils.deprecation import deprecate_kwarg
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from ..models import PreTrainedModelWrapper, create_reference_model
@ -202,9 +201,6 @@ class DPOTrainer(Trainer):
_tag_names = ["trl", "dpo"]
@deprecate_kwarg(
"tokenizer", "0.16.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
)
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,

View File

@ -87,7 +87,7 @@ class GKDTrainer(SFTTrainer):
):
# add remove_unused_columns=False to the dataclass args
args.remove_unused_columns = False
data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_seq_length)
data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length)
super().__init__(
model,

View File

@ -79,6 +79,8 @@ class GRPOConfig(TrainingArguments):
If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
context size, which might be much larger than the KV cache, leading to inefficiencies.
vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
> Parameters that control the training
@ -86,7 +88,12 @@ class GRPOConfig(TrainingArguments):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
beta (`float`, *optional*, defaults to `0.04`):
KL coefficient.
KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
speed.
num_iterations (`int`, *optional*, defaults to `1`):
Number of iterations per batch (denoted as μ in the algorithm).
epsilon (`float`, *optional*, defaults to `0.2`):
Epsilon value for clipping.
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
weighted equally with weight `1.0`.
@ -201,6 +208,10 @@ class GRPOConfig(TrainingArguments):
"context size, which might be much larger than the KV cache, leading to inefficiencies."
},
)
vllm_guided_decoding_regex: Optional[str] = field(
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
)
# Parameters that control the training
learning_rate: float = field(
@ -212,7 +223,18 @@ class GRPOConfig(TrainingArguments):
)
beta: float = field(
default=0.04,
metadata={"help": "KL coefficient."},
metadata={
"help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving "
"training speed."
},
)
num_iterations: int = field(
default=1,
metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."},
)
epsilon: float = field(
default=0.2,
metadata={"help": "Epsilon value for clipping."},
)
reward_weights: Optional[list[float]] = field(
default=None,

View File

@ -43,6 +43,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.utils import is_peft_available
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..extras.profiling import profiling_decorator
from ..import_utils import is_vllm_available
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from .callbacks import SyncRefModelCallback
@ -55,6 +56,7 @@ if is_peft_available():
if is_vllm_available():
from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams
if is_wandb_available():
import wandb
@ -66,26 +68,63 @@ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
class RepeatRandomSampler(Sampler):
"""
Sampler that repeats the indices of a dataset N times.
Sampler that repeats the indices of a dataset in a structured manner.
Args:
data_source (`Sized`):
Dataset to sample from.
repeat_count (`int`):
Number of times to repeat each index.
seed (`Optional[int]`):
mini_repeat_count (`int`):
Number of times to repeat each index per batch.
batch_size (`int`, *optional*, defaults to `1`):
Number of unique indices per batch.
repeat_count (`int`, *optional*, defaults to `1`):
Number of times to repeat the full sampling process.
seed (`int` or `None`, *optional*, defaults to `None`):
Random seed for reproducibility (only affects this sampler).
Example:
```python
>>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2)
>>> sampler = RepeatRandomSampler(["a", "b", "c", "d", "e", "f", "g"], mini_repeat_count=2, batch_size=3, repeat_count=4)
>>> list(sampler)
[2, 2, 0, 0, 3, 3, 1, 1]
[4, 4, 3, 3, 0, 0,
4, 4, 3, 3, 0, 0,
4, 4, 3, 3, 0, 0,
4, 4, 3, 3, 0, 0,
1, 1, 2, 2, 6, 6,
1, 1, 2, 2, 6, 6,
1, 1, 2, 2, 6, 6,
1, 1, 2, 2, 6, 6]
```
```txt
mini_repeat_count = 3
- - -
[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, |
4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, |
8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, |
repeat_count = 2
0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, |
4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, |
8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, ...] |
--------- --------- --------- ---------
--------- --------- --------- ---------
--------- --------- --------- ---------
batch_size = 12
```
"""
def __init__(self, data_source: Sized, repeat_count: int, seed: Optional[int] = None):
def __init__(
self,
data_source: Sized,
mini_repeat_count: int,
batch_size: int = 1,
repeat_count: int = 1,
seed: Optional[int] = None,
):
self.data_source = data_source
self.mini_repeat_count = mini_repeat_count
self.batch_size = batch_size
self.repeat_count = repeat_count
self.num_samples = len(data_source)
self.seed = seed
@ -94,15 +133,25 @@ class RepeatRandomSampler(Sampler):
self.generator.manual_seed(seed)
def __iter__(self):
indexes = [
idx
for idx in torch.randperm(self.num_samples, generator=self.generator).tolist()
for _ in range(self.repeat_count)
]
return iter(indexes)
# E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7)
indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
def __len__(self):
return self.num_samples * self.repeat_count
# [2, 4, 3, 1, 0, 6, 5]
# -> [[2, 4, 3], [1, 0, 6], [5]] (batch_size = 3)
indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]
# [[2, 4, 3], [1, 0, 6], [5]]
# -> [[2, 4, 3], [1, 0, 6]]
indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]
for chunk in indexes:
for _ in range(self.repeat_count):
for index in chunk:
for _ in range(self.mini_repeat_count):
yield index
def __len__(self) -> int:
return self.num_samples * self.mini_repeat_count * self.repeat_count
class GRPOTrainer(Trainer):
@ -244,18 +293,28 @@ class GRPOTrainer(Trainer):
)
if peft_config is not None:
if not is_peft_available():
raise ImportError("PEFT is required to use `peft_config`. Run `pip install peft`.")
model = get_peft_model(model, peft_config)
# Enable gradient checkpointing if requested
if args.gradient_checkpointing:
model = self._enable_gradient_checkpointing(model, args)
# Reference model
if is_deepspeed_zero3_enabled():
self.beta = args.beta
if self.beta == 0.0:
# If beta is 0.0, the reference model is not needed
self.ref_model = None
elif is_deepspeed_zero3_enabled():
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
elif not is_peft_model(model):
# If PEFT configuration is not provided, create a reference model based on the initial model.
self.ref_model = create_reference_model(model)
else:
elif is_peft_model(model):
# If PEFT is used, the reference model is not needed since the adapter can be disabled
# to revert to the initial model.
self.ref_model = None
else:
# If PEFT configuration is not provided, create a reference model based on the initial model.
self.ref_model = create_reference_model(model)
# Processing class
if processing_class is None:
@ -313,7 +372,14 @@ class GRPOTrainer(Trainer):
self.num_generations = args.num_generations # = G in the GRPO paper
self.use_vllm = args.use_vllm
self.beta = args.beta
# Multi-step
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
self.epsilon = args.epsilon
# Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle.
self._step = 0
# Buffer the batch to reuse generated outputs across multiple updates. For more details, see
# `_get_train_sampler` and `_prepare_inputs`.
self._buffered_inputs = [None] * args.gradient_accumulation_steps
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
@ -324,7 +390,7 @@ class GRPOTrainer(Trainer):
model.warnings_issued["estimate_tokens"] = True
# Initialize the metrics
self._metrics = defaultdict(list)
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
self.log_completions = args.log_completions
super().__init__(
@ -412,9 +478,19 @@ class GRPOTrainer(Trainer):
enable_prefix_caching=True,
max_model_len=self.args.vllm_max_model_len,
)
# Guided decoding, if enabled
if args.vllm_guided_decoding_regex is not None:
guided_decoding = GuidedDecodingParams(backend="outlines", regex=args.vllm_guided_decoding_regex)
else:
guided_decoding = None
# Sampling parameters
self.sampling_params = SamplingParams(
temperature=args.temperature,
max_tokens=self.max_completion_length,
guided_decoding=guided_decoding,
n=args.num_generations,
)
self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation
@ -461,20 +537,78 @@ class GRPOTrainer(Trainer):
self._signature_columns = ["prompt"]
def _get_train_sampler(self) -> Sampler:
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
# preventing discrepancies in group formation.
return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)
# Returns a sampler that
# 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are
# distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt
# group. Using the same seed across processes ensures consistent prompt assignment, preventing discrepancies
# in group formation.
# 2. repeats the batch multiple times to allow reusing generaations across multiple updates. Refer to
# _prepare_inputs to see how the generations are stored and reused.
# In the following figure, the values are the prompt indices. The first row shows the first sampled batch, the
# second row shows the second sampled batch, and so on.
#
# | GPU 0 | GPU 1 | GPU 2 |
#
# global_step step <───────> num_generations=3
# <───────────> per_device_train_batch_size=4
# ▲ 0 0 0 0 0 1 1 1 2 2 2 3 3 3 │
# grad_accum=3 │ 0 1 4 4 4 5 5 5 6 6 6 7 7 7 │ Generate completions for each prompt
# ▼ 0 2 8 8 8 9 9 9 10 10 10 11 11 11 │
#
# 1 3 0 0 0 1 1 1 2 2 2 3 3 3 │ The sampled prompts are the same as in the first iteration
# 1 4 4 4 4 5 5 5 6 6 6 7 7 7 │ Reuse the completions (here, once, because num_iterations=2)
# 1 5 8 8 8 9 9 9 10 10 10 11 11 11 │
#
# 2 6 12 12 12 13 13 13 14 14 14 15 15 15
# 2 7 16 16 16 17 17 17 18 18 18 19 19 19
# 2 8 20 20 20 21 21 21 22 22 22 23 23 23
# ...
effective_batch_size = (
self.args.per_device_train_batch_size
* self.accelerator.num_processes
* self.args.gradient_accumulation_steps
)
return RepeatRandomSampler(
data_source=self.train_dataset,
mini_repeat_count=self.num_generations,
batch_size=effective_batch_size // self.num_generations,
repeat_count=self.num_iterations,
seed=self.args.seed,
)
def _get_eval_sampler(self, eval_dataset) -> Sampler:
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
# preventing discrepancies in group formation.
return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)
# See _get_train_sampler for an explanation of the sampler.
return RepeatRandomSampler(
data_source=eval_dataset,
mini_repeat_count=self.num_generations,
seed=self.args.seed,
)
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
"""Enables gradient checkpointing for the model."""
# Ensure use_cache is disabled
model.config.use_cache = False
# Enable gradient checkpointing on the base model for PEFT
if is_peft_model(model):
model.base_model.gradient_checkpointing_enable()
# Enable gradient checkpointing for non-PEFT models
else:
model.gradient_checkpointing_enable()
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
use_reentrant = (
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
)
if use_reentrant:
model.enable_input_require_grads()
return model
# Get the per-token log probabilities for the completions for the model and the reference model
@profiling_decorator
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
@ -486,6 +620,7 @@ class GRPOTrainer(Trainer):
logits = logits[:, -logits_to_keep:]
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
@profiling_decorator
def _move_model_to_vllm(self):
with unwrap_model_for_generation(
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
@ -495,7 +630,6 @@ class GRPOTrainer(Trainer):
if is_peft_model(unwrapped_model):
unwrapped_model.merge_adapter()
state_dict = unwrapped_model.state_dict()
unwrapped_model.unmerge_adapter()
# Remove base_model and base_layer prefixes
state_dict = {
k.removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in state_dict.items()
@ -510,11 +644,32 @@ class GRPOTrainer(Trainer):
}
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
# Unmerge the adapter to restore the model to its original state.
# This must be done after loading weights to ensure they correspond to the merged state.
if is_peft_model(unwrapped_model):
unwrapped_model.unmerge_adapter()
@profiling_decorator
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
mode = "eval" if self.control.should_evaluate else "train"
if mode == "train":
if self.state.global_step % self.num_iterations == 0:
inputs = self._generate_and_score_completions(inputs)
self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs
else:
inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
self._step += 1
else:
# In evaluation, we don't reuse completions across multiple updates, so we don't need to buffer inputs.
inputs = self._generate_and_score_completions(inputs)
return inputs
def _generate_and_score_completions(
self, inputs: dict[str, Union[torch.Tensor, Any]]
) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device
prompts = [x["prompt"] for x in inputs]
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
@ -538,8 +693,17 @@ class GRPOTrainer(Trainer):
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = list(dict.fromkeys(all_prompts_text))
all_outputs = self.llm.generate(
ordered_set_of_prompts, sampling_params=self.sampling_params, use_tqdm=False
)
completion_ids = []
for outputs in all_outputs:
for output in outputs.outputs:
completion_ids.append(output.token_ids)
else:
completion_ids = [None] * len(all_prompts_text)
# Broadcast the completions from the main process to all processes, ensuring each process receives its
@ -575,12 +739,23 @@ class GRPOTrainer(Trainer):
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# Concatenate prompt_mask with completion_mask for logit computation
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
with torch.inference_mode():
if self.ref_model is not None:
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
# computation here, and use per_token_logps.detach() instead.
if self.num_iterations > 1:
old_per_token_logps = self._get_per_token_logps(
self.model, prompt_completion_ids, attention_mask, logits_to_keep
)
else:
old_per_token_logps = None
if self.beta == 0.0:
ref_per_token_logps = None
elif self.ref_model is not None:
ref_per_token_logps = self._get_per_token_logps(
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
)
@ -647,16 +822,21 @@ class GRPOTrainer(Trainer):
advantages = advantages[process_slice]
# Log the metrics
mode = "eval" if self.control.should_evaluate else "train"
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics[mode]["completion_length"].append(completion_length)
reward_per_func = rewards_per_func.mean(0)
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
else:
reward_func_name = reward_func.__name__
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
self._metrics[mode][f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
self._metrics["reward"].append(rewards.mean().item())
self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
self._metrics[mode]["reward"].append(rewards.mean().item())
self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
if (
self.log_completions
@ -682,10 +862,12 @@ class GRPOTrainer(Trainer):
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"old_per_token_logps": old_per_token_logps,
"ref_per_token_logps": ref_per_token_logps,
"advantages": advantages,
}
@profiling_decorator
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
@ -700,22 +882,36 @@ class GRPOTrainer(Trainer):
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
# Compute the KL divergence between the model and the reference model
ref_per_token_logps = inputs["ref_per_token_logps"]
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
if self.beta != 0.0:
ref_per_token_logps = inputs["ref_per_token_logps"]
per_token_kl = (
torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
)
# x - x.detach() allows for preserving gradients from x
# Compute the loss
advantages = inputs["advantages"]
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's computation (see
# _generate_and_score_completions) and use per_token_logps.detach() instead.
old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
coef_2 = torch.clamp(coef_1, 1 - self.epsilon, 1 + self.epsilon)
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
if self.beta != 0.0:
per_token_loss = per_token_loss + self.beta * per_token_kl
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
# Log the metrics
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
self._metrics["completion_length"].append(completion_length)
mode = "eval" if self.control.should_evaluate else "train"
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
if self.beta != 0.0:
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
is_clipped = (per_token_loss1 < per_token_loss2).float()
clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
return loss
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
@ -727,11 +923,12 @@ class GRPOTrainer(Trainer):
return loss, None, None
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
mode = "eval" if self.control.should_evaluate else "train"
metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
if next(iter(logs.keys())).startswith("eval_"):
if mode == "eval":
metrics = {f"eval_{key}": val for key, val in metrics.items()}
logs = {**logs, **metrics}
@ -739,7 +936,7 @@ class GRPOTrainer(Trainer):
super().log(logs, start_time)
else: # transformers<=4.46
super().log(logs)
self._metrics.clear()
self._metrics[mode].clear()
def create_model_card(
self,

View File

@ -49,13 +49,11 @@ class SFTConfig(TrainingArguments):
`skip_prepare_dataset`.
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
max_seq_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated from the
right.
max_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right.
If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
packing (`bool`, *optional*, defaults to `False`):
Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to define sequence
length.
Whether to pack multiple sequences into a fixed-length format. Uses `max_length` to define sequence length.
eval_packing (`bool` or `None`, *optional*, defaults to `None`):
Whether to pack the eval dataset. If `None`, uses the same value as `packing`.
@ -95,19 +93,19 @@ class SFTConfig(TrainingArguments):
default=None,
metadata={"help": "Number of processes to use for processing the dataset."},
)
max_seq_length: Optional[int] = field(
max_length: Optional[int] = field(
default=1024,
metadata={
"help": "Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated "
"from the right. If `None`, no truncation is applied. When packing is enabled, this value sets the "
"help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from"
"the right. If `None`, no truncation is applied. When packing is enabled, this value sets the "
"sequence length."
},
)
packing: bool = field(
default=False,
metadata={
"help": "Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to "
"define sequence length."
"help": "Whether to pack multiple sequences into a fixed-length format. Uses `max_length` to define "
"sequence length."
},
)
eval_packing: Optional[bool] = field(
@ -132,13 +130,17 @@ class SFTConfig(TrainingArguments):
num_of_sequences: int = field(
default=None,
metadata={
"help": "Deprecated. Use `max_seq_length` instead, which specifies the maximum length of the tokenized "
"help": "Deprecated. Use `max_length` instead, which specifies the maximum length of the tokenized "
"sequence, unlike `num_of_sequences`, which referred to string sequences."
},
)
chars_per_token: float = field(
default=None,
metadata={"help": "Deprecated. If you want to customize the packing length, use `max_seq_length`."},
metadata={"help": "Deprecated. If you want to customize the packing length, use `max_length`."},
)
max_seq_length: Optional[int] = field(
default=None,
metadata={"help": "Deprecated. Use `max_length` instead."},
)
def __post_init__(self):
@ -153,7 +155,7 @@ class SFTConfig(TrainingArguments):
if self.num_of_sequences is not None:
warnings.warn(
"`num_of_sequences` is deprecated and will be remove in version 0.18.0. Use `max_seq_length` instead, "
"`num_of_sequences` is deprecated and will be remove in version 0.18.0. Use `max_length` instead, "
"which specifies the maximum length of the tokenized sequence, unlike `num_of_sequences`, which r"
"eferred to string sequences.",
DeprecationWarning,
@ -162,6 +164,12 @@ class SFTConfig(TrainingArguments):
if self.chars_per_token is not None:
warnings.warn(
"`chars_per_token` is deprecated and will be remove in version 0.18.0. If you want to customize the "
"packing length, use `max_seq_length`.",
"packing length, use `max_length`.",
DeprecationWarning,
)
if self.max_seq_length is not None:
warnings.warn(
"`max_seq_length` is deprecated and will be remove in version 0.20.0. Use `max_length` instead.",
DeprecationWarning,
)

View File

@ -41,17 +41,10 @@ from transformers import (
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_liger_kernel_available, is_peft_available
from transformers.utils.deprecation import deprecate_kwarg
from ..data_utils import is_conversational, maybe_apply_chat_template, pack_examples
from ..data_utils import is_conversational, maybe_apply_chat_template, maybe_convert_to_chatml, pack_examples
from .sft_config import SFTConfig
from .utils import (
ConstantLengthDataset,
compute_token_accuracy,
generate_model_card,
get_comet_experiment_url,
peft_module_casting_to_bf16,
)
from .utils import ConstantLengthDataset, generate_model_card, get_comet_experiment_url, peft_module_casting_to_bf16
if is_peft_available():
@ -107,6 +100,8 @@ class SFTTrainer(Trainer):
- [Standard](dataset_formats#standard): Each sample contains plain text.
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
and content).
The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
@ -140,9 +135,6 @@ class SFTTrainer(Trainer):
_tag_names = ["trl", "sft"]
@deprecate_kwarg(
"tokenizer", "0.16.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
)
def __init__(
self,
model: Union[str, nn.Module, PreTrainedModel],
@ -181,12 +173,13 @@ class SFTTrainer(Trainer):
)
if isinstance(model, str):
model = self._create_model_from_path(model, args)
self.use_liger = is_liger_kernel_available() and isinstance(model, AutoLigerKernelForCausalLM)
# PEFT configuration and model wrapping
if peft_config is not None:
model = self._prepare_peft_model(model, peft_config, args)
# 3. Handle the tokenizer
# Handle the tokenizer
if processing_class is None:
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path)
if processing_class.pad_token is None:
@ -275,8 +268,10 @@ class SFTTrainer(Trainer):
if args.use_liger:
if not is_liger_kernel_available():
raise ImportError("Please install Liger-kernel for use_liger=True")
return AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
return AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
model = AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
return model
def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel:
"""Prepares a model for PEFT training."""
@ -368,6 +363,10 @@ class SFTTrainer(Trainer):
if isinstance(dataset, ConstantLengthDataset):
return dataset
# If the dataset is already preprocessed (tokenized), skip the processing steps.
column_names = list(next(iter(dataset)).keys())
is_processed = "input_ids" in column_names
# Build the kwargs for the `map` function
map_kwargs = {}
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
@ -375,7 +374,15 @@ class SFTTrainer(Trainer):
with PartialState().local_main_process_first():
# Apply the formatting function if any
if formatting_func is not None:
if formatting_func is not None and is_processed:
warnings.warn(
"You passed a dataset that is already processed (contains an `input_ids` field) together with a "
"formatting function. Therefore `formatting_func` will be ignored. Either remove the "
"`formatting_func` or pass a dataset that is not already processed.",
UserWarning,
)
if formatting_func is not None and not is_processed:
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset"
@ -395,6 +402,15 @@ class SFTTrainer(Trainer):
dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"])
# Convert the dataset to ChatML if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML"
dataset = dataset.map(
maybe_convert_to_chatml,
remove_columns="conversations" if "conversations" in dataset.column_names else None,
**map_kwargs,
)
# Apply the chat template if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
@ -405,24 +421,30 @@ class SFTTrainer(Trainer):
**map_kwargs,
)
# Tokenize the dataset
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
dataset = dataset.map(lambda ex: processing_class(ex[args.dataset_text_field]), **map_kwargs)
# Tokenize the dataset if needed
if not is_processed:
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
def tokenize(ex):
tokenized = processing_class(ex[args.dataset_text_field])
return {"input_ids": tokenized["input_ids"], "attention_mask": tokenized["attention_mask"]}
dataset = dataset.map(tokenize, **map_kwargs)
# Pack or truncate
if packing:
if args.max_seq_length is None:
raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
if args.max_length is None:
raise ValueError("When packing is enabled, `max_length` can't be `None`.")
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Packing {dataset_name} dataset"
dataset = dataset.select_columns("input_ids")
dataset = dataset.map(
pack_examples, batched=True, fn_kwargs={"seq_length": args.max_seq_length}, **map_kwargs
pack_examples, batched=True, fn_kwargs={"seq_length": args.max_length}, **map_kwargs
)
elif args.max_seq_length is not None:
elif args.max_length is not None:
dataset = dataset.map(
lambda ex: {key: ex[key][: args.max_seq_length] for key in ["input_ids", "attention_mask"]},
lambda ex: {key: ex[key][: args.max_length] for key in ["input_ids", "attention_mask"]},
**map_kwargs,
)
# For Liger kernel, ensure only input_ids is present
@ -440,18 +462,28 @@ class SFTTrainer(Trainer):
)
# Compute token accuracy if we have labels and if the model is not using Liger (no logits)
if "labels" in inputs and not self.args.use_liger:
if "labels" in inputs and not self.use_liger:
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = inputs["labels"][..., 1:].contiguous()
# Gather logits and labels from all GPUs first
shift_logits = self.accelerator.gather_for_metrics(shift_logits)
shift_labels = self.accelerator.gather_for_metrics(shift_labels)
# Get predictions
predictions = shift_logits.argmax(dim=-1)
# Then compute accuracy on the gathered tensors
if self.accelerator.is_main_process:
accuracy = compute_token_accuracy(shift_logits, shift_labels)
self._metrics["mean_token_accuracy"].append(accuracy)
# Create mask for non-padding tokens (assuming ignore_index is -100)
mask = shift_labels != -100
# Calculate accuracy only on non-padding tokens
correct_predictions = (predictions == shift_labels) & mask
total_tokens = mask.sum()
correct_tokens = correct_predictions.sum()
# Gather the correct_tokens and total_tokens across all processes
correct_tokens = self.accelerator.gather_for_metrics(correct_tokens)
total_tokens = self.accelerator.gather_for_metrics(total_tokens)
# Compute the mean token accuracy and log it
accuracy = (correct_tokens.sum() / total_tokens.sum()).item() if total_tokens.sum() > 0 else 0.0
self._metrics["mean_token_accuracy"].append(accuracy)
return (loss, outputs) if return_outputs else loss

View File

@ -140,7 +140,7 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
warnings.warn(
f"Could not find response key `{self.response_template}` in the following instance: "
f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss "
"calculation. Note, if this happens often, consider increasing the `max_seq_length`.",
"calculation. Note, if this happens often, consider increasing the `max_length`.",
UserWarning,
)
batch["labels"][i, :] = self.ignore_index
@ -167,7 +167,7 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
warnings.warn(
f"Could not find response key `{self.response_template}` in the following instance: "
f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss "
"calculation. Note, if this happens often, consider increasing the `max_seq_length`.",
"calculation. Note, if this happens often, consider increasing the `max_length`.",
UserWarning,
)
batch["labels"][i, :] = self.ignore_index
@ -182,7 +182,7 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
warnings.warn(
f"Could not find instruction key `{self.instruction_template}` in the following instance: "
f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss "
"calculation. Note, if this happens often, consider increasing the `max_seq_length`.",
"calculation. Note, if this happens often, consider increasing the `max_length`.",
UserWarning,
)
batch["labels"][i, :] = self.ignore_index
@ -1650,27 +1650,6 @@ def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor
return mask, *tensors
def compute_token_accuracy(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> float:
"""
Compute the mean token accuracy.
"""
# Get predictions
predictions = logits.argmax(dim=-1)
# Create mask for non-padding tokens (assuming pad_token_id is ignore_index)
mask = labels != ignore_index
# Calculate accuracy only on non-padding tokens
correct_predictions = (predictions == labels) & mask
total_tokens = mask.sum()
correct_tokens = correct_predictions.sum()
# Calculate accuracy
accuracy = correct_tokens.item() / total_tokens.item() if total_tokens > 0 else 0.0
return accuracy
def selective_log_softmax(logits, index):
"""
A memory-efficient implementation of the common `log_softmax -> gather` operation.