mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
17 Commits
Author | SHA1 | Date | |
---|---|---|---|
013d360b8f | |||
e5ae703d35 | |||
a92e00e810 | |||
9b3c5bf64f | |||
15fec312d5 | |||
be1e34003c | |||
6aaf379a82 | |||
49adf74833 | |||
6c54f023ae | |||
963243a7d1 | |||
aafd8cbea5 | |||
822653824b | |||
ba036576d4 | |||
293b620950 | |||
ae3bd0d07a | |||
6d9fc11fd6 | |||
ffcb9f4aee |
@ -42,7 +42,7 @@ accelerate launch $EXTRA_ACCELERATE_ARGS \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_steps $MAX_STEPS \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--max_seq_length $SEQ_LEN \
|
||||
--max_length $SEQ_LEN \
|
||||
$EXTRA_TRAINING_ARGS
|
||||
"""
|
||||
|
||||
|
@ -12,6 +12,10 @@
|
||||
|
||||
[[autodoc]] maybe_apply_chat_template
|
||||
|
||||
## maybe_convert_to_chatml
|
||||
|
||||
[[autodoc]] maybe_convert_to_chatml
|
||||
|
||||
## extract_prompt
|
||||
|
||||
[[autodoc]] extract_prompt
|
||||
|
@ -44,7 +44,7 @@ training_args = DPOConfig(..., max_completion_length=...)
|
||||
</hfoption>
|
||||
<hfoption id="SFT">
|
||||
|
||||
SFT truncation is applied to the input sequence via the `max_seq_length` parameter.
|
||||
SFT truncation is applied to the input sequence via the `max_length` parameter.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/truncation_input_ids.png" alt="Truncation input ids" width="600"/>
|
||||
@ -55,7 +55,7 @@ To set the truncation parameter, use the following code snippet:
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(..., max_seq_length=...)
|
||||
training_args = SFTConfig(..., max_length=...)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
@ -85,7 +85,7 @@ Packing eliminates padding, preserves all sequence information, and allows for f
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(..., packing=True, max_seq_length=512)
|
||||
training_args = SFTConfig(..., packing=True, max_length=512)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
@ -19,7 +19,7 @@ from trl import SFTConfig, SFTTrainer
|
||||
dataset = load_dataset("stanfordnlp/imdb", split="train")
|
||||
|
||||
training_args = SFTConfig(
|
||||
max_seq_length=512,
|
||||
max_length=512,
|
||||
output_dir="/tmp",
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
@ -29,7 +29,7 @@ trainer = SFTTrainer(
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
Make sure to pass the correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
|
||||
Make sure to pass the correct value for `max_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
|
||||
|
||||
You can also construct a model outside of the trainer and pass it as follows:
|
||||
|
||||
@ -550,12 +550,12 @@ import torch
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number
|
||||
max_length = 2048 # Supports automatic RoPE Scaling, so choose any number
|
||||
|
||||
# Load model
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="unsloth/mistral-7b",
|
||||
max_seq_length=max_seq_length,
|
||||
max_seq_length=max_length,
|
||||
dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
||||
load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False
|
||||
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
|
||||
@ -581,7 +581,7 @@ model = FastLanguageModel.get_peft_model(
|
||||
random_state=3407,
|
||||
)
|
||||
|
||||
training_args = SFTConfig(output_dir="./output", max_seq_length=max_seq_length)
|
||||
training_args = SFTConfig(output_dir="./output", max_length=max_length)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
@ -624,7 +624,7 @@ To learn more about Liger-Kernel, visit their [official repository](https://gith
|
||||
|
||||
Pay attention to the following best practices when training a model with that trainer:
|
||||
|
||||
- [`SFTTrainer`] always truncates by default the sequences to the `max_seq_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training.
|
||||
- [`SFTTrainer`] always truncates by default the sequences to the `max_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training.
|
||||
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
|
||||
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
|
||||
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
|
||||
|
@ -185,7 +185,7 @@ trainer = SFTTrainer(
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
peft_config=peft_config,
|
||||
max_seq_length=None,
|
||||
max_length=None,
|
||||
formatting_func=prepare_sample_text,
|
||||
processing_class=tokenizer,
|
||||
args=training_args,
|
||||
|
6
setup.py
6
setup.py
@ -71,7 +71,7 @@ To create the package for PyPI.
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
__version__ = "0.15.0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
__version__ = "0.16.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
|
||||
REQUIRED_PKGS = [
|
||||
"accelerate>=0.34.0",
|
||||
@ -85,13 +85,13 @@ EXTRAS = {
|
||||
"diffusers": ["diffusers>=0.18.0"],
|
||||
"judges": ["openai>=1.23.2", "llm-blender>=0.0.2"],
|
||||
# liger-kernel depends on triton, which is only available on Linux https://github.com/triton-lang/triton#compatibility
|
||||
"liger": ["liger-kernel>=0.4.0; sys_platform != 'win32'"],
|
||||
"liger": ["liger-kernel>=0.5.3; sys_platform != 'win32'"],
|
||||
"mergekit": ["mergekit>=0.0.5.1"],
|
||||
"peft": ["peft>=0.8.0"],
|
||||
"quantization": ["bitsandbytes"],
|
||||
"scikit": ["scikit-learn"],
|
||||
"test": ["parameterized", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "pytest"],
|
||||
"vllm": ["vllm>=0.7.1; sys_platform != 'win32'"], # vllm is not available on Windows
|
||||
"vllm": ["vllm>=0.7.2; sys_platform != 'win32'"], # vllm is not available on Windows
|
||||
"vlm": ["Pillow"],
|
||||
}
|
||||
EXTRAS["dev"] = []
|
||||
|
@ -46,7 +46,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]")
|
||||
self.eval_dataset = load_dataset("stanfordnlp/imdb", split="test[:10%]")
|
||||
self.max_seq_length = 128
|
||||
self.max_length = 128
|
||||
self.peft_config = LoraConfig(
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
@ -74,7 +74,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=10,
|
||||
packing=packing,
|
||||
max_seq_length=self.max_seq_length,
|
||||
max_length=self.max_length,
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
@ -100,7 +100,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=10,
|
||||
packing=packing,
|
||||
max_seq_length=self.max_seq_length,
|
||||
max_length=self.max_length,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
@ -135,7 +135,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
|
||||
max_steps=10,
|
||||
fp16=True,
|
||||
packing=packing,
|
||||
max_seq_length=self.max_seq_length,
|
||||
max_length=self.max_length,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
@ -172,7 +172,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
|
||||
max_steps=10,
|
||||
fp16=True, # this is sufficient to enable amp
|
||||
packing=packing,
|
||||
max_seq_length=self.max_seq_length,
|
||||
max_length=self.max_length,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
@ -205,7 +205,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=10,
|
||||
packing=packing,
|
||||
max_seq_length=self.max_seq_length,
|
||||
max_length=self.max_length,
|
||||
fp16=True, # this is sufficient to enable amp
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
|
||||
@ -242,7 +242,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=10,
|
||||
packing=packing,
|
||||
max_seq_length=self.max_seq_length,
|
||||
max_length=self.max_length,
|
||||
fp16=True, # this is sufficient to enable amp
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
|
||||
@ -286,7 +286,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=10,
|
||||
packing=packing,
|
||||
max_seq_length=self.max_seq_length,
|
||||
max_length=self.max_length,
|
||||
fp16=True, # this is sufficient to enable amp
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
|
||||
@ -324,7 +324,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=10,
|
||||
packing=packing,
|
||||
max_seq_length=self.max_seq_length,
|
||||
max_length=self.max_length,
|
||||
fp16=True, # this is sufficient to enable amp
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
|
||||
@ -364,7 +364,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
|
||||
|
||||
training_args = SFTConfig(
|
||||
packing=packing,
|
||||
max_seq_length=self.max_seq_length,
|
||||
max_length=self.max_length,
|
||||
output_dir=tmp_dir,
|
||||
logging_strategy="no",
|
||||
report_to="none",
|
||||
@ -411,7 +411,7 @@ class SFTTrainerSlowTester(unittest.TestCase):
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=2,
|
||||
packing=packing,
|
||||
max_seq_length=self.max_seq_length,
|
||||
max_length=self.max_length,
|
||||
use_liger=True,
|
||||
)
|
||||
|
||||
|
@ -24,6 +24,7 @@ from trl.data_utils import (
|
||||
extract_prompt,
|
||||
is_conversational,
|
||||
maybe_apply_chat_template,
|
||||
maybe_convert_to_chatml,
|
||||
maybe_extract_prompt,
|
||||
maybe_unpair_preference_dataset,
|
||||
pack_examples,
|
||||
@ -435,6 +436,51 @@ class TestPackExamples(unittest.TestCase):
|
||||
self.assertEqual(dataset.to_dict(), expected_output)
|
||||
|
||||
|
||||
class TestMaybeConvertToChatML(unittest.TestCase):
|
||||
def test_with_conversations_key(self):
|
||||
# Particular case where the key is "conversations": we rename it to "messages"
|
||||
example = {
|
||||
"conversations": [
|
||||
{"from": "user", "value": "What color is the sky?"},
|
||||
{"from": "assistant", "value": "It is blue."},
|
||||
]
|
||||
}
|
||||
expected_output = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."},
|
||||
]
|
||||
}
|
||||
self.assertEqual(maybe_convert_to_chatml(example), expected_output)
|
||||
|
||||
def test_without_conversations_key(self):
|
||||
# Same as before, but we don't rename the keys
|
||||
example = {
|
||||
"prompt": [{"from": "user", "value": "What color is the sky?"}],
|
||||
"completion": [{"from": "assistant", "value": "It is blue."}],
|
||||
}
|
||||
expected_output = {
|
||||
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"completion": [{"role": "assistant", "content": "It is blue."}],
|
||||
}
|
||||
self.assertEqual(maybe_convert_to_chatml(example), expected_output)
|
||||
|
||||
def test_not_conversional(self):
|
||||
# When not needed, the example should remain unchanged
|
||||
example = {"text": "The sky is blue."}
|
||||
self.assertEqual(maybe_convert_to_chatml(example), example)
|
||||
|
||||
def test_already_chatml(self):
|
||||
# When the example is already in ChatML format, it should remain unchanged
|
||||
example = {
|
||||
"messages": [
|
||||
{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."},
|
||||
]
|
||||
}
|
||||
self.assertEqual(maybe_convert_to_chatml(example), example)
|
||||
|
||||
|
||||
# Run the tests
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -25,10 +25,113 @@ from transformers.utils import is_peft_available
|
||||
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
from trl.import_utils import is_vllm_available
|
||||
from trl.trainer.grpo_trainer import RepeatRandomSampler
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig
|
||||
from peft import LoraConfig, PeftModel
|
||||
|
||||
|
||||
class RepeatRandomSamplerTester(unittest.TestCase):
|
||||
def test_sampler(self):
|
||||
dataset = ["a", "b", "c", "d", "e", "f", "g"]
|
||||
sampler = RepeatRandomSampler(dataset, mini_repeat_count=2)
|
||||
# Should output something like [4, 4, 3, 3, 0, 0, 1, 1, 2, 2, 6, 6, 5, 5]
|
||||
sampled = list(sampler)
|
||||
# Check that the length is doubled
|
||||
assert len(sampled) == 2 * len(dataset)
|
||||
# Check that all indexes are present
|
||||
assert set(sampled) == set(range(len(dataset)))
|
||||
# Check that each element is repeated twice
|
||||
assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2))
|
||||
|
||||
def test_sampler_no_repeat(self):
|
||||
dataset = ["a", "b", "c", "d", "e", "f", "g"]
|
||||
sampler = RepeatRandomSampler(dataset, mini_repeat_count=1)
|
||||
# Should output something like [4, 3, 0, 1, 2, 6, 5]
|
||||
sampled = list(sampler)
|
||||
# Check that the length is the same
|
||||
assert len(sampled) == len(dataset)
|
||||
# Check that all indexes are present
|
||||
assert set(sampled) == set(range(len(dataset)))
|
||||
|
||||
def test_sampler_with_batch_size(self):
|
||||
dataset = ["a", "b", "c", "d", "e", "f", "g", "h"]
|
||||
sampler = RepeatRandomSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2)
|
||||
# Should output something like [4, 3, 4, 3, 0, 1, 0, 1, 2, 6, 2, 6, 5, 7, 5, 7]
|
||||
sampled = list(sampler)
|
||||
# Check that the length is doubled
|
||||
assert len(sampled) == 2 * len(dataset)
|
||||
# Check that all indexes are present
|
||||
assert set(sampled) == set(range(len(dataset)))
|
||||
# Check that each element is repeated as expected
|
||||
assert all(sampled[i : i + 1] == sampled[i + 2 : i + 3] for i in range(0, len(sampled), 4))
|
||||
|
||||
def test_sampler_with_batch_size_and_drop(self):
|
||||
dataset = ["a", "b", "c", "d", "e", "f", "g"]
|
||||
sampler = RepeatRandomSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2)
|
||||
# Should output something like [4, 3, 4, 3, 0, 1, 0, 1, 2, 6, 2, 6]
|
||||
sampled = list(sampler)
|
||||
# Check that the length is doubled
|
||||
assert len(sampled) == 2 * (
|
||||
len(dataset) - 1
|
||||
) # one element is dropped, because it's not enough to form a batch
|
||||
# Check that the sampled indexes are a subset of the dataset indexes
|
||||
assert set(sampled).issubset(set(range(len(dataset))))
|
||||
# Check that each element is repeated as expected
|
||||
assert all(sampled[i : i + 1] == sampled[i + 2 : i + 3] for i in range(0, len(sampled), 4))
|
||||
|
||||
def test_sampler_with_mini_repeat_count_and_batch_size_1(self):
|
||||
dataset = ["a", "b", "c", "d", "e", "f", "g"]
|
||||
sampler = RepeatRandomSampler(dataset, mini_repeat_count=2, batch_size=3, repeat_count=2)
|
||||
# Should output something like [4, 4, 3, 3, 0, 0, 4, 4, 3, 3, 0, 0,
|
||||
# 1, 1, 2, 2, 6, 6, 1, 1, 2, 2, 6, 6]
|
||||
sampled = list(sampler)
|
||||
# Check that the length is quadrupled
|
||||
assert len(sampled) == 4 * (len(dataset) - 1) # 1 element is dropped, because it's not enough to form a batch
|
||||
# Check that the sampled indexes are a subset of the dataset indexes
|
||||
assert set(sampled).issubset(set(range(len(dataset))))
|
||||
# Check that each element is repeated as expected
|
||||
assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2))
|
||||
# Check that the batch is repeated as expected
|
||||
assert sampled[0:6] == sampled[6:12]
|
||||
assert sampled[12:18] == sampled[18:24]
|
||||
|
||||
def test_sampler_with_mini_repeat_count_and_batch_size_2(self):
|
||||
dataset = ["a", "b", "c", "d", "e", "f", "g"]
|
||||
sampler = RepeatRandomSampler(dataset, mini_repeat_count=3, batch_size=2, repeat_count=2)
|
||||
# Should output something like [4, 4, 4, 3, 3, 3, 4, 4, 4, 3, 3, 3,
|
||||
# 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1,
|
||||
# 2, 2, 2, 6, 6, 6, 2, 2, 2, 6, 6, 6]
|
||||
sampled = list(sampler)
|
||||
# Check that the length is sextupled
|
||||
assert len(sampled) == 6 * (len(dataset) - 1) # 1 element is dropped, because it's not enough to form a batch
|
||||
# Check that the sampled indexes are a subset of the dataset indexes
|
||||
assert set(sampled).issubset(set(range(len(dataset))))
|
||||
# Check that each element is repeated as expected
|
||||
assert all(sampled[i] == sampled[i + 1] == sampled[i + 2] for i in range(0, len(sampled), 3))
|
||||
# Check that the batch is repeated as expected
|
||||
assert sampled[0:6] == sampled[6:12]
|
||||
assert sampled[12:18] == sampled[18:24]
|
||||
assert sampled[24:30] == sampled[30:36]
|
||||
|
||||
def test_sampler_with_mini_repeat_count_and_batch_size_3(self):
|
||||
dataset = ["a", "b", "c", "d", "e", "f", "g"]
|
||||
sampler = RepeatRandomSampler(dataset, mini_repeat_count=2, batch_size=2, repeat_count=3)
|
||||
# Should output something like [4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3,
|
||||
# 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1,
|
||||
# 2, 2, 6, 6, 2, 2, 6, 6, 2, 2, 6, 6]
|
||||
sampled = list(sampler)
|
||||
# Check that the length is sextupled
|
||||
assert len(sampled) == 6 * (len(dataset) - 1) # 1 element is dropped, because it's not enough to form a batch
|
||||
# Check that the sampled indexes are a subset of the dataset indexes
|
||||
assert set(sampled).issubset(set(range(len(dataset))))
|
||||
# Check that each element is repeated as expected
|
||||
assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2))
|
||||
# Check that the batch is repeated as expected
|
||||
assert sampled[0:4] == sampled[4:8] == sampled[8:12]
|
||||
assert sampled[12:16] == sampled[16:20] == sampled[20:24]
|
||||
assert sampled[24:28] == sampled[28:32] == sampled[32:36]
|
||||
|
||||
|
||||
class GRPOTrainerTester(unittest.TestCase):
|
||||
@ -96,6 +199,37 @@ class GRPOTrainerTester(unittest.TestCase):
|
||||
|
||||
trainer.train()
|
||||
|
||||
def test_training_multiple_iterations(self):
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = GRPOConfig(
|
||||
output_dir=tmp_dir,
|
||||
learning_rate=0.1, # increase the learning rate to speed up the test
|
||||
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
||||
num_generations=3, # reduce the number of generations to reduce memory usage
|
||||
max_completion_length=32, # reduce the completion length to reduce memory usage
|
||||
num_iterations=2,
|
||||
report_to="none",
|
||||
)
|
||||
trainer = GRPOTrainer(
|
||||
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
||||
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# Check that the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
||||
|
||||
@require_peft
|
||||
def test_training_peft(self):
|
||||
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
||||
@ -133,6 +267,57 @@ class GRPOTrainerTester(unittest.TestCase):
|
||||
elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer)
|
||||
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.")
|
||||
|
||||
@require_peft
|
||||
def test_training_peft_with_gradient_checkpointing(self):
|
||||
"""Test that training works with PEFT and gradient checkpointing enabled."""
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
||||
torch_dtype=torch.float32, # Use float32 for testing to avoid precision issues
|
||||
use_cache=False, # Required for gradient checkpointing
|
||||
)
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none"
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = GRPOConfig(
|
||||
output_dir=tmp_dir,
|
||||
learning_rate=0.1,
|
||||
per_device_train_batch_size=3,
|
||||
num_generations=3,
|
||||
max_completion_length=32,
|
||||
gradient_checkpointing=True, # Enable gradient checkpointing
|
||||
report_to="none",
|
||||
)
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
|
||||
# Verify gradient checkpointing is enabled
|
||||
self.assertIsInstance(trainer.model, PeftModel)
|
||||
|
||||
# Store initial parameters to check which ones change
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# Check that only LoRA parameters have changed, base model parameters remain unchanged
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
if "lora" in n.lower(): # LoRA parameters should change
|
||||
self.assertFalse(torch.equal(param, new_param), f"LoRA parameter {n} has not changed.")
|
||||
else: # Base model parameters should not change
|
||||
self.assertTrue(torch.equal(param, new_param), f"Base parameter {n} has changed.")
|
||||
|
||||
def test_training_different_reward_model(self):
|
||||
# Use a reward model different from the model: different chat template, tokenization, etc.
|
||||
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train")
|
||||
@ -500,6 +685,36 @@ class GRPOTrainerTester(unittest.TestCase):
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
||||
|
||||
def test_beta_zero_no_ref_model_and_no_kl(self):
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = GRPOConfig(
|
||||
output_dir=tmp_dir,
|
||||
beta=0.0, # set beta to 0 to test the case where the reference model is not used
|
||||
learning_rate=0.1, # increase the learning rate to speed up the test
|
||||
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
||||
num_generations=3, # reduce the number of generations to reduce memory usage
|
||||
max_completion_length=32, # reduce the completion length to reduce memory usage
|
||||
report_to="none",
|
||||
)
|
||||
trainer = GRPOTrainer(
|
||||
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
||||
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# Check that the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
||||
|
||||
@unittest.skipIf(not is_vllm_available(), "vLLM is not available")
|
||||
@require_torch_accelerator
|
||||
@require_peft
|
||||
@ -548,3 +763,39 @@ class GRPOTrainerTester(unittest.TestCase):
|
||||
elif "base_layer" not in n and "original_module" not in n:
|
||||
# We expect the peft params to be different (except for the base layer)
|
||||
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.")
|
||||
|
||||
@unittest.skipIf(not is_vllm_available(), "vLLM is not available")
|
||||
@require_torch_accelerator
|
||||
def test_training_vllm_guided_decoding(self):
|
||||
"""Test that training works with vLLM for generation with guided decoding."""
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = GRPOConfig(
|
||||
output_dir=tmp_dir,
|
||||
learning_rate=0.1, # increase the learning rate to speed up the test
|
||||
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
||||
num_generations=3, # reduce the number of generations to reduce memory usage
|
||||
max_completion_length=32, # reduce the completion length to reduce memory usage
|
||||
report_to="none",
|
||||
use_vllm=True,
|
||||
vllm_device="cuda:0", # will raise a warning, but allows this test to work with only one GPU
|
||||
vllm_guided_decoding_regex=r"<reasoning>\n.*\n</reasoning>\n<answer>\n.*\n</answer>",
|
||||
)
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B-Instruct", # tiny model is too small for vLLM
|
||||
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# Check that the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
||||
|
@ -288,7 +288,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
|
||||
self.assertIn("model.safetensors", os.listdir(tmp_dir + "/checkpoint-2"))
|
||||
|
||||
def test_sft_trainer_with_pretokenzied_data_packing(self):
|
||||
def test_sft_trainer_with_pretokenized_data_packing(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = SFTConfig(
|
||||
output_dir=tmp_dir,
|
||||
@ -326,7 +326,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
eval_steps=1,
|
||||
save_steps=1,
|
||||
per_device_train_batch_size=2,
|
||||
max_seq_length=32, # make sure there is at least 1 packed sequence
|
||||
max_length=32, # make sure there is at least 1 packed sequence
|
||||
packing=True,
|
||||
report_to="none",
|
||||
)
|
||||
@ -353,7 +353,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
train_dataset=self.conversational_lm_dataset["train"],
|
||||
)
|
||||
|
||||
# Same, but with packing with `max_seq_length`
|
||||
# Same, but with packing with `max_length`
|
||||
training_args = SFTConfig(
|
||||
output_dir=tmp_dir,
|
||||
dataloader_drop_last=True,
|
||||
@ -361,7 +361,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
eval_steps=1,
|
||||
save_steps=1,
|
||||
per_device_train_batch_size=2,
|
||||
max_seq_length=16, # make sure there is at least 1 packed sequence
|
||||
max_length=16, # make sure there is at least 1 packed sequence
|
||||
packing=True,
|
||||
report_to="none",
|
||||
)
|
||||
@ -396,7 +396,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
eval_steps=1,
|
||||
save_steps=1,
|
||||
per_device_train_batch_size=2,
|
||||
max_seq_length=32, # make sure there is at least 1 packed sequence
|
||||
max_length=32, # make sure there is at least 1 packed sequence
|
||||
packing=True,
|
||||
report_to="none",
|
||||
)
|
||||
@ -461,7 +461,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
save_steps=1,
|
||||
num_train_epochs=2,
|
||||
per_device_train_batch_size=2,
|
||||
max_seq_length=16,
|
||||
max_length=16,
|
||||
packing=True,
|
||||
report_to="none",
|
||||
)
|
||||
@ -485,7 +485,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
save_steps=1,
|
||||
num_train_epochs=2,
|
||||
per_device_train_batch_size=2,
|
||||
max_seq_length=16,
|
||||
max_length=16,
|
||||
report_to="none",
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
@ -534,7 +534,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
max_steps=2,
|
||||
save_steps=1,
|
||||
per_device_train_batch_size=2,
|
||||
max_seq_length=16,
|
||||
max_length=16,
|
||||
packing=True,
|
||||
report_to="none",
|
||||
)
|
||||
@ -558,7 +558,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
max_steps=2,
|
||||
save_steps=1,
|
||||
per_device_train_batch_size=2,
|
||||
max_seq_length=16,
|
||||
max_length=16,
|
||||
packing=True,
|
||||
report_to="none",
|
||||
)
|
||||
@ -583,7 +583,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
max_steps=2,
|
||||
save_steps=1,
|
||||
per_device_train_batch_size=2,
|
||||
max_seq_length=16,
|
||||
max_length=16,
|
||||
report_to="none",
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
@ -606,7 +606,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
max_steps=2,
|
||||
save_steps=1,
|
||||
per_device_train_batch_size=2,
|
||||
max_seq_length=16,
|
||||
max_length=16,
|
||||
report_to="none",
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
@ -755,7 +755,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
save_steps=1,
|
||||
per_device_train_batch_size=2,
|
||||
packing=True,
|
||||
max_seq_length=500,
|
||||
max_length=500,
|
||||
report_to="none",
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
@ -782,7 +782,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
per_device_train_batch_size=2,
|
||||
save_strategy="epoch",
|
||||
packing=True,
|
||||
max_seq_length=500,
|
||||
max_length=500,
|
||||
report_to="none",
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
@ -1088,7 +1088,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
per_device_train_batch_size=2,
|
||||
gradient_checkpointing=True,
|
||||
packing=True,
|
||||
max_seq_length=16, # make sure there is at least 1 packed sequence
|
||||
max_length=16, # make sure there is at least 1 packed sequence
|
||||
eval_packing=False,
|
||||
report_to="none",
|
||||
)
|
||||
@ -1114,7 +1114,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
save_steps=2,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_checkpointing=True,
|
||||
max_seq_length=16, # make sure there is at least 1 packed sequence
|
||||
max_length=16, # make sure there is at least 1 packed sequence
|
||||
packing=True,
|
||||
report_to="none",
|
||||
)
|
||||
@ -1139,7 +1139,7 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
save_steps=2,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_checkpointing=True,
|
||||
max_seq_length=16, # make sure there is at least 1 packed sequence
|
||||
max_length=16, # make sure there is at least 1 packed sequence
|
||||
packing=False,
|
||||
report_to="none",
|
||||
)
|
||||
@ -1370,3 +1370,65 @@ class SFTTrainerTester2(unittest.TestCase):
|
||||
"base_layer" not in n
|
||||
): # We expect the peft parameters to be different (except for the base layer)
|
||||
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||
|
||||
def test_train_with_non_chatml_conversational_data(self):
|
||||
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
|
||||
|
||||
# Rename role/content to from/value to ensure SFT works with non-chatML conversational data
|
||||
def rename_fields(example: list[dict]):
|
||||
return {"conversations": [{"from": m["role"], "value": m["content"]} for m in example["messages"]]}
|
||||
|
||||
dataset = dataset.map(rename_fields, remove_columns="messages")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Initialize the trainer
|
||||
training_args = SFTConfig(output_dir=tmp_dir, report_to="none")
|
||||
trainer = SFTTrainer(args=training_args, model=model, train_dataset=dataset)
|
||||
|
||||
# Save the initial parameters to compare them later
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
# Train the model
|
||||
trainer.train()
|
||||
|
||||
# Check that the training loss is not None
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# Check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||
|
||||
def test_sft_trainer_with_pretokenized_data(self):
|
||||
# Get the model and dataset
|
||||
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
||||
|
||||
def tokenize_example(example):
|
||||
return tokenizer(example["text"])
|
||||
|
||||
# Apply tokenization
|
||||
tokenized_dataset = dataset.map(tokenize_example, remove_columns=["text"])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Initialize the trainer
|
||||
training_args = SFTConfig(output_dir=tmp_dir, report_to="none")
|
||||
trainer = SFTTrainer(args=training_args, model=model, train_dataset=tokenized_dataset)
|
||||
|
||||
# Save the initial parameters to compare them later
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
# Train the model
|
||||
trainer.train()
|
||||
|
||||
# Check that the training loss is not None
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# Check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||
|
@ -368,7 +368,7 @@ class TrainerArgTester(unittest.TestCase):
|
||||
tmp_dir,
|
||||
dataset_text_field="dummy_text_field",
|
||||
packing=True,
|
||||
max_seq_length=256,
|
||||
max_length=256,
|
||||
dataset_num_proc=4,
|
||||
dataset_batch_size=512,
|
||||
neftune_noise_alpha=0.1,
|
||||
@ -379,7 +379,7 @@ class TrainerArgTester(unittest.TestCase):
|
||||
trainer = SFTTrainer(model_id, args=training_args, train_dataset=dataset)
|
||||
self.assertEqual(trainer.args.dataset_text_field, "dummy_text_field")
|
||||
self.assertEqual(trainer.args.packing, True)
|
||||
self.assertEqual(trainer.args.max_seq_length, 256)
|
||||
self.assertEqual(trainer.args.max_length, 256)
|
||||
self.assertEqual(trainer.args.dataset_num_proc, 4)
|
||||
self.assertEqual(trainer.args.dataset_batch_size, 512)
|
||||
self.assertEqual(trainer.args.neftune_noise_alpha, 0.1)
|
||||
|
@ -27,7 +27,6 @@ from trl.trainer import compute_accuracy
|
||||
from trl.trainer.utils import (
|
||||
DataCollatorForChatML,
|
||||
batch_generation,
|
||||
compute_token_accuracy,
|
||||
decode_and_strip_padding,
|
||||
flush_left,
|
||||
generate_model_card,
|
||||
@ -456,60 +455,6 @@ class TestFlushLeft(unittest.TestCase):
|
||||
self.assertTrue(torch.equal(new_mask, expected_mask))
|
||||
|
||||
|
||||
class TestComputeTokenAccuracy(unittest.TestCase):
|
||||
def test_basic_accuracy(self):
|
||||
# Test basic accuracy computation
|
||||
logits = torch.tensor([[[0.9, 0.1], [0.8, 0.2]], [[0.3, 0.7], [0.6, 0.4]]]) # Shape: [2, 2, 2]
|
||||
labels = torch.tensor([[1, 0], [1, 0]]) # Shape: [2, 2]
|
||||
accuracy = compute_token_accuracy(logits, labels)
|
||||
self.assertAlmostEqual(accuracy, 0.75) # 3 correct out of 4 tokens
|
||||
|
||||
def test_with_ignore_index(self):
|
||||
# Test accuracy computation with ignored tokens
|
||||
logits = torch.tensor([[[0.9, 0.1], [0.8, 0.2]], [[0.3, 0.7], [0.6, 0.4]]])
|
||||
labels = torch.tensor([[1, -100], [1, 0]]) # -100 is ignored
|
||||
accuracy = compute_token_accuracy(logits, labels, ignore_index=-100)
|
||||
self.assertAlmostEqual(accuracy, 2 / 3) # 2 correct out of 3 non-ignored tokens
|
||||
|
||||
def test_all_ignored(self):
|
||||
# Test case where all tokens are ignored
|
||||
logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]])
|
||||
labels = torch.tensor([[-100, -100]])
|
||||
accuracy = compute_token_accuracy(logits, labels)
|
||||
self.assertEqual(accuracy, 0.0) # No valid tokens to compute accuracy
|
||||
|
||||
def test_perfect_accuracy(self):
|
||||
# Test case with 100% accuracy
|
||||
logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]])
|
||||
labels = torch.tensor([[1, 0]])
|
||||
accuracy = compute_token_accuracy(logits, labels)
|
||||
self.assertEqual(accuracy, 1.0) # All predictions correct
|
||||
|
||||
def test_zero_accuracy(self):
|
||||
# Test case with 0% accuracy
|
||||
logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]])
|
||||
labels = torch.tensor([[0, 1]])
|
||||
accuracy = compute_token_accuracy(logits, labels)
|
||||
self.assertEqual(accuracy, 0.0) # All predictions wrong
|
||||
|
||||
def test_batch_accuracy(self):
|
||||
# Test accuracy computation across multiple batches
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[[0.9, 0.1], [0.8, 0.2], [0.3, 0.7]], # Batch 1
|
||||
[[0.2, 0.8], [0.7, 0.3], [0.6, 0.4]], # Batch 2
|
||||
]
|
||||
)
|
||||
labels = torch.tensor(
|
||||
[
|
||||
[1, 0, 1], # Batch 1
|
||||
[1, 0, -100], # Batch 2 (last token ignored)
|
||||
]
|
||||
)
|
||||
accuracy = compute_token_accuracy(logits, labels)
|
||||
self.assertAlmostEqual(accuracy, 0.8)
|
||||
|
||||
|
||||
class TestSelectiveLogSoftmax(unittest.TestCase):
|
||||
@parameterized.expand([(torch.float64,), (torch.float32,), (torch.float16,), (torch.bfloat16,)])
|
||||
def test_selective_log_softmax(self, dtype):
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.15.0"
|
||||
__version__ = "0.16.0.dev0"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@ -26,6 +26,7 @@ _import_structure = {
|
||||
"extract_prompt",
|
||||
"is_conversational",
|
||||
"maybe_apply_chat_template",
|
||||
"maybe_convert_to_chatml",
|
||||
"maybe_extract_prompt",
|
||||
"maybe_unpair_preference_dataset",
|
||||
"pack_examples",
|
||||
@ -101,7 +102,7 @@ _import_structure = {
|
||||
"XPOTrainer",
|
||||
],
|
||||
"trainer.callbacks": ["MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"],
|
||||
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config", "compute_token_accuracy"],
|
||||
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"],
|
||||
}
|
||||
|
||||
try:
|
||||
@ -126,6 +127,7 @@ if TYPE_CHECKING:
|
||||
extract_prompt,
|
||||
is_conversational,
|
||||
maybe_apply_chat_template,
|
||||
maybe_convert_to_chatml,
|
||||
maybe_extract_prompt,
|
||||
maybe_unpair_preference_dataset,
|
||||
pack_examples,
|
||||
@ -202,7 +204,7 @@ if TYPE_CHECKING:
|
||||
XPOTrainer,
|
||||
)
|
||||
from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback
|
||||
from .trainer.utils import compute_token_accuracy, get_kbit_device_map, get_peft_config, get_quantization_config
|
||||
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
|
||||
|
||||
try:
|
||||
if not is_diffusers_available():
|
||||
|
@ -31,7 +31,8 @@ def is_conversational(example: dict[str, Any]) -> bool:
|
||||
dataset type.
|
||||
|
||||
Returns:
|
||||
`bool`: `True` if the data is in a conversational format, `False` otherwise.
|
||||
`bool`:
|
||||
`True` if the data is in a conversational format, `False` otherwise.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -185,20 +186,21 @@ def maybe_apply_chat_template(
|
||||
For keys `"messages"`, `"prompt"`, `"chosen"`, `"rejected"`, and `"completion"`, the values are lists of
|
||||
messages, where each message is a dictionary with keys `"role"` and `"content"`.
|
||||
tokenizer (`PreTrainedTokenizer`):
|
||||
The tokenizer to apply the chat template with.
|
||||
Tokenizer to apply the chat template with.
|
||||
tools (`list[Union[dict, Callable]]` or `None`, *optional*, defaults to `None`):
|
||||
A list of tools (callable functions) that will be accessible to the model.
|
||||
If the template does not support function calling, this argument will have no effect
|
||||
|
||||
Returns:
|
||||
`dict[str, str]`: The formatted example with the chat template applied.
|
||||
`dict[str, str]`:
|
||||
Formatted example with the chat template applied.
|
||||
|
||||
Notes:
|
||||
- This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced by
|
||||
`"text"`.
|
||||
- This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced
|
||||
by `"text"`.
|
||||
|
||||
- In case of prompt-only data, if the last role is `"user"`, the generation prompt is added to the prompt. Else,
|
||||
if the last role is `"assistant"`, the final message is continued.
|
||||
- In case of prompt-only data, if the last role is `"user"`, the generation prompt is added to the prompt.
|
||||
Else, if the last role is `"assistant"`, the final message is continued.
|
||||
|
||||
Example:
|
||||
|
||||
@ -462,3 +464,52 @@ def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str,
|
||||
# Split the values into chunks of size seq_length
|
||||
examples = {k: [v[i : i + seq_length] for i in range(0, len(v), seq_length)] for k, v in examples.items()}
|
||||
return examples
|
||||
|
||||
|
||||
def maybe_convert_to_chatml(example: dict[str, list]) -> dict[str, list]:
|
||||
"""
|
||||
Convert a conversational dataset with fields `from` and `value` to ChatML format.
|
||||
|
||||
This function modifies conversational data to align with OpenAI's ChatML format:
|
||||
- Replaces the key `"from"` with `"role"` in message dictionaries.
|
||||
- Replaces the key `"value"` with `"content"` in message dictionaries.
|
||||
- Renames `"conversations"` to `"messages"` for consistency with ChatML.
|
||||
|
||||
Args:
|
||||
example (`dict[str, list]`):
|
||||
A single data entry containing a list of messages.
|
||||
|
||||
Returns:
|
||||
`dict[str, list]`:
|
||||
Example reformatted to ChatML style.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from trl import maybe_convert_to_chatml
|
||||
>>> example = {
|
||||
... "conversations": [
|
||||
... {"from": "user", "value": "What color is the sky?"},
|
||||
... {"from": "assistant", "value": "It is blue."}
|
||||
... ]
|
||||
... }
|
||||
>>> maybe_convert_to_chatml(example)
|
||||
{'messages': [{'role': 'user', 'content': 'What color is the sky?'},
|
||||
{'role': 'assistant', 'content': 'It is blue.'}]}
|
||||
```
|
||||
"""
|
||||
# List of possible keys containing message lists
|
||||
for key in ["prompt", "completion", "chosen", "rejected", "messages", "conversations"]:
|
||||
if key in example and isinstance(example[key], list):
|
||||
messages = example[key]
|
||||
for message in messages:
|
||||
if isinstance(message, dict):
|
||||
if "from" in message:
|
||||
message["role"] = message.pop("from")
|
||||
if "value" in message:
|
||||
message["content"] = message.pop("value")
|
||||
|
||||
# Rename "conversations" to "messages"
|
||||
if "conversations" in example:
|
||||
example["messages"] = example.pop("conversations")
|
||||
|
||||
return example
|
||||
|
41
trl/extras/profiling.py
Normal file
41
trl/extras/profiling.py
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import time
|
||||
|
||||
from transformers import is_wandb_available
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
|
||||
def profiling_decorator(func):
|
||||
"""
|
||||
Decorator to profile a function and log the time taken to execute it.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
start_time = time.perf_counter()
|
||||
result = func(self, *args, **kwargs)
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
if "wandb" in self.args.report_to and wandb.run is not None and self.accelerator.is_main_process:
|
||||
wandb.log({f"profiling/Time taken: {self.__class__.__name__}.{func.__name__}": duration})
|
||||
return result
|
||||
|
||||
return wrapper
|
@ -20,7 +20,6 @@ from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
from accelerate.utils import is_deepspeed_available
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
|
||||
|
||||
@ -175,7 +174,6 @@ def add_hooks(model: "DeepSpeedEngine") -> None:
|
||||
|
||||
|
||||
@contextmanager
|
||||
@deprecate_kwarg("is_peft_model", "0.16.0", warn_if_greater_or_equal_version=True)
|
||||
def unwrap_model_for_generation(
|
||||
model: Union["DistributedDataParallel", "DeepSpeedEngine"],
|
||||
accelerator: "Accelerator",
|
||||
|
@ -76,7 +76,6 @@ _import_structure = {
|
||||
"disable_dropout_in_model",
|
||||
"empty_cache",
|
||||
"peft_module_casting_to_bf16",
|
||||
"compute_token_accuracy",
|
||||
],
|
||||
"xpo_config": ["XPOConfig"],
|
||||
"xpo_trainer": ["XPOTrainer"],
|
||||
@ -145,7 +144,6 @@ if TYPE_CHECKING:
|
||||
DataCollatorForCompletionOnlyLM,
|
||||
RunningMoments,
|
||||
compute_accuracy,
|
||||
compute_token_accuracy,
|
||||
disable_dropout_in_model,
|
||||
empty_cache,
|
||||
peft_module_casting_to_bf16,
|
||||
|
@ -51,7 +51,6 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPIN
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalLoopOutput
|
||||
from transformers.utils import is_peft_available, is_torch_xpu_available
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
|
||||
from ..models import PreTrainedModelWrapper, create_reference_model
|
||||
@ -202,9 +201,6 @@ class DPOTrainer(Trainer):
|
||||
|
||||
_tag_names = ["trl", "dpo"]
|
||||
|
||||
@deprecate_kwarg(
|
||||
"tokenizer", "0.16.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
||||
|
@ -87,7 +87,7 @@ class GKDTrainer(SFTTrainer):
|
||||
):
|
||||
# add remove_unused_columns=False to the dataclass args
|
||||
args.remove_unused_columns = False
|
||||
data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_seq_length)
|
||||
data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length)
|
||||
|
||||
super().__init__(
|
||||
model,
|
||||
|
@ -79,6 +79,8 @@ class GRPOConfig(TrainingArguments):
|
||||
If set, the `max_model_len` to use for vLLM. This could be useful when running with reduced
|
||||
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model
|
||||
context size, which might be much larger than the KV cache, leading to inefficiencies.
|
||||
vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
|
||||
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
|
||||
|
||||
> Parameters that control the training
|
||||
|
||||
@ -86,7 +88,12 @@ class GRPOConfig(TrainingArguments):
|
||||
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
||||
[`~transformers.TrainingArguments`].
|
||||
beta (`float`, *optional*, defaults to `0.04`):
|
||||
KL coefficient.
|
||||
KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
|
||||
speed.
|
||||
num_iterations (`int`, *optional*, defaults to `1`):
|
||||
Number of iterations per batch (denoted as μ in the algorithm).
|
||||
epsilon (`float`, *optional*, defaults to `0.2`):
|
||||
Epsilon value for clipping.
|
||||
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
|
||||
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
|
||||
weighted equally with weight `1.0`.
|
||||
@ -201,6 +208,10 @@ class GRPOConfig(TrainingArguments):
|
||||
"context size, which might be much larger than the KV cache, leading to inefficiencies."
|
||||
},
|
||||
)
|
||||
vllm_guided_decoding_regex: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
|
||||
)
|
||||
|
||||
# Parameters that control the training
|
||||
learning_rate: float = field(
|
||||
@ -212,7 +223,18 @@ class GRPOConfig(TrainingArguments):
|
||||
)
|
||||
beta: float = field(
|
||||
default=0.04,
|
||||
metadata={"help": "KL coefficient."},
|
||||
metadata={
|
||||
"help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving "
|
||||
"training speed."
|
||||
},
|
||||
)
|
||||
num_iterations: int = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."},
|
||||
)
|
||||
epsilon: float = field(
|
||||
default=0.2,
|
||||
metadata={"help": "Epsilon value for clipping."},
|
||||
)
|
||||
reward_weights: Optional[list[float]] = field(
|
||||
default=None,
|
||||
|
@ -43,6 +43,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
|
||||
from ..extras.profiling import profiling_decorator
|
||||
from ..import_utils import is_vllm_available
|
||||
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
|
||||
from .callbacks import SyncRefModelCallback
|
||||
@ -55,6 +56,7 @@ if is_peft_available():
|
||||
|
||||
if is_vllm_available():
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
@ -66,26 +68,63 @@ RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
|
||||
|
||||
class RepeatRandomSampler(Sampler):
|
||||
"""
|
||||
Sampler that repeats the indices of a dataset N times.
|
||||
Sampler that repeats the indices of a dataset in a structured manner.
|
||||
|
||||
Args:
|
||||
data_source (`Sized`):
|
||||
Dataset to sample from.
|
||||
repeat_count (`int`):
|
||||
Number of times to repeat each index.
|
||||
seed (`Optional[int]`):
|
||||
mini_repeat_count (`int`):
|
||||
Number of times to repeat each index per batch.
|
||||
batch_size (`int`, *optional*, defaults to `1`):
|
||||
Number of unique indices per batch.
|
||||
repeat_count (`int`, *optional*, defaults to `1`):
|
||||
Number of times to repeat the full sampling process.
|
||||
seed (`int` or `None`, *optional*, defaults to `None`):
|
||||
Random seed for reproducibility (only affects this sampler).
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> sampler = RepeatRandomSampler(["a", "b", "c", "d"], repeat_count=2)
|
||||
>>> sampler = RepeatRandomSampler(["a", "b", "c", "d", "e", "f", "g"], mini_repeat_count=2, batch_size=3, repeat_count=4)
|
||||
>>> list(sampler)
|
||||
[2, 2, 0, 0, 3, 3, 1, 1]
|
||||
[4, 4, 3, 3, 0, 0,
|
||||
4, 4, 3, 3, 0, 0,
|
||||
4, 4, 3, 3, 0, 0,
|
||||
4, 4, 3, 3, 0, 0,
|
||||
|
||||
1, 1, 2, 2, 6, 6,
|
||||
1, 1, 2, 2, 6, 6,
|
||||
1, 1, 2, 2, 6, 6,
|
||||
1, 1, 2, 2, 6, 6]
|
||||
```
|
||||
|
||||
```txt
|
||||
mini_repeat_count = 3
|
||||
- - -
|
||||
[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, |
|
||||
4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, |
|
||||
8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, |
|
||||
repeat_count = 2
|
||||
0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, |
|
||||
4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, |
|
||||
8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, ...] |
|
||||
--------- --------- --------- ---------
|
||||
--------- --------- --------- ---------
|
||||
--------- --------- --------- ---------
|
||||
batch_size = 12
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, data_source: Sized, repeat_count: int, seed: Optional[int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
data_source: Sized,
|
||||
mini_repeat_count: int,
|
||||
batch_size: int = 1,
|
||||
repeat_count: int = 1,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
self.data_source = data_source
|
||||
self.mini_repeat_count = mini_repeat_count
|
||||
self.batch_size = batch_size
|
||||
self.repeat_count = repeat_count
|
||||
self.num_samples = len(data_source)
|
||||
self.seed = seed
|
||||
@ -94,15 +133,25 @@ class RepeatRandomSampler(Sampler):
|
||||
self.generator.manual_seed(seed)
|
||||
|
||||
def __iter__(self):
|
||||
indexes = [
|
||||
idx
|
||||
for idx in torch.randperm(self.num_samples, generator=self.generator).tolist()
|
||||
for _ in range(self.repeat_count)
|
||||
]
|
||||
return iter(indexes)
|
||||
# E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7)
|
||||
indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples * self.repeat_count
|
||||
# [2, 4, 3, 1, 0, 6, 5]
|
||||
# -> [[2, 4, 3], [1, 0, 6], [5]] (batch_size = 3)
|
||||
indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]
|
||||
|
||||
# [[2, 4, 3], [1, 0, 6], [5]]
|
||||
# -> [[2, 4, 3], [1, 0, 6]]
|
||||
indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]
|
||||
|
||||
for chunk in indexes:
|
||||
for _ in range(self.repeat_count):
|
||||
for index in chunk:
|
||||
for _ in range(self.mini_repeat_count):
|
||||
yield index
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.num_samples * self.mini_repeat_count * self.repeat_count
|
||||
|
||||
|
||||
class GRPOTrainer(Trainer):
|
||||
@ -244,18 +293,28 @@ class GRPOTrainer(Trainer):
|
||||
)
|
||||
|
||||
if peft_config is not None:
|
||||
if not is_peft_available():
|
||||
raise ImportError("PEFT is required to use `peft_config`. Run `pip install peft`.")
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
# Enable gradient checkpointing if requested
|
||||
if args.gradient_checkpointing:
|
||||
model = self._enable_gradient_checkpointing(model, args)
|
||||
|
||||
# Reference model
|
||||
if is_deepspeed_zero3_enabled():
|
||||
self.beta = args.beta
|
||||
if self.beta == 0.0:
|
||||
# If beta is 0.0, the reference model is not needed
|
||||
self.ref_model = None
|
||||
elif is_deepspeed_zero3_enabled():
|
||||
self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
|
||||
elif not is_peft_model(model):
|
||||
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
||||
self.ref_model = create_reference_model(model)
|
||||
else:
|
||||
elif is_peft_model(model):
|
||||
# If PEFT is used, the reference model is not needed since the adapter can be disabled
|
||||
# to revert to the initial model.
|
||||
self.ref_model = None
|
||||
else:
|
||||
# If PEFT configuration is not provided, create a reference model based on the initial model.
|
||||
self.ref_model = create_reference_model(model)
|
||||
|
||||
# Processing class
|
||||
if processing_class is None:
|
||||
@ -313,7 +372,14 @@ class GRPOTrainer(Trainer):
|
||||
self.num_generations = args.num_generations # = G in the GRPO paper
|
||||
self.use_vllm = args.use_vllm
|
||||
|
||||
self.beta = args.beta
|
||||
# Multi-step
|
||||
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
|
||||
self.epsilon = args.epsilon
|
||||
# Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle.
|
||||
self._step = 0
|
||||
# Buffer the batch to reuse generated outputs across multiple updates. For more details, see
|
||||
# `_get_train_sampler` and `_prepare_inputs`.
|
||||
self._buffered_inputs = [None] * args.gradient_accumulation_steps
|
||||
|
||||
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
||||
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
||||
@ -324,7 +390,7 @@ class GRPOTrainer(Trainer):
|
||||
model.warnings_issued["estimate_tokens"] = True
|
||||
|
||||
# Initialize the metrics
|
||||
self._metrics = defaultdict(list)
|
||||
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
|
||||
self.log_completions = args.log_completions
|
||||
|
||||
super().__init__(
|
||||
@ -412,9 +478,19 @@ class GRPOTrainer(Trainer):
|
||||
enable_prefix_caching=True,
|
||||
max_model_len=self.args.vllm_max_model_len,
|
||||
)
|
||||
|
||||
# Guided decoding, if enabled
|
||||
if args.vllm_guided_decoding_regex is not None:
|
||||
guided_decoding = GuidedDecodingParams(backend="outlines", regex=args.vllm_guided_decoding_regex)
|
||||
else:
|
||||
guided_decoding = None
|
||||
|
||||
# Sampling parameters
|
||||
self.sampling_params = SamplingParams(
|
||||
temperature=args.temperature,
|
||||
max_tokens=self.max_completion_length,
|
||||
guided_decoding=guided_decoding,
|
||||
n=args.num_generations,
|
||||
)
|
||||
|
||||
self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation
|
||||
@ -461,20 +537,78 @@ class GRPOTrainer(Trainer):
|
||||
self._signature_columns = ["prompt"]
|
||||
|
||||
def _get_train_sampler(self) -> Sampler:
|
||||
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
|
||||
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
|
||||
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
|
||||
# preventing discrepancies in group formation.
|
||||
return RepeatRandomSampler(self.train_dataset, self.num_generations, seed=self.args.seed)
|
||||
# Returns a sampler that
|
||||
# 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are
|
||||
# distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt
|
||||
# group. Using the same seed across processes ensures consistent prompt assignment, preventing discrepancies
|
||||
# in group formation.
|
||||
# 2. repeats the batch multiple times to allow reusing generaations across multiple updates. Refer to
|
||||
# _prepare_inputs to see how the generations are stored and reused.
|
||||
|
||||
# In the following figure, the values are the prompt indices. The first row shows the first sampled batch, the
|
||||
# second row shows the second sampled batch, and so on.
|
||||
#
|
||||
# | GPU 0 | GPU 1 | GPU 2 |
|
||||
#
|
||||
# global_step step <───────> num_generations=3
|
||||
# <───────────> per_device_train_batch_size=4
|
||||
# ▲ 0 0 0 0 0 1 1 1 2 2 2 3 3 3 │
|
||||
# grad_accum=3 │ 0 1 4 4 4 5 5 5 6 6 6 7 7 7 │ Generate completions for each prompt
|
||||
# ▼ 0 2 8 8 8 9 9 9 10 10 10 11 11 11 │
|
||||
#
|
||||
# 1 3 0 0 0 1 1 1 2 2 2 3 3 3 │ The sampled prompts are the same as in the first iteration
|
||||
# 1 4 4 4 4 5 5 5 6 6 6 7 7 7 │ Reuse the completions (here, once, because num_iterations=2)
|
||||
# 1 5 8 8 8 9 9 9 10 10 10 11 11 11 │
|
||||
#
|
||||
# 2 6 12 12 12 13 13 13 14 14 14 15 15 15
|
||||
# 2 7 16 16 16 17 17 17 18 18 18 19 19 19
|
||||
# 2 8 20 20 20 21 21 21 22 22 22 23 23 23
|
||||
# ...
|
||||
effective_batch_size = (
|
||||
self.args.per_device_train_batch_size
|
||||
* self.accelerator.num_processes
|
||||
* self.args.gradient_accumulation_steps
|
||||
)
|
||||
return RepeatRandomSampler(
|
||||
data_source=self.train_dataset,
|
||||
mini_repeat_count=self.num_generations,
|
||||
batch_size=effective_batch_size // self.num_generations,
|
||||
repeat_count=self.num_iterations,
|
||||
seed=self.args.seed,
|
||||
)
|
||||
|
||||
def _get_eval_sampler(self, eval_dataset) -> Sampler:
|
||||
# Returns a sampler that ensures each prompt is repeated across multiple processes. This guarantees that
|
||||
# identical prompts are distributed to different GPUs, allowing rewards to be computed and normalized correctly
|
||||
# within each prompt group. Using the same seed across processes ensures consistent prompt assignment,
|
||||
# preventing discrepancies in group formation.
|
||||
return RepeatRandomSampler(eval_dataset, self.num_generations, seed=self.args.seed)
|
||||
# See _get_train_sampler for an explanation of the sampler.
|
||||
return RepeatRandomSampler(
|
||||
data_source=eval_dataset,
|
||||
mini_repeat_count=self.num_generations,
|
||||
seed=self.args.seed,
|
||||
)
|
||||
|
||||
def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
|
||||
"""Enables gradient checkpointing for the model."""
|
||||
# Ensure use_cache is disabled
|
||||
model.config.use_cache = False
|
||||
|
||||
# Enable gradient checkpointing on the base model for PEFT
|
||||
if is_peft_model(model):
|
||||
model.base_model.gradient_checkpointing_enable()
|
||||
# Enable gradient checkpointing for non-PEFT models
|
||||
else:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
|
||||
use_reentrant = (
|
||||
"use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
|
||||
)
|
||||
|
||||
if use_reentrant:
|
||||
model.enable_input_require_grads()
|
||||
|
||||
return model
|
||||
|
||||
# Get the per-token log probabilities for the completions for the model and the reference model
|
||||
@profiling_decorator
|
||||
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
||||
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
||||
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
|
||||
@ -486,6 +620,7 @@ class GRPOTrainer(Trainer):
|
||||
logits = logits[:, -logits_to_keep:]
|
||||
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
|
||||
|
||||
@profiling_decorator
|
||||
def _move_model_to_vllm(self):
|
||||
with unwrap_model_for_generation(
|
||||
self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
|
||||
@ -495,7 +630,6 @@ class GRPOTrainer(Trainer):
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.merge_adapter()
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
unwrapped_model.unmerge_adapter()
|
||||
# Remove base_model and base_layer prefixes
|
||||
state_dict = {
|
||||
k.removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in state_dict.items()
|
||||
@ -510,11 +644,32 @@ class GRPOTrainer(Trainer):
|
||||
}
|
||||
else:
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if self.accelerator.is_main_process:
|
||||
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
llm_model.load_weights(state_dict.items())
|
||||
if self.accelerator.is_main_process:
|
||||
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
llm_model.load_weights(state_dict.items())
|
||||
# Unmerge the adapter to restore the model to its original state.
|
||||
# This must be done after loading weights to ensure they correspond to the merged state.
|
||||
if is_peft_model(unwrapped_model):
|
||||
unwrapped_model.unmerge_adapter()
|
||||
|
||||
@profiling_decorator
|
||||
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
|
||||
mode = "eval" if self.control.should_evaluate else "train"
|
||||
if mode == "train":
|
||||
if self.state.global_step % self.num_iterations == 0:
|
||||
inputs = self._generate_and_score_completions(inputs)
|
||||
self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs
|
||||
else:
|
||||
inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
|
||||
self._step += 1
|
||||
else:
|
||||
# In evaluation, we don't reuse completions across multiple updates, so we don't need to buffer inputs.
|
||||
inputs = self._generate_and_score_completions(inputs)
|
||||
return inputs
|
||||
|
||||
def _generate_and_score_completions(
|
||||
self, inputs: dict[str, Union[torch.Tensor, Any]]
|
||||
) -> dict[str, Union[torch.Tensor, Any]]:
|
||||
device = self.accelerator.device
|
||||
prompts = [x["prompt"] for x in inputs]
|
||||
prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
|
||||
@ -538,8 +693,17 @@ class GRPOTrainer(Trainer):
|
||||
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
|
||||
all_prompts_text = gather_object(prompts_text)
|
||||
if self.accelerator.is_main_process:
|
||||
outputs = self.llm.generate(all_prompts_text, sampling_params=self.sampling_params, use_tqdm=False)
|
||||
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
|
||||
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
|
||||
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
|
||||
# prompt individually.
|
||||
ordered_set_of_prompts = list(dict.fromkeys(all_prompts_text))
|
||||
all_outputs = self.llm.generate(
|
||||
ordered_set_of_prompts, sampling_params=self.sampling_params, use_tqdm=False
|
||||
)
|
||||
completion_ids = []
|
||||
for outputs in all_outputs:
|
||||
for output in outputs.outputs:
|
||||
completion_ids.append(output.token_ids)
|
||||
else:
|
||||
completion_ids = [None] * len(all_prompts_text)
|
||||
# Broadcast the completions from the main process to all processes, ensuring each process receives its
|
||||
@ -575,12 +739,23 @@ class GRPOTrainer(Trainer):
|
||||
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
||||
|
||||
# Concatenate prompt_mask with completion_mask for logit computation
|
||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
|
||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
||||
|
||||
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
||||
|
||||
with torch.inference_mode():
|
||||
if self.ref_model is not None:
|
||||
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
|
||||
# computation here, and use per_token_logps.detach() instead.
|
||||
if self.num_iterations > 1:
|
||||
old_per_token_logps = self._get_per_token_logps(
|
||||
self.model, prompt_completion_ids, attention_mask, logits_to_keep
|
||||
)
|
||||
else:
|
||||
old_per_token_logps = None
|
||||
|
||||
if self.beta == 0.0:
|
||||
ref_per_token_logps = None
|
||||
elif self.ref_model is not None:
|
||||
ref_per_token_logps = self._get_per_token_logps(
|
||||
self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
|
||||
)
|
||||
@ -647,16 +822,21 @@ class GRPOTrainer(Trainer):
|
||||
advantages = advantages[process_slice]
|
||||
|
||||
# Log the metrics
|
||||
mode = "eval" if self.control.should_evaluate else "train"
|
||||
|
||||
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
|
||||
self._metrics[mode]["completion_length"].append(completion_length)
|
||||
|
||||
reward_per_func = rewards_per_func.mean(0)
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
|
||||
reward_func_name = reward_func.config._name_or_path.split("/")[-1]
|
||||
else:
|
||||
reward_func_name = reward_func.__name__
|
||||
self._metrics[f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
|
||||
self._metrics[mode][f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
|
||||
|
||||
self._metrics["reward"].append(rewards.mean().item())
|
||||
self._metrics["reward_std"].append(std_grouped_rewards.mean().item())
|
||||
self._metrics[mode]["reward"].append(rewards.mean().item())
|
||||
self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
|
||||
|
||||
if (
|
||||
self.log_completions
|
||||
@ -682,10 +862,12 @@ class GRPOTrainer(Trainer):
|
||||
"prompt_mask": prompt_mask,
|
||||
"completion_ids": completion_ids,
|
||||
"completion_mask": completion_mask,
|
||||
"old_per_token_logps": old_per_token_logps,
|
||||
"ref_per_token_logps": ref_per_token_logps,
|
||||
"advantages": advantages,
|
||||
}
|
||||
|
||||
@profiling_decorator
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
if return_outputs:
|
||||
raise ValueError("The GRPOTrainer does not support returning outputs")
|
||||
@ -700,22 +882,36 @@ class GRPOTrainer(Trainer):
|
||||
per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
|
||||
|
||||
# Compute the KL divergence between the model and the reference model
|
||||
ref_per_token_logps = inputs["ref_per_token_logps"]
|
||||
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
|
||||
if self.beta != 0.0:
|
||||
ref_per_token_logps = inputs["ref_per_token_logps"]
|
||||
per_token_kl = (
|
||||
torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
|
||||
)
|
||||
|
||||
# x - x.detach() allows for preserving gradients from x
|
||||
# Compute the loss
|
||||
advantages = inputs["advantages"]
|
||||
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
|
||||
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
|
||||
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
||||
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's computation (see
|
||||
# _generate_and_score_completions) and use per_token_logps.detach() instead.
|
||||
old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
|
||||
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
|
||||
coef_2 = torch.clamp(coef_1, 1 - self.epsilon, 1 + self.epsilon)
|
||||
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
|
||||
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
|
||||
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
||||
if self.beta != 0.0:
|
||||
per_token_loss = per_token_loss + self.beta * per_token_kl
|
||||
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
|
||||
|
||||
# Log the metrics
|
||||
completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
|
||||
self._metrics["completion_length"].append(completion_length)
|
||||
mode = "eval" if self.control.should_evaluate else "train"
|
||||
|
||||
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
||||
self._metrics["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
||||
if self.beta != 0.0:
|
||||
mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
|
||||
self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
||||
|
||||
is_clipped = (per_token_loss1 < per_token_loss2).float()
|
||||
clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
|
||||
self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
|
||||
return loss
|
||||
|
||||
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
|
||||
@ -727,11 +923,12 @@ class GRPOTrainer(Trainer):
|
||||
return loss, None, None
|
||||
|
||||
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
metrics = {key: sum(val) / len(val) for key, val in self._metrics.items()} # average the metrics
|
||||
mode = "eval" if self.control.should_evaluate else "train"
|
||||
metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
|
||||
|
||||
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
||||
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
||||
if next(iter(logs.keys())).startswith("eval_"):
|
||||
if mode == "eval":
|
||||
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
||||
|
||||
logs = {**logs, **metrics}
|
||||
@ -739,7 +936,7 @@ class GRPOTrainer(Trainer):
|
||||
super().log(logs, start_time)
|
||||
else: # transformers<=4.46
|
||||
super().log(logs)
|
||||
self._metrics.clear()
|
||||
self._metrics[mode].clear()
|
||||
|
||||
def create_model_card(
|
||||
self,
|
||||
|
@ -49,13 +49,11 @@ class SFTConfig(TrainingArguments):
|
||||
`skip_prepare_dataset`.
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
Number of processes to use for processing the dataset.
|
||||
max_seq_length (`int` or `None`, *optional*, defaults to `1024`):
|
||||
Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated from the
|
||||
right.
|
||||
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
||||
Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from the right.
|
||||
If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
|
||||
packing (`bool`, *optional*, defaults to `False`):
|
||||
Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to define sequence
|
||||
length.
|
||||
Whether to pack multiple sequences into a fixed-length format. Uses `max_length` to define sequence length.
|
||||
eval_packing (`bool` or `None`, *optional*, defaults to `None`):
|
||||
Whether to pack the eval dataset. If `None`, uses the same value as `packing`.
|
||||
|
||||
@ -95,19 +93,19 @@ class SFTConfig(TrainingArguments):
|
||||
default=None,
|
||||
metadata={"help": "Number of processes to use for processing the dataset."},
|
||||
)
|
||||
max_seq_length: Optional[int] = field(
|
||||
max_length: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated "
|
||||
"from the right. If `None`, no truncation is applied. When packing is enabled, this value sets the "
|
||||
"help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from"
|
||||
"the right. If `None`, no truncation is applied. When packing is enabled, this value sets the "
|
||||
"sequence length."
|
||||
},
|
||||
)
|
||||
packing: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to "
|
||||
"define sequence length."
|
||||
"help": "Whether to pack multiple sequences into a fixed-length format. Uses `max_length` to define "
|
||||
"sequence length."
|
||||
},
|
||||
)
|
||||
eval_packing: Optional[bool] = field(
|
||||
@ -132,13 +130,17 @@ class SFTConfig(TrainingArguments):
|
||||
num_of_sequences: int = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Deprecated. Use `max_seq_length` instead, which specifies the maximum length of the tokenized "
|
||||
"help": "Deprecated. Use `max_length` instead, which specifies the maximum length of the tokenized "
|
||||
"sequence, unlike `num_of_sequences`, which referred to string sequences."
|
||||
},
|
||||
)
|
||||
chars_per_token: float = field(
|
||||
default=None,
|
||||
metadata={"help": "Deprecated. If you want to customize the packing length, use `max_seq_length`."},
|
||||
metadata={"help": "Deprecated. If you want to customize the packing length, use `max_length`."},
|
||||
)
|
||||
max_seq_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Deprecated. Use `max_length` instead."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
@ -153,7 +155,7 @@ class SFTConfig(TrainingArguments):
|
||||
|
||||
if self.num_of_sequences is not None:
|
||||
warnings.warn(
|
||||
"`num_of_sequences` is deprecated and will be remove in version 0.18.0. Use `max_seq_length` instead, "
|
||||
"`num_of_sequences` is deprecated and will be remove in version 0.18.0. Use `max_length` instead, "
|
||||
"which specifies the maximum length of the tokenized sequence, unlike `num_of_sequences`, which r"
|
||||
"eferred to string sequences.",
|
||||
DeprecationWarning,
|
||||
@ -162,6 +164,12 @@ class SFTConfig(TrainingArguments):
|
||||
if self.chars_per_token is not None:
|
||||
warnings.warn(
|
||||
"`chars_per_token` is deprecated and will be remove in version 0.18.0. If you want to customize the "
|
||||
"packing length, use `max_seq_length`.",
|
||||
"packing length, use `max_length`.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if self.max_seq_length is not None:
|
||||
warnings.warn(
|
||||
"`max_seq_length` is deprecated and will be remove in version 0.20.0. Use `max_length` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
@ -41,17 +41,10 @@ from transformers import (
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalPrediction
|
||||
from transformers.utils import is_liger_kernel_available, is_peft_available
|
||||
from transformers.utils.deprecation import deprecate_kwarg
|
||||
|
||||
from ..data_utils import is_conversational, maybe_apply_chat_template, pack_examples
|
||||
from ..data_utils import is_conversational, maybe_apply_chat_template, maybe_convert_to_chatml, pack_examples
|
||||
from .sft_config import SFTConfig
|
||||
from .utils import (
|
||||
ConstantLengthDataset,
|
||||
compute_token_accuracy,
|
||||
generate_model_card,
|
||||
get_comet_experiment_url,
|
||||
peft_module_casting_to_bf16,
|
||||
)
|
||||
from .utils import ConstantLengthDataset, generate_model_card, get_comet_experiment_url, peft_module_casting_to_bf16
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
@ -107,6 +100,8 @@ class SFTTrainer(Trainer):
|
||||
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
||||
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
||||
and content).
|
||||
|
||||
The trainer also supports processed datasets (tokenized) as long as they contain an `input_ids` field.
|
||||
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
||||
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
||||
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
|
||||
@ -140,9 +135,6 @@ class SFTTrainer(Trainer):
|
||||
|
||||
_tag_names = ["trl", "sft"]
|
||||
|
||||
@deprecate_kwarg(
|
||||
"tokenizer", "0.16.0", "processing_class", warn_if_greater_or_equal_version=True, raise_if_both_names=True
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[str, nn.Module, PreTrainedModel],
|
||||
@ -181,12 +173,13 @@ class SFTTrainer(Trainer):
|
||||
)
|
||||
if isinstance(model, str):
|
||||
model = self._create_model_from_path(model, args)
|
||||
self.use_liger = is_liger_kernel_available() and isinstance(model, AutoLigerKernelForCausalLM)
|
||||
|
||||
# PEFT configuration and model wrapping
|
||||
if peft_config is not None:
|
||||
model = self._prepare_peft_model(model, peft_config, args)
|
||||
|
||||
# 3. Handle the tokenizer
|
||||
# Handle the tokenizer
|
||||
if processing_class is None:
|
||||
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path)
|
||||
if processing_class.pad_token is None:
|
||||
@ -275,8 +268,10 @@ class SFTTrainer(Trainer):
|
||||
if args.use_liger:
|
||||
if not is_liger_kernel_available():
|
||||
raise ImportError("Please install Liger-kernel for use_liger=True")
|
||||
return AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
|
||||
return AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
|
||||
model = AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
|
||||
return model
|
||||
|
||||
def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel:
|
||||
"""Prepares a model for PEFT training."""
|
||||
@ -368,6 +363,10 @@ class SFTTrainer(Trainer):
|
||||
if isinstance(dataset, ConstantLengthDataset):
|
||||
return dataset
|
||||
|
||||
# If the dataset is already preprocessed (tokenized), skip the processing steps.
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
is_processed = "input_ids" in column_names
|
||||
|
||||
# Build the kwargs for the `map` function
|
||||
map_kwargs = {}
|
||||
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
|
||||
@ -375,7 +374,15 @@ class SFTTrainer(Trainer):
|
||||
|
||||
with PartialState().local_main_process_first():
|
||||
# Apply the formatting function if any
|
||||
if formatting_func is not None:
|
||||
if formatting_func is not None and is_processed:
|
||||
warnings.warn(
|
||||
"You passed a dataset that is already processed (contains an `input_ids` field) together with a "
|
||||
"formatting function. Therefore `formatting_func` will be ignored. Either remove the "
|
||||
"`formatting_func` or pass a dataset that is not already processed.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
if formatting_func is not None and not is_processed:
|
||||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||
map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset"
|
||||
|
||||
@ -395,6 +402,15 @@ class SFTTrainer(Trainer):
|
||||
|
||||
dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"])
|
||||
|
||||
# Convert the dataset to ChatML if needed
|
||||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||
map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML"
|
||||
dataset = dataset.map(
|
||||
maybe_convert_to_chatml,
|
||||
remove_columns="conversations" if "conversations" in dataset.column_names else None,
|
||||
**map_kwargs,
|
||||
)
|
||||
|
||||
# Apply the chat template if needed
|
||||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
|
||||
@ -405,24 +421,30 @@ class SFTTrainer(Trainer):
|
||||
**map_kwargs,
|
||||
)
|
||||
|
||||
# Tokenize the dataset
|
||||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
|
||||
dataset = dataset.map(lambda ex: processing_class(ex[args.dataset_text_field]), **map_kwargs)
|
||||
# Tokenize the dataset if needed
|
||||
if not is_processed:
|
||||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
|
||||
|
||||
def tokenize(ex):
|
||||
tokenized = processing_class(ex[args.dataset_text_field])
|
||||
return {"input_ids": tokenized["input_ids"], "attention_mask": tokenized["attention_mask"]}
|
||||
|
||||
dataset = dataset.map(tokenize, **map_kwargs)
|
||||
|
||||
# Pack or truncate
|
||||
if packing:
|
||||
if args.max_seq_length is None:
|
||||
raise ValueError("When packing is enabled, `max_seq_length` can't be `None`.")
|
||||
if args.max_length is None:
|
||||
raise ValueError("When packing is enabled, `max_length` can't be `None`.")
|
||||
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||
map_kwargs["desc"] = f"Packing {dataset_name} dataset"
|
||||
dataset = dataset.select_columns("input_ids")
|
||||
dataset = dataset.map(
|
||||
pack_examples, batched=True, fn_kwargs={"seq_length": args.max_seq_length}, **map_kwargs
|
||||
pack_examples, batched=True, fn_kwargs={"seq_length": args.max_length}, **map_kwargs
|
||||
)
|
||||
elif args.max_seq_length is not None:
|
||||
elif args.max_length is not None:
|
||||
dataset = dataset.map(
|
||||
lambda ex: {key: ex[key][: args.max_seq_length] for key in ["input_ids", "attention_mask"]},
|
||||
lambda ex: {key: ex[key][: args.max_length] for key in ["input_ids", "attention_mask"]},
|
||||
**map_kwargs,
|
||||
)
|
||||
# For Liger kernel, ensure only input_ids is present
|
||||
@ -440,18 +462,28 @@ class SFTTrainer(Trainer):
|
||||
)
|
||||
|
||||
# Compute token accuracy if we have labels and if the model is not using Liger (no logits)
|
||||
if "labels" in inputs and not self.args.use_liger:
|
||||
if "labels" in inputs and not self.use_liger:
|
||||
shift_logits = outputs.logits[..., :-1, :].contiguous()
|
||||
shift_labels = inputs["labels"][..., 1:].contiguous()
|
||||
|
||||
# Gather logits and labels from all GPUs first
|
||||
shift_logits = self.accelerator.gather_for_metrics(shift_logits)
|
||||
shift_labels = self.accelerator.gather_for_metrics(shift_labels)
|
||||
# Get predictions
|
||||
predictions = shift_logits.argmax(dim=-1)
|
||||
|
||||
# Then compute accuracy on the gathered tensors
|
||||
if self.accelerator.is_main_process:
|
||||
accuracy = compute_token_accuracy(shift_logits, shift_labels)
|
||||
self._metrics["mean_token_accuracy"].append(accuracy)
|
||||
# Create mask for non-padding tokens (assuming ignore_index is -100)
|
||||
mask = shift_labels != -100
|
||||
|
||||
# Calculate accuracy only on non-padding tokens
|
||||
correct_predictions = (predictions == shift_labels) & mask
|
||||
total_tokens = mask.sum()
|
||||
correct_tokens = correct_predictions.sum()
|
||||
|
||||
# Gather the correct_tokens and total_tokens across all processes
|
||||
correct_tokens = self.accelerator.gather_for_metrics(correct_tokens)
|
||||
total_tokens = self.accelerator.gather_for_metrics(total_tokens)
|
||||
|
||||
# Compute the mean token accuracy and log it
|
||||
accuracy = (correct_tokens.sum() / total_tokens.sum()).item() if total_tokens.sum() > 0 else 0.0
|
||||
self._metrics["mean_token_accuracy"].append(accuracy)
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
|
@ -140,7 +140,7 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
|
||||
warnings.warn(
|
||||
f"Could not find response key `{self.response_template}` in the following instance: "
|
||||
f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss "
|
||||
"calculation. Note, if this happens often, consider increasing the `max_seq_length`.",
|
||||
"calculation. Note, if this happens often, consider increasing the `max_length`.",
|
||||
UserWarning,
|
||||
)
|
||||
batch["labels"][i, :] = self.ignore_index
|
||||
@ -167,7 +167,7 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
|
||||
warnings.warn(
|
||||
f"Could not find response key `{self.response_template}` in the following instance: "
|
||||
f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss "
|
||||
"calculation. Note, if this happens often, consider increasing the `max_seq_length`.",
|
||||
"calculation. Note, if this happens often, consider increasing the `max_length`.",
|
||||
UserWarning,
|
||||
)
|
||||
batch["labels"][i, :] = self.ignore_index
|
||||
@ -182,7 +182,7 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
|
||||
warnings.warn(
|
||||
f"Could not find instruction key `{self.instruction_template}` in the following instance: "
|
||||
f"{self.tokenizer.decode(batch['input_ids'][i])}. This instance will be ignored in loss "
|
||||
"calculation. Note, if this happens often, consider increasing the `max_seq_length`.",
|
||||
"calculation. Note, if this happens often, consider increasing the `max_length`.",
|
||||
UserWarning,
|
||||
)
|
||||
batch["labels"][i, :] = self.ignore_index
|
||||
@ -1650,27 +1650,6 @@ def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor
|
||||
return mask, *tensors
|
||||
|
||||
|
||||
def compute_token_accuracy(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> float:
|
||||
"""
|
||||
Compute the mean token accuracy.
|
||||
"""
|
||||
# Get predictions
|
||||
predictions = logits.argmax(dim=-1)
|
||||
|
||||
# Create mask for non-padding tokens (assuming pad_token_id is ignore_index)
|
||||
mask = labels != ignore_index
|
||||
|
||||
# Calculate accuracy only on non-padding tokens
|
||||
correct_predictions = (predictions == labels) & mask
|
||||
total_tokens = mask.sum()
|
||||
correct_tokens = correct_predictions.sum()
|
||||
|
||||
# Calculate accuracy
|
||||
accuracy = correct_tokens.item() / total_tokens.item() if total_tokens > 0 else 0.0
|
||||
|
||||
return accuracy
|
||||
|
||||
|
||||
def selective_log_softmax(logits, index):
|
||||
"""
|
||||
A memory-efficient implementation of the common `log_softmax -> gather` operation.
|
||||
|
Reference in New Issue
Block a user