Compare commits

...

1018 Commits

Author SHA1 Message Date
89bbe0d205 Merge branch 'main' into refactor_generate_5 2025-10-20 11:42:15 -06:00
28bba8c6b1 Added SFT LoRA notebook (#4244)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
2025-10-20 11:24:54 +02:00
2f1802bc6e Fix missing CI slow tests: ImportError: vLLM is not installed (#4304) 2025-10-20 08:03:48 +02:00
0739b1f87d fix import 2025-10-19 17:27:03 +00:00
7bb1ee07b0 require vision 2025-10-18 20:48:33 +00:00
ff6782a1e8 fix: update documentation for prepare_multimodal_messages_vllm 2025-10-18 05:10:05 +00:00
31913e2742 fix prepare_multimodal_messages 2025-10-18 05:05:19 +00:00
5f87ee989d fix return-dict 2025-10-18 03:55:42 +00:00
23d13f9ae9 oops 2025-10-18 01:18:04 +00:00
ced5450e0d safe prepare_multimodal_messages_vllm 2025-10-18 01:17:17 +00:00
1a6f04000b test 2025-10-18 00:53:54 +00:00
7a2936e0a2 style 2025-10-18 00:38:17 +00:00
ba8b93831f rloo 2025-10-18 00:37:20 +00:00
c0c88071a3 fix style 2025-10-18 00:08:25 +00:00
fe11512100 dedup and some fixes 2025-10-18 00:02:48 +00:00
919ff5bced Merge branch 'main' into refactor_generate_5 2025-10-17 22:59:41 +00:00
e0eec055b4 🧺 [4/N] Refactor _generate in GRPO/RLOO: Move forward_kwargs outside generation method (#4154)
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Co-authored-by: YonatanGideoni <yonatan.gideoni@gmail.com>
Co-authored-by: burtenshaw <ben.burtenshaw@gmail.com>
Co-authored-by: sergiopaniego <sergiopaniegoblanco@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-10-17 15:36:13 -06:00
f4c554da22 Update links to docs in README to latest packaged version (#4084) 2025-10-17 08:06:40 -06:00
a932e2796d ⬆️ Bump dev version (#4293) 2025-10-15 18:11:52 -06:00
04fd1203af Release: v0.24 (#4292) 2025-10-15 18:10:10 -06:00
19d2f97932 Deprecate BestOfNSampler (#4291)
Co-authored-by: behroozazarkhalili <ermiaazarkhalili>
Co-authored-by: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
2025-10-15 18:06:34 -06:00
31caf64778 Remove unused commands directory (#4258)
Co-authored-by: behroozazarkhalili <ermiaazarkhalili>
2025-10-15 18:01:50 -06:00
8e2d5516ca Add accuracy reward (#4270)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-10-15 18:01:07 -06:00
94aac4a101 Remove how_to_train.md: outdated training FAQ (#4267)
Co-authored-by: behroozazarkhalili <ermiaazarkhalili>
2025-10-15 23:49:04 +00:00
26b7c2507e Add support for token_type_ids in DPOTrainer (#4285)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-10-15 17:33:35 -06:00
aa25c2697c Remove using_llama_models.md: outdated Llama2-specific documentation (#4268)
Co-authored-by: behroozazarkhalili <ermiaazarkhalili>
2025-10-15 14:13:27 -07:00
93c7d88563 Remove logging.md: trainer-specific metrics documentation (#4269)
Co-authored-by: behroozazarkhalili <ermiaazarkhalili>
2025-10-15 14:12:32 -07:00
c7c041ecc8 Fix CI slow tests: ImportError: vLLM is not installed (#4287) 2025-10-15 18:15:36 +02:00
ef40c047aa Replace unittest skipTest with pytest.skip (#4263) 2025-10-15 18:15:28 +02:00
7e0adbc552 Fix CI dev test TypeError: unexpected keyword argument 'load_in_4bit' (#4262) 2025-10-15 18:14:49 +02:00
773afd9314 💰 RichProgressCallback enhancement (#4245) 2025-10-15 09:39:17 -06:00
966b397201 Fix CI slow test OSError: You are trying to access a gated repo (#4283) 2025-10-15 16:11:11 +02:00
927cf6ba46 Fix docstrings with Sphinx 'deprecated' directive (#4279) 2025-10-15 10:39:12 +02:00
56cb6ccf76 Fix typo in Colab link (#4276) 2025-10-14 18:51:17 +02:00
49c8f14b06 Add Qwen3-VL notebooks (SFT, GRPO) (#4275)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-14 18:45:01 +02:00
cefbacb30e Fix style with make precommit (#4265) 2025-10-14 12:13:15 +02:00
fae245a062 Use FutureWarning instead of DeprecationWarning (#4266) 2025-10-14 12:12:03 +02:00
2aa9506c69 Fix docstring interlinks (#4221) 2025-10-13 13:40:24 +02:00
d6eeb290d9 Raise deprecation warning for Python 3.9 (#4226) 2025-10-13 11:06:09 +02:00
1684ef279a Fix Python version check for skipping tests on Python 3.13.8 (#4246) 2025-10-10 17:41:24 +02:00
aab21eb5e7 Include chat_template_kwargs in apply_chat_template (#4233)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-10-10 10:39:29 -05:00
b997a31981 [Online-DPO] fix the completion_len == max_new_tokens crash (#4193)
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-10-10 17:21:01 +02:00
86d1963cc1 Fix CI slow test AttributeError: 'TestSFTTrainerSlow' object has no attribute 'addCleanup' (#4255) 2025-10-10 17:19:53 +02:00
039d526d24 Deprecate unused dataset_formatting module (#4242)
Co-authored-by: behroozazarkhalili <ermiaazarkhalili>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-10-10 10:16:18 -05:00
bcd059a384 Remove obsolete research_projects directory (#4243)
Co-authored-by: behroozazarkhalili <ermiaazarkhalili>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-10-10 10:15:47 -05:00
0e57b4a9df 🧺 [3/N] Refactor _generate in GRPO/RLOO: Rely on generator for prompt truncation (#4153) 2025-10-10 10:02:11 -05:00
98488e0946 Fix CI slow test ValueError: Unknown loss type: dapo (#4254) 2025-10-10 16:37:02 +02:00
f45e86571b Fix CI ImportError for 'require_torch_gpu_if_bnb_not_multi_backend_enabled' (#4253) 2025-10-10 16:13:22 +02:00
f5827928a0 Install peft from main for CI tests with dev dependencies (#4250) 2025-10-10 16:12:15 +02:00
f853e091ea Fix CI CUDA out of memory errors by improving GPU memory management (#4238) 2025-10-10 09:49:45 +02:00
803ec0d856 Fix CI slow test ValueError: Backward pass should have cleared tracker of all tensors (#4236)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-10-10 09:28:34 +02:00
7a0a615d50 Warnings pointing to RFC (#4224) 2025-10-09 17:05:36 -06:00
c38cb69ec7 🧘 Enhance markdown style (#4235) 2025-10-09 13:49:44 -05:00
68ef15c686 Remove unused log_example_reports.py script (#4241)
Co-authored-by: behroozazarkhalili <ermiaazarkhalili>
2025-10-09 09:18:48 -07:00
3dd7fc2850 Fix CI IndentationError for Python 3.13.8 (#4240) 2025-10-09 15:46:41 +02:00
51ced65153 Replace setup with pyproject in CI tests paths (#4230) 2025-10-09 09:35:08 +02:00
4bb883a6e6 Update CI Docker image to pytorch/pytorch:2.8.0 (#4232) 2025-10-09 08:09:15 +02:00
f7846321e7 Remove unused Path import in __init__.py (#4227) 2025-10-08 21:30:54 +02:00
a944890ff1 Fix callable annotations (#4216) 2025-10-08 21:21:21 +02:00
521db3520a Fix CI unittest asserts (#4234) 2025-10-08 21:18:41 +02:00
e2c97a805a Exclude vllm dependencies from dev extra (#4229) 2025-10-08 18:14:23 +02:00
d1d0407d3c 🏷️ Account for token_type_ids in DataCollatorForVisionLanguageModeling (#4190) 2025-10-08 09:34:48 -06:00
824ff8c73e Add Efficient Online Training with GRPO and vLLM in TRL to community tutorials (#4219) 2025-10-08 12:59:04 +02:00
f15399d3d3 Fix entropy and accuracy calculation for prompt_tuning techniques. (#4196) 2025-10-08 09:42:19 +01:00
f6e7c200c0 Merge branch 'refactor_generate_4' into refactor_generate_5 2025-10-07 12:16:00 -06:00
a0ee1e635f Merge branch 'refactor_generate_3' into refactor_generate_4 2025-10-07 12:15:32 -06:00
45290c9cfc Merge branch 'main' into refactor_generate_3 2025-10-07 12:15:11 -06:00
cc578b6b14 🧺 [2/N] Refactor _generate in GRPO/RLOO: Use prompt_ids from generation (#4152) 2025-10-07 12:11:34 -06:00
30cf68a97b 🎨 Support mixing image+text and text-only examples (#4203)
Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
2025-10-07 10:21:10 -06:00
452284b8dc Add trainers taxonomy to docs (#4195) 2025-10-07 16:06:30 +02:00
6be53e19bc [DOCS] fix prose in lora guide (#4217) 2025-10-07 10:40:37 +02:00
3080fc1bd7 Fix LoRA params in Python in LoRA without regret (#4215) 2025-10-07 09:56:04 +02:00
5e4a026160 Merge branch 'refactor_generate_4' into refactor_generate_5 2025-10-06 18:41:57 -06:00
5d870955f8 Fix prompt-completion labeling with add_generation_prompt and warning (#4201)
Co-authored-by: behroozazarkhalili <ermiaazarkhalili>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-10-06 18:35:50 -06:00
ed54e2a1cb Merge branch 'refactor_generate_3' into refactor_generate_4 2025-10-06 18:34:31 -06:00
ee03478a14 remove test case for prompt truncation 2025-10-07 00:32:37 +00:00
e3c679c9c7 style 2025-10-06 23:59:17 +00:00
ddf3405c6c gfpo 2025-10-06 23:59:08 +00:00
2ce6c1ff41 token_type_ids and RLOO 2025-10-06 23:53:53 +00:00
34034e7f76 Merge branch 'refactor_generate_3' into refactor_generate_4 2025-10-06 17:44:45 -06:00
a84325c73b style 2025-10-06 22:35:42 +00:00
cb1d4201f7 Merge branch 'refactor_generate_4' into refactor_generate_5 2025-10-06 16:34:22 -06:00
2c012dca20 Merge branch 'refactor_generate_3' into refactor_generate_4 2025-10-06 16:25:24 -06:00
db552be924 Merge branch 'refactor_generate_2' into refactor_generate_3 2025-10-06 16:25:14 -06:00
4a274d5271 Merge branch 'main' into refactor_generate_2 2025-10-06 16:25:07 -06:00
8265800abf Fix trl-internal-testing/tiny-DbrxForCausalLM (#4213) 2025-10-06 15:11:16 -06:00
ac2717f980 Merge branch 'refactor_generate_3' into refactor_generate_4 2025-10-06 13:21:18 -06:00
766bbcefa0 Merge branch 'refactor_generate_2' into refactor_generate_3 2025-10-06 13:19:59 -06:00
5b9a6ab7ae Merge branch 'main' into refactor_generate_2 2025-10-06 13:16:57 -06:00
65eb45c32b Apply style and revert change in sft_video_llm example (#4214) 2025-10-06 13:07:18 -06:00
ae6837f8d4 Removed tokenizer/processor creation from example scripts (#4211) 2025-10-06 18:40:18 +02:00
df386f9667 Merge branch 'main' into refactor_generate_2 2025-10-06 10:02:54 -06:00
7f5b4995b6 Replace setup with pyproject and fix packaging unintended modules (#4194) 2025-10-06 15:56:32 +00:00
d258e36e45 Remove Optional from processing_class in PPOTrainer (#4212) 2025-10-06 15:56:32 +00:00
4fdaa4c672 Updated vLLM integration guide (#4162)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-10-06 15:56:31 +00:00
8319ce0b75 Replace unittest with pytest (#4188) 2025-10-06 15:56:29 +00:00
6543f51a9d Hotfix: Exclude transformers 4.57.0 for Python 3.9 (#4209)
Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
2025-10-06 15:55:39 +00:00
ae2a0e71ad Remove tokenizer creation from sft example script (#4197) 2025-10-06 15:55:39 +00:00
5d34144b6f Remove custome_container for building the docs (#4198) 2025-10-06 15:55:38 +00:00
c1e7ad2696 [DOCS/FIX] lora without regrets - fix lr (#4207) 2025-10-06 15:55:38 +00:00
21a67fc43f [DOCS] Lora without regret (#4181)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: sergiopaniego <sergiopaniegoblanco@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-10-06 15:55:38 +00:00
648947911a Replace remaining trainer.tokenizer with trainer.processing_class in GRPO test (#4192) 2025-10-06 15:55:38 +00:00
f9c3c3c726 🌡️ Have vLLM return processed (temperature scaled) log probs (#4163)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-10-06 15:55:38 +00:00
cf9d8e76c4 Hotfix wrong formatting of docstrings with blockquote tips (#4187) 2025-10-06 15:55:38 +00:00
192deb3b2b Fix CI ImportError: FlashAttention2 and decorator order for all parameterized tests (#4176) 2025-10-06 15:55:38 +00:00
e82db740f0 🔣 Fix test: replace trainer.tokenizer by trainer.processing_class (#4185) 2025-10-06 15:55:38 +00:00
56a8f1128b Replace setup with pyproject and fix packaging unintended modules (#4194) 2025-10-06 17:45:44 +02:00
529101537f Remove Optional from processing_class in PPOTrainer (#4212) 2025-10-06 16:04:06 +02:00
0588b1f01d Updated vLLM integration guide (#4162)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-10-06 15:57:17 +02:00
45ee98b05e Replace unittest with pytest (#4188) 2025-10-06 11:14:54 +02:00
3800a6ecc7 Hotfix: Exclude transformers 4.57.0 for Python 3.9 (#4209)
Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
2025-10-06 11:13:21 +02:00
7ad9ce8acc Remove tokenizer creation from sft example script (#4197) 2025-10-06 11:04:20 +02:00
0c2dc14014 Remove custome_container for building the docs (#4198) 2025-10-06 08:31:58 +02:00
ced8b337ba [DOCS/FIX] lora without regrets - fix lr (#4207) 2025-10-06 08:23:11 +02:00
1eff7da9e0 [DOCS] Lora without regret (#4181)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: sergiopaniego <sergiopaniegoblanco@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-10-03 20:40:37 +02:00
1cbfb00b6a Replace remaining trainer.tokenizer with trainer.processing_class in GRPO test (#4192) 2025-10-03 09:08:53 +02:00
e086f073cf 🌡️ Have vLLM return processed (temperature scaled) log probs (#4163)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-10-01 11:58:13 -06:00
e5d437ed76 Hotfix wrong formatting of docstrings with blockquote tips (#4187) 2025-10-01 19:42:36 +02:00
d1b4691900 Fix CI ImportError: FlashAttention2 and decorator order for all parameterized tests (#4176) 2025-10-01 18:01:56 +02:00
39c603872f 🔣 Fix test: replace trainer.tokenizer by trainer.processing_class (#4185) 2025-10-01 09:16:42 -06:00
d599c207cd Merge branch 'main' into refactor_generate_2 2025-10-01 08:49:04 -06:00
5a4021f23e Fix handling of f_divergence_type in DPO (#4171)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-10-01 09:44:14 +02:00
377b0811c9 rm test_training_vlm_and_prompt_truncation 2025-10-01 02:28:38 +00:00
c434fa23bf truncation_side=left 2025-10-01 02:28:07 +00:00
ddfd3b58c9 same for rloo 2025-10-01 02:14:43 +00:00
4dce145d40 remove vision tokens 2025-10-01 01:09:40 +00:00
5cc6af57a5 Merge branch 'refactor_generate_2' into refactor_generate_3 2025-09-30 19:00:51 -06:00
5fca5b8802 fix normal generation path 2025-10-01 00:46:15 +00:00
49577adb19 Same for RLOO 2025-10-01 00:17:37 +00:00
e164ec5aab repicate all_prompt_ids 2025-10-01 00:11:48 +00:00
e7aa945273 fix vllm client server 2025-09-30 23:10:16 +00:00
f11759e66d Merge branch 'main' into refactor_generate_2 2025-09-30 16:29:59 -06:00
ea66a9e650 🧺 [1/N] Refactor _generate in GRPO/RLOO: list of ints instead of tensors (#4146)
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
2025-09-30 16:22:30 -06:00
da209f89fc 🎁 RewardTrainer refactor (#4093)
Co-authored-by: juejuezi <juejuezi.git@foxmail.com>
Co-authored-by: Yi Shi <96773624+singing-cat@users.noreply.github.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-09-30 15:13:45 -06:00
ebb8899f5d Fix Flash Attention x Padding-Free loss (#4170) 2025-09-30 12:01:29 -06:00
70e2017dbc 🎞️ Support sequence classification models in clone_chat_template (#4097) 2025-09-30 11:42:56 -06:00
4368f54c97 👾 Use our own require_bitsandbytes (#4137) 2025-09-30 11:11:29 -06:00
22720d176b Add logging for training completion and model saving in training scripts (#4048) 2025-09-30 10:57:33 -06:00
c8a5add88a Fix PEFT interlinks in docstrings (#4178) 2025-09-30 18:32:23 +02:00
a7b54f988b Fix CI ValueError: Unknown loss type: dapo (#4173) 2025-09-30 18:27:21 +02:00
78bf77abbd 🅰️ Remove apex (#4139) 2025-09-30 09:52:52 -06:00
3b9ac65a05 🖨️ Print rich table for messages (#4160) 2025-09-30 09:07:57 -06:00
7a78320f58 Fix link in docstring of RLOOTrainer (#4180) 2025-09-30 16:54:55 +02:00
67e83aee90 Fix docstring interlink to parent class for NashMDTrainer and XPOTrainer (#4179) 2025-09-30 15:43:37 +02:00
a0df357591 Fix docstrings with 'deprecated' Sphinx directive (#4174) 2025-09-30 10:13:35 +02:00
864e593e9f Add missing FDivergenceType docstring (#4165) 2025-09-29 20:03:33 +02:00
6428647063 Remove unnecessary list comprehensions (#4164) 2025-09-29 20:02:46 +02:00
8a5bfecc3a 💡 Replace <Tip> with new markdown syntax (#4161)
Co-authored-by: sergiopaniego <sergiopaniegoblanco@gmail.com>
2025-09-29 10:48:00 -06:00
910aeebe06 Pass required token_type_ids (#4148) 2025-09-29 17:40:11 +02:00
e208823b3e Add docstring for OnlineTrainerState (#4166) 2025-09-29 17:26:14 +02:00
a01b9caf81 Merge branch 'refactor_generate_4' into refactor_generate_5 2025-09-26 19:34:32 -06:00
b0e02795e2 Merge branch 'refactor_generate_3' into refactor_generate_4 2025-09-26 19:34:15 -06:00
3f02702600 Merge branch 'refactor_generate_2' into refactor_generate_3 2025-09-26 19:34:11 -06:00
4b9c1262a9 Merge branch 'refactor_generate' into refactor_generate_2 2025-09-26 19:34:00 -06:00
e82bfb4264 Merge branch 'main' into refactor_generate 2025-09-26 19:33:52 -06:00
f397a61e82 😷 Refactor GRPO/RLOO to isolate _generate for GRPO with replay buffer (#4158) 2025-09-26 19:31:06 -06:00
effb41ba5d Merge branch 'main' into refactor_generate 2025-09-26 19:12:04 -06:00
7fe9dd42ac 📽 Multi image support for GRPO replay buffer (#4157) 2025-09-26 19:11:53 -06:00
79c774af54 🟩 Drop image_split_sizes in favour of image_grid_thw (#4156) 2025-09-26 18:50:27 -06:00
c5064d61ea gfpo 2025-09-27 00:04:17 +00:00
7b7a11d833 test and doc 2025-09-27 00:00:52 +00:00
b8c0c9b219 Merge branch 'refactor_generate_2' into refactor_generate_3 2025-09-26 17:49:26 -06:00
c8041e1ccc Merge branch 'refactor_generate' into refactor_generate_2 2025-09-26 17:48:06 -06:00
55a2480195 rloo + doc 2025-09-26 23:46:50 +00:00
15c6620c84 refactor: update prepare_multimodal_messages to accept images directly and enhance handling of structured messages 2025-09-26 23:32:38 +00:00
48a1c30e7e don't re-prepare data 2025-09-26 22:20:23 +00:00
9925199ee9 move forward_kwargs outside of generate 2025-09-26 22:14:58 +00:00
8149d0578f rm truncation test 2025-09-26 21:27:47 +00:00
35f99fd867 requires padding 2025-09-26 21:27:33 +00:00
fc263a309a rm imports 2025-09-26 20:01:37 +00:00
d8af0039fa rm useless comment 2025-09-26 19:59:12 +00:00
0b5865e8f5 ensure proper truncation and side 2025-09-26 19:57:23 +00:00
acee7d817f rm truncate_with_protected_tokens 2025-09-26 19:45:09 +00:00
11acc758c2 rm enforce eager 2025-09-26 19:43:45 +00:00
46d8eb79cf revert 2025-09-26 19:43:17 +00:00
0e2ae34a93 rely on generator for prompt truncation 2025-09-26 19:41:24 +00:00
e770efeede Merge branch 'refactor_generate' into refactor_generate_2 2025-09-26 12:57:02 -06:00
8d34d546bb remove pad token removal 2025-09-26 18:56:45 +00:00
d79b9e1c8f get prompt ids from generation 2025-09-26 18:41:51 +00:00
b3bd0b05d4 another one 2025-09-26 18:05:49 +00:00
9da4830c53 simplify a bit + comment 2025-09-26 16:22:44 +00:00
236b78b455 better 2025-09-26 16:14:18 +00:00
8766fa5cc0 consistent naming 2025-09-26 16:12:07 +00:00
53772ef7b8 getting closer 2025-09-26 16:02:03 +00:00
27dc9585a0 fix num_input_tokens_seen 2025-09-26 03:09:42 +00:00
3d8ea27c68 wrong merge commit 2025-09-26 02:54:26 +00:00
d3f1d3c801 Merge branch 'main' into refactor_generate 2025-09-25 20:51:09 -06:00
9603b41d7e 😷 Refactor GRPO/RLOO to isolate _generate (#4114) 2025-09-25 20:48:52 -06:00
9435a9400f refactor in grpo 2025-09-26 02:48:11 +00:00
2dc69a68e0 Merge branch 'main' into generate-method 2025-09-25 18:01:23 -06:00
1a66b431d0 revert chage data utils 2025-09-25 23:57:14 +00:00
c1ae6aa787 back to working point 2025-09-25 23:56:11 +00:00
8b3a724602 progress again again 2025-09-25 23:27:53 +00:00
5ee56ed04f Fixed some <Tip> rendering issues (#4143) 2025-09-25 14:47:46 -06:00
0213662cd4 progress continues 2025-09-25 18:24:46 +00:00
e85e634bff Refactor trainers classes to use BaseTrainer with shared functionality (#4128) 2025-09-25 18:32:57 +02:00
ebe32c26d8 progress 2025-09-25 06:14:02 +00:00
b0dceb97ac restart 2025-09-25 04:03:39 +00:00
d633c4337f Fix import statement and GRPO test case (#4141) 2025-09-24 16:23:32 -06:00
d1e24df031 [GRPO]: Sample from a Replay Buffer To Substitute Groups with 0 std. (#4060)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-09-24 21:12:16 +01:00
b4cadde233 Merge branch 'main' into generate-method 2025-09-24 13:57:42 -06:00
094e0760d4 🌵 Mark GKD trainer test as expected failure due to OOM issue (#4126) 2025-09-24 12:26:44 -06:00
01c9b4c414 🤸‍♀️ Fix DFT test (#4135) 2025-09-24 12:25:56 -06:00
18faf03c4e Fix CI: torch.AcceleratorError: CUDA error: device-side assert triggered (#4138) 2025-09-24 20:12:17 +02:00
ec6ad259d2 nits style and align 2025-09-24 17:26:25 +00:00
c83e710831 same for rloo 2025-09-24 17:17:14 +00:00
cdb4c76a3f Merge branch 'main' into generate-method 2025-09-24 10:09:25 -06:00
d144e73e78 🪙 [Experimental] Support GSPO-token (#3820)
Co-authored-by: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-09-24 09:57:18 -06:00
be1ffe59d2 🌺 Fix GPT-OSS test (#4134) 2025-09-24 09:07:48 -06:00
fb6bdab33b Improve typing of SFT trainer (#4007) 2025-09-24 07:45:03 -06:00
526303edbd [SFTrainer]: Fix DFT Loss (#4112)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-09-24 11:46:12 +01:00
9e5e60c933 👩‍🦯 Fix usage of VLM using text only (#4080)
Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-09-23 12:07:25 -06:00
5c52f46f9a Remove Python version < 3.13 constraint from vllm extra dependencies (#4125) 2025-09-23 17:04:32 +02:00
365d5017f4 Merge branch 'main' into generate-method 2025-09-23 08:55:43 -06:00
deac14a39f 🧹 Remove max_batch_tokens, num_blocks and block_size from generation kwargs (#4065) 2025-09-23 08:50:52 -06:00
3d5a30bb77 👋 Remove backend parameter from GuidedDecodingParams (#4123) 2025-09-23 08:12:13 -06:00
251fdb228a 📌 Pin vLLM version (#4122) 2025-09-23 08:02:30 -06:00
37806e618b 📤 Fix a dataset loading bug in scripts 2025-09-23 05:21:40 +00:00
008c7ad9aa [vllm] ensure MASTER_ADDR/MASTER_PORT are set safely (#4057)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-09-22 23:19:12 -06:00
e8ba9eaf27 📤 Fix a dataset loading bug in scripts (#4124)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-09-22 22:58:40 -06:00
abe07c9e32 🐯 fix: use_liger_kernel with IterableDataset (#4087)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-09-22 20:23:58 -06:00
d8665e1236 Merge branch 'main' into generate-method 2025-09-22 20:21:14 -06:00
fe02ea2b52 😴 Add vllm_enable_sleep_mode to RLOO Trainer (#4107)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-09-22 19:41:29 -06:00
a6a8c448a0 Merge branch 'main' into generate-method 2025-09-22 18:19:32 -06:00
68408d7219 📽 Multi image support for GRPO/RLOO (#4113) 2025-09-22 18:17:42 -06:00
c5004406ff Merge branch 'multi-image-support' into generate-method 2025-09-22 18:08:02 -06:00
9b6652eed4 rm VLM x RM warning 2025-09-23 00:05:23 +00:00
1c53094868 clarify image column desc 2025-09-22 23:57:13 +00:00
05270f820f update layers to ignore 2025-09-22 23:51:57 +00:00
485781cb3e Merge branch 'main' into multi-image-support 2025-09-22 17:47:19 -06:00
94f8d00a62 🔭 Align param passing to VLM configs in generate_tiny_models (#4118)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-09-22 17:14:10 -06:00
562c662c2b Merge branch 'main' into multi-image-support 2025-09-22 16:42:28 -06:00
b5ca3799ad 🟩 Drop image_split_sizes in favour of image_grid_thw (#4111) 2025-09-22 16:38:39 -06:00
efbb03a0d6 Merge branch 'drop-image_split_sizes' into multi-image-support 2025-09-22 16:20:42 -06:00
e17ec42797 Merge branch 'main' into drop-image_split_sizes 2025-09-22 16:17:57 -06:00
a68b4af50f Fix code style with make precommit (#4119) 2025-09-22 13:19:54 -06:00
9f0ed8b130 CI hotfix: xfail test_training_with_transformers_paged for transformers<4.57.0 (#4120) 2025-09-22 13:19:30 -06:00
27f22ba5a1 docs: correct option name to enable vllm sleep mode (#4102)
Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
2025-09-22 13:04:00 +02:00
d3a769fe8f fix doc 2025-09-20 17:15:13 +00:00
b628744752 rm vllm 2025-09-20 17:15:02 +00:00
4fc2b5b71d gfpo 2025-09-20 17:13:23 +00:00
4d12aebf33 Merge branch 'multi-image-support' into generate-method 2025-09-20 10:53:36 -06:00
fc52e6832d test fixed! 2025-09-20 16:26:34 +00:00
dfc0d388ab Merge branch 'drop-image_split_sizes' into multi-image-support 2025-09-20 09:52:06 -06:00
52d8bd91b0 Merge branch 'main' into drop-image_split_sizes 2025-09-20 09:51:51 -06:00
fa738768c6 skip failing test 2025-09-20 15:21:36 +00:00
86f74b486f Fix VLM configs in generate_tiny_models (#4101) 2025-09-20 09:49:16 +02:00
f998432622 debug 2025-09-20 05:18:40 +00:00
ae1f497959 generate method 2025-09-20 05:08:48 +00:00
fc6b11fcae update test 2025-09-20 04:22:54 +00:00
529add673c oops 2025-09-20 03:55:03 +00:00
099a39bd6a peft rloo 2025-09-20 03:04:07 +00:00
1257796ba8 rloo test 2025-09-20 03:01:47 +00:00
f4c82bfc04 fix gfpo 2025-09-20 02:55:59 +00:00
d2adc63eb6 test peft 2025-09-20 02:52:33 +00:00
088897b9cd fix 2025-09-20 02:25:10 +00:00
86cc30bf3c gfpo 2025-09-20 00:43:43 +00:00
30ad7ca371 rloo 2025-09-20 00:37:54 +00:00
dcf4b92da0 no vlm reward models 2025-09-20 00:18:18 +00:00
3ca6ad5003 log with wandb 2025-09-19 23:31:06 +00:00
229c554929 multi-image grpo 2025-09-19 22:45:57 +00:00
c8933aa856 gfpo 2025-09-19 21:10:06 +00:00
449ef07919 simpler 2025-09-19 21:05:47 +00:00
552e899015 Refactor image handling: replace image_split_sizes with image_grid_thw in GRPO and RLOO trainers; update split_pixel_values_by_grid to use image_grid_thw 2025-09-19 20:57:51 +00:00
26b497ea63 Fix typos (#4109) 2025-09-19 09:44:07 -06:00
d22bdb8031 Fix typos (#4106)
Signed-off-by: Yuanyuan Chen <cyyever@outlook.com>
Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
2025-09-19 16:58:43 +02:00
0e204482e6 Some nits GRPO and RLOO trainer docs (#4108)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-09-19 16:37:25 +02:00
3c8d7209f1 👁️ Add VLM support to RLOO trainer (#4067)
Co-authored-by: behroozazarkhalili <ermiaazarkhalili>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-09-18 21:54:06 -06:00
0450f05ad9 [GKD] Fix batchmean reduce op in GKDTrainer's loss (#4105) 2025-09-18 19:44:04 +02:00
7e2075347e Fix get_peft_model() so that prepare_model_for_kbit_training does not reapply to an instance of PeftModel, thus freezing all the layers (#4081)
Co-authored-by: Hoesu <hoesu.chung@qraftec.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-09-18 10:31:03 +02:00
20cc58d777 ℹ️ Enable XPU for vLLM client (#4031) 2025-09-17 22:06:25 -06:00
a6c0c57f6b ℹ️ feat: Add NPU and XPU support for activation offloading (#4056) 2025-09-17 22:03:56 -06:00
10dc36d610 🌪️ [GFPO]: implement GFPO in GRPOTrainer (#3989) 2025-09-17 19:14:40 -06:00
d2d1912d96 ⚖️ Align SFT and DPO for model creation and deprecate DPOConfig.padding_value in favour or pad_token_id (#4006)
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
2025-09-17 18:39:26 -06:00
08ea00289a 🧶 feat: Add WeaveCallback for W&B Weave integration (#4089)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-09-17 18:10:45 -06:00
4ff8b4e007 📜 Convert set to list of tags (#4092) 2025-09-17 14:05:41 -06:00
6356343fd2 Add deprecation warnings to docstrings (#4083)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-09-17 09:30:43 +02:00
45e59f77ea ⌨️ Pin num2words (#4094)
Co-authored-by: sergiopaniego <sergiopaniegoblanco@gmail.com>
2025-09-16 08:48:09 -06:00
4bd4acf172 🏞️ Context Parallelism benchmark guide (#4075)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-09-16 08:46:12 -06:00
8380869d33 Community Tutorials design adaptation for videos (#4095) 2025-09-16 16:28:22 +02:00
5139af3712 Add support for testing experimental features (#4082) 2025-09-16 07:46:48 +02:00
2f46c18a66 Align slow tests with regular tests (#4085) 2025-09-16 07:22:30 +02:00
e2b18ec4e7 ▶️ Add video to community tutorials (#4090)
Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
2025-09-15 10:51:23 -06:00
78f1a928ce 🗑️ Remove deprecated AlignPropTrainer, DDPOTrainer and IterativeSFTTrainer (#4068) 2025-09-15 09:56:41 -06:00
1d0b196f6b Reviewed HF jobs updated docs (#4088) 2025-09-15 08:41:08 -06:00
5a1c2f9b3b Aux loss is already included in the loss returned by Transformers (#4078) 2025-09-14 16:56:58 +01:00
9955ee7eaa 🐳 Docker update + Simplify Jobs doc (#3931)
Co-authored-by: sergiopaniego <sergiopaniegoblanco@gmail.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-09-13 18:35:55 -06:00
304eaf8053 🛠️ Fix CI (#4076) 2025-09-13 12:38:48 -06:00
69e288ebad ✂️ [GRPO VLM] Update split sizes to generalize (#4032)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-09-12 19:11:32 -06:00
d655ce48f8 🌾 [Experimental] BEMA for ref model (#3898)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-09-12 11:47:44 -06:00
91c4bba922 🧪 Add trl.experimental Submodule (#4073)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-09-12 11:02:23 -06:00
2845d024a4 Set Ruff src for first-party imports (#4074) 2025-09-12 15:43:04 +02:00
f4ff248407 ♨️ [GRPO] Fix potential hang in get_high_entropy_mask (#4041) 2025-09-11 19:33:39 -06:00
b8eb5c5d2d Improve docstring of AlignPropTrainer (#4059)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-09-11 11:42:31 -06:00
07f9ad982d 💡 Fix type hint to make_parser function in multiple scripts (#4050) 2025-09-11 11:36:05 -06:00
417915a3e4 Fix CI failure in slow GRPO test due to missing pillow dependency (#4064) 2025-09-11 17:35:57 +02:00
44ddc28bcd Hotfix: Add ParallelismConfig fallback for transformers with old accelerate (#4063) 2025-09-11 15:11:41 +02:00
e8b8499f1f Remove redundant 'None' from docstrings (#4058) 2025-09-11 08:16:34 +02:00
7eb7f42372 ⬆️ Bump dev version (#4054) 2025-09-09 22:17:35 -06:00
6adfd138d8 Release: 0.23 (#4053) 2025-09-09 22:16:17 -06:00
a647e5a78a 🗜 Hotfix: avoid passing quantization_config=None (#4019) 2025-09-09 14:50:15 -06:00
816ac610c0 🪪 Update SFTTrainer to handle labels correctly and add configuration example in paper index (#4051) 2025-09-09 14:49:36 -06:00
373a64a7ce 💬 Remove setting chat template in sft script (#4037) 2025-09-09 13:24:08 -06:00
09e19244c0 Improve SFT doc (#4005)
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
2025-09-09 13:22:00 -06:00
a228cb51d1 Add autodoc for BestOfNSampler and improve docstrings (#4034) 2025-09-09 20:28:02 +02:00
6c6f13b5f3 🏂 Fix label shifting logic in SFTTrainer for compatibility with CP (#4038) 2025-09-09 12:08:38 -06:00
b3f9f613f9 Update VLM arch check to AutoModelForImageTextToText for DPO and Online DPO (#4049) 2025-09-09 11:10:27 -06:00
659d2c1284 🧨 DFT (#4042) 2025-09-09 08:23:30 -06:00
82b34e5723 Update transformers minimum version to 4.56.1 (#4047) 2025-09-09 16:05:04 +02:00
27e30f86ef CI hotfix: xfail test_training_with_transformers_paged (#4046) 2025-09-09 15:47:25 +02:00
af82b38482 ⚖️ Remove average_tokens_across_devices default replacement (#4039) 2025-09-09 07:39:12 -06:00
1b799a23c1 🥓 [docs] add CP docs (#3994)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-09-08 21:46:22 -06:00
e4ebf3ba11 Add autodoc for AlignPropTrainer and AlignPropConfig (#4033) 2025-09-08 20:13:23 +02:00
e458df650a Add missing trainer docstrings (#4030) 2025-09-08 20:12:58 +02:00
a1ee7d2182 [doc] Group paper index by trainer (#4027) 2025-09-08 18:03:48 +02:00
1d06757e57 [doc] Paper index for Truncated Importance Sampling (#4026) 2025-09-08 08:11:08 +02:00
4f9009b0f2 Fix formatting errors in docstrings (#4025) 2025-09-08 07:22:00 +02:00
c9484b161f Align docstring parameters with function definitions (#4017) 2025-09-07 10:40:09 +02:00
f5c2fec4a9 Fix typo in GRPO quickstart (#4020) 2025-09-06 10:31:09 +02:00
d1bf56020d ⚖️ Add vLLM server mode and VLM support to OnlineDPOTrainer (#3783)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-09-05 16:58:49 -06:00
19f9b9ee69 Add missing doc strings in SFTrainer (#4003)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-09-04 23:20:07 +01:00
1eb38018b7 [SFTTrainer]: Add Aux Loss for MoE models. (#4012)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-09-04 22:49:39 +01:00
deae7e00b8 🌵 Refactor entropy_from_logits for memory efficiency (#4013) 2025-09-04 13:59:48 -06:00
0c69fd2867 👷 Added Kernels on the Hub x TRL guide (#3969)
Co-authored-by: vb <vaibhavs10@gmail.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-09-04 15:37:02 +02:00
b5fd290b2c [SFT] fix: collator docstring (#4011) 2025-09-04 14:35:09 +02:00
67991605c0 Comprehensive Paper Index Enhancement with 9 New Algorithm Implementations (#3990)
Co-authored-by: behroozazarkhalili <ermiaazarkhalili>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-09-03 19:59:16 -06:00
208e9f7df7 📏 torch_dype to dtype everywhere (#4000)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-09-03 15:45:37 -06:00
3bfa981bd2 [GRPO]: Fix Multi-GPU training for Entropy based masking of tokens. (#3964)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-09-03 22:10:16 +01:00
6a5dfffe56 💾 [bugfix] fix PPO save_checkpoint (#3998) 2025-09-03 14:51:34 -06:00
18633dbb06 ✖️ Support pad-to-multiple-of and padding-free (#3996) 2025-09-03 08:37:44 -06:00
e4dbf57bf2 Fixed tags shown problem in memory usage docs (#3999) 2025-09-03 08:35:51 -06:00
12fc85fd13 [GRPO] Truncated Importance Sampling to address rollout-training mismatch (#3867)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-09-03 09:50:54 +02:00
fdd6bda111 Add pre-commit and hf-doc-builder as dev dependencies (#3993) 2025-09-03 08:25:14 +02:00
cb84da0ece fix: add return to shift_tokens_right (#3987) 2025-09-02 19:18:17 -06:00
35702ce378 ⚖️ Fix scale_rewards issue in GRPO (#3992)
Co-authored-by: Leon <leon.ericsson@foi.se>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-09-02 18:02:06 -06:00
705306d78b 🎯 Add Trackio integration documentation and update TOC (#3971)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-09-02 13:06:06 -06:00
edbe8234bc [GRPO] Adds an option to sleep vllm when running in colocated mode (#3968)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
2025-09-01 09:59:52 +02:00
4c47b32811 🪃 args.gradient_checkpointing = False instead of args = dataclasses.replace(args, gradient_checkpointing=False) (#3981) 2025-08-30 16:01:33 -07:00
92046bb972 👮 Fix GRPO CLI by setting parameters for get_soft_overlong_punishment (#3972) 2025-08-30 16:00:26 -07:00
39faf36a91 Refactor version retrieval to use importlib.metadata for improved reliability 2025-08-29 20:44:05 +00:00
1cb4150dfb ⬆️ Bump dev version (#3978) 2025-08-29 13:21:55 -07:00
3a6b365c0d Release: v0.22 (#3977) 2025-08-29 13:19:34 -07:00
7ae16d3234 🧱 PyPI publishing workflow (#3976) 2025-08-29 12:52:25 -07:00
ab984fabac Style 2025-08-29 19:50:23 +00:00
419d716a6b Fix CI (#3975) 2025-08-29 12:23:20 -07:00
f538bd3085 📜 GSPO docs - Sequence importance ratio and differences in relation to GRPO (#3816)
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-08-29 12:08:40 -07:00
8aa0eed816 ℹ️ Validate examples on xpu (#3897)
Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
Signed-off-by: YAO Matrix <matrix.yao@intel.com>
2025-08-29 10:56:57 -07:00
e7b37d4e8d 🔥 [Refactor] RLOOTrainer (#3801)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Edward Beeching <edbeeching@users.noreply.github.com>
2025-08-29 09:27:28 -06:00
b7676d1701 Fixed some typos and added small details about trackio to docs (#3965) 2025-08-27 17:57:19 +02:00
515e9eb255 [CI] Modify tests to handle device allocation for models (#3962) 2025-08-27 17:23:37 +02:00
26442abff2 Add HF jobs tag when creating model card via jobs (#3956) 2025-08-27 12:18:05 +02:00
0c91515b58 🧭 HF jobs x TRL guide (#3890)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-08-26 21:44:29 -07:00
4b3517facc 📸 Return position_ids for flash_attention_3 (#3942) 2025-08-26 20:32:17 -07:00
6f5865131b 🦥 Unsloth Docs update (#3955)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-08-26 20:17:21 -07:00
0c7ab76a01 LitePPO: Fix Docs for paper index (#3954) 2025-08-26 20:16:43 -07:00
ffc061b5e5 ✂️ fix: handle list tensors in split_tensor_dict function (#3951) 2025-08-25 09:56:16 -07:00
38fc1f6ecf 🤸 [SFT] Drop entropy calculation when using liger (#3947) 2025-08-25 09:14:39 +02:00
39cc9a826a [GKD] add liger loss (#3946)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-08-24 19:25:25 +02:00
1f15f187c3 [DPO] Adding support for different losses which are now supported by Liger (#3815)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-08-24 18:53:35 +02:00
181a841877 🗂 Update paper_index section (#3937)
Co-authored-by: behroozazarkhalili <ermiaazarkhalili>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-08-22 12:13:22 -07:00
da167d88b2 🎆 Add entropy logging in SFT (#3940) 2025-08-22 10:40:23 -07:00
2324245cad 🏌️ DAPO loss type (#3938) 2025-08-22 10:38:28 -07:00
fe44806b68 🪶 [GRPO] PPO Lite: Scale rewards by Std of Batch (#3935)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-08-21 12:47:07 -07:00
251c0488c8 📦 Wrapping the main execution code to avoid multi-processing issues from vLLM (#3932)
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
2025-08-21 12:45:13 -07:00
e2eaa2334d 🗞 bugfix 'TrainerState' object is not subscriptable (#3936)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-08-21 12:33:23 -07:00
48d7ecc67b 🗑️ Deprecate setup_chat_format (#3929) 2025-08-20 14:06:23 -07:00
215294872e prepare_multimodal_messages fix 2025-08-20 17:25:51 +00:00
MQY
85ead751f5 ♻️ Reuse multimodal message preparation from SFTTrainer in GRPOTrainer (#3919)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-08-20 10:04:54 -07:00
8793a46760 🧾 Use logger.warning instead of warnings.warn (#3923)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-08-20 09:20:09 -07:00
730e19d939 🤹‍♂️ Multi-image testing dataset (#3916) 2025-08-20 08:27:14 -07:00
7233b981ce 🧹 Clean SFT tests (#3922) 2025-08-20 07:36:03 -07:00
18836f078e ✏️ Fix typos (#3921)
Signed-off-by: cyy <cyyever@outlook.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-08-19 10:07:34 -07:00
e575ea3815 📚 Update BEMACallback documentation to ignore docstyle and fix lag parameter description (#3917) 2025-08-18 17:57:45 -07:00
52eaa552aa ➡️ SFTTrainer for VLM: support completion-only loss (#3908) 2025-08-18 17:23:41 -07:00
0227d68e50 🌓 SFTTrainer for VLM: Support for prompt-completion data (#3907) 2025-08-18 16:46:17 -07:00
b08bc7f33e ♻️ use_cache should be set in the forward pass (#3891) 2025-08-18 14:47:33 -07:00
152235a8e5 🗑 Deprecate IterativeSFTTrainer (#3905) 2025-08-18 14:28:04 -07:00
4fcef6c32d 🐯 Support assistant-only training and Liger (#3914) 2025-08-18 14:23:46 -07:00
d15049bf71 🗳️ Extend BCO Trainer dataset format support (#3134) 2025-08-17 00:35:23 -07:00
b9718449a8 🗿 [CPO] Add AlphaPO method via CPOTrainer (#3824)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-08-16 23:26:02 -07:00
0e7c99ab07 Optimize completion_ids list conversion in GRPO trainer (#3874)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-08-16 21:47:13 -07:00
MQY
c99cd2361e 🌳 Enhance segment tree implementation for non-power-of-2 values (#3888)
Co-authored-by: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-08-16 21:39:57 -07:00
68937969b4 Add tests for get_position_ids_from_packed_seq_lengths (#3883) 2025-08-16 21:36:53 -07:00
a6f802f41d ⚔️ Optimize truncate_with_protected_tokens to use vectorized operations (#3875)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-08-16 21:17:54 -07:00
jp
dfb96af810 ☑️ Check eval batch size in grpo (#3889) 2025-08-15 21:41:04 -07:00
485e7d1c74 ✏️ Fix SFTTrainer token accuracy computation with PromptEncoder (#3821) 2025-08-14 20:22:05 -07:00
7ee8f796ff 👔 HF Doc Builder style (#3498) 2025-08-14 18:58:12 -07:00
64b7028fe9 🪄 Improve quickstart documentation with updated API examples (#3873)
Co-authored-by: behroozazarkhalili <ermiaazarkhalili>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-08-14 17:17:16 -07:00
1324448c6f 👁️ VLM blog (#3899) 2025-08-14 17:09:16 -07:00
206964ce16 🎢 [Callbacks] BEMA (#3855)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-08-14 13:54:52 -07:00
39efa8affb 🧩 Fix reward_processing_classes validation in GRPOTrainer (#3876)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-08-13 15:47:45 -07:00
499d9fb32c Minor optimizations in SFT. (#3884) 2025-08-13 14:27:31 -07:00
44e6c153a5 🔮 Native VLM support for SFTTrainer (#3862)
Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-08-12 20:43:00 -07:00
f5b1ed24a0 Replaced unittest.TestCase with TrlTestCase that handles tmp dir (#3863) 2025-08-12 12:37:19 -07:00
7f53ac08f2 🕹️ [GRPO] Fix vllm mode validation in distributed setting (#3886)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-08-12 11:15:31 -07:00
b4c418110c 💇 Add soft overlong punishment reward function and update documentation (#3804) 2025-08-12 10:58:41 -07:00
80b660de76 ⌨️ Add py.typed (#3841)
Signed-off-by: cyy <cyyever@outlook.com>
2025-08-12 10:06:53 -07:00
65d7894b6a Integrate PEFT model preparation across trainers and utilities (#3882) 2025-08-12 10:02:27 -07:00
72d4d82b8c 🎚️ Add dataset mixer (#3791)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-08-11 20:14:50 -07:00
de27d612b0 🦦 Validate vllm_mode param in GRPO (#3866) 2025-08-08 21:00:18 -07:00
a222aeb462 🎀 New defaults: gradient_checkpointing=True (#3510) 2025-08-08 20:59:37 -07:00
cb95323429 👋 Remove --bf16 value in scripts (#3869) 2025-08-07 12:25:36 -07:00
2fb7090231 👁️ From AutoModelForVision2Seq to AutoModelForImageTextToText (#3836) 2025-08-07 08:00:16 -07:00
f23543fc96 [GRPO] 👁️ Fix vLLM server mode for VLM GRPO training incompatibility for certain AutoProcessors (#3832)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
2025-08-07 11:04:02 +02:00
d3f63ca292 Small style fix in README (#3861) 2025-08-07 09:51:30 +02:00
ad0b9dae1e Typo fix in new model description (#3854) 2025-08-06 11:23:01 +02:00
f3289be384 🔗 Fix collection link in doc (#3852) 2025-08-05 15:51:31 -07:00
f9b0947155 ⬆️ Bump dev version (#3850) 2025-08-05 09:52:43 -07:00
46d09bd240 Release: v0.21 (#3849) 2025-08-05 09:50:17 -07:00
17393b8c82 🌺 OpenAI GPT OSS & Harmony support (#3848)
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
2025-08-05 09:44:59 -07:00
21060b25a5 🪦 Remove deprecated (#3817)
Co-authored-by: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com>
2025-08-05 09:14:59 -07:00
5d914a4125 [GRPO]: Fix Entropy Mask Threshold Calculation when using Multi-GPU training (#3833)
Co-authored-by: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com>
2025-08-05 12:27:59 +02:00
67763762bc Add 'Post training a VLM for reasoning with GRPO using TRL' recipe to Community tutorials (#3843) 2025-08-04 18:46:53 +02:00
072d7dd5a6 Improve trainer doc (#3818) 2025-08-01 11:14:16 +02:00
ead5aaf934 Performance optimization: Replace list comprehensions with tensor operations in BCO and KTO trainers (#3813)
Co-authored-by: chiliu <chiliu@paypal.com>
2025-08-01 11:11:20 +02:00
dbbc770f45 fix CI docs and grpo slow test (#3814)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-07-31 14:10:00 +02:00
294e8cb093 Fix citation 2025-07-31 03:10:19 +00:00
79c5797d92 GSPO parameters update from v2 (#3798)
Co-authored-by: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-07-30 20:11:00 -06:00
ab2400029a add xpu support for mergekit (#3800)
Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
2025-07-30 20:07:55 -06:00
3ae60cd1b4 Add GSPO script examples (VLM/LLM) (#3810) 2025-07-30 20:07:23 -06:00
9a1e6a4508 Correction parameter description (#3803)
Co-authored-by: lunzhongwang <lunzhongwang@soulapp.cn>
Co-authored-by: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com>
2025-07-30 21:41:15 +02:00
90c7876da5 Add vLLM transformers backend to online methods (#3773)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: sergiopaniego <sergiopaniegoblanco@gmail.com>
2025-07-30 18:24:50 +02:00
72bbc6dd0d Examples list updated in docs (#3806) 2025-07-30 04:09:29 -06:00
25ce0f31ae 🐙 Add MPO VLM example script (#3799) 2025-07-29 20:52:32 -06:00
9269f9f151 Fix broken PEFT+TRL docs link in using_llama_models.md (#3794)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-07-29 20:24:11 +02:00
eb5d0fe484 ⬆️ Bump dev version (#3793) 2025-07-28 22:11:46 -06:00
30576d2ddc Release: v0.20 (#3792) 2025-07-28 22:08:54 -06:00
5522cc0a3f 👐 FSDP2+GRPO (#3687)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-07-28 22:01:08 -06:00
303d3b1d63 📘 SFT doc rewrite (#3619)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-07-28 17:06:45 -06:00
3d765b0702 🔍 Add guidance on choosing max_length value and include visualization tool (#3630) 2025-07-28 16:29:35 -06:00
fcd3e0fd15 🌋 [GRPO] add support for pixel_attention_mask (SmolVLM2) and image_sizes (LLaVa-Next) (#3760)
Co-authored-by: sergiopaniego <sergiopaniego@users.noreply.huggingface.co>
Co-authored-by: sergiopaniego <sergiopaniegoblanco@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-07-28 16:28:29 -06:00
8a23c866f8 💬 Fix clone_chat_template vocab size and support PEFT instruction tuning (#3763)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-07-28 11:47:17 -06:00
5bb3ca4b21 📍 Support training peft model with gradient checkpointing (#3785)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-07-28 11:27:57 -06:00
fd70021cd7 📐 Add epsilon hyperparameter recommendation to GSPO (#3790) 2025-07-28 09:34:45 -06:00
a902450e85 🤏 [SFT] Improve doc on training on assistant only messages (#3784) 2025-07-27 22:00:53 -06:00
03034317d0 🎞️ GSPO (#3775) 2025-07-27 06:14:29 -06:00
23ea671c5e 🍿 [SFT] Fix dataset indexing which crashed with a IterableDataset (#3771)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-07-26 16:42:07 -06:00
fc08f55518 🩹 [Hotfix] Fix pynccl communicator assertion error with VLLMClient (#3774)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-07-26 16:33:18 -06:00
2f4cb38f28 📐 Fix CI and GeometricMixtureWrapper (#3779) 2025-07-26 16:15:08 -06:00
eee9ec94ef Update missing uv dep (#3772) 2025-07-25 08:00:03 -07:00
a043fd74a3 Add uv scripts headers (#3767) 2025-07-25 07:48:40 -07:00
d16b960dfa 🤓 [GRPO] Documentation for entropy metric (#3770) 2025-07-25 07:26:10 -06:00
daad892730 🌌 [GRPO] Log generation entropy (#3700)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-07-24 23:55:23 -06:00
097d6153a2 🔠 Support model str in OnlineDPO (#3765)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-07-24 23:29:54 -06:00
bc3eebb73e 🔔 Add deprecation warnings for AlignPropTrainer and DDPOTrainer (#3755)
Co-authored-by: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com>
2025-07-24 23:27:41 -06:00
1fb115daff Prevent NCCL Device Conflicts Between vLLM Server and Trainers (#3762)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-07-24 23:16:15 -06:00
3a40f18192 Add MPO recipe to Community tutorials (#3766) 2025-07-24 09:16:35 -07:00
56f4201db6 👁️ [GRPO] Add VLM training capabilities to the trainer (#3072) 2025-07-22 20:31:08 -07:00
a50bdc6388 👨‍💼 [SFT] Packing with completion_only and assistant_only training (#3749)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-07-21 21:49:10 -07:00
e102ac8df1 ⚰️ Remove deprecated (#3704) 2025-07-21 18:16:29 -07:00
d870230218 🐙 MPO (#2544)
Co-authored-by: ariG23498 <aritra.born2fly@gmail.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: sergiopaniego <sergiopaniegoblanco@gmail.com>
2025-07-21 11:13:05 -07:00
68ce3a3f07 Add Object detection grounding recipe to Community tutorials (#3752) 2025-07-21 11:02:48 +02:00
5787f3bf63 [GRPO] Fix: Processing ref logprobs in batches (#3740)
Co-authored-by: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com>
2025-07-20 16:17:02 +02:00
116ec493fa 🏗️ Refactor top-entropy in GRPO (#3727) 2025-07-19 13:48:57 -07:00
1b17fa78ae uses steps_per_generation in vllm max_num_seqs (#3747) 2025-07-19 09:58:14 -07:00
c389599057 Add comment for average_tokens_across_devices (#3746) 2025-07-19 07:35:32 -07:00
e333da8cf0 Updated missing processing_class docs for rest of trainers (#3745) 2025-07-18 19:51:07 +02:00
c8347b4287 Updated processing_class docs for trainers (#3737) 2025-07-16 07:26:32 -07:00
8684cb4666 🕸 Use wandb.run.url instead of wandb.run.get_url() (deprecated) (#3726) 2025-07-15 18:44:18 -07:00
508d551db1 🔧 Fix GRPO sampling logic (#3725) 2025-07-15 13:39:09 -07:00
569d60e999 [GRPO] remove common activation offloading substring in all cases (#3738) 2025-07-15 13:33:48 -07:00
640a9f3916 📥 Set environment variables for vLLM distributed training in GRPOTrainer (#3723) 2025-07-11 20:15:22 -07:00
5a2b04a699 ↔️ Fix CB in GRPO (#3722)
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
2025-07-11 18:21:24 -07:00
dffd1acb94 👋 Remove --bf16 flag from training scripts (#3724)
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
2025-07-11 18:20:15 -07:00
43e6b24e70 Remove deprecated processor.tokenizer (#3720) 2025-07-11 15:46:34 -06:00
2ae43f80d9 [Online DPO] Safeguard logit slice against empty prompt (#3719) 2025-07-11 12:40:17 +02:00
c949b66f01 Fix ORPOTrainer loss scaling with gradient accumulation (#3716) 2025-07-11 00:37:00 +02:00
97085539a3 BUG: Disregard pad token entropies for entropy threshold calculation (#3715) 2025-07-10 16:06:26 +02:00
68ed863eed ⚗️ Tiny MoE for test (#3712) 2025-07-09 08:25:47 -07:00
0462dd7f12 [SFT] Add seq_lengths to signature columns (#3699)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-07-08 19:20:13 +02:00
68db24e010 🔭 Fix package discovery configuration in setup.cfg (#3703) 2025-07-07 19:50:56 -07:00
2d086f26a5 📣 Use explicit version for checking datasets version (#3702) 2025-07-07 11:35:57 -07:00
b674989f15 ✂️ [BUG when vllm and prompt_truncation are used]: Strip out pad tokens in truncated prompt text (#3698)
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-07-07 11:29:34 -07:00
0353d67661 Fix mislabeling: "First-fit decreasing" is actually "Best-fit-decreasing" (#3696)
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
2025-07-07 19:47:18 +02:00
d98d53983b Add type hints to dpo_trainer.py (#3631)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com>
2025-07-06 10:33:36 +02:00
c30344e9ee Restore the effect of liger_kernel's monkey_patch on global modules in UT. (#3680)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-07-06 09:40:44 +02:00
db19d79e30 [CI] Fix slow grpo CI (#3693) 2025-07-04 19:46:21 +02:00
e8abe03a06 [fix] type error of quantile (#3667) 2025-07-04 17:30:26 +02:00
7eb52c1b4e fix: support dict access in SFT Trainer (#3677)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-07-04 11:27:46 +02:00
686cd35a72 Fix non-serializable torch.dtype bug in VLLM weight sync (#3690)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-07-03 21:25:29 +02:00
601a25693e Update steps_per_generation default description grpo_config.py (#3685) 2025-07-03 20:47:05 +02:00
d42188b17f Support datasets 4 (#3688)
Co-authored-by: Quentin Lhoest <quentinlhoest@Quentin-Ls-MacBook-Pro.local>
2025-07-03 11:45:37 -06:00
4ccc5ca7bd Faster position_ids computation for FFD packing (#3649)
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
2025-07-03 13:43:22 +02:00
d1e116c67d [SFT] drop attention_mask if we have position ids for fa2 (#3673)
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
2025-07-03 09:18:41 +02:00
90cdf96418 🖼️ Add mlflow support for generate_during_eval DPOTrainer (#3660)
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
2025-07-02 14:42:11 -06:00
b520378b97 Enable completion-only loss in SFTTrainer when using Liger Kernel (#3674)
Co-authored-by: kwhitecross <kwhitecross@cs.umass.edu>
Co-authored-by: shirinyamani <75791599+shirinyamani@users.noreply.github.com>
2025-07-02 12:12:14 -06:00
e04f7eb3b9 feat: Pass trainer state to reward functions (#3669)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-07-01 14:16:26 +02:00
02cce41d06 Add support for CB with native transformers (#3471)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-07-01 12:26:09 +02:00
6a6d4345c9 Add paranthesis to correct the check. (#3658) 2025-06-28 07:19:01 +02:00
79ec242aef [GRPO] Make sure special tokens aren't lost when truncating prompt. (#3651) 2025-06-26 09:29:20 +02:00
7e8ef867ae Add entropy based filtering inside the GRPOTrainer. (#3563)
Co-authored-by: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-06-25 22:38:41 +02:00
32df09358e 🤝 validate gradient_accumulation_steps vs steps_per_generation for on-policy GRPO (#3493)
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
2025-06-25 18:03:22 +02:00
0336e4bcbb ️ GRPO script reward_funcs error (#3639)
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
2025-06-25 16:47:08 +02:00
ab331bfd56 Update dpo_vlm.py (#3629)
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
2025-06-24 13:56:34 +02:00
84d7b5bbfa env var for vllm colocate exp added (#3638) 2025-06-24 13:44:19 +02:00
b40c959c00 fixing num_processes (#3637) 2025-06-24 13:42:58 +02:00
34fa6b9af2 🐛 fix grpo generation_kwargs (#3634)
Signed-off-by: ahatamizadeh <ahatamizadeh@nvidia.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-06-24 11:43:45 +02:00
eef7a43427 Revert "🔍 Add guidance on choosing max_length value and include visualization tool"
This reverts commit 89c699f59839bb1e2917c2da770015320d087a88.
2025-06-22 23:08:26 +02:00
89c699f598 🔍 Add guidance on choosing max_length value and include visualization tool 2025-06-22 23:06:36 +02:00
559a99f053 ⬆️ Bump dev version (#3626) 2025-06-20 19:02:19 +02:00
5b3ea9dd43 Release: v0.19 (#3625) 2025-06-20 18:43:31 +02:00
c262674ea7 🧰 [SFT] Tool support (#3597)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-06-20 17:39:24 +02:00
5c3dd3ab24 🔍 Add test to verify chat template consistency (#3624) 2025-06-20 17:16:52 +02:00
4c92de0000 ⚔️ Fix bf16 fp16 config conflict issue (#3598)
Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-06-20 15:00:39 +02:00
67f17f7ea4 📜 Add chat_template_path parameter to SFTConfig (#3599) 2025-06-20 14:15:03 +02:00
37a71e82bf 🧬 Add generation_kwargs as a property of GRPOConfig to support additional generation arguments. (#3617)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-06-20 14:14:48 +02:00
b0958c6f8f [GRPO] Fix prompt truncation (max_prompt_length) with vLLM. (#3601)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-06-20 12:12:33 +02:00
8bad863ffa Add vllm_gpu_memory_utilization recommendation script (#3554)
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-06-19 23:17:47 +02:00
d00441505d 🎁 Put the reward computation in a separate function (#3620)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-06-19 22:59:44 +02:00
9554c2f319 🤵‍♂️ SFT on assistant messages only (#3586)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-06-19 22:59:26 +02:00
712afd5dd1 🦘 Skip no-op ChatML conversion for datasets already in ChatML format (#3594) 2025-06-19 22:37:58 +02:00
086e9d56e3 📚 SFTTrainer support chat template kwargs (#3609) 2025-06-19 22:12:30 +02:00
5206c927f6 🔖 Fix: ensure user-provided labels are retained in self._signature_columns (#3589)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-06-19 16:03:58 +02:00
e4b586a389 👔 Apply doc-builder style (#3615)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-06-19 12:02:51 +02:00
0576346758 🏛️ Fix CI and Iterative SFT (#3614) 2025-06-19 11:33:20 +02:00
e63588a56a 🏁 Refactor reference model initialization in GRPOTrainer (#3575)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
2025-06-18 16:20:36 +02:00
d9d25a71b2 [SFT] Clarify default collator docs (#3606) 2025-06-18 14:43:09 +02:00
58ea227d4c Change enforce_eager default value in vLLM server. (#3607) 2025-06-18 14:42:53 +02:00
a768484d47 Fix Typos in Comments and Improve Clarity in Trainer Modules (#3596) 2025-06-18 14:42:42 +02:00
d17ec7ad72 Fix: list-typed tags handling in Trainer::create_model_card (#3613) 2025-06-18 14:32:36 +02:00
ed9b78a5f7 🗳️ Remove logging_steps parameter from for simpler setup (#3612) 2025-06-18 13:52:21 +02:00
d6a969ff7d ♻️ Avoids redundant calculation of ref logps in the new policy update loop (#3600)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-06-18 11:56:45 +02:00
FT
8a235a9b71 Fix Typo in Documentation and Notebook; Improve Library Installation Comment (#3593) 2025-06-15 16:46:41 +02:00
afa06c3b56 Fix typos and improve metric descriptions in documentation (#3585) 2025-06-15 16:00:38 +02:00
77ec43ce31 🛡️ Adding trust_remote_code to vllm-serve (#3588)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-06-15 16:00:07 +02:00
4126803875 💬 Fix setup_chat_format and add clone_chat_template (#3404)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Edward Beeching <edbeeching@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-06-15 15:59:42 +02:00
91b3f5ee9a 💡 Fix wrong type hint for formatting_func argument in SFTTrainer (#3584)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-06-15 15:38:12 +02:00
b6e255a9d3 💡 Fix type hints in trainer/utils.py (#3591) 2025-06-15 12:43:54 +02:00
0d54f05fa3 Adjust max_num_batched_tokens (#3565)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-06-13 16:08:07 +02:00
72c91e77f5 📨 [SFT] Tokenize directly when applying the chat template (#3572) 2025-06-13 16:03:55 +02:00
32ffa1170e 🎀 New defaults: bf16=True (#3515) 2025-06-13 13:40:12 +02:00
fd4c9e3b72 Add Community Tutorial: GRPO text summarization example with Unsloth optimizations (#3576) 2025-06-13 13:08:10 +02:00
c5e64b479b 🫸 Push model card with checkpoint (#3550) 2025-06-13 11:18:02 +02:00
15ff54790b 🏗️ Add test for training with multiple dataloader workers and update worker initialization for compatibility with transformers 4.52.0 (#3568) 2025-06-12 19:13:19 +02:00
3d077fd3de Add support for IterableDataset in DPO Trainer (#3559)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-06-12 13:06:34 +02:00
53c4a7c2b8 [Liger] liger DPO support (#2568)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Vaibhav Jindal <32337828+vaibhavjindal@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-06-12 12:25:12 +02:00
aff16a5b2f Fix dev version (#3570) 2025-06-12 10:06:20 +02:00
1314aac502 ℹ️ Unify autocast behavior to torch.autocast and make it cover XPU (#3541)
Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-06-10 09:13:00 +02:00
e99a8aec4b Update tests_latest.yml (#3558) 2025-06-09 21:15:17 -07:00
b9572737b4 🆙 Bump transformers to 4.51 and use _VALID_DICT_FIELDS (#3553) 2025-06-09 21:50:57 +02:00
4cafb2744a 🧮 Rearrange DPOTrainer (#3501)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-06-09 19:44:24 +02:00
c49c7b7d4e 🛋️ Fix CI and bump accelerate (#3551) 2025-06-09 14:56:20 +02:00
b773a4c191 💽 [TRLParser] Fail when unknown args are provided in the config file. (#3543)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-06-05 21:43:21 -07:00
7c8355d038 📦 Packing with flash attn kwargs to avoid cross-contamination (#3526)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-06-05 21:18:46 -07:00
50a2fa8ec8 Faster FFD packing (#3537)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-06-04 14:37:28 -07:00
0333108854 🎀 [SFT][Bugfix] sets average_tokens_across_devices to true in SFTConfig (#3538)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-06-04 14:20:57 -07:00
6ffde23a45 💭 [Data] Fix DeepSeek-R1 case (#3522)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-06-04 11:48:16 -07:00
6f288c2d9d 🐳 Add DeepseekV3 model configurations and update tests for new models (#3536) 2025-06-04 09:34:28 -07:00
8cf6220cef 🧭 Remove useless transformers version checks (#3534) 2025-06-04 09:03:38 -07:00
da7b3fe745 🎯 Don't use getattr to get gradient_checkpointing (#3535) 2025-06-04 09:03:24 -07:00
24ef9eb8e7 📰 Add blog "No GPU left behind: Unlocking Efficiency with Co-located vLLM in TRL" (#3527) 2025-06-03 13:22:50 -07:00
b0eff324aa 🎀 New defaults: logging_steps=10 (#3514) 2025-06-03 11:45:08 -07:00
026fc9439c 🪦 RIP trl chat (#3531) 2025-06-03 12:19:03 -06:00
a912ad1bcf 🎀 New defaults: preparing the new structure (#3530) 2025-06-03 10:48:26 -07:00
fef915e36f 📉 FFD packing (#3521) 2025-06-02 13:15:22 -07:00
0db63f0f50 Add "🐯 Liger GRPO meets TRL" (#3525) 2025-06-02 11:32:31 -07:00
7359ddcc6f 🎀 New default: beta=0.0 for GRPO (#3516) 2025-05-30 09:51:07 -07:00
0844936930 🧭 Patch release guide (#3512) 2025-05-30 09:50:31 -07:00
897c87fa91 📚 Fix doc building by removing vLLM from dev dependencies in setup.cfg (#3511) 2025-05-29 11:39:40 -07:00
c13de6f9c0 📎 Fix clip ratio logging (#3506) 2025-05-28 08:46:35 -07:00
722847abbc ⬆️ Bump dev version (#3505) 2025-05-27 19:03:59 -07:00
ef4b0b225c Release: v0.18 (#3504) 2025-05-27 18:43:58 -07:00
8e8e62b380 ✂️ [DPO] Fix truncation keep_end leading to zero'd out samples (#3398)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-05-27 16:36:01 -07:00
824100ce25 🏰 [vllm] Support base_url parameter for vLLM client initialization (#3324)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-27 16:05:40 -07:00
4e7f0a5eb9 🤧 LD-DPO support (#3458)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-27 16:05:30 -07:00
17a9069710 📏 Completion length logging fix + remainder logging fix (#3482)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-27 14:31:03 -07:00
cb07c44920 Forgotten commit from #3502 2025-05-27 20:02:22 +00:00
0b6a1874f1 🔭 [GRPO] Log advantages and fraction of samples with an std of zero (#3502)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-27 12:58:41 -07:00
ac18c9d532 🐌 Clean two-sided clipping (#3499) 2025-05-27 09:39:37 -07:00
d1174adc5b 🛠️ Initialize reward_kwargs to prevent UnboundLocalError in GRPOTrainer (#3459)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-26 18:28:27 -07:00
cd838417e4 👇 Update grpo.py to fix bugs for cli grpo --reward_funcs my_lib.my_reward (#3454)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-26 17:59:57 -07:00
c7e3f096a5 [GKD] fix the gkd script (#3497) 2025-05-26 20:22:15 +02:00
5c08897570 [GRPO] disabling top_k sampling default (#3494)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-05-26 11:32:07 +02:00
3ef9faf257 [Docs] sync logging doc to current metrics (#3478) 2025-05-25 17:46:28 +02:00
9ac614fb08 Fix mis-aligned prompts and completions in colocate mode (#3491)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-05-24 16:50:45 -06:00
29401e790e [Doc][SFT] Update sft_trainer.md. link prompt-completion dataset example (#3486)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-05-24 19:13:00 +02:00
31bf3f9244 Fix typo (#3489) 2025-05-24 13:24:15 +02:00
7f32792c07 [CI] fix sampler api to make the CI green (#3488) 2025-05-23 17:32:23 +02:00
3d8727918a [SFT] update minimal liger version (#3483) 2025-05-23 13:44:20 +02:00
65245f6be8 Update .pre-commit-config.yaml (#3479) 2025-05-22 16:08:23 +02:00
a528b9c465 [NashMD] fix the edge case where the model is a peft model (#3473)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-05-20 17:02:04 +02:00
e0dd525021 🙅 PPO value_model can't be None, so it shouldn't be Optional (#3300) 2025-05-19 17:01:08 -07:00
64aa06499b enable activation offloading on XPU (#3444)
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-05-19 11:56:14 +02:00
be93a0c30c enable vllm c-s tests on XPU (#3445)
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-05-19 11:55:57 +02:00
f9fbd91ea9 [CI] fix CI failure of transformer dev (#3457) 2025-05-19 10:08:42 +02:00
54d4f6b13a 🎁 Reward submodule (#3430) 2025-05-15 19:10:22 -07:00
05bc43e960 feat: Implement Two-Sided Clipping for GRPO Trainer (#3434)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-05-13 20:36:39 +02:00
d3dc8ff654 use device agnostic empty_cache in ppo & rloo (#3439)
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-05-13 20:10:14 +02:00
21738c3732 enable trl env on xpu (#3438)
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-05-13 11:36:01 +02:00
eab175d434 🏹 Support kv_cache_dtype to quantize kv-cache in vllm (#3422)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-05-08 17:11:16 -07:00
4da4dc9117 Update README.md 2025-05-07 20:49:35 -07:00
6b3a02385d Update README.md (#3420) 2025-05-07 20:48:22 -07:00
abbbb93d6a 🧪 Testing support for Qwen3 tiny (#3415) 2025-05-07 19:32:42 -07:00
cafa663c84 [Models] Activation checkpointing from TorchTune (#2954)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: DanFosing <danfoss12340@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Robert <robert.veres00@gmail.com>
Co-authored-by: Robert Veres <robert.veres@languagetool.org>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Mathew Shen <datahonor@gmail.com>
Co-authored-by: Ishan Kumar <ishankumar216@gmail.com>
Co-authored-by: Huazhong Ji <hzji210@gmail.com>
Co-authored-by: tpoisonooo <khj.application@aliyun.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-05-07 12:36:11 +02:00
fd04a5461a 🐍 Support Python 3.13 (#2593) 2025-05-06 21:38:23 -07:00
56e5766205 🎁 Reward takes completion ids (#3272)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-05-06 10:34:50 -07:00
89d44caece 📝 vLLM-integration documentation (#3376)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-06 09:37:02 -06:00
adfa7fd59a 🎲 [GRPO] Shuffle mini batches (#3391)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-06 11:09:00 +02:00
cf5183db7f 💔 [GRPO] Decouple gradient accumulation from the number of minibatches generated (#3388)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-05-06 09:59:32 +02:00
1954c02d86 🤝 Compatibility of the TRL CLI with accelerate arguments (#3409)
Co-authored-by: Lewis Tunstall <lewis.c.tunstall@gmail.com>
2025-05-06 00:09:23 -07:00
45f4c58832 ✌️ Add support for FSDP2 (#3317)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-05-06 08:29:11 +02:00
cc044e35b2 🕊️ Un-restrict diffusers (#3407) 2025-05-02 15:06:53 -07:00
999acd53ec 🕺 Migrate setup configuration from setup.py to setup.cfg and make rich an optional dep (#3403)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-05-02 11:03:57 -07:00
8606b1ad09 🪪 Remove license classifier (#3402) 2025-05-02 10:03:39 -07:00
a673da5773 👉 [DPO] Model forward pass padding side fix (#3307)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-05-01 20:37:55 -07:00
00b8e311aa 🦁 Fix liger initialization (#3401)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-05-01 20:36:46 -07:00
c163cf5081 💔 [SFT] Raise error when formatting_func is used with completion_only_loss (#3385)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-01 16:23:27 -07:00
bc9c019c43 [IterativeSFT] Small refresher (#3378) 2025-05-01 16:18:41 -07:00
18596cf232 🧑‍🤝‍🧑 Co-Locating vLLM w/ training to for higher throughput and GPU utilization (#3394)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-01 16:17:26 -07:00
280d35301b 🌊 Add MLflow metrics in profiling context (#3400)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-01 16:15:38 -07:00
13fa8402a3 [GRPO] Reference model initialization bug fix (#3397) 2025-05-01 17:31:21 +02:00
09b669fbf7 [🐯+GRPO] Support FSDP + Fix bug when using LigerGRPO with DDP (#3260)
Co-authored-by: Ubuntu <azureuser@liger-ci-h100-vm.kvghai4yzzmufguwws3040dwlf.dx.internal.cloudapp.net>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-04-30 22:49:45 +02:00
01d0be15cb Deprecate TextEnvironment and tools (#3389) 2025-04-29 20:25:36 +02:00
3a42af1c78 DPO fixes for evaluations (#3377) 2025-04-29 17:16:30 +02:00
aaf39604ba PEFT support for Liger GRPO (#3355)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-04-29 17:05:35 +02:00
2bf48478e8 📋 Allow calling trl cli in sft mode with config file (#3380)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-04-28 14:23:42 -07:00
a8cfca6d01 ⚰️ Remove deprecated (#3364) 2025-04-26 11:11:35 -07:00
1bca49515e Better guards for DeepSpeed imports (#3351) 2025-04-26 10:18:11 +02:00
39e96394a9 🎭 Fix train and eval mode checking in GRPOTrainer and SFTTrainer (#3337)
Co-authored-by: Jiaming Ma <jiaming.ma@connect.polyu.hk>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-04-25 17:42:43 -07:00
8e6ed93dfd 🥸🔢 Adding pad_multiple to SFT trainer (#3365) 2025-04-25 18:12:35 -06:00
29c5e05e3a 🔢 Pad to multiple of (#3362) 2025-04-25 09:53:20 -07:00
a9b27f82d6 ⬆️ Bump dev version (#3357) 2025-04-24 16:22:12 -07:00
cd6b3de356 Release: v0.17 (#3356) 2025-04-24 16:15:45 -07:00
36685c8bba Up to 4x faster: Data Parallel for vLLM server (#3310)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
2025-04-24 15:14:16 -07:00
89556c8cbf 🍡 Fix using reward model and DeepSpeed ZeRO 3 (#3326) 2025-04-23 15:09:33 -07:00
f3e8c23044 Define default chat template for SFT (#3309)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-23 15:49:42 +02:00
9ee6c3aa56 🏁 Fix adding special tokens in SFT (#3328) 2025-04-22 17:51:51 -07:00
ef05331752 [CPO] Check that max_prompt_length < max_length (#3341) 2025-04-22 15:45:15 -07:00
05e2ba6e01 🦄 Add optional uvicorn log level for vLLM serve (#3338)
Co-authored-by: Jiaming Ma <jiaming.ma@connect.polyu.hk>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-22 11:45:13 -07:00
1b4f189e09 💡 Fix type hint in _generate_and_score_completions (#3336) 2025-04-22 08:57:29 -07:00
1faa7f9b36 🧸 Fix unset tokenizer pad_token (#3290)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-04-21 17:20:09 -07:00
66e6eab9bb [doc] Update sft_trainer.md in table x->✓ (#3313)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-21 17:05:20 -07:00
27af0aaf4a Fix typo in text_environments.md (#3305) 2025-04-21 16:39:55 -07:00
b4ffda769e 🙋 Add Optional Eager Execution Mode for vLLM Serving (#3335)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-04-21 15:33:59 -07:00
0dad4eb7ca 🎲 [GRPO] Make training dataset shuffle optional (#3334)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-04-21 14:34:31 -07:00
c82f626f94 Empty commit to test new protection rules 2025-04-20 23:07:28 +00:00
33add19161 Empty commit to trigger CI 2025-04-20 23:00:31 +00:00
294f35bf3c ☝️ [GRPO] Generate once per effective batch (#3283)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-04-17 16:35:58 -07:00
9874b3aa04 [GRPO] Add metrics for low and high clipped token probabilities (#3289)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-16 14:43:34 +02:00
1e61f6cc5a 🅾️ Fixes typo in SFTTrainer (#3282) 2025-04-15 15:23:40 -07:00
27adc30162 🧗 Add Ascend NPU support for vLLM server (#3286)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-15 15:22:46 -07:00
df737f99c1 🏷️ Fixed naming error in output_dir for Gemma 3 VLM script (#3297) 2025-04-15 14:51:26 -07:00
c04e84c454 Expose EOS token in SFTConfig (#3299)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-15 21:53:28 +02:00
d625c5533a ⏱️ Fix vLLM server to support V1 Engine (#3276) 2025-04-10 18:29:50 -07:00
6cdd24a360 🦾 Test vLLM client-server (#3277) 2025-04-10 18:29:04 -07:00
8b38570258 🕊️ Un-restrict diffusers (#3274) 2025-04-10 07:24:11 -07:00
95b1a9f612 Add Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset) guide to docs (#3235)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-04-10 09:33:41 +02:00
5c1511423b 🔗 Fix Dr. GRPO paper link (#3275) 2025-04-09 19:31:15 -07:00
5e2e9cb442 🩺 Dr. GRPO loss (#3256)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-04-09 11:13:22 -07:00
227df8271e ♾️ [CI] Remove test_raise_error_not_causallm (#3265) 2025-04-09 10:39:36 -07:00
ae1581474e 🚧 Temporarily restrict diffusers to <0.33.0 due to ftfy optional dep issue breaking doc builds (#3273) 2025-04-09 10:20:43 -07:00
47b9515fb1 👎 [GRPO] Adds option to disable dropout (#3234)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-04-09 09:59:06 -07:00
c4891dcfee 🕷 Fix online DPO crash when model is a DataParallel object (#3225)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-09 09:29:13 -07:00
055cee255a Revert "reward takes completion ids"
This reverts commit 73a2fb05545db3c2e92f9311473738278b0d9cd0.
2025-04-09 14:41:55 +00:00
73a2fb0554 reward takes completion ids 2025-04-09 14:40:42 +00:00
982ba08092 🐯 is_liger_kernel_available with min version (#3266) 2025-04-09 06:43:59 -07:00
e03e7acc5c ⛏️ Add cli dict parsing for grpo_config (#3082) 2025-04-08 15:55:33 -07:00
9df19e8a75 📜 Fix license and copyrights (#3264) 2025-04-08 15:22:58 -07:00
1d7b8c4f70 Overlong-filtering for GRPO (#3248)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-08 12:52:52 -06:00
7e170612a4 💠 Fix multi-gpu padding free (#3245) 2025-04-08 11:43:56 -07:00
559724ee2c 📦 [SFT] Deprecate batched formatting_func (#3147)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-04-08 09:42:17 -07:00
a5a46725c8 🗑️ Deprecate ConstantLengthDataset (#3242) 2025-04-08 08:03:57 -07:00
b6bcafb8bb 🏃 Fix and make CI faster (#3160) 2025-04-08 06:12:08 -07:00
4bfb8eb0d1 🔭 Add support for better KL estimator (k3) in PPOTrainer (#3240)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-05 22:33:28 -07:00
4d66bad208 ☑ Update PULL_REQUEST_TEMPLATE.md (#3241) 2025-04-05 16:28:19 -07:00
e90117b3e1 PPOTrainer: fix progress bar for num_mini_batches > 1 (#2531)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-05 15:47:28 -07:00
31b54a6237 🌊 Add error for iterable datasets in GRPOTrainer (#3216) 2025-04-05 15:41:53 -07:00
17e33cdaa0 🎀 Simplify logging text (#3219)
Co-authored-by: Lewis Tunstall <lewis.c.tunstall@gmail.com>
2025-04-05 15:38:32 -07:00
5a0cebc786 📢 Improve GRPO trainer error message for invalid num_generations (#3199)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-04 21:52:00 -07:00
65308cfd84 ⏯️ Fix logging when resuming from checkpoint GRPO (#3185) 2025-04-04 21:51:36 -07:00
1755e03f6f Update ruff to 11.3 and base Python version to 3.9 (#3230)
Signed-off-by: cyy <cyyever@outlook.com>
2025-04-04 13:50:14 +02:00
793735a698 🐯 Integrate Liger GRPO Loss to GRPO Trainer (#3184)
Co-authored-by: Ubuntu <azureuser@liger-ci-h100-vm.kvghai4yzzmufguwws3040dwlf.dx.internal.cloudapp.net>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-03 19:17:00 +02:00
e70a0efeca Group completion metrics by common prefix (#3212)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-03 08:11:35 +02:00
7eaca76ed1 📚 Accumulate completions for logging (#3217) 2025-04-02 17:00:43 -07:00
657f9ce6ee 🗝️ Fix type hint in vLLM client (#3205) 2025-04-02 09:40:21 -07:00
485852c942 😷 Fix SFT masking EOS when equal to PAD (#3200) 2025-04-02 08:56:05 -07:00
9f3702f6be [GRPO] Improve completion length logging (#3188) 2025-04-01 10:00:40 +02:00
e751a16df5 🐗 [CI] Fix trufflehog false positives (#3192) 2025-03-31 11:01:55 -07:00
582bc5684b Show unique prompts in GRPO WandB tables (#3191) 2025-03-31 18:50:21 +02:00
c5ba70d4fc Fix breaking typo for flash_attention reducing_memory_usage.md (#3190) 2025-03-31 12:17:10 +02:00
5b586da3cc 📎 Fix is_clipped to compute the effective clip_ratio (#3175)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-03-30 22:24:14 -07:00
488025cd87 ⏯️ Fix: handle None inputs when resuming GRPO Trainer from checkpoint (#3148)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-30 21:25:53 -07:00
2594cb39de ❤️‍🩹 [CI] fix transformers dev CI failure (#3176)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-29 18:39:40 -07:00
2fe2337067 🏃 Migrate CI to self-hosted runners (#3174) 2025-03-29 11:56:44 -07:00
f6b4d6e569 [Liger] Liger KTO support (#2812)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-03-28 20:56:59 +01:00
26d86757a7 💎 Gemma 3 VLM SFT example script for single-image and multi-image (#3131)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-03-26 08:16:02 -07:00
9771f259ed 💰 Richer rich table - log all the rewards (#3156) 2025-03-26 07:45:51 -07:00
7bdedd4075 👨‍🍳 vLLM serve: destroy process group on exit and pass worker_cls as string (#3159) 2025-03-26 07:00:57 -07:00
a069a2f19c 🔫 Disable triggering CI when PR is draft (#3154) 2025-03-25 10:59:01 -07:00
ea45f513f3 ⚰️ Remove deprecated (#3153) 2025-03-25 09:57:50 -07:00
a91023990a 🩹 Fix CI (#3155) 2025-03-25 09:16:23 -07:00
1a9387b922 Enable number of printed completions to be set (#3149)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-03-25 08:47:13 +01:00
1884ff1bb8 🤝 Align GRPO equation doc with the implementation (#3151) 2025-03-24 11:37:06 -07:00
bfe2075608 🐇 [Research] Layer Skip SFT (#3111)
Co-authored-by: Mostafa Elhoushi <m.elhoushi@ieee.org>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-03-24 11:02:00 -07:00
6067e2a669 BCOTrainer version upgrade fixes (#2867)
Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
2025-03-24 10:55:00 +01:00
dee37342a8 📊 Fix clip_ratio logging and better document logged values (#3145) 2025-03-23 16:05:42 -07:00
8037f18cdf Fix: Multi gpu hang for ORPO and CPO Trainer (#3069) 2025-03-23 16:25:15 +01:00
a0a53171cc ⬆️ Bump dev version 2025-03-22 21:14:59 +00:00
23a635ed61 Release: v0.16 (#3137) 2025-03-22 14:03:54 -07:00
9b38b0b5ee ⚖️ Add option not to scale rewards (Dr. GRPO) (#3135) 2025-03-22 13:47:52 -07:00
0f26049ea2 ☎️ Documentation for disable gathering of model weights for generation in DeepSpeed ZeRO-3 (#3136) 2025-03-22 13:29:47 -07:00
7511aa4e36 Pack 300 times faster, truncate 100 times faster (#3009)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-22 12:33:31 -07:00
f713f614e9 🚀 Scaling GRPO to 70B+ Models and Multi-Node Training with vLLM Server & NCCL Communication (#3094)
* 🚀allow GRPO to connect to VLLM in remote/local node with NCCL communication

* Update trl/extras/remote_vllm_helper.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* use argparse for options

* add  imports for remote vllm helper

* formatting

* fix arguments

* use cli options

* vllm serve

* clean server

* better naming

* client

* style

* new params in generate

* this method is the new default

* update config

* do not use asserts

* update config

* separate host and post

* proper deprectation

* deprecated arg in the vllm server

* simplify moving

* document host and port

* style

* update trainer

* new generate args

* update doc

* Fix for zero3

* Better naming

* Remove remote_vllm_helper

* remove grpo_with_remote_vllm

* remove cloudpickle from deps

* Some consistency

* Update docs/source/grpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update setup.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* add revision argument to vllm server

* Update docs/source/grpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/grpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Reset the prefix cache after updating weights

* Update vllm_client.py

* Update vllm_client.py

* Update vllm_serve.py

* Add health check endpoint to vLLM server

* connection timeout

* style

* fix doc langauge hint

* move reset_prefix_cache to its own endpoint

* async

* merge peft adaptor to send to vllm

* Looks simple. Wasn't.

* Peft compatibility

* Update docs/source/speeding_up_training.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/speeding_up_training.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/extras/vllm_client.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* GatheredParameters can be disabled

* gather and ungather peft weights within the same deepseed context

* use is_vllm_available

* minor consistency fixes

* fix error when deepspeed is not installed

* fix deepspeed import when not peft

* simpler

* multinode doc

* minor code and comments changes

* style

* optional deps

* vllm_server_timeout as arg

* small refinement in doc

* update deps

* Fix VLLMClient argument in grpo_trainer; Add zero3+peft vllm transfer solution

* Revert "Fix VLLMClient argument in grpo_trainer; Add zero3+peft vllm transfer solution"

This reverts commit d759c9c4d12ff4531482c465c6257a59987ba748.

* log num_tokens

* disable vllm test (in the future we'll add a mock for vllm server for them)

* style

* fix ds3_gather_for_generation

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-03-21 12:12:08 -07:00
a34987956c 🎬 Clip higher (#3118)
* epsilon range added

* epsilon doc str updated

* test removed

* pre-commit run

* Update trl/trainer/grpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* upper epsilon updated

* precommit updates added

* minor format and dtype fixes

* moving upper bound computation in init

* hf.co for paper link

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-03-19 19:28:19 -06:00
0f88c179e3 Merge pull request #3079 from huggingface/flexible_reward
Flexible_reward
2025-03-18 11:32:16 -06:00
beda4328cc Use main process for dataset.map (#3106) 2025-03-18 17:36:12 +01:00
07cfe1677e add "_prepare_fsdp" for DPOTrainer (#2539)
* enable prepare fsdp

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* remove activation_checkpointing

* move to utils.py

* fix style

* Update utils.py

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-03-17 14:37:15 +01:00
9f7755d8ed 🕊️ Padding-free for SFT (#3076) 2025-03-15 12:52:24 -07:00
4e3f569eb8 Update grpo_trainer.md [ci skip] 2025-03-14 18:48:50 -07:00
979fda1548 title multi-task added for example4 2025-03-15 01:19:31 +00:00
f6fb6a88a9 precommit fixed applied 2025-03-15 01:10:32 +00:00
6cbf8fbc9f Merge branch 'flexible_reward' of github.com:huggingface/trl into flexible_reward 2025-03-15 01:08:08 +00:00
5cb390cd30 Add EOS token to processed input in SFT (#3091)
* Add EOS token to processed input

* Update sft_trainer.py

* fix test
2025-03-14 18:06:15 -07:00
b3c391e628 Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 19:03:31 -06:00
1b85ca6147 grpo doc updated 2025-03-15 01:03:04 +00:00
e7a1290b0a Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:57:13 -06:00
3822edd67b Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:56:54 -06:00
230455cab0 Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:56:33 -06:00
08f014d559 Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:50:56 -06:00
10740333bd Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:49:07 -06:00
058a733c30 Update tests/test_grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:48:59 -06:00
3f193972d8 Update tests/test_grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:48:39 -06:00
b575596b89 Update tests/test_grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:45:55 -06:00
118c43f0e0 Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:44:05 -06:00
40b1c33edf Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:38:08 -06:00
1a2e74cc5a Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:35:38 -06:00
80f7dcb16d Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:35:04 -06:00
4404ccd24a Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:34:50 -06:00
39f77ca2d8 Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:34:36 -06:00
52085dd96b final version 2025-03-15 00:19:34 +00:00
c7a1c95017 Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:07:38 -06:00
3003058418 Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:07:31 -06:00
a759cee2e0 Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:07:24 -06:00
0a3bad44f0 Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:07:13 -06:00
bb5b96a823 Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:07:06 -06:00
8466c7273e Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:06:59 -06:00
a871ec8e91 Update tests/test_grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:06:36 -06:00
f7572221db Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:06:29 -06:00
8ec2e42833 Online fixes 2025-03-14 23:58:33 +00:00
218d493d11 Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 17:15:54 -06:00
1a9f78eb3a Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 16:57:18 -06:00
a10978ebdf reviews reflected 2025-03-14 22:27:46 +00:00
87fbb831d3 Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 14:04:39 -06:00
52f39d6a24 Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 13:57:48 -06:00
931f7a14d2 conflict 2 pushes fixed 2025-03-14 19:47:05 +00:00
9951105a90 Merge remote-tracking branch 'origin/flexible_reward' into flexible_reward 2025-03-14 19:36:32 +00:00
5a6e23aac9 review commnts reflected + unittest n doc added 2025-03-14 19:28:59 +00:00
d9104c8b0d Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-13 16:27:55 -06:00
d5a5840307 Remove simple_test.py from version control 2025-03-13 22:23:09 +00:00
f3cbd41e2c interactive reward_func added 2025-03-13 22:09:12 +00:00
d41a32f619 restriction removed from util 2025-03-13 18:58:07 +00:00
fc4dae256d 🫣 [GRPO] add cache_implementation option in GRPO (#3075)
* add cache_implementation option in GRPO

* add cache_implementation to config

* Update trl/trainer/grpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-13 19:21:36 +01:00
e4e5671e80 💎 Gemma 3 SFT example on Codeforces dataset (#3070)
* Gemma 3 and padding free

* remove padding free changes

* style

* update sft cli

* update script

* revert

* style
2025-03-13 10:50:52 -07:00
7c76f103da irrelavant reward ignorance added 2025-03-13 17:39:49 +00:00
aad18ef52a 🎭 Minor spelling fix in documentation (caracteres -> characters) (#3074)
Signed-off-by: Ed Snible <snible@us.ibm.com>
2025-03-13 08:59:24 -07:00
b55d9f0412 Fixing JSD loss computation as per definition (#3043)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-03-13 11:52:50 +01:00
4871c82b0c 🏊 [SFT] Compatibility with padding free and iterable dataset (#3053)
* Compatibilitywith padding free and iterable dataset

* Fix collator test

* add a test for streaming

* some cleaning

* improve and fix tests

* tiny revert

* bump datasets to 3.0.0
2025-03-12 11:44:25 -07:00
fd9e5a7cab 🦥 Fixed SFTTrainer.compute_loss hang by re-summing before the gather (#3056) 2025-03-12 05:43:33 -07:00
5463e49a55 use argument names with processing_class (#3062) 2025-03-12 13:03:45 +01:00
22759c8208 👯 [GRPO] Relax the assumption that prompts are unique within a batch (#3052)
* Relax the assumption that prompts are unique within a batch

* style
2025-03-11 15:24:06 -07:00
2ee6fd369f 💠 Fixing SFTTrainer.compute_loss crash with accelerate (#3048)
* Fixed crash in SFTTrainer due to accelerator.gather_for_metrics during training

* Moved sum outside of accelerator.gather_for_metrics

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-11 11:08:51 -07:00
844a9c665f 🏁 Passing custom BOS/EOS token to GPROTrainer.generation_config (#3046)
* Passing custom BOS/EOS token to fallback GRPOTrainer.generation_config

* Reordered kwargs per PR comment

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-11 11:08:33 -07:00
04f6597377 🌡️ Fix temperature inconsistency in GRPO trainer (#3029)
* fix temperature inconsistency in GRPO trainer

* adding 1e-7 isn't necessary

* comment

---------

Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-03-11 10:36:42 -07:00
e3244d2d09 🚀 Supporting deepspeed>=0.16.4's rename (#2963)
* Added else clause to avoid NameError on optimizer_offload

* Accounted for deepspeed's renaming in 0.16.4

* Switched to packaging.version.parse over the (broken) tuple split

* Moved from NotImplementedError to RuntimeError in else clause
2025-03-05 15:49:21 +01:00
6a02c69789 🎲 Add support for additional generation kwargs in GRPO Trainer (#2989)
* Add support for additional generation kwargs in GRPO Trainer

- Extend GRPOConfig to support additional generation kwargs
- Update GRPOTrainer to incorporate additional generation parameters
- Add tests for training with additional generation kwargs for both standard and vLLM modes

* Add missing vllm_gpu_memory_utilization=0.5

* 🔧 Refactor GRPO generation parameters and configuration

- Restructure GRPOConfig to separate generation parameters
- Add support for top_p, top_k, min_p, repetition_penalty, and length_penalty
- Remove additional_generation_kwargs in favor of explicit parameters
- Update GRPOTrainer to use new generation parameter configuration

* Update tests

* Remove length_penalty and fix tests

* Update defaults and docs

- Change temperature type from Optional[float] to float
- Set default top_p to 1.0 instead of None
- Simplify parameter descriptions by removing redundant "if set to None" text
- Maintain consistent type hints and default values for generation parameters

* GRPO remove optional type hint for temperature parameter

* Remove length_penalty from sampling_kwargs dict in GRPOTrainer

* some refactoring

* top k None support

* change value of in test to amke them work

---------

Co-authored-by: Robert Veres <robert.veres@languagetool.org>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-03-05 09:58:00 +01:00
a1c58aa42a 🗜️ Loosened tokenizer type hint on apply_chat_template (#3005) 2025-03-04 17:41:42 +01:00
3f0695a4ca 🌍 Use global normalization for KL logging (to match normalization for loss) (#3004)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-04 17:14:22 +01:00
a72b50b772 📚 Update customization and distributing training documentation (#2991) 2025-03-04 16:37:54 +01:00
ea1d9be2a7 ✌️ Remove double compute of sum in SFTTrainer (#3001)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-04 16:35:30 +01:00
402187baab Improve ci (#3007)
* Create codeQL.yml

* Create custom-queries.qls

* Update custom-queries.qls
2025-03-04 15:53:51 +01:00
5858ceab7e 🪙 [SFT] Log num_tokens and some logging fixes (#3006) 2025-03-04 15:45:11 +01:00
7442d42c21 Update pr_style_bot.yml (#3003) 2025-03-03 19:23:16 +01:00
98de0e7c62 🚀 DeepSpeed integration documentation (#2993)
* ds doc

* Update docs/source/deepspeed_integration.md

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-03-03 14:51:45 +01:00
491921c1a4 🛣️ inference_mode to no_grad when computing old_per_token_logps (#2987) 2025-02-28 22:58:05 +01:00
ad6a35bdd5 🫔 [GRPO] Pass wrapped model to unwrap_model_for_generation for DeepSpeed Stage-3 compatibility (#2871)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-28 18:17:04 +01:00
7bc9858a8f 🔍 Update GRPO config documentation for beta parameter stability (#2992) 2025-02-28 17:39:12 +01:00
b882f57d93 ⚰️ Deprecate liger-kernel (#2949)
* Deprecate liger

* remove import

* oops, shouldn't be here

* Fix other deprecations

* remove liger from gkd for now

* remove liger for teacher

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-02-28 14:58:47 +01:00
ac7bde5832 📑 Fix logged metrics for KTO (#2982) 2025-02-28 14:58:31 +01:00
3d94e4e25c 📜 Update README and doc index (#2986)
* Update readme and doc index

* bold

* consistent uppercase
2025-02-28 13:51:58 +01:00
1a303cca8e 🧬 Fix typo in grpo_trainer.py (#2988) 2025-02-28 13:49:47 +01:00
ac327d5e84 🪪 Adds a more fine-grained profiling context (#2975)
* adds a more fine grained profiling context

* precommit

* fix reward func name

* add reward to RM name

* Update trl/extras/profiling.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* some doc and fixes

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-02-27 21:58:39 +01:00
c0854c32c9 🌌 Fix logits computation in trainer prediction step (#2969)
* Fix logits computation in DPO trainer prediction step

* fix compute_metrics for bco and test

* same for cpo

* same from dpo

* for kto

* anf finally orpo

* Apply style fixes

---------

Co-authored-by: kyungdae-jo <kyungdae.jo@navercorp.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-02-27 17:09:11 +01:00
aa18ecfde7 👂 Update learning rate doc in KTOConfig (#2912)
* Update kto_config.py

Fix the mismatch between documentation (and suggested) kto learning rate

* fix doc

---------

Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-27 14:40:54 +01:00
6849c050b9 🕸 Add distributing training guide (#2956) 2025-02-27 14:31:52 +01:00
27a6f2201b 🧗 Add GRPO Trainer support for third-party accelerators (#2836)
* Add GRPO Trainer support for Ascend NPU

* 更新 grpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* code format

* 更新 grpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* patch mem_get_info

* stylre

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-02-27 13:25:24 +01:00
f074dcdc86 👧🏽 Adding DoRA support to model config (#2974) 2025-02-27 12:37:22 +01:00
0caff61600 Update grpo_trainer.py (#2973) 2025-02-27 09:38:32 +01:00
019fc6dbaa 🔢 Fix GRPO doc about num_iterations (#2966) 2025-02-26 12:46:08 +01:00
69ad852e56 Parameterize enable_prefix_caching (#2900)
* parameterize enable_prefix_caching

* apply review suggestion

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-25 00:40:09 +01:00
45ccdefac4 📌 Pin liger-kernel and vLLM (#2952)
* pin liger-kernel

* style
2025-02-25 00:34:16 +01:00
703484a8c2 🗿 Updated DPO default values for alpha and tau (#2918)
* updated DPO default values for alpha and tau

* same for grpo

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-02-25 00:19:48 +01:00
9b76d5f2e9 ↩️ Fix typo in TextEnvironment init param, should be max_tool_response (#2921)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-25 00:08:06 +01:00
cbe0681ba1 📇 GRPO: print completions to console and update docs (#2951)
*  Enhance GRPO logging with configurable completions sampling

- Update `GRPOConfig` to replace `log_completions` with `log_completions_steps`
- Add `print_prompt_completions_sample()` utility function for rich console logging
- Modify `GRPOTrainer` to additionally print 5 random prompt-completion pairs every log_completions_steps steps

* GRPO trainer completions logging, move wandb checks together

* Add rich availability check and use fallback in print_prompt_completions_sample when rich is not available

* Update docstrings on print_prompt_completions_sample

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Revert back to simple log_completions bool

* GRPO log completions fully

* Remove print fallback from print_prompt_completions_sample

* Move accelerator main process check up for grpo log completions

* Explicit variable names in print_prompt_completions_sample

* Make GRPOConfig docstring match field description

* Update log_completions docs again

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update GRPOConfig docs to match field

* improve readibility when prompt or completions are multilines

* log reward

* prevent hanging, don't print without rich, print reward

* style

---------

Co-authored-by: Robert Veres <robert.veres@languagetool.org>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-02-24 23:53:13 +01:00
4e0cf01aef Prevent applying the chat template to tokenized datasets (#2939)
* Update sft_config.py

* Update sft_trainer.py

* Update sft_config.py

* Update sft_trainer.py

* Apply style fixes

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-02-24 23:14:49 +01:00
5c05913196 🐯 Fix LigerKernel for SFTTrainer (#2940)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-24 17:29:48 +01:00
caba04da42 ☠️ Update max_seq_length to max_length in SFTConfig (#2947) 2025-02-24 16:26:20 +01:00
be5a088337 📋 Add vLLM version to environment printout (#2946) 2025-02-24 14:22:43 +01:00
38861475e6 ♻️ Fix caching in SFT (#2945) 2025-02-24 10:54:39 +01:00
f69707dab4 🐈 Bye bye chat (#2934)
* Bye chat

* better warning

* style error

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-02-23 19:18:28 +01:00
76f00fc394 Ensure precommit exits 0 status 2025-02-23 16:34:54 +00:00
8453017622 🧼 Upgrade ruff (#2938) 2025-02-23 17:33:50 +01:00
3608709529 Update pr_style_bot.yml 2025-02-23 14:32:36 +01:00
21f0055893 🤖 Style bot (#2935) 2025-02-23 14:29:22 +01:00
013d360b8f 🔹 Fix: Miscalculated mask shape in comments (#2925) 2025-02-21 17:01:53 +01:00
e5ae703d35 🐦🔥 6x faster GRPO with multi-step optimization (#2899)
* Add num_updates and epsilon parameters to GRPOConfig and GRPOTrainer

* test sampler

* update the loss computation

* fix eval sampler

* should work now

* buffer inputs with grad accum

* optimize when num_iterations == 1

* test

* minor comment removal and fix log metric

* beta position

* clarify comment [ci skip]

* clarify sampler doc [ci skip]

* fix collision with eval logging

* clarify
2025-02-20 19:51:45 +01:00
a92e00e810 🪪 Adds profiling decorators for GRPOTrainer (#2889)
* adds profiling decorator

* naming + precommit

* style

* revert inclusion of slider table

* revert 2

* revert3

* revert4

* revert 5 fml

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-20 09:57:42 +01:00
9b3c5bf64f 📍 [GRPO] add gradient_checkpointing (#2848)
* add gradient_checkpointing

* added a helper

* Update trl/trainer/grpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/grpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* minor refactor for better readability

* use acceelrate util

* enable_input_require_grads is in base class

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-02-18 18:09:16 +01:00
15fec312d5 🍃 GRPO - Do not load reference model when beta == 0 (#2806)
* 🔧 Optimize GRPO training by conditionally loading reference model based on beta value

*  Add test for GRPOTrainer with beta=0 to ensure no reference model and KL divergence

* 🔧 Refactor GRPOTrainer code for improved readability and maintainability

* 🔧 Simplify per_token_loss calculation in GRPOTrainer for clarity

* fix test, style, and some struct for clarity

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-18 17:57:15 +01:00
be1e34003c 🩳 max_seq_length to max_length (#2895)
* `max_seq_length` to `max_length`

* remove in 0.20
2025-02-18 16:53:37 +01:00
6aaf379a82 ⚰️ Remove deprecated (#2894) 2025-02-18 16:53:21 +01:00
49adf74833 Add vLLM guided decoding support to GRPO Trainer (#2811)
*  Add vLLM guided decoding support to GRPO Trainer

* 🔧 Update vLLM guided decoding in GRPO to use regex parameter

* style and docstring

* test

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-18 16:53:05 +01:00
6c54f023ae 🪂 Don't gather logits in SFT to avoid hanging (#2890)
* Don't gather logits

* Remove unused function and test
2025-02-18 15:31:08 +01:00
963243a7d1 Optimize vllm num_generations (#2855)
* small optimization of vllm batching

* style

* adds comment

* style
2025-02-18 11:44:15 +01:00
aafd8cbea5 🍟 [SFT] Handles the dataset if it has been preprocessed (#2863)
* return dataset if it's preprocessed

* add is_processed flag variable

* add test

* move test_sft_trainer_directly_with_pretokenized_data to Tester2

* Update sft_trainer.py

* no need for padding and truncation

* minor reorganization

* Update trl/trainer/sft_trainer.py

* let the collator pad

* style

* fix tests

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-02-18 09:56:47 +01:00
822653824b 🧶 [GRPO][vLLM + LoRA] Move unmerge of PEFT model after weight loading (#2873) 2025-02-17 20:34:07 +01:00
ba036576d4 💬 Add maybe_convert_to_chatml map for conversational datasets in SFT (#2862)
* add back get_formatting_func_from_dataset

* maybe_convert_to_chatml

* maybe_convert_to_chatml before maybe_apply_chat_template map

* remove comment

* test

* desc

* style

* Update trl/data_utils.py

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-17 16:47:06 +01:00
293b620950 [GRPO] Fix loss normalization (#2881)
* fix GRPO loss normalization

* fix sum dim

* fix loss= repeated
2025-02-17 13:26:21 +01:00
ae3bd0d07a 🆙 Bump vLLM min version to 0.7.2 (#2860)
Bumps vllm as there were a number of throughput improvements in vllm==0.7.2

Also may resolve issue such as https://github.com/huggingface/trl/issues/2851
2025-02-17 10:54:07 +01:00
6d9fc11fd6 [SFT] fix check for AutoLigerKernelForCausalLM (#2874)
* fix check for AutoLigerKernelForCausalLM

* fix case where AutoLigerKernelForCausalLM is not defined

* update min liger version

* formatting

* fix win CI
2025-02-17 07:50:55 +01:00
ffcb9f4aee ⬆️ Bump dev version 2025-02-13 14:33:44 +00:00
00e5889380 Release: v0.15 2025-02-13 14:28:36 +00:00
5c9cf2003d 👨‍👩‍👧 GRPO + PEFT + vLLM (#2818)
* peft + grpo + vllm

* test change

* support model alread peft

* Update tests/test_grpo_trainer.py

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-02-13 15:23:36 +01:00
8830786a23 🪆 Fix for Incorrect ValueError Handling in reward_weights in grpo_trainer.py (#2843)
- Fixed a bug where an extra `len` call inside the error message caused a `TypeError` instead of the expected `ValueError`.
- Replaced `len(len(args.reward_weights))` with the correct `len(args.reward_weights)` to properly calculate the number of reward weights.
- Ensured that a `ValueError` is now raised with an accurate and clear message when the number of reward weights does not match the number of reward functions.

This fix prevents confusion during debugging and ensures proper error handling during validation.

Tested with cases where:
- `args.reward_weights` is None (default case).
- `args.reward_weights` has mismatched lengths with `reward_funcs`.
2025-02-13 13:46:18 +01:00
b0f513c13d Fix PeftModel check when moving weights to vlllm (#2850)
This check meant that peft now because a required dep when running GRPO with vllm. 

This PR should resolve this.
2025-02-13 12:23:10 +01:00
81221661c6 Fix GRPO PEFT (#2725) 2025-02-12 18:36:01 +01:00
7347c292c3 🥾 Allow bootstrap GRPO (#2829)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-02-11 18:56:22 +01:00
2106b31298 👴 Update tokenizer parameter to processing_class in tests (#2828) 2025-02-11 11:46:26 +01:00
9b67eea473 🙌 Share vLLM device with training when only 1 available (#2827)
* Fix GPU device selection in GRPOTrainer in case training with onyl one

* update doc

* style

* update warning
2025-02-11 11:30:37 +01:00
e752fc6c2e ⚖️ Add reward weight in multi-reward settings for GRPO (#2676)
* added reward weights for multi-reward runs in GRPO

* reward_weights are float, moved from GRPOTrainer to GRPOConfig

* minor comment fix

* minor

* fix test

* missing link

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-11 11:15:41 +01:00
674bb75f59 🫘 Add set_seed() call in GRPO to ensure unique seed for each process (#2824)
* Add set_seed() function to ensure unique seed for each process

* share seed sampler

* style
2025-02-11 10:30:27 +01:00
b9df81045b 📤 GRPO refactor loading the model weights to vllm (#2817)
* GRPO refactor loading the model weights to vllm

* style

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-02-10 15:20:38 +01:00
55e680e142 fix: typos in documentation files (#2804) 2025-02-08 20:46:47 +01:00
09eefa73ab ⛰️ Reduce peak vram consumption with efficient selective log_softmax (#2799)
* Reduce mem consumption across many trainers with efficient selective log-softmax approach

* rename

* typo fix

* precommit

* Update tests/test_core.py

* relocate

* precommit

* style

* smaller values for test, and run on cpu

* nit doc improvements

* style

* fix test

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-02-08 00:59:46 +01:00
7fdb69aa7d Fix GRPO example in README (#2800) 2025-02-08 00:29:26 +01:00
5b9236d1e8 🔬 SFT simplification (#2405)
* initial commit

* update

* Refactor SFTTrainer and SFTConfig

* Update SFTConfig class in sft_config.py

* Fix SFTConfig torch_dtype validation and dataset preprocessing flag

* Refactor dataset mapping and conversion

* Refactor dataset mapping in SFTTrainer

* Fix SFTTrainerTester unit test by removing unnecessary code

* Remove unused variables and update tokenization logic

* Remove pack_dataset function

* Add deprecation warning for tokenizer in SFTTrainer constructor

* add docstring back

* Update model parameter type annotation

* Update SFTTrainer class definition

* style

* preprocess_dataset -> _prepare_dataset

* Retro compat

* Update formatting_func type hint in SFTTrainer constructor

* typo

* better comment

* simplify tokenize row

* Fix type hint for peft_config

* fix doc

* Add pack_examples function to `test_data_utils.py`

* promote pack_examples and document

* improve doc

* Add new SFTTrainerTester2 class for testing

* test was reversed

* ©️ Copyrights update (#2454)

* First changes

* Other files

* Finally

* rm comment

* fix nashmd

* Fix example

* Fix example

* 💬 Fix chat for windows (#2443)

* fix chat for windows

* add some tests back

* Revert "add some tests back"

This reverts commit 350aef52f53f8cf34fccd7ad0f78a3dd63867e06.

* 🆔 Add `datast_config` to `ScriptArguments` (#2440)

* datast_config_name

* Update trl/utils.py

* sort import

* typo

* Trigger CI

* Rename `dataset_config_name` to `dataset_config`

* 🏎 Fix deepspeed preparation of `ref_model` in `OnlineDPOTrainer` (#2417)

* Remove unused deepspeed code

* add model prep back

* add deepspeed even if it doesn't work

* rm old code

* 👯 Standardize `model_args` (#2442)

* `model_config` -> `model_args`

* sort

* refactor config

* drop skip prepare dataset

* add sep to packing

* drop prompt-completion for now

* Revert "drop prompt-completion for now"

This reverts commit 16ef195031ac9c860f8f2ac383ff34133fcbe70f.

* Revert "add sep to packing"

This reverts commit dc84d08da7a4b7804c064be1a15605f1770549e2.

* Revert "drop skip prepare dataset"

This reverts commit d2ee070d994a4b29ad33128a8ef99f101994a6c7.

* Revert "refactor config"

This reverts commit f732aa8728e42623ee5817b514263912cab337e4.

* Format

* Update doc-builder workflow to use specific commit sha

* add peft edge cases

* no logits when using liger

* remove unused columns

* proper handle of prompt-completion

* trick to keep messages

* fix messages missing

* for Liger kernel, ensure only input_ids is present

* packing and liger are compatible

* shinny doc and final nits

* another nit

* refactor config and doc

* re add truncation

* fix ci

* drop deprecated params in tests

* fix link

* fix config docstring

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-02-08 00:21:36 +01:00
82d12eb751 📠 Log completions for GRPO (#2772)
* log completions

* typo

* wandb

* Fix completions

* Fix style?

* Remove double import

* Revert

* group logging

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-02-07 12:41:58 +01:00
84d73fd00b 🎯 [SFT] add token accuracy metric (#2597)
* add token accuracy metric

* fix return type

* shift tokens

* use compute_loss so that the model is called only once

* add to logs

* log from main process

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-07 11:09:46 +01:00
2241f17914 🆚 Distinguish padding and eos when they differ (#2793) 2025-02-07 11:08:49 +01:00
cf97133d51 📉 Optimize GRPO memory usage by redefining per_device_batch_size as generations per device (#2776)
* Distribute

* fix some logic errors

* fix and document RepeatRandomSampler

* comment

* doc clarification

* fix type hint

* more readable

* fix eval

* fix tests

* roll back to distribute generation

* improve comment [ci skip]

* fix slice

* catch for eval batch size as well; fix completion_ids in vllm

* log completions

* Revert "log completions"

This reverts commit 1e4af8ffb8dda15d7596e707ac784208db88135a.

* Before the first training step, the model has no optimizer: fix ds3
2025-02-06 20:20:44 +01:00
724acb9716 💡 Add 'Post training an LLM for reasoning with GRPO in TRL' tutorial (#2785) 2025-02-06 18:28:05 +01:00
7134a1e73f Revert "Before the first training step, the model has no optimizer: fix ds3"
This reverts commit bf6e7edea54f2e34b2f6802468ee3224c4aa8030.
2025-02-06 17:19:57 +00:00
bf6e7edea5 Before the first training step, the model has no optimizer: fix ds3 2025-02-06 17:19:05 +00:00
e95f9fb74a 🙃 Fix reward function in GRPO example (#2777) 2025-02-06 09:51:44 +01:00
a85768f120 💡 GRPO vram-efficiency improvement; only compute relevant logprobs (#2773) 2025-02-06 08:52:21 +01:00
78c5ce23fd ↔️ GRPO: Set max_model_len when initializing vLLM instance (#2728)
* Set max_model_len when initializing vLLM instance

* Introduce vllm_max_model_len arg

* Replace vllm args with vllm_init_kwargs

* Update docstring

* Add missing import

* Remove default values from newly deprecated args

* Docs update

* Reverted to adding single arg for max_model_len

* Remove spurious import

* Remove spurious line

* style

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-02-06 00:12:31 +01:00
af4ad47035 🚧 Add Optional ZeRO-3 Weight Gathering for GRPO in Sequence Generation (#2667)
* Add (grpo) unwrap_model_generation zero3 gathering

* proper placement

* Disabling this option is not compatible with vLLM generation.

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-02-04 23:24:35 +01:00
b2ae99925d 🔁 🦈 Support iterative GRPO (#2700)
* support for synchronization ref-model added

* support for synchronization ref-model added

* tests for sync_ref_model added

* Update tests/test_grpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* split and fix test

* style

* doc

* move after init to ensure accelerator exists

* Update tests/test_grpo_trainer.py

* style

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-02-04 23:10:13 +01:00
bd946f93c1 🤖 Properly unwrap torch.compile-ed models in GRPO (#2750)
* properly unwrap torch.compile-ed models with GRPO

* add test and compat with reward models

* ignore test windows

* properly unwrap torch.compile-ed models with GRPO

* add test and compat with reward models

* ignore test windows

* chore: lint

* style

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-04 22:22:10 +01:00
f42e34e613 🔎 Add missing script argument in PPO documentation (#2720) 2025-02-04 21:53:10 +01:00
338fbd546b 📖 Clarification max len in Reward documentation (#2740)
* Nit fix about max_lenth argument.

* copy to docstring

* typo

* consistency

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-04 21:16:29 +01:00
32f8fa8aad 📐 Add vLLM dtype configuration for GRPO trainer (#2738)
* feat: Add vLLM dtype configuration for GRPO trainer

* added vllm dtype info in docstring

* send to vLLM doc

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-04 21:10:56 +01:00
1a2276402f 📌 vLLM >= 0.7.1 for device fix (#2766)
see https://github.com/huggingface/trl/issues/2745
2025-02-04 20:12:22 +01:00
1f344c9377 💔 Decouple loss computing and generation in GRPO (#2762) 2025-02-04 13:21:51 +01:00
85121fc300 🔂 Use vLLM prefix caching for speedup (#2757)
* use vllm prefix caching for speedup

* comment

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-02-04 11:20:50 +01:00
bbdd6db17c ⚠️ Fix attention masking in GRPO (#2708)
* Update grpo_trainer.py

* Update grpo_trainer.py

* Update grpo_trainer.py

* Slight name change

* Fix typo

* Improve readability + move attn mask to args

* revert adding "completion_"

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-02 20:44:54 +01:00
6e088d165c docs: Fix typos in alias descriptions (#2729) 2025-02-02 11:59:46 +01:00
a325a0eec5 fix: Fix typo in filename in ultrafeedback-prompt.py (#2716) 2025-02-01 14:53:47 +01:00
0ec1ccd990 💰 Fix incorrect calculation in Olivia's baguette spending logic (#2727) 2025-02-01 14:52:08 +01:00
1c35a48b50 🏰 num_logits_to_keep to logits_to_keep (#2721) 2025-01-31 20:19:39 +01:00
2ce36ae889 📖 Nit fix in SFT Documentation (#2722) 2025-01-31 16:46:23 +01:00
bf6919117e Improve GRPO example (#2717) 2025-01-31 12:04:44 +01:00
265663af6a 📖 Add GRPOTrainer to README.md (#2713)
* [DOCS] add GRPOTrainer to README.md

I replaced RLOOTrainer with GRPOTrainer because you thought you might want to keep it limited, but let me know if you want both.

* Update README.md

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-01-31 10:30:44 +01:00
5ab15d3fef fix: Fix typo in filename Update ultrafeedback.py (#2699) 2025-01-31 10:01:32 +01:00
fecaa991de 📋 Add eval loss logging during prediction in GRPO (#2694)
* add eval loss logging during predition

* make sure the train and eval logs aren't mixed

* test grpo in eval

* fix tests

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-01-30 18:37:45 +01:00
ab30a01baf 💡 Add "Mini-R1: Reproduce Deepseek R1 „aha moment“ a RL tutorial" (#2697)
* more readable

* add tuto
2025-01-30 17:12:04 +01:00
6dc278a042 ☠️ Remove deprecated (#2692)
* remove deprecated

* remove from test

* remove from test 2
2025-01-30 16:30:40 +01:00
67441bb432 🧠 Fix typo in "understand" in ppo_trainer.md (#2695) 2025-01-30 16:30:24 +01:00
62685fbf20 docs: Fix broken "Good First Issue" link in CONTRIBUTING.md (#2693)
* docs: Fix broken "Good First Issue" link in CONTRIBUTING.md

* Update CONTRIBUTING.md

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-01-30 13:15:37 +01:00
4197956395 🙈 Fixed typo in the GRPO documentation (#2691) 2025-01-30 11:17:02 +01:00
9ac8d9773b 📄 Add GRPO batch size note in docs (#2672)
* add note for OOM error

* update note

* Apply suggestions from code review

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-01-30 09:57:43 +01:00
094d51b599 📖 Docs fix spelling issues (#2682)
* Update alignprop_trainer.md

* Update best_of_n.md

* Update clis.md

* Update community_tutorials.md

* Update cpo_trainer.md

* Update dataset_formats.md

* Update detoxifying_a_lm.md

* Update dpo_trainer.md

* Update rloo_trainer.md

* Update clis.md

* Update rloo_trainer.md
2025-01-30 09:42:14 +01:00
df8f619ec5 📦 trl.templates in excluded packages (#2690) 2025-01-30 09:31:08 +01:00
56880ba73d ⬆️ Bump dev version (#2689) 2025-01-30 09:23:31 +01:00
801582ec24 📉 Use num_logits_to_keep to reduce memory usage in GRPO (#2683)
* use num_logits to keep

* add comment back

* Update trl/trainer/grpo_trainer.py
2025-01-29 17:12:18 +01:00
ed14ed9043 vLLM for fast generation in GRPO (#2600)
* doc

* fsdp

* use vllm config

* vllm

* Update trl/trainer/grpo_config.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/grpo_config.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* typo

* top_k, top_p

* Link to vllm pr

* fix missing device

* fix tests

* fix citation

* fix title and paper_id

* formatting

* output the correct number of generations

* initial async vllm

* fix missing args

* fix promps

* Pass prompt_token_ids directly

* Repeat each prompt num_generations times

* get the slice of results per processor

* undo citation

* OMG

* nothing can resist me!!!!

* working

* vllm_device to "auto"

* add vllm test

* add initial vllm docs

* add vllm link and pip instructions

* add multi-gpu strategy fot vllm

* Update docs/source/grpo_trainer.md

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update docs/source/grpo_trainer.md

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update docs/source/grpo_trainer.md

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* add doc strings

* Update docs/source/grpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/grpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/grpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* add important tag

* fix typo

* overrides default batch size and grad accum and better doc

* Under no circumstances should you examine the contents of this commit.

* auto device, warnings, errors

* better error message

* require_torch_accelerator test vllm

* speeding up traing doc

* device as str

* does it prevent deepspeed init to hang?

* update docs

* require torch accelertor for vllm test

* unwrap compat with ds z3

* simplify examble in doc

* More comments, fix ds3 hanging

* faster, not sure why

* style

* move doc about speed

* revert change in config files

* fix default value in doc [ci skip]

* style [ci skip]

* better comment [ci skip]

* fix warning

* Update grpo_config.py

* Update deepspeed_zero1.yaml

* Update trl/trainer/grpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Apply suggestions from code review

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/grpo_trainer.md

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-01-29 13:01:10 +01:00
4659ad916f 🖊 Fix typos (#2673)
* fix typos

* fix typo

* fix typo

* fix typos

* fix typos

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo
2025-01-28 11:26:36 +01:00
1123bd0f51 🏷️ Add model tags to model trained with GRPO (#2663) 2025-01-26 13:37:15 +01:00
55a329e9f0 🌀 Fix GRPO default completion length doc (#2662) 2025-01-26 10:05:21 +01:00
4720656654 📏 Log completion length in GRPO (#2659) 2025-01-25 20:56:09 +01:00
807046b7d7 📍 Disable caching when grad checkpointing enable in GRPO (#2653)
* disable caching when grad checkpointing

* style
2025-01-25 13:14:34 +01:00
317d2d477b 🔎 Finegrained reward logging for GRPO (#2651) 2025-01-25 11:43:00 +01:00
aeb03cf1a9 👐 DeepSpeed integration for GRPO (#2652) 2025-01-25 10:10:29 +01:00
2578e95023 🚛 Provide all columns of the dataset to the reward function (#2650)
* The reward function is provided with all col from the dataset

* Minor clarifications

* minor renaming in doc [ci skip]

* fix indentation
2025-01-24 20:31:07 +01:00
6f99f42f72 🥞 Fix KTO gradient accumulation loss scaling (#2648) 2025-01-24 16:23:16 +01:00
d14f7f3eb2 🥞 Fix GRPO gradient accumulation loss scaling (#2647) 2025-01-24 16:22:54 +01:00
8e65825d4c 🥞 Fix CPO gradient accumulation loss scaling (#2645) 2025-01-24 12:22:46 +01:00
5e4d7be0e1 Update grpo_trainer.md 2025-01-24 09:06:16 +01:00
f34b70a32e 🌯 Fix context manager runtime error when gather is disabled (#2639) 2025-01-23 21:23:54 +01:00
0e216f7411 🍭 Custom reward function for RLOO (#2612)
* rloo custom reward function and test

* idont even know why i did that

* removing get_reward_custom

* remove get_reward_custom test

* fix code quality check

* adding test

* end this mysery already

* fix test
2025-01-23 22:46:37 +03:30
59c201433c 🥞 Fix BCO gradient accumulation loss scaling (#2638) 2025-01-23 18:57:43 +01:00
40c238395e 🥞 Fix DPO gradient accumulation loss scaling (#2615)
* fix DPO for gradient accumulation

* Update trl/trainer/dpo_trainer.py

* Update trl/trainer/dpo_trainer.py

* Update trl/trainer/dpo_trainer.py

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-01-23 18:12:06 +01:00
a1d2955116 🏆 Custom reward function for GRPO and shiny doc (#2606)
* initial commit

* doc on custom reward function

* test

* doc doc doc

* fix collator

* style

* links?

* I need a docdoc 🎵

* fix link

* I do like writing doc tbh

* it takes time, but it's worth it

* no return!

* type hint

* it's probably the best of both worlds [ci skip]

* new doc before implementation

* tests

* more doc

* style

* multiple pretrained funcs

* fix arg name

* main?

* example for R1

* fix script

* clearer

* import [ci skip]

* Update docs/source/grpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-01-23 17:39:45 +01:00
887c1f3fa3 💎 Rename an inner var in GRPO to improve clarity (#2616)
* rename advatages to per_token_loss for clarity

* doc ci
2025-01-23 17:30:22 +01:00
949db2357e 👋 Drop MDX (#2611) 2025-01-23 13:38:15 +01:00
fe4b5efe4e ✂️ Reintroduce truncation_mode in DPOTrainer (#2551)
* reintroduce truncation mode in DPOTrainer

* move truncation_mode in dataset.map invocation

* truncate full sequence

* "." [ci skip]

* Empty commit

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-01-22 15:33:50 +01:00
a9b54a852e 🫷 Include stop token in policy model's generation_config (#2528)
* Include stop token in policy model's generation_config

* Fix formatting

* Update trl/trainer/ppo_trainer.py

* Update trl/trainer/ppo_trainer.py

* don't modify args

* clarify doc

* more nice doc

* missing no [ci skip]

* really don't modify args

* oups

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-01-22 12:24:42 +01:00
d4222a1e08 🧩 PPO/RLOO/OnlineDPO sequence generation: make deepsped 3 weight gathering optional (#2557)
* PPO/RLOO/OnlineDPO: add ds3_gather_for_generation argument to control weights gathering for generation

* code formatting

* rephrase and document

* more doc

* style [ci skip]

* Trigger CI

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-01-21 22:44:18 +01:00
a5c88d6c75 Add uv installation instructions (#2601)
* add uv

* Update docs/source/installation.mdx

* Update docs/source/installation.mdx

* pypi -> PyPI

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-01-21 22:09:18 +01:00
b6a084c46e 💾 Reduce memory peak in GRPO by adding max_prompt_length and loop usage in logp computation (#2598)
* add max_prompt len to config

* truncate prompt and compute log probs line by line
2025-01-21 15:12:04 +01:00
d9f056862f 🧰 Tool fine-tuning support DPO (#2479)
* adding tool fine-tuning support for DPO

* precommit

* adding test for DPOTrainer with tool usage

* style

* fix test

* a comment

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-01-21 09:32:31 +03:30
3d2c1e49b1 Fix merge error (#2595) 2025-01-20 22:17:39 +01:00
5fd78367ae 🫣 Ignore CLI test for Python 3.9 (#2592)
* ignore cli test for python 3.9

* move import inside tests
2025-01-20 21:26:11 +01:00
0f5ffad26e 👨‍👨‍👧‍👧 GRPO (#2565)
* init grpo [ci skip]

* initial version

* refine args defs

* model card

* initial doc

* fix badges

* fix spaces

* try link to super in doc

* temperature, fix indexing, and std=0.0

* grpo script for cli

* peft support

* move data preparation in `compute_loss`

* weird doc trial

* fix device and some logging

* unwrap_model_for_generation for distributed setting

* Compat with distrib training

* revert grpo config doc trial (didn't work)

* test

* allow model to be str and processing_class to be none; fix loss computation

* advantage is always 0.0: don't log

* fix peft not installed

* proper reward model for testing

* fix script for cli

* add trl grpo to cli doc

* test peft

* flush left

* fix reward calculation

* new reward model

* support any reward model

* fix reward processing class def

* log reward std

* fix reward logging

* fix grad computation

* skip embed layer in test

* remove optimizer_cls_and_kwargs

* improve GRPO default args

* reduce mem usage for grpo test

* reduce mem usage in test grpo

* reduce memory usage for test

* Fix the test

* remove redondant

* fix min version

* Update test_grpo_trainer.py

* Update test_grpo_trainer.py

* Fix test, finally found the solution!

* some doc

* Update doc-builder workflow to use specific commit sha

* more doc

* advantages

* drop cancel fo no grad

* logged metrics [ci skip]

* completion col is ignored [ci skip]

* fix latex

* double space? ~?

* try a latex fix

* with branch

* Empty commit

* Empty commit

* double space seems to be the solution
2025-01-20 19:02:15 +01:00
88514d51e3 Update reducing_memory_usage.md 2025-01-18 21:12:25 +01:00
76837e82b9 🎞️ Fix documentation SFT -max_seq_length instead of max_length (#2590) 2025-01-18 21:10:33 +01:00
35553930da 🫢 Add max_prompt_length parameter in tests (#2588)
* Add max_prompt_length parameter to tokenizer

* style [ci skip]
2025-01-17 19:40:38 +01:00
fd4b283b82 ✂️ Truncate by default (#2587)
* set default for max_length and max prompt lenngth and add guidelines for defaults

* remove dep kwargs

* truncate prompt in prm

* Update CONTRIBUTING.md [ci skip]
2025-01-17 17:03:41 +01:00
1b1140aa69 [RLOO] fix token_level_kl (#2575)
* fix token_level_kl

* fix non_score_reward and rlhf_reward

* add rloo test

* update test

* fix docs

* fix doc
2025-01-17 14:59:25 +01:00
4c7eb6fe29 🐛 Simplify bug report template (#2585) 2025-01-17 14:40:37 +01:00
564fc86759 Update issue_auto_labeller.yml [ci skip] 2025-01-17 14:10:33 +01:00
3215a1c586 Update issue_auto_labeller.yml 2025-01-17 13:59:14 +01:00
cdc16f3ac6 🔖 Issues Auto-Labeller (#2542)
* Initial commit for auto labeller

* Using HF instead of openai

* secrets name change

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-01-17 13:46:24 +01:00
2ecd53ad77 🏎️ vLLM for Online DPO (#2558)
* vllm online dpo

* new arg and add back generation config [skip ci]

* import utils

* optional import and comment

* is_vllm_available

* support conv and not conv [ci skip]

* add old code back

* use func [skip ci]

* fix _generate call

* fix and dedicated func

* top k 50

* style

* add import error

* new testing model

* Update OnlineDPOTrainer class with new features

* test vllm

* fix generate tiny script

* max len arg

* fix comment [ci skip]

* revert num_return_sequences

* vllm dep

* Add require_torch_accelerator import and skip test if vllm is not available

* proper require_torch_accelerator

* add vllm section

* Add hfoption sections to speeding_up_training.md

* no, an id

* Update vllm dependency to exclude Windows platform

* Note on future release

* style
2025-01-17 11:39:13 +01:00
5877786b5a 🪄 Minor comment style modif (#2582) 2025-01-17 11:12:00 +01:00
57d9a97394 Refine model card method docstring (#2566)
* refine model card docstring

* bco

* prm
2025-01-13 15:58:01 +01:00
751fb1d84b 🏛️ Improve DPO configuration documentation structure (#2561)
* better structure dpo config

* fix tests

* fix regex

* add contributing guidelines
2025-01-12 15:23:19 +01:00
edabe0a2d8 [RLOO] Reinforce++ (#2552)
* Reinforce++

* formatting

* fix link
2025-01-09 12:09:29 +01:00
abfffc510b 💔 Fix dataset type unpair conversion docs (#2550)
Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
2025-01-08 19:33:05 +01:00
ed7de87dc7 🎴 Add readme for datasets (#2491)
* adding readme for ultrafeedback dataset

* using ModelCard as DatasetsCard like hf datasets is understaffed

* more info in readme.md of the dataset

* generated readme for all dataset scripts

* precommit

* fixing test

* md format; corrections; generation script link

* some collections

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-01-08 17:25:51 +01:00
beb892bfe0 ↩️ Revert ORPO loss changes (#2527)
* revert orpo changes

* add comment
2025-01-08 16:13:20 +01:00
f2d42fa0c2 🔠 Fix SFT truncation documentation (#2521)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-01-08 15:35:49 +01:00
d6a7e9d6f5 ℹ️ XPU support for DPO (#2533)
* add xpu support

* bug fix

* remove header

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* fix import and use the util

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-01-08 15:32:03 +01:00
451677203d 🕊️ DPO padding free (#2520)
* padding free

* specify dtype

* test

* warnings when not flash attention

* fix test

* remove

* docstring padding-free

* flash-attn dep

* Stronger warning

* require_flash_attn in test

* flash-attn in CI

* rm flash-attn from dep

* Remove flash-attn dependency from test workflows

* refactor

* Update .github/workflows/tests.yml

* Update trl/trainer/dpo_trainer.py

* drop require flash-attn

* fix dtype

* refine warning

* Update trl/trainer/dpo_config.py

* Add logic to compute mean logits for chosen and rejected tokens with padding-free

* format

* Update trl/trainer/dpo_trainer.py

* Update trl/trainer/dpo_trainer.py

* fix comment [ci skip]

* fix num logits to keep
2025-01-08 09:22:17 +01:00
2f25f54ab9 ✒️ Fix typo in formatting_func's documentation in ConstantLengthDataset (#2549) 2025-01-07 21:26:28 +01:00
a50124dd3a 🧑‍🤝‍🧑 Proper metrics gathering across ranks before logging (#2474)
* dpo_trainer gather metrics across ranks before logging

according to https://github.com/huggingface/trl/issues/2468

* fix everywhere

* gather_for_metrics

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-01-07 15:05:54 +01:00
1d23ecc36f ©️ Update copyrights year (#2547)
* happy new year

* fix wandb import sort
2025-01-07 14:53:09 +01:00
52d213173f 🚜 Use field in dataclasses (#2494)
* in hh-rlhf-helpful-base

* delete tokenize ds

* dataset scripts

* alignprop

* judge tldr

* ddpo

* zen

* sft video

* literal to choices

* chat

* script args

* alignprop

* bco

* better help format

* cpo

* ddpo

* whether or not -> whether

* dpo

* dont set the possible values

* `Optional[...]` to ... or `None`

* xpo

* gkd

* kto

* nash

* online dpo

* Fix typo in learning rate help message

* orpo

* more ... or `None`

* model config

* ppo

* prm

* reward

* rloo

* sft

* online policy config

* make style
2025-01-06 18:29:09 +01:00
d9ee2fd202 Remove graph breaks for torch.compile() in padding free branch in DataCollatorForCompletionOnlyLM (#2158)
* feat: Add info to batch in DataCollatorForCompletionOnlyLM

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix: formatting

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* feat: Add info to batch in DataCollatorForCompletionOnlyLM

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix: formatting

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix: max_length_k to int

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix:Added comments

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* test cases

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* test cases

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* test cases

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* feat: Add info to batch in DataCollatorForCompletionOnlyLM

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* fix: formatting

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* feat: Add info to batch in DataCollatorForCompletionOnlyLM

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* test cases

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* test cases

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* test cases

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* unit test changes

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>

* style

* add test

* remove test

---------

Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-01-06 15:50:29 +01:00
763738f457 ☄️ Update Comet integration to include LogCompletionsCallback and Trainer.evaluation_loop() (#2501)
* Implemented integration with Comet in `LogCompletionsCallback`. Implemented related integration test.

* Implemented integration with Comet in `CPOTrainer.evaluation_loop()` during logging of `game_log` table.

* Implemented integration with Comet in `CPOTrainer.evaluation_loop()` during logging of `game_log` table.

* Implemented integration with Comet in `DPOTrainer.evaluation_loop()` during logging of `game_log` table.

* Implemented integration with Comet in `BCOTrainer.evaluation_loop()` during logging of `game_log` table.

* Implemented integration with Comet in `KTOTrainer.evaluation_loop()` during logging of `game_log` table.

* Implemented integration with Comet in `ORPOTrainer.evaluation_loop()` during logging of `game_log` table.
2024-12-28 18:35:01 +01:00
aed5da580e 📦 Packing documentation (#2503) 2024-12-22 12:44:07 +01:00
99451b421a 👬 Rename collator PreferenceCollator to DataCollatorForPreference (#2510) 2024-12-22 12:43:55 +01:00
5239b9462d 💧 Generalize disable_dropout (#2511) 2024-12-22 12:19:17 +01:00
8fb267ff1e 👨‍🍳 Clarify DPO data preparation (#2512) 2024-12-22 12:18:22 +01:00
2e1adbb6ff Remove RLOO example test (#2513) 2024-12-22 12:16:14 +01:00
b668048fe1 Update community_tutorials.md (#2509)
* Update community_tutorials.md

* Update community_tutorials.md
2024-12-20 17:40:42 +01:00
8c49ea39ec 🏚 Remove unused components (#2480) 2024-12-19 19:29:39 +01:00
88ad1a099c fix orpo chosen-nll loss (#2502) 2024-12-19 11:33:06 +01:00
9908dda6d9 🗂️ Reorganize documentation (#2483)
* reorganize doc

* consistent ing

* Add reducing_memory_usage.md

* integration with peft

* Add new files and update table of contents

* Add speeding_up_training.md to docs/source and update _toctree.yml

* unsloth

* Liger kernel

* Truncation

* Update truncation parameters for DPO and SFT

* dedicated Intergation section

* clarify

* illustrate

* Sort

* badge for prm
2024-12-18 16:28:11 +01:00
5e204e1eaa 🏞️ Proper dataset for documentation images (#2499)
* first images

* almost all!

* Final

* Some were missing
2024-12-18 11:28:45 +01:00
82cfeb8930 🤩 Add SmolVLM tutorials to Community Tutorials page (#2498) 2024-12-17 23:31:34 +01:00
0fe73a8af5 🗣️ Improve prose for smol course (#2487) 2024-12-16 11:17:29 +01:00
33fb9efc43 ⚰️ Remove deprecated (#2485) 2024-12-15 21:02:59 +01:00
f68d11f9f9 Bump version 2024-12-15 19:56:54 +01:00
aeca63774f 👨‍🏫 smol course links and badges (#2484)
* smol course links and badges

* try without space

* revert space
2024-12-15 19:38:48 +01:00
117c6d4b52 📥 Fix missing BitsAndBytesConfig import in doc (#2478) 2024-12-15 16:54:38 +01:00
6d4ed070f1 ☄️ Add support for Comet experiment management SDK integration (#2462)
* Added support for Comet URL integration into model cards created by trainers.

* Moved `get_comet_experiment_url()` into utils.py

* Updated Comet badge in the model card to use PNG image instead of text.

* Fixed bug related to running PPO example during model saving. The error as following: 'GPTNeoXForCausalLM' object has no attribute 'policy'. Introduced guard check that attribute `policy` exists.

* Implemented utility method to handle logging of tabular data to the Comet experiment.

* Implemented logging of the completions table to Comet by `PPOTrainer`.

* Implemented logging of the completions table to Comet by `WinRateCallback`.

* Implemented logging of the completions table to Comet by `RLOOTrainer` and `RewardTrainer`.

* Restored line to the main branch version.

* Moved Comet related utility methods into `trainer/utils.py` to resolve merge conflict with master branch,

* Update trl/trainer/utils.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Implemented raising of `ModuleNotFoundError` error when logging table to Comet if `comet-ml` is not installed.

* import comet with other imports

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-12-13 22:08:10 +01:00
cd7156fb34 👀 Add "PaliGemma 🤝 Direct Preference Optimization" in community tutorials (#2475) 2024-12-13 20:29:35 +01:00
ca850be0a2 🕹️ CLI refactor (#2380)
* Refactor main function in dpo.py

* Update setup.py and add cli.py

* Add examples to package data

* style

* Refactor setup.py file

* Add new file t.py

* Move dpo to package

* Update MANIFEST.in and setup.py, refactor trl/cli.py

* Add __init__.py to trl/scripts directory

* Add license header to __init__.py

* File moved instruction

* Add Apache License and update file path

* Move dpo.py to new location

* Refactor CLI and DPO script

* Refactor import structure in scripts package

* env

* rm config from chat arg

* rm old cli

* chat init

* test cli [skip ci]

* Add `datast_config_name` to `ScriptArguments` (#2440)

* add missing arg

* Add test cases for 'trl sft' and 'trl dpo' commands

* Add sft.py script and update cli.py to include sft command

* Move sft script

* chat

* style [ci skip]

* kto

* rm example config

* first step on doc

* see #2442

* see #2443

* fix chat windows

* ©️ Copyrights update (#2454)

* First changes

* Other files

* Finally

* rm comment

* fix nashmd

* Fix example

* Fix example [ci skip]

* 💬 Fix chat for windows (#2443)

* fix chat for windows

* add some tests back

* Revert "add some tests back"

This reverts commit 350aef52f53f8cf34fccd7ad0f78a3dd63867e06.

* 🆔 Add `datast_config` to `ScriptArguments` (#2440)

* datast_config_name

* Update trl/utils.py [ci skip]

* sort import

* typo [ci skip]

* Trigger CI

* Rename `dataset_config_name` to `dataset_config`

* 🏎 Fix deepspeed preparation of `ref_model` in `OnlineDPOTrainer` (#2417)

* Remove unused deepspeed code

* add model prep back

* add deepspeed even if it doesn't work

* rm old code

* Fix config name

* Remove `make dev` in favor of `pip install -e .[dev]`

* Update script paths and remove old symlink related things

* Fix chat script path [ci skip]

* style
2024-12-13 17:52:23 +01:00
179ba53671 🐾 Process-supervised RM Trainer (#2127)
* initial skeleton

* tokenize fn

* adding bos and eos to tokenization fn

* prmtrainer

* fixing small typo in tokenize

* typo in input_ids and labels construction

* numpy dimension

* introduce the stepwise reward trainer

* update markdown files

* let user decide post step separator in config

* doc post_step_separator

* do not add post step_tokens to last step of the reasoning process

* renaming prm to stepwisereward

* formatting

* fix tokenize kwargs

* adapt test to the new post_token args

* adding example script

* fix small typo

* add create_model_card and renaming

* fixing booleans

* Adding the new stepwise_preference instead of placeholders for datasets

* formatting

* Update docs/source/_toctree.yml

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update examples/scripts/stepwise_reward_modeling.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/stepwise_reward_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/stepwise_reward_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* update push to hub

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* step_separator can't be None

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* fix suggested typos

* add citation

* reformat doc

* reordering init

* push to hub prm800k

* changing dataset in example

* change dataset format to align with the sky is blue example

* fix tokenization column names

* fix num labels in openai example

* add support for conversational dataset

* remove training whitespace

* replace tokenizer with processing class

* Update docs/source/dataset_formats.mdx

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* remove openai_prm800k

* Update trl/trainer/stepwise_reward_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/stepwise_reward_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update docs/source/stepwise_reward_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/stepwise_reward_trainer.mdx

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* renaming

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* renaming

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* minor renamings in docs

* using prm800k instead of openai_prm800k

* update num labels to 2 following the new format

* changing doc examples to math examples

* change reference to dataset_formats.mdx

* changing dataset config in test

* remove conversational dataset support

* remove conv dataset support

* fix bos token

* fix scriptarguments in example

* completion to completions

* remove valuerror for step_separator inside steps

* run precommit

* remove conv dataset support

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* renaming zen dataset

* remove unused printing

* unknown label column

* introduce the train on last step arg

* _tokenize support train_on_last_step

* incorporate train_on_last_step to tests

* formatting

* remove comments in trainer

* Refactor `tokenize_row`

* Update max_completion_length parameter in StepwiseRewardConfig

* Collator

* Update comment

* Update type hint

* fix table

* Remove collator

* don't need pad token id

* add error back

* max length args

* use tokenizer arg

* Update doc

* label -> labels

* fixing tokenization issues in tokenize row

* correct labels for token classification

* adding max_length to tokenize_row

* reformat tests

* adding tests for tokenize row

* fixing typos in comments

* update doc

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Add math_shepherd.py script for dataset processing

* split the dataset

* formatting

* same evaluation method for the two training methods

* adding filtering to example script

* formatting

* Add features to avoid casting labels to bool in dataset tokenization

* Update docs/source/stepwise_reward_trainer.mdx [ci skip]

* Add learning_rate parameter to StepwiseRewardConfig class

* update doc

* Remove unused setup_chat_format function

* Fix warning message in stepwise_reward_modeling.py

* Update logging steps in stepwise_reward_trainer.mdx

* little doc change [ci skip]

* Fix copyrights

* fix space after copyrights

* Update dataset loading in stepwise_reward_modeling.py

* refine compute_accuracy and proper test

* fix tests

* style

* renamings

* renaming in init

* doc renaming

* fix sorting and tag

* experiemental [ci skip]

* trigger CI

* other doc fix

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-12-13 15:56:10 +01:00
e3e171a26b 🔨 Support for tools for data utils (#2455)
* function calling training support for SFTTraining

* adding tool support to data_utils

* adding test for function calling tokenizer

* reverting changes to sfttrainer and config,added maybe_apply_chat_template

* arg for maybe_apply_chat_templates docstring

* Doc sectioning

* minor test modification

* minor doc modification

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-12-12 17:11:50 +01:00
b3aff441ff 🎞️ Add "Fine-tuning open AI models using Hugging Face TRL" YouTube video to community tutorials (#2467) 2024-12-12 16:40:28 +01:00
efc687db62 🛠️ Update tests and fix PPO (#2463)
* [bugfix] critic not update

* Update ppo_trainer.py

* Update ppo_trainer.py

* add failing test

* test both policy and critic

* formatting

* fix tests

* formatting

* Update tests/test_ppo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* fix test

---------

Co-authored-by: NINGBENZHE <53843873+NINGBENZHE@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-12-12 12:53:32 +01:00
f2e362656c ⚖️ Add tests_latest.yml workflow file (#2457)
* Add tests_latest.yml workflow file

* don't check the branch

* Fix workflow
2024-12-11 18:11:41 +01:00
c9c4f18039 [bugfix] Fix DataCollatorForChatML unexpected generation prompt (#2450)
* [bugfix] Fix DataCollatorForChatML unexpected generation prompt

* Update utils.py

* Update test_utils.py

* Update tests/test_utils.py

* Update tests/test_utils.py

* Update tests/test_utils.py

* Update tests/test_utils.py

* Update test_utils.py

* Update tests/test_utils.py

* Update tests/test_utils.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-12-11 15:18:54 +01:00
460e780265 👯 Standardize model_args (#2442)
* `model_config` -> `model_args`

* sort
2024-12-10 12:51:20 +01:00
7ba118a229 🏎 Fix deepspeed preparation of ref_model in OnlineDPOTrainer (#2417)
* Remove unused deepspeed code

* add model prep back

* add deepspeed even if it doesn't work

* rm old code
2024-12-10 12:40:13 +01:00
6a05feff02 🆔 Add datast_config to ScriptArguments (#2440)
* datast_config_name

* Update trl/utils.py [ci skip]

* sort import

* typo [ci skip]

* Trigger CI

* Rename `dataset_config_name` to `dataset_config`
2024-12-10 11:09:26 +01:00
2f72f47191 💬 Fix chat for windows (#2443)
* fix chat for windows

* add some tests back

* Revert "add some tests back"

This reverts commit 350aef52f53f8cf34fccd7ad0f78a3dd63867e06.
2024-12-10 10:40:23 +01:00
9410874787 ©️ Copyrights update (#2454)
* First changes

* Other files

* Finally

* rm comment

* fix nashmd

* Fix example

* Fix example [ci skip]
2024-12-10 10:40:00 +01:00
9c5388b69e 🔗 Add "Open in Colab" badges in community tutorials page (#2441) 2024-12-06 10:51:55 +01:00
b02189aaa5 🗂️ Harmonize run and example batch sizes in RLOO docs (#2439)
Doc has different grad_accumulation_steps and per_device_batch size than the actual hyperparameters, can be verified from wandb run.
2024-12-04 19:19:14 +01:00
52201d3c18 🧮 Fix max_steps calculation in RLOOTrainer (#2433) 2024-12-03 21:31:32 +01:00
9ff79a65e3 🔮 Fix unused precomputed ref log probs in DPO (#2431) 2024-12-03 11:36:57 +01:00
9001a8682c 📑 Refactor TrlParser (#2412)
* refactor parser

* Only document some methods

* Update imports in cli_utils.py and remove config option in utils.py

* add `test_parse_args_and_arg_override_config` and remove unnecessary mocks [ci skip]

* fix comment [ci skip]

* fix comment [ci skip]

* Extra arg in config also returned

* fix docstring [ci skip]

* add mock back

* use `deprecate_kwarg`
2024-12-02 19:57:35 +01:00
f6f42651e2 🧑‍🍳 Add precompute batch size argument in DPOTrainer for reference model (#2426)
* added precompute_batch

* review-fixes

* moving up

* Update trl/trainer/dpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_config.py

* Update trl/trainer/dpo_config.py [ci skip]

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-12-02 17:17:41 +01:00
148b592313 Update modeling_base.py (#2419) 2024-11-30 12:14:36 +01:00
d6a8f2c2f6 ⚠️ Add warning guidelines and update codebase to follow best practices (#2350)
* Add guidelines for working with warnings in the codebase

* Remove unnecessary warnings and improve code initialization

* Fix warnings and improve accuracy calculation

* Add rich library dependency for text formatting

* Update LoRA weight loading warning message

* Fix logging and import issues in AlignPropConfig

* Fix warnings and improve code readability

* Remove unused import statements

* Refactor CPOTrainer class in cpo_trainer.py

* Remove unnecessary warnings and raise ValueError for missing model

* Fix warnings and improve code consistency

* Update CONTRIBUTING.md to clarify the purpose of warnings

* Fix string formatting in DataCollatorForCompletionOnlyLM class

* Update SimPO loss parameters in CPOTrainer

* Fix warnings and remove unnecessary code in ConstantLengthDataset class

* Clarify warning guidelines

* Rewrite the entire section

* Fix capitalization in CONTRIBUTING.md

* Fix formatting in CONTRIBUTING.md
2024-11-29 16:07:38 +01:00
8d9cfaafeb 🌋 Add support for LLaVA-Next in DPOTrainer (#2413)
* add support for llava-next in dpotrainer

* enable unit test

* code style

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Ignore last layer in test

---------

Co-authored-by: zesong.cwz <zesong.cwz@taobao.com>
Co-authored-by: 1rubbishyuan <2773496952@qq.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-11-29 15:53:50 +01:00
94e4135a17 🔓 Remove lm_head check in AutoModelForCausalLMWithValueHead (#2398)
* Remove lm_head check in `AutoModelForCausalLMWithValueHead`

* Style

* Remove test
2024-11-29 15:52:35 +01:00
ac267781ec 🌐 Community Tutorials (#2411)
* Add community notebooks to API documentation

* fix extension

* add table of community tutorials

* respond to feedback - fix links and split table

* add class references

* rename file and update toc

* Update docs/source/community_tutorials.md

* Update docs/source/community_tutorials.md

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-11-29 11:39:37 +01:00
2c6e0d9705 Add note about special tokens in chat templates for LoRA SFT (#2414) 2024-11-29 10:35:39 +01:00
e1d781353b 👁️ Added SFT support for SmolVLM models via standalone script sft_vlm_smol_vlm.py (#2409)
* Added SFT VLM script for SmolVLM

* Run make precommit

* Updated command example
2024-11-28 18:45:37 +01:00
a34e9bf84f 🖨 Add Script Utilities section to the documentation (#2407)
* Add script_utils.md to the documentation

* Refactor ScriptArguments class documentation

* Refactor TrlParser class to improve code organization and readability
2024-11-28 16:43:08 +01:00
c10cc8995b 🗝️ Update type hints (#2399)
* New type hint structure

* Update type hints

* Delete wrong file

* Remove dict import
2024-11-26 20:37:27 +01:00
9368dccef6 🐢 Fix slow tests (#2397)
* fix slow CI

* fix dpo

* formatting

* Apply suggestions from code review

* `setup_chat_format` may add a pad token

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-11-26 15:38:46 +01:00
43df3a485a 🧳 Move zen generation script and fix tests (#2393)
* Move zen

* step -> stepwise_supervision

* Fix train_test_split shuffle issue

* Fix tests

* Update tests/test_sft_trainer.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* Fix typo in key name

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-11-26 14:08:06 +01:00
baee06f2e8 🖋️ Fix warning message formatting in KTOTrainer (#2394) 2024-11-26 13:05:25 +01:00
bbd8cbb720 🤐 Fix deprecation warnings (#2395) 2024-11-26 11:29:07 +01:00
4f937c7629 🤐 Fix deprecation warnings (#2392) 2024-11-26 11:18:43 +01:00
16fa13ce72 👮 Deprecate policy in favor of model in PPOTrainer (#2386) 2024-11-26 08:13:10 +01:00
453db5cd79 🤏 New models for tests (#2287)
* first commit

* uncomment

* other tests adaptations

* Remove unused variable in test_setup_chat_format

* Remove unused import statement

* style

* Add Bart model

* Update BCOTrainerTester class in test_bco_trainer.py

* Update model IDs and tokenizers in test files

* Add new models and processors

* Update model IDs in test files

* Fix formatting issue in test_dataset_formatting.py

* Refactor dataset formatting in test_dataset_formatting.py

* Fix dataset sequence length in SFTTrainerTester

* Remove tokenizer

* Remove print statement

* Add reward_model_path and sft_model_path to PPO trainer

* Fix tokenizer padding issue

* Add chat template for testing purposes in PaliGemma model

* Update PaliGemma model and chat template

* Increase learning rate to speed up test

* Update model names in run_dpo.sh and run_sft.sh scripts

* Update model and dataset names

* Fix formatting issue in test_dataset_formatting.py

* Fix formatting issue in test_dataset_formatting.py

* Remove unused chat template

* Update model generation script

* additional models

* Update model references in test files

* Remove unused imports in test_online_dpo_trainer.py

* Add is_llm_blender_available import and update reward_tokenizer

* Refactor test_online_dpo_trainer.py: Move skipped test case decorator

* remove models without chat templates

* Update model names in scripts and tests

* Update model_id in test_modeling_value_head.py

* Update model versions in test files

* Fix formatting issue in test_dataset_formatting.py

* Update embedding model ID in BCOTrainerTester

* Update test_online_dpo_trainer.py with reward model changes

* Update expected formatted text in test_dataset_formatting.py

* Add reward_tokenizer to TestOnlineDPOTrainer

* fix tests

* Add SIMPLE_CHAT_TEMPLATE to T5 tokenizer

* Fix dummy_text format in test_rloo_trainer.py

* Skip outdated test for chatML data collator

* Add new vision language models

* Commented out unused model IDs in test_vdpo_trainer

* Update model and vision configurations in generate_tiny_models.py and test_dpo_trainer.py

* Update model and tokenizer references

* Don't push if it already exists

* Add comment explaining test skip

* Fix model_exists function call and add new models

* Update LlavaForConditionalGeneration model and processor

* `qgallouedec` -> `trl-internal-testing`
2024-11-25 16:31:56 +01:00
ee3cbe1946 💾 Deprecate config in favor of args in PPOTrainer (#2384) 2024-11-25 14:48:08 +01:00
17e8060984 📦 Support for packing tokenized datasets for SFT (#2011)
* feat: add support for packing tokenized datasetS

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* fix: address review comments

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

* feat: add tests for pretokenized dataset packing

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>

---------

Signed-off-by: Mehant Kammakomati <mehant.kammakomati2@ibm.com>
2024-11-25 10:36:58 +01:00
163695e85c 🙈 Suppress warning for estimating tokens in trainers (#2389)
* Suppress warning for estimating tokens in trainer

* Suppress warning for estimating FLOPs in ORPO and Reward trainers
2024-11-24 16:55:43 +01:00
672c96546d Update log method to include start_time parameter (#2381) 2024-11-21 21:30:10 +01:00
bdeb117320 📝 Fix typo in dataset generation script (#2379) 2024-11-21 20:37:44 +01:00
6578fdc101 🔀 Add MergeModelCallBack (#2282)
* Create mergekit_utils.py

* adding mergekit as an optional dependancy

* adding MergeModel to callbacks

* adding mergekit_utils dependencies to callbacks

* setting lower bound for mergekit

* setting mergekit lower band to 0.0.5.1

* adding support for MergeModelCallBack __init__.py

* adding support for mergemodelcallback

* mergemodelcallback tests

* Update callbacks.py

* Update __init__.py

* Update __init__.py

* Update test_callbacks.py

* Update trl/trainer/callbacks.py

removing ## from docs

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/callbacks.py

removing ## from docs

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/trainer/callbacks.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* using different dataset for tests

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/mergekit_utils.py

adding types

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/mergekit_utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Apply suggestions from code review

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* replacing get_last_checkpoint

* renaming Merge to merge_models

* setting mergers default value to linear

* removing unnecessary docs and comments

* adding docstring to Mergeconfig

* adding mergekits link to docstring

* precommit

* removing duplicated import

* typos in mergekit_utils docstring

* fixing tests

* making mergemodelcallback tests optional

* Make import optional

* minor

* use tmp dir in test

* sort

* Add import error checks for mergekit extra

* use a common _merge_and_maybe_push method and compat with windows path

* debug windows

* Update dependencies for mergekit and add test dependencies

* Add assertion to check if merged folder exists in the last checkpoint

* Fix temporary directory cleanup in test_callbacks.py

* Add sys import and skip test for Python versions below 3.10 due to cleanup errors with temp dir

* revert change for debug

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-11-21 14:06:45 +01:00
a0066f47f8 Add start_time to _maybe_log_save_evaluate (#2373) 2024-11-20 12:49:49 +01:00
5626806aef 🧲 Use our own require_bitsandbytes (#2370)
* use our own require_bitsandbytes

* rephrase
2024-11-20 11:51:05 +01:00
bb0afc2459 remove redunant call to eval and train (#2372) 2024-11-20 11:24:41 +01:00
066fc37bd3 Fix dev install (#2369) 2024-11-19 13:30:09 +01:00
b80c1a6fb8 🎲 Move random judges in testing utilities (#2365)
* Update judges and testing utilities

* Update judges in test files

* Update judges in test files
2024-11-18 18:43:18 +01:00
b5eabbeb07 🤝 Mixture of judges (#2159)
* base judge

* adding mixture of judges

* update doc

* update doc

* formatting

* fix small typo in doc

* fix randomcontraintjudge

* replace arxiv by hf papers

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* formatting

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* fix naming in __init__

* run precommi

* adding gold answers to judges

* cgpo llm judges

* fix init

* output type

* adjust booleans in test

* adapt moj doc

* renaming and removing factuality and safety judges

* fix typo in import

* fix small typo in naming

* formatting

* Update trl/trainer/judges.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* update parameter name

* update tests

* update doc

* Update trl/trainer/judges.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update doc

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* fix alltruejudge type

* Refactor judge variable names and update test names

* Clarify judgment logic

* Fix invalid binary judgment check in AllTrueJudge class

* Fix invalid binary judgment check in AllTrueJudge class

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-11-18 16:54:57 +01:00
cbf9abcd07 🗺️ Implementation DiscoPOP Loss (#2323)
* Implement DiscoPOP Loss

* Updated DiscoPOP documentation

* Corrected docs/source/dpo_trainer.mdx

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update docs/source/dpo_trainer.mdx

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/dpo_config.py

* Delete scripts directory

* style

* empty commit

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2024-11-18 14:15:00 +01:00
6f8fe59aeb 📃 Fix description for parameter "generate_during_eval" in dpo_config (#2364) 2024-11-18 14:03:02 +01:00
1293f37c5f 📉 Add PEFT support for PPOTrainer (#2344)
* Add peft/lora support for

* Fix: style

* Fix: typo

* Add ppo.py PEFT example

* Fixed the optional dependencies error

* skip peft test if peft is unavailable

* Update trl/trainer/ppo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-11-18 11:54:09 +01:00
e7870dd5d6 🗃️ Use specified data_collator in RLOOTrainer and PPOTrainer (#2360)
* Fix "Use specified data_collator instead of hard-coding the option"

* Remove query_responses = [] since it's immediately overwritten afterwards.

* Use self.data_collator

* Use specified data_collator instead of hard-coded one in PPOTrainer

* Move the data_collator creation

* Run make precommit
2024-11-18 11:53:47 +01:00
21d5baf338 🔮 Inference mode in GeometricMixtureWrapper.forward (#2345)
* geom mixture model train

* use inference_mode
2024-11-18 09:58:26 +01:00
76dbb1a576 🪜 Stepwise supervision dataset type (#2148) 2024-11-18 09:58:00 +01:00
b8c9d9c7bc ⚖️ Add use_soft_judge option to WinRateCallback (#2347)
* add `use_soft_judge` option to WinRateCallback

* formatting

* Update trl/trainer/callbacks.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* renamed soft_win_rate to avg_win_prob

* Update trl/trainer/callbacks.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* fix tests

* keep orignal

* formatting

* Update tests/test_callbacks.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/callbacks.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update tests/test_callbacks.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update tests/test_callbacks.py

* fix test

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2024-11-15 15:49:43 +01:00
623963126b 👋 Remove deprecated tokenizer argument in BCO, GKD, Iterative SFT, Nash MD and XPO (#2349) 2024-11-12 09:22:17 -04:00
2d24d35013 Adding video llm fine-tuning example (#2336)
* adding video example

* exposing more parameters

* fixing formatting

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-11-12 12:56:38 +01:00
dde20b23cf 🖨️ Fix error text in BCO and KTO tokenizing function (#2286) 2024-11-11 19:18:36 -04:00
015321e135 👈 Add tokenizer arg back and add deprecation guidelines (#2348)
* Add deprecation and backward compatibility guidelines

* Update tokenizer argument in trainer classes

* Add warning message for TRL Judges API
2024-11-11 19:06:20 -04:00
454f36d951 💣 Remove transformers version check (#2343) 2024-11-11 09:34:26 -04:00
9b7f9f3519 🪡 Various RLOO fixes (#2325) 2024-11-11 08:59:03 -04:00
518e29ca9c 🫴 Better guide users in error reporting (#2327)
* update issue template

* Add checklist for bug report template

* Fix formatting in bug report template

* Update bug report template with additional instructions for code formatting and screenshots

* Update bug report template with code formatting instructions

* Update bug report template with code examples

* Update code block placeholder in bug report template

* Update .github/ISSUE_TEMPLATE/bug-report.yml

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-11-11 08:42:16 -04:00
ac7b6cfdfa 🧞 Add output_layer to the list of lm_head_namings in AutoModelForCausalLMWithValueHead (#2328) 2024-11-11 08:16:09 -04:00
0238d96c6f DPO trainer supports num_logits_to_keep to save memory (#2129)
* Support num_logits_to_keep, which computes necessary logits in the forward pass.

* update doc

* bug fix

* update

* check is model supports num_logits_to_keep

* ruff format

* update test file

* peft model support

* test passed

* update

* apply use_num_logits_to_keep

* fix num_logits_to_keep compute bug

* compare all outputs

* pytest

* pass test

* use check_min_version

* format

* test_dpo_trainer_use_num_logits_to_keep passed

* add some comments

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-11-10 11:34:51 +01:00
c86b51cd12 Bump liger-kernel to fix grad acc and more features (#2333)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2024-11-08 12:16:33 +01:00
ac77c09223 Fix gradient_checkpointing_kwargs assignment in examples (#2331)
Co-authored-by: Ping <ping.zhu@jmuse.cn>
2024-11-07 09:28:10 +01:00
7f2ccbe3a2 fix truncating index in DPOTrainer's concatenated_forward() (#2332) 2024-11-07 09:27:32 +01:00
74e20cbbbc 🪪 Check with token_id instead of token in DPOTrainer (#2324) 2024-11-04 21:08:41 +01:00
27b9e3a93f 🪧 Fix slack notification titles (#2322) 2024-11-04 21:02:27 +01:00
dc2b8b9e90 🧽 Fix judge documentation (#2320)
* Bump dev version to `0.13.0.dev0`

* Update version number to 0.12 in CITATION.cff

* 🧽 Fix judge documentation (#2318)

* Update judge examples and documentation

* without ':'

* Clean doc

* Fix typo in example code

* Add space after Attributes

* Update attribute name in judges.py

* Add installation instructions for llm-blender library

* Update PairRMJudge attributes documentation

* Fix return type in PairRMJudge

* Revert "🧽 Fix judge documentation (#2318)"

This reverts commit 337005d95169371935fb87f1c559c7412f8472a4.

* Revert "🧽 Fix judge documentation (#2318)"

This reverts commit 337005d95169371935fb87f1c559c7412f8472a4.

* 🧽 Fix judge documentation (#2318)

* Update judge examples and documentation

* without ':'

* Clean doc

* Fix typo in example code

* Add space after Attributes

* Update attribute name in judges.py

* Add installation instructions for llm-blender library

* Update PairRMJudge attributes documentation

* Fix return type in PairRMJudge
2024-11-04 19:00:27 +01:00
5e90682836 ⚰️ Remove deprecated args, script arguments, and PPOv2 (#2306)
* Remove deprecated args

* Remove deprecated args in SFTTrainer

* Remove deprecated script argument classes

* Remove deprecated PPOv2Config and PPOv2Trainer classes

* Commented out sync_ref_model line in test_trainers_args.py
2024-11-04 16:07:26 +01:00
3b439967f4 📰 Update blog posts in documentation (#2319)
* Bump dev version to `0.13.0.dev0`

* Update version number to 0.12 in CITATION.cff

* Add publication date to blog post

* 🧽 Fix judge documentation (#2318)

* Update judge examples and documentation

* without ':'

* Clean doc

* Fix typo in example code

* Add space after Attributes

* Update attribute name in judges.py

* Add installation instructions for llm-blender library

* Update PairRMJudge attributes documentation

* Fix return type in PairRMJudge

* Revert "🧽 Fix judge documentation (#2318)"

This reverts commit 337005d95169371935fb87f1c559c7412f8472a4.

* Update blog post publication dates

* revert to p5

* Update image URLs in index.mdx

* Sort and uniform thumbnail

* Update image alignment in index.mdx
2024-11-04 16:00:27 +01:00
2f34a161cd Bump dev version to 0.13.0.dev0 (#2305)
* Bump dev version to `0.13.0.dev0`

* Update version number to 0.12 in CITATION.cff

* 🧽 Fix judge documentation (#2318)

* Update judge examples and documentation

* without ':'

* Clean doc

* Fix typo in example code

* Add space after Attributes

* Update attribute name in judges.py

* Add installation instructions for llm-blender library

* Update PairRMJudge attributes documentation

* Fix return type in PairRMJudge

* Revert "🧽 Fix judge documentation (#2318)"

This reverts commit 337005d95169371935fb87f1c559c7412f8472a4.
2024-11-04 15:59:52 +01:00
337 changed files with 52464 additions and 24648 deletions

View File

@ -7,36 +7,7 @@ body:
value: |
Thanks for taking the time to fill out this bug report! 🤗
Before you submit your bug report:
- If it is your first time submitting, be sure to check our [bug report guidelines](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#did-you-find-a-bug)
- type: textarea
id: system-info
attributes:
label: System Info
description: Please share your system info with us. You can run the command `trl env` and copy-paste its output below.
placeholder: trl version, transformers version, platform, python version, ...
validations:
required: true
- type: checkboxes
id: information-scripts-examples
attributes:
label: Information
description: 'The problem arises when using:'
options:
- label: "The official example scripts"
- label: "My own modified scripts"
- type: checkboxes
id: information-tasks
attributes:
label: Tasks
description: "The tasks I am working on are:"
options:
- label: "An officially supported task in the `examples` folder"
- label: "My own task or dataset (give details below)"
🚩 If it is your first time submitting, be sure to check our [bug report guidelines](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#did-you-find-a-bug)
- type: textarea
id: reproduction
@ -50,18 +21,47 @@ body:
Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
placeholder: |
Steps to reproduce the behavior:
value: |
```python
from trl import ...
1.
2.
3.
```
outputs:
```
Traceback (most recent call last):
File "example.py", line 42, in <module>
...
```
- type: textarea
id: expected-behavior
id: system-info
attributes:
label: System Info
description: |
Please provide information about your system: platform, Python version, PyTorch version, Transformers version, devices, TRL version, ...
You can get this information by running `trl env` in your terminal.
placeholder: Copy-paste the output of `trl env`
validations:
required: true
- type: checkboxes
id: terms
attributes:
label: Expected behavior
description: "A clear and concise description of what you would expect to happen."
label: Checklist
description: |
Before submitting, please confirm that you've completed each of the following.
If an item doesn't apply to your issue, check it anyway to show you've reviewed it.
options:
- label: "I have checked that my issue isn't already filed (see [open issues](https://github.com/huggingface/trl/issues?q=is%3Aissue))"
required: true
- label: "I have included my system information"
required: true
- label: "Any code provided is minimal, complete, and reproducible ([more on MREs](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))"
required: true
- label: "Any code provided is properly formatted in code blocks, (no screenshot, [more on code blocks](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))"
required: true
- label: "Any traceback provided is complete"
required: true

View File

@ -21,8 +21,7 @@ Fixes # (issue)
Pull Request section?
- [ ] Was this discussed/approved via a GitHub issue? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the
[documentation guidelines](https://github.com/huggingface/trl/tree/main/docs).
- [ ] Did you make sure to update the documentation with your changes?
- [ ] Did you write any new necessary tests?

19
.github/codeql/custom-queries.qls vendored Normal file
View File

@ -0,0 +1,19 @@
import codeql
from WorkflowString interpolation, Workflow workflow
where
interpolation.getStringValue().matches("${{ github.event.issue.title }}") or
interpolation.getStringValue().matches("${{ github.event.issue.body }}") or
interpolation.getStringValue().matches("${{ github.event.pull_request.title }}") or
interpolation.getStringValue().matches("${{ github.event.pull_request.body }}") or
interpolation.getStringValue().matches("${{ github.event.review.body }}") or
interpolation.getStringValue().matches("${{ github.event.comment.body }}") or
interpolation.getStringValue().matches("${{ github.event.inputs.* }}") or
interpolation.getStringValue().matches("${{ github.event.head_commit.message }}")
interpolation.getStringValue().matches("${{ github.event.* }}") and
(
step.getKey() = "run" or // Injection in run
step.getKey() = "env" or // Injection via env
step.getKey() = "with" // Injection via with
)
select workflow, "🚨 Do not use directly as input of action"

View File

@ -14,6 +14,5 @@ jobs:
commit_sha: ${{ github.sha }}
package: trl
version_tag_suffix: ""
custom_container: huggingface/transformers-doc-builder
secrets:
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}

View File

@ -9,10 +9,10 @@ concurrency:
jobs:
build:
if: github.event.pull_request.draft == false
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
with:
commit_sha: ${{ github.event.pull_request.head.sha }}
pr_number: ${{ github.event.number }}
package: trl
version_tag_suffix: ""
custom_container: huggingface/transformers-doc-builder

26
.github/workflows/codeQL.yml vendored Normal file
View File

@ -0,0 +1,26 @@
name: "CodeQL Analysis - Workflows"
on:
workflow_dispatch:
jobs:
analyze:
name: "Analyze GitHub Workflows"
runs-on: ubuntu-latest
permissions:
security-events: write
actions: read
contents: read
steps:
- name: "Checkout repository"
uses: actions/checkout@v4
- name: "Initialize CodeQL"
uses: github/codeql-action/init@v2
with:
languages: "yaml"
queries: +security-and-quality, ./.github/codeql/custom-queries.qls
- name: "Perform CodeQL Analysis"
uses: github/codeql-action/analyze@v2

View File

@ -1,95 +1,84 @@
name: Build Docker images (scheduled)
name: Build TRL Docker image
on:
push:
branches:
- main
workflow_dispatch:
workflow_call:
schedule:
- cron: "0 1 * * *"
concurrency:
group: docker-image-builds
cancel-in-progress: false
env:
CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
jobs:
trl-latest:
name: "Latest TRL GPU"
trl:
name: "Build and push TRL Docker image"
runs-on: ubuntu-latest
steps:
- name: Cleanup disk
run: |
sudo ls -l /usr/local/lib/
sudo ls -l /usr/share/
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
sudo rm -rf /usr/local/lib/android
sudo rm -rf /usr/share/dotnet
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Check out code
- name: Checkout code
uses: actions/checkout@v4
- name: Get TRL version from PyPI
run: |
VERSION=$(curl -s https://pypi.org/pypi/trl/json | jq -r .info.version)
echo "VERSION=$VERSION" >> $GITHUB_ENV
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
uses: docker/login-action@v1
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push GPU
- name: Build and Push
uses: docker/build-push-action@v4
with:
context: ./docker/trl-latest-gpu
context: docker/trl
push: true
tags: huggingface/trl-latest-gpu
tags: |
huggingface/trl:${{ env.VERSION }}
huggingface/trl
- name: Post to Slack
if: always()
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: 🤗 Results of the trl-latest-gpu Docker Image build
slack_channel: ${{ secrets.CI_DOCKER_CHANNEL }}
title: 🤗 Results of the TRL Dev Docker Image build
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
trl-source:
name: "Latest TRL + HF ecosystem from source"
trl-dev:
name: "Build and push TRL Dev Docker image"
runs-on: ubuntu-latest
steps:
- name: Cleanup disk
run: |
sudo ls -l /usr/local/lib/
sudo ls -l /usr/share/
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
sudo rm -rf /usr/local/lib/android
sudo rm -rf /usr/share/dotnet
sudo du -sh /usr/local/lib/
sudo du -sh /usr/share/
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Check out code
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
uses: docker/login-action@v1
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_PASSWORD }}
- name: Build and Push GPU
- name: Build and Push
uses: docker/build-push-action@v4
with:
context: ./docker/trl-source-gpu
context: docker/trl-dev
push: true
tags: huggingface/trl-source-gpu
tags: |
huggingface/trl:dev
- name: Post to Slack
if: always()
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: 🤗 Results of the trl-source-gpu Docker Image build
slack_channel: ${{ secrets.CI_DOCKER_CHANNEL }}
title: 🤗 Results of the TRL Dev Docker Image build
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}

View File

@ -0,0 +1,15 @@
name: "Hugging Face Issue Labeler"
on:
issues:
types: opened
jobs:
triage:
runs-on: ubuntu-latest
permissions:
issues: write
steps:
- uses: actions/checkout@v3
- uses: August-murr/auto-labeler@main
with:
hf-api-key: ${{ secrets.CI_HF_API_TOKEN }}

127
.github/workflows/pr_style_bot.yml vendored Normal file
View File

@ -0,0 +1,127 @@
name: PR Style Bot
on:
workflow_dispatch:
permissions:
contents: write
pull-requests: write
jobs:
run-style-bot:
if: >
contains(github.event.comment.body, '@bot /style') &&
github.event.issue.pull_request != null
runs-on: ubuntu-latest
steps:
- name: Extract PR details
id: pr_info
uses: actions/github-script@v6
with:
script: |
const prNumber = context.payload.issue.number;
const { data: pr } = await github.rest.pulls.get({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: prNumber
});
// We capture both the branch ref and the "full_name" of the head repo
// so that we can check out the correct repository & branch (including forks).
core.setOutput("prNumber", prNumber);
core.setOutput("headRef", pr.head.ref);
core.setOutput("headRepoFullName", pr.head.repo.full_name);
- name: Check out PR branch
uses: actions/checkout@v3
env:
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
HEADREF: ${{ steps.pr_info.outputs.headRef }}
with:
# Instead of checking out the base repo, use the contributor's repo name
repository: ${{ env.HEADREPOFULLNAME }}
ref: ${{ env.HEADREF }}
# You may need fetch-depth: 0 for being able to push
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Debug
env:
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
HEADREF: ${{ steps.pr_info.outputs.headRef }}
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
run: |
echo "PR number: ${{ env.PRNUMBER }}"
echo "Head Ref: ${{ env.HEADREF }}"
echo "Head Repo Full Name: ${{ env.HEADREPOFULLNAME }}"
- name: Set up Python
uses: actions/setup-python@v4
- name: Install dependencies
run: |
pip install ruff pre-commit
- name: Download Makefile from main branch
run: |
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/trl/main/Makefile
- name: Compare Makefiles
run: |
if ! diff -q main_Makefile Makefile; then
echo "Error: The Makefile has changed. Please ensure it matches the main branch."
exit 1
fi
echo "No changes in Makefile. Proceeding..."
rm -rf main_Makefile
- name: Run make style and make quality
run: |
make precommit || true
- name: Commit and push changes
id: commit_and_push
env:
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
HEADREF: ${{ steps.pr_info.outputs.headRef }}
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
echo "HEADREPOFULLNAME: ${{ env.HEADREPOFULLNAME }}, HEADREF: ${{ env.HEADREF }}"
# Configure git with the Actions bot user
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
# Make sure your 'origin' remote is set to the contributor's fork
git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${{ env.HEADREPOFULLNAME }}.git"
# If there are changes after running style/quality, commit them
if [ -n "$(git status --porcelain)" ]; then
git add .
git commit -m "Apply style fixes"
# Push to the original contributor's forked branch
git push origin HEAD:${{ env.HEADREF }}
echo "changes_pushed=true" >> $GITHUB_OUTPUT
else
echo "No changes to commit."
echo "changes_pushed=false" >> $GITHUB_OUTPUT
fi
- name: Comment on PR with workflow run link
if: steps.commit_and_push.outputs.changes_pushed == 'true'
uses: actions/github-script@v6
with:
script: |
const prNumber = parseInt(process.env.prNumber, 10);
const runUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
body: `Style fixes have been applied. [View the workflow run here](${runUrl}).`
});
env:
prNumber: ${{ steps.pr_info.outputs.prNumber }}

43
.github/workflows/publish.yml vendored Normal file
View File

@ -0,0 +1,43 @@
name: Publish to PyPI
on:
push:
branches:
- main
- v*-release
paths:
- "VERSION"
jobs:
publish:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Read version
id: get_version
run: echo "version=$(cat VERSION)" >> $GITHUB_OUTPUT
- name: Debug - Show version.txt content
run: echo "Version is ${{ steps.get_version.outputs.version }}"
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.x"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build twine
- name: Build package
run: python -m build
- name: Publish to PyPI
if: ${{ !contains(steps.get_version.outputs.version, 'dev') }}
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
run: |
python -m twine upload dist/*

View File

@ -2,7 +2,7 @@ name: Slow tests (on push)
on:
push:
branches: [ main ]
branches: [main]
paths:
# Run only when python files are modified
- "trl/**.py"
@ -12,87 +12,100 @@ env:
IS_GITHUB_CI: "1"
SLACK_API_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
jobs:
run_all_tests_single_gpu:
strategy:
fail-fast: false
matrix:
docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"]
runs-on:
group: aws-g4dn-2xlarge
env:
CUDA_VISIBLE_DEVICES: "0"
TEST_TYPE: "single_gpu_${{ matrix.docker-image-name }}"
TEST_TYPE: "single_gpu"
container:
image: ${{ matrix.docker-image-name }}
options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
options: --gpus all --shm-size "16gb"
defaults:
run:
shell: bash
steps:
- uses: actions/checkout@v4
- name: Pip install
- name: Git checkout
uses: actions/checkout@v4
- name: Install system dependencies
run: |
source activate trl
pip install -e ".[test]" --no-deps
pip install pytest-reportlog parameterized
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
source .venv/bin/activate
uv pip install ".[dev]"
uv pip install pytest-reportlog parameterized
- name: Run slow SFT tests on single GPU
if: always()
run: |
source activate trl
source .venv/bin/activate
make slow_tests
- name: Generate Report
if: always()
run: |
pip install slack_sdk tabulate
source .venv/bin/activate
uv pip install slack_sdk tabulate
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
run_all_tests_multi_gpu:
strategy:
fail-fast: false
matrix:
docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"]
runs-on:
group: aws-g4dn-2xlarge
env:
CUDA_VISIBLE_DEVICES: "0,1"
TEST_TYPE: "multi_gpu_${{ matrix.docker-image-name }}"
TEST_TYPE: "multi_gpu"
container:
image: ${{ matrix.docker-image-name }}
options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
options: --gpus all --shm-size "16gb"
defaults:
run:
shell: bash
steps:
- uses: actions/checkout@v4
- name: Pip install
- name: Git checkout
uses: actions/checkout@v4
- name: Install system dependencies
run: |
source activate trl
pip install -e ".[test]" --no-deps
pip install pytest-reportlog parameterized
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
source .venv/bin/activate
uv pip install ".[dev]"
uv pip install pytest-reportlog parameterized
- name: Run slow SFT tests on Multi GPU
if: always()
run: |
source activate trl
source .venv/bin/activate
make slow_tests
- name: Run end-to-end examples tests on multi GPU
if: always()
run: |
source activate trl
pip install deepspeed
make test_examples
- name: Generate Reports
if: always()
run: |
pip install slack_sdk tabulate
source .venv/bin/activate
uv pip install slack_sdk tabulate
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
python scripts/log_example_reports.py --text_file_name temp_results_sft_tests.txt >> $GITHUB_STEP_SUMMARY
python scripts/log_example_reports.py --text_file_name temp_results_dpo_tests.txt >> $GITHUB_STEP_SUMMARY
rm *.txt

View File

@ -11,21 +11,20 @@ on:
- "scripts/**.py"
- "tests/**.py"
- "trl/**.py"
- "setup.py"
- "pyproject.toml"
env:
TQDM_DISABLE: 1
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True"
jobs:
check_code_quality:
name: Check code quality
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: recursive
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
@ -38,126 +37,217 @@ jobs:
name: Tests
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
os: ['ubuntu-latest', 'windows-latest']
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
fail-fast: false
runs-on: ${{ matrix.os }}
runs-on:
group: aws-g4dn-2xlarge
container:
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
options: --gpus all
defaults:
run:
shell: bash
if: github.event.pull_request.draft == false
steps:
- uses: actions/checkout@v4
- name: Git checkout
uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "pip"
cache-dependency-path: |
setup.py
requirements.txt
- name: Install Make and Git
run: |
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[dev]"
source .venv/bin/activate
uv pip install ".[dev]"
- name: Test with pytest
run: |
source .venv/bin/activate
make test
- name: Post to Slack
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results with ${{ matrix.python-version }} on ${{ matrix.os }} with lastest dependencies
title: Results with Python ${{ matrix.python-version }} and latest dependencies
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
tests_dev:
name: Tests with dev dependencies
runs-on: 'ubuntu-latest'
runs-on:
group: aws-g4dn-2xlarge
container:
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
options: --gpus all
defaults:
run:
shell: bash
if: github.event.pull_request.draft == false
steps:
- uses: actions/checkout@v4
- name: Git checkout
uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: "pip"
cache-dependency-path: |
setup.py
requirements.txt
- name: Install Make and Git
run: |
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -U git+https://github.com/huggingface/accelerate.git
python -m pip install -U git+https://github.com/huggingface/datasets.git
python -m pip install -U git+https://github.com/huggingface/transformers.git
python -m pip install ".[dev]"
source .venv/bin/activate
uv pip install ".[dev]"
uv pip install -U git+https://github.com/huggingface/accelerate.git
uv pip install -U git+https://github.com/huggingface/datasets.git
uv pip install -U git+https://github.com/huggingface/transformers.git
uv pip install -U git+https://github.com/huggingface/peft.git
- name: Test with pytest
run: |
source .venv/bin/activate
make test
- name: Post to Slack
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results with ${{ matrix.python-version }} on ${{ matrix.os }} with dev dependencies
title: Results with Python 3.12 and dev dependencies
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
tests_wo_optional_deps:
name: Tests without optional dependencies
runs-on: 'ubuntu-latest'
runs-on:
group: aws-g4dn-2xlarge
container:
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
options: --gpus all
defaults:
run:
shell: bash
if: github.event.pull_request.draft == false
steps:
- uses: actions/checkout@v4
- name: Git checkout
uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: "pip"
cache-dependency-path: |
setup.py
requirements.txt
- name: Install Make and Git
run: |
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[test]"
source .venv/bin/activate
uv pip install ".[test]"
- name: Test with pytest
run: |
source .venv/bin/activate
make test
- name: Post to Slack
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results with ${{ matrix.python-version }} on ${{ matrix.os }} without optional dependencies
title: Results with Python 3.12 without optional dependencies
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
tests_min_versions:
name: Tests with minimum versions
runs-on: 'ubuntu-latest'
runs-on:
group: aws-g4dn-2xlarge
container:
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
options: --gpus all
defaults:
run:
shell: bash
if: github.event.pull_request.draft == false
steps:
- uses: actions/checkout@v4
- name: Git checkout
uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: "pip"
cache-dependency-path: |
setup.py
requirements.txt
- name: Install Make and Git
run: |
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install accelerate==0.34.0
python -m pip install datasets==2.21.0
python -m pip install transformers==4.46.0
python -m pip install ".[dev]"
source .venv/bin/activate
uv pip install ".[dev]"
uv pip install accelerate==1.4.0
uv pip install datasets==3.0.0
uv pip install transformers==4.56.1
- name: Test with pytest
run: |
source .venv/bin/activate
make test
- name: Post to Slack
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results with ${{ matrix.python-version }} on ${{ matrix.os }} with minimum versions
title: Results with Python 3.12 and minimum dependencies versions
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}

66
.github/workflows/tests_latest.yml vendored Normal file
View File

@ -0,0 +1,66 @@
name: Tests latest TRL release with dev dependencies
on:
schedule:
- cron: '0 0 * * *' # Runs daily at midnight UTC
workflow_dispatch:
env:
TQDM_DISABLE: 1
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
jobs:
tests:
name: Tests latest TRL release with dev dependencies
runs-on:
group: aws-g4dn-2xlarge
container:
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
options: --gpus all
defaults:
run:
shell: bash
steps:
- name: Git checkout
uses: actions/checkout@v4
with: { ref: v0.24-release }
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install Make and Git
run: |
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
source .venv/bin/activate
uv pip install ".[dev]"
uv pip install -U git+https://github.com/huggingface/accelerate.git
uv pip install -U git+https://github.com/huggingface/datasets.git
uv pip install -U git+https://github.com/huggingface/transformers.git
- name: Test with pytest
run: |
source .venv/bin/activate
make test
- name: Post to Slack
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results of latest TRL with Python 3.12 and dev dependencies
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}

View File

@ -12,4 +12,7 @@ jobs:
with:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main
uses: trufflesecurity/trufflehog@853e1e8d249fd1e29d0fcc7280d29b03df3d643d
with:
# exclude buggy postgres detector that is causing false positives and not relevant to our codebase
extra_args: --results=verified,unknown --exclude-detectors=postgres

3
.gitignore vendored
View File

@ -143,6 +143,3 @@ checklink/cookies.txt
nbs/wandb/
examples/notebooks/wandb/
wandb/
# cli scripts that are symlinked from `examples/scripts`
trl/commands/scripts/

View File

@ -1,8 +1,8 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.3
rev: v0.11.10
hooks:
- id: ruff
- id: ruff-check
types_or: [ python, pyi ]
args: [ --fix ]
- id: ruff-format

View File

@ -31,4 +31,4 @@ keywords:
- pytorch
- transformers
license: Apache-2.0
version: 0.11.1
version: "0.24"

View File

@ -1,15 +1,10 @@
# How to contribute to TRL?
Everyone is welcome to contribute, and we value everybody's contribution. Code
contributions are not the only way to help the community. Answering questions, helping
others, and improving the documentation are also immensely valuable.
Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable.
It also helps us if you spread the word! Reference the library in blog posts
about the awesome projects it made possible, shout out on Twitter every time it has
helped you, or simply ⭐️ the repository to say thank you.
It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you.
However you choose to contribute, please be mindful and respect our
[code of conduct](https://github.com/huggingface/trl/blob/main/CODE_OF_CONDUCT.md).
However you choose to contribute, please be mindful and respect our [code of conduct](https://github.com/huggingface/trl/blob/main/CODE_OF_CONDUCT.md).
**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).**
@ -22,9 +17,7 @@ There are several ways you can contribute to TRL:
* Implement trainers for new post-training algorithms.
* Contribute to the examples or the documentation.
If you don't know where to start, there is a special [Good First
Issue](https://github.com/huggingface/trl/contribute) listing. It will give you a list of
open issues that are beginner-friendly and help you start contributing to open-source. The best way to do that is to open a Pull Request and link it to the issue that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over.
If you don't know where to start, there is a special [Good First Issue](https://github.com/huggingface/trl/labels/%F0%9F%91%B6%20good%20first%20issue) listing. It will give you a list of open issues that are beginner-friendly and help you start contributing to open-source. The best way to do that is to open a Pull Request and link it to the issue that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over.
For something slightly more challenging, you can also take a look at the [Good Second Issue](https://github.com/huggingface/trl/labels/Good%20Second%20Issue) list. In general though, if you feel like you know what you're doing, go for it and we'll help you get there! 🚀
@ -33,12 +26,12 @@ For something slightly more challenging, you can also take a look at the [Good S
Before you start contributing make sure you have installed all the dev tools:
```bash
make dev
pip install -e .[dev]
```
## Fixing outstanding issues
If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](#create-a-pull-request) and open a Pull Request!
If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](#submitting-a-pull-request-pr) and open a Pull Request!
## Submitting a bug-related issue or feature request
@ -48,14 +41,12 @@ Do your best to follow these guidelines when submitting a bug-related issue or a
The TRL library is robust and reliable thanks to users who report the problems they encounter.
Before you report an issue, we would really appreciate it if you could **make sure the bug was not
already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code.
Before you report an issue, we would really appreciate it if you could **make sure the bug was not already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code.
Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so we can quickly resolve it:
* Your **OS type and version**, **Python**, **PyTorch**, **TRL** and **Transformers** versions.
* A short, self-contained, code snippet that allows us to reproduce the bug in
less than 30s.
* A short, self-contained, code snippet that allows us to reproduce the bug in less than 30s.
* The *full* traceback if an exception is raised.
* Attach any other additional information, like screenshots, you think may help.
@ -106,29 +97,20 @@ We're always looking for improvements to the documentation that make it more cle
## Submitting a pull request (PR)
Before writing code, we strongly advise you to search through the existing PRs or
issues to make sure that nobody is already working on the same thing. If you are
unsure, it is always a good idea to open an issue to get some feedback.
Before writing code, we strongly advise you to search through the existing PRs or issues to make sure that nobody is already working on the same thing. If you are unsure, it is always a good idea to open an issue to get some feedback.
You will need basic `git` proficiency to be able to contribute to
TRL. `git` is not the easiest tool to use but it has the greatest
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
Git](https://git-scm.com/book/en/v2) is a very good reference.
You will need basic `git` proficiency to be able to contribute to TRL. `git` is not the easiest tool to use but it has the greatest manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro Git](https://git-scm.com/book/en/v2) is a very good reference.
Follow these steps to start contributing:
1. Fork the [repository](https://github.com/huggingface/trl) by
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
under your GitHub user account.
1. Fork the [repository](https://github.com/huggingface/trl) by clicking on the 'Fork' button on the repository's page. This creates a copy of the code under your GitHub user account.
2. Clone your fork to your local disk, and add the base repository as a remote. The following command
assumes you have your public SSH key uploaded to GitHub. See the following guide for more
[information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
2. Clone your fork to your local disk, and add the base repository as a remote. The following command assumes you have your public SSH key uploaded to GitHub. See the following guide for more [information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
```bash
$ git clone git@github.com:<your Github handle>/trl.git
$ cd trl
$ git remote add upstream https://github.com/huggingface/trl.git
git clone git@github.com:<your Github handle>/trl.git
cd trl
git remote add upstream https://github.com/huggingface/trl.git
```
3. Create a new branch to hold your development changes, and do this for every new PR you work on.
@ -136,15 +118,15 @@ Follow these steps to start contributing:
Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):
```bash
$ git checkout main
$ git fetch upstream
$ git merge upstream/main
git checkout main
git fetch upstream
git merge upstream/main
```
Once your `main` branch is synchronized, create a new branch from it:
```bash
$ git checkout -b a-descriptive-name-for-my-changes
git checkout -b a-descriptive-name-for-my-changes
```
**Do not** work on the `main` branch.
@ -152,32 +134,27 @@ Follow these steps to start contributing:
4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library:
```bash
$ make dev
pip install -e .[dev]
```
(If TRL was already installed in the virtual environment, remove
it with `pip uninstall trl` before reinstalling it.)
(If TRL was already installed in the virtual environment, remove it with `pip uninstall trl` before reinstalling it.)
Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using
the provided Dev Container. Documentation on how to get started with dev containers is available [here](https://code.visualstudio.com/docs/remote/containers).
Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using the provided Dev Container. Check [the documentation on how to get started with dev containers](https://code.visualstudio.com/docs/remote/containers).
5. Develop the features on your branch.
As you work on the features, you should make sure that the test suite
passes. You should run the tests impacted by your changes like this (see
below an explanation regarding the environment variable):
As you work on the features, you should make sure that the test suite passes. You should run the tests impacted by your changes like this (see below an explanation regarding the environment variable):
```bash
$ pytest tests/<TEST_TO_RUN>.py
pytest tests/<TEST_TO_RUN>.py
```
> For the following commands leveraging the `make` utility, we recommend using the WSL system when running on
> Windows. More information [here](https://docs.microsoft.com/en-us/windows/wsl/about).
> For the following commands leveraging the `make` utility.
You can also run the full suite with the following command.
```bash
$ make test
make test
```
TRL relies on `ruff` for maintaining consistent code formatting across its source files. Before submitting any PR, you should apply automatic style corrections and run code verification checks.
@ -187,21 +164,21 @@ Follow these steps to start contributing:
To apply these checks and corrections in one step, use:
```bash
$ make precommit
make precommit
```
This command runs the following:
- Executes `pre-commit` hooks to automatically fix style issues with `ruff` and other tools.
- Runs additional scripts such as adding copyright information.
* Executes `pre-commit` hooks to automatically fix style issues with `ruff` and other tools.
* Runs additional scripts such as adding copyright information.
If you prefer to apply the style corrections separately or review them individually, the `pre-commit` hook will handle the formatting for the files in question.
Once you're happy with your changes, add changed files using `git add` and
make a commit with `git commit` to record your changes locally:
Once you're happy with your changes, add changed files using `git add` and make a commit with `git commit` to record your changes locally:
```bash
$ git add modified_file.py
$ git commit
git add modified_file.py
git commit
```
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
@ -210,36 +187,28 @@ Follow these steps to start contributing:
repository regularly. This way you can quickly account for changes:
```bash
$ git fetch upstream
$ git rebase upstream/main
git fetch upstream
git rebase upstream/main
```
Push the changes to your account using:
```bash
$ git push -u origin a-descriptive-name-for-my-changes
git push -u origin a-descriptive-name-for-my-changes
```
6. Once you are satisfied (**and the checklist below is happy too**), go to the
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
to the project maintainers for review.
6. Once you are satisfied (**and the checklist below is happy too**), go to the webpage of your fork on GitHub. Click on 'Pull request' to send your changes to the project maintainers for review.
7. It's ok if maintainers ask you for changes. It happens to core contributors too! To ensure everyone can review your changes in the pull request, work on your local branch and push the updates to your fork. They will automatically appear in the pull request.
### Checklist
1. The title of your pull request should be a summary of its contribution;
2. If your pull request addresses an issue, please mention the issue number in
the pull request description to make sure they are linked (and people
consulting the issue know you are working on it);
3. To indicate a work in progress please prefix the title with `[WIP]`, or mark
the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate
it from PRs ready to be merged;
2. If your pull request addresses an issue, please mention the issue number in the pull request description to make sure they are linked (and people consulting the issue know you are working on it);
3. To indicate a work in progress please prefix the title with `[WIP]`, or mark the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate it from PRs ready to be merged;
4. Make sure existing tests pass;
5. Add high-coverage tests. No quality testing = no merge.
### Tests
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in
@ -249,10 +218,211 @@ We use `pytest` to run the tests. From the root of the
repository here's how to run tests with `pytest` for the library:
```bash
$ python -m pytest -sv ./tests
python -m pytest -sv ./tests
```
That's how `make test` is implemented (without the `pip install` line)!
You can specify a smaller set of tests to test only the feature
you're working on.
### Default values guidelines
1. **Use defaults when appropriate**:
Provide default values unless the parameter's value varies significantly by use case. For example, datasets or models should not have defaults, but parameters like `learning_rate` should.
2. **Prioritize proven defaults**:
Default values should align with those recommended in the original paper or method. Alternatives require strong evidence of superior performance in most cases.
3. **Ensure safety and predictability**:
Defaults must be safe, expected and reliable. Avoid settings that could lead to surprising outcomes, such as excessive memory usage or poor performance in edge cases.
4. **Balance consistency and flexibility**:
Aim for consistent defaults across similar functions or methods. However, consistency should not be preferred to point 2 or 3.
5. **Opt-in for new features**:
Do not enable new features or improvements (e.g., novel loss functions) by default. Users should explicitly opt-in to use these.
### Writing documentation
High-quality documentation is crucial for maintaining a project that is easy to use, understand, and extend. When adding new features, ensure they are thoroughly documented to maintain consistency and clarity throughout the project.
To illustrate what good documentation looks like, heres an example of a well-documented function:
````python
def replicate_str(string: str, n: int, sep: str = " ") -> str:
r"""
Replicate a string `n` times with a separator.
Args:
string (`str`):
String to replicate.
n (`int`):
Number of times to replicate the string.
sep (`str`, *optional*, defaults to `" "`):
Separator to use between each replication.
Returns:
`str`: The replicated string.
Examples:
```python
>>> replicate_str("hello", 3)
"hello hello hello"
>>> replicate_str("hello", 3, sep=", ")
"hello, hello, hello"
```
"""
return sep.join([string] * n)
````
* **Line Wrapping:** Applied a consistent line wrap at column 120 to improve readability.
* **Definite Articles:** Removed definite articles where possible to streamline language. (Eg: Changed "The string to replicate" to "String to replicate")
* **Type Annotations:**
* Always include type definitions, indicating if a parameter is optional and specifying the default value.
* Note that `Optional` means that the value can be `None`, and `*optional*` means that it is not required for the user to pass a value.
E.g., for arguments that can't be `None` and aren't required:
```txt
foo (`int`, *optional*, defaults to `4`):
```
For arguments that can be `None` and are required:
```txt
foo (`Optional[int]`):
```
for arguments that can be `None` and aren't required (in this case, if the default value is `None`, you can omit it):
```txt
foo (`Optional[int]`, *optional*):
```
* **String Defaults:**
* Ensured that default string values are wrapped in double quotes:
```txt
defaults to `"foo"`
```
* **Dictionary Typing:**
* Replaced generic `dict` type hints with more explicit `dict[str, Any]` to clarify expected key-value pairs.
* **Default Value Formatting:**
* Consistently surrounded default values with backticks for improved formatting:
```txt
defaults to `4`
```
* **Sub-sectioning:** When the number of arguments is large, consider breaking them into sub-sections for better readability.
```python
def calculate_statistics(data: list[float], precision: int = 2, include_variance: bool = False) -> dict[str, float]:
r"""
Calculates basic statistics for a given dataset.
Args:
> Data inputs
data (`list[float]`):
A list of numerical values to analyze.
> Configuration parameters
precision (`int`, *optional*, defaults to `2`):
Number of decimal places to round the results.
include_variance (`bool`, *optional*, defaults to `False`):
Whether to include the variance of the dataset in the results.
Returns:
`dict[str, float]`:
A dictionary containing calculated statistics such as mean, median, and optionally variance.
"""
...
```
### Deprecation and backward compatibility
Our approach to deprecation and backward compatibility is flexible and based on the features usage and impact. Each deprecation is carefully evaluated, aiming to balance innovation with user needs.
When a feature or component is marked for deprecation, its use will emit a warning message. This warning will include:
* **Transition Guidance**: Instructions on how to migrate to the alternative solution or replacement.
* **Removal Version**: The target version when the feature will be removed, providing users with a clear timeframe to transition.
Example:
```python
warnings.warn(
"The `Trainer.foo` method is deprecated and will be removed in version 0.14.0. "
"Please use the `Trainer.bar` class instead.",
FutureWarning,
)
```
The deprecation and removal schedule is based on each feature's usage and impact, with examples at two extremes:
* **Experimental or Low-Use Features**: For a feature that is experimental or has limited usage, backward compatibility may not be maintained between releases. Users should therefore anticipate potential breaking changes from one version to the next.
* **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning.
These examples represent the two ends of a continuum. The specific timeline for each feature will be determined individually, balancing innovation with user stability needs.
### Working with warnings
Warnings play a critical role in guiding users toward resolving potential issues, but they should be used thoughtfully to avoid unnecessary noise. Unlike logging, which provides informational context or operational details, warnings signal conditions that require attention and action. Overusing warnings can dilute their importance, leading users to ignore them entirely.
#### Definitions
* **Correct**: An operation is correct if it is valid, follows the intended approach, and aligns with the current best practices or guidelines within the codebase. This is the recommended or intended way to perform the operation.
* **Supported**: An operation is supported if it is technically valid and works within the current codebase, but it may not be the most efficient, optimal, or recommended way to perform the task. This includes deprecated features or legacy approaches that still work but may be phased out in the future.
#### Choosing the right message
* **Correct → No warning**:
If the operation is fully valid and expected, no message should be issued. The system is working as intended, so no warning is necessary.
* **Correct but deserves attention → No warning, possibly a log message**:
When an operation is correct but uncommon or requires special attention, providing an informational message can be helpful. This keeps users informed without implying any issue. If available, use the logger to output this message. Example:
```python
logger.info("This is an informational message about a rare but correct operation.")
```
* **Correct but very likely a mistake → Warning with option to disable**:
In rare cases, you may want to issue a warning for a correct operation thats very likely a mistake. In such cases, you must provide an option to suppress the warning. This can be done with a flag in the function. Example:
```python
def my_function(foo, bar, _warn=True):
if foo == bar:
if _warn:
logger.warning("foo and bar are the same, this is likely a mistake. Ignore this warning by setting `_warn=False`.")
# Do something
```
* **Supported but not correct → Warning**:
If the operation is technically supported but is deprecated, suboptimal, or could cause future issues (e.g., conflicting arguments), a warning should be raised. This message should be actionable, meaning it must explain how to resolve the issue. Example:
```python
def my_function(foo, bar):
if foo and bar:
logger.warning("Both `foo` and `bar` were provided, but only one is allowed. Ignoring `foo`. Please pass only one of these arguments.")
# Do something
```
* **Not supported → Exception**:
If the operation is invalid or unsupported, raise an exception. This indicates that the operation cannot be performed and requires immediate attention. Example:
```python
def my_function(foo, bar):
if foo and bar:
raise ValueError("Both `foo` and `bar` were provided, but only one is allowed. Please pass only one of these arguments.")
```
By following this classification, you ensure that warnings, information, and exceptions are used appropriately, providing clear guidance to the user without cluttering the system with unnecessary messages.

View File

@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Copyright 2020-2025 The HuggingFace Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,6 +1,7 @@
include settings.ini
include LICENSE
include CONTRIBUTING.md
include README.md
recursive-exclude * __pycache__
include trl/accelerate_configs/*.yaml
include trl/templates/*.md
recursive-exclude * __pycache__
prune tests

View File

@ -1,38 +1,19 @@
.PHONY: test precommit common_tests slow_tests test_examples tests_gpu
.PHONY: test precommit common_tests slow_tests tests_gpu test_experimental
check_dirs := examples tests trl
ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
COMMAND_FILES_PATH = `pwd`/commands
dev:
[ -L "$(pwd)/trl/commands/scripts" ] && unlink "$(pwd)/trl/commands/scripts" || true
pip install -e ".[dev]"
ln -s `pwd`/examples/scripts/ `pwd`/trl/commands
test:
python -m pytest -n auto --dist=loadfile -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' ./tests/
pytest -n auto -m "not slow and not low_priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' tests/
precommit:
pre-commit run --all-files
python scripts/add_copyrights.py
tests_gpu:
python -m pytest tests/test_* $(if $(IS_GITHUB_CI),--report-log "common_tests.log",)
pre-commit run --all-files
doc-builder style trl tests docs/source --max_len 119
slow_tests:
python -m pytest tests/slow/test_* $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
pytest -m "slow" tests/ $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
test_examples:
touch temp_results_sft_tests.txt
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_sft.sh; \
echo $$?','$${file} >> temp_results_sft_tests.txt; \
done
touch temp_results_dpo_tests.txt
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_dpo.sh; \
echo $$?','$${file} >> temp_results_dpo_tests.txt; \
done
test_experimental:
pytest -k "experimental"

237
README.md
View File

@ -1,7 +1,7 @@
# TRL - Transformer Reinforcement Learning
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_banner_dark.png" alt="TRL Banner">
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png" alt="TRL Banner">
</div>
<hr> <br>
@ -12,26 +12,33 @@
<p align="center">
<a href="https://github.com/huggingface/trl/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/huggingface/trl.svg?color=blue"></a>
<a href="https://huggingface.co/docs/trl/index"><img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/trl/index.svg?down_color=red&down_message=offline&up_color=blue&up_message=online"></a>
<a href="https://huggingface.co/docs/trl/index"><img alt="Documentation" src="https://img.shields.io/website?label=documentation&url=https%3A%2F%2Fhuggingface.co%2Fdocs%2Ftrl%2Findex&down_color=red&down_message=offline&up_color=blue&up_message=online"></a>
<a href="https://github.com/huggingface/trl/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/trl.svg"></a>
<a href="https://huggingface.co/trl-lib"><img alt="Hugging Face Hub" src="https://img.shields.io/badge/🤗%20Hub-trl--lib-yellow"></a>
</p>
## 🎉 What's New
> **✨ OpenAI GPT OSS Support**: TRL now fully supports fine-tuning the latest [OpenAI GPT OSS models](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4)! Check out the:
>
> - [OpenAI Cookbook](https://cookbook.openai.com/articles/gpt-oss/fine-tune-transfomers)
> - [GPT OSS recipes](https://github.com/huggingface/gpt-oss-recipes)
> - [Our example script](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_gpt_oss.py)
## Overview
TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the [🤗 Transformers](https://github.com/huggingface/transformers) ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.
## Highlights
- **Trainers**: Various fine-tuning methods are easily accessible via trainers like [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer) and more.
- **Efficient and scalable**:
- Leverages [🤗 Accelerate](https://github.com/huggingface/accelerate) to scale from single GPU to multi-node clusters using methods like DDP and DeepSpeed.
- Full integration with [`PEFT`](https://github.com/huggingface/peft) enables training on large models with modest hardware via quantization and LoRA/QLoRA.
- Integrates [Unsloth](https://github.com/unslothai/unsloth) for accelerating training using optimized kernels.
- Leverages [🤗 Accelerate](https://github.com/huggingface/accelerate) to scale from single GPU to multi-node clusters using methods like [DDP](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) and [DeepSpeed](https://github.com/deepspeedai/DeepSpeed).
- Full integration with [🤗 PEFT](https://github.com/huggingface/peft) enables training on large models with modest hardware via quantization and LoRA/QLoRA.
- Integrates [🦥 Unsloth](https://github.com/unslothai/unsloth) for accelerating training using optimized kernels.
- **Command Line Interface (CLI)**: A simple interface lets you fine-tune and interact with models without needing to write code.
- **Trainers**: Various fine-tuning methods are easily accessible via trainers like [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`ORPOTrainer`](https://huggingface.co/docs/trl/orpo_trainer) and more.
- **AutoModels**: Use pre-defined model classes like [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) to simplify reinforcement learning (RL) with LLMs.
- **Command Line Interface (CLI)**: A simple interface lets you fine-tune with models without needing to write code.
## Installation
@ -59,9 +66,91 @@ If you want to use the examples you can clone the repository with the following
git clone https://github.com/huggingface/trl.git
```
## Quick Start
For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.
### `SFTTrainer`
Here is a basic example of how to use the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer):
```python
from trl import SFTTrainer
from datasets import load_dataset
dataset = load_dataset("trl-lib/Capybara", split="train")
trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
)
trainer.train()
```
### `GRPOTrainer`
[`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer) implements the [Group Relative Policy Optimization (GRPO) algorithm](https://huggingface.co/papers/2402.03300) that is more memory-efficient than PPO and was used to train [Deepseek AI's R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).
```python
from datasets import load_dataset
from trl import GRPOTrainer
dataset = load_dataset("trl-lib/tldr", split="train")
# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
return [len(set(c)) for c in completions]
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_num_unique_chars,
train_dataset=dataset,
)
trainer.train()
```
### `DPOTrainer`
[`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer) implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train [Llama 3](https://huggingface.co/papers/2407.21783) and many other models. Here is a basic example of how to use the `DPOTrainer`:
```python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer
)
trainer.train()
```
### `RewardTrainer`
Here is a basic example of how to use the [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer):
```python
from trl import RewardTrainer
from datasets import load_dataset
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
trainer = RewardTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
train_dataset=dataset,
)
trainer.train()
```
## Command Line Interface (CLI)
You can use the TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT) and Direct Preference Optimization (DPO), or vibe check your model with the chat CLI:
You can use the TRL Command Line Interface (CLI) to quickly get started with post-training methods like Supervised Fine-Tuning (SFT) or Direct Preference Optimization (DPO):
**SFT:**
@ -79,117 +168,7 @@ trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--output_dir Qwen2.5-0.5B-DPO
```
**Chat:**
```bash
trl chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
```
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.
## How to use
For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.
### `SFTTrainer`
Here is a basic example of how to use the `SFTTrainer`:
```python
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
dataset = load_dataset("trl-lib/Capybara", split="train")
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
trainer = SFTTrainer(
args=training_args,
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
)
trainer.train()
```
### `RewardTrainer`
Here is a basic example of how to use the `RewardTrainer`:
```python
from trl import RewardConfig, RewardTrainer
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForSequenceClassification.from_pretrained(
"Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
)
model.config.pad_token_id = tokenizer.pad_token_id
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
trainer = RewardTrainer(
args=training_args,
model=model,
processing_class=tokenizer,
train_dataset=dataset,
)
trainer.train()
```
### `RLOOTrainer`
`RLOOTrainer` implements a [REINFORCE-style optimization](https://huggingface.co/papers/2402.14740) for RLHF that is more performant and memory-efficient than PPO. Here is a basic example of how to use the `RLOOTrainer`:
```python
from trl import RLOOConfig, RLOOTrainer, apply_chat_template
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
reward_model = AutoModelForSequenceClassification.from_pretrained(
"Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
)
ref_policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback-prompt")
dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
dataset = dataset.map(lambda x: tokenizer(x["prompt"]), remove_columns="prompt")
training_args = RLOOConfig(output_dir="Qwen2.5-0.5B-RL")
trainer = RLOOTrainer(
config=training_args,
processing_class=tokenizer,
policy=policy,
ref_policy=ref_policy,
reward_model=reward_model,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
)
trainer.train()
```
### `DPOTrainer`
`DPOTrainer` implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train Llama 3 and many other models. Here is a basic example of how to use the `DPOTrainer`:
```python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer)
trainer.train()
```
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/clis) or use `--help` for more details.
## Development
@ -198,9 +177,21 @@ If you want to contribute to `trl` or customize it to your needs make sure to re
```bash
git clone https://github.com/huggingface/trl.git
cd trl/
make dev
pip install -e .[dev]
```
## Experimental
A minimal incubation area is available under `trl.experimental` for unstable / fast-evolving features. Anything there may change or be removed in any release without notice.
Example:
```python
from trl.experimental.new_trainer import NewTrainer
```
Read more in the [Experimental docs](https://huggingface.co/docs/trl/experimental).
## Citation
```bibtex

167
RELEASE.md Normal file
View File

@ -0,0 +1,167 @@
# Making a release
> [!NOTE]
> VERSION needs to be formatted following the `v{major}.{minor}.{patch}` convention. We need to follow this convention to be able to retrieve versioned scripts.
## Major/Minor Release
### 1. Ensure your local repository is up to date with the upstream repository
```bash
git checkout main
git pull origin main
```
> [!WARNING]
> Do not merge other pull requests into `main` until the release is done. This is to ensure that the release is stable and does not include any untested changes. Announce internally (#trl-internal) to other maintainers that you are doing a release and that they must not merge PRs until the release is done.
### 2. Create a release branch from main
```bash
git checkout -b release-v{major}.{minor}
```
### 3. Change the version in the following files
- `.github/workflows/tests_latest.yml`:
```diff
- with: { ref: v{major}.{minor-1}-release }
+ with: { ref: v{major}.{minor}-release }
```
- `CITATION.cff`
```diff
- version: "{major}.{minor-1}"
+ version: "{major}.{minor}"
```
- `VERSION`
```diff
- {major}.{minor}.0.dev0
+ {major}.{minor}.0
```
### 4. Commit and push these changes
```shell
git add .github/workflows/tests_latest.yml CITATION.cff VERSION
git commit -m 'Release: {major}.{minor}'
git push origin release-v{major}.{minor}
```
### 5. Create a pull request
from `release-v{major}.{minor}` to `main`, named `Release: v{major}.{minor}`, wait for tests to pass, and request a review.
### 6. Once the pull request is approved, merge it into `main`
It will automatically publish the new version of the package on PyPI.
### 7. Add a tag in git to mark the release
```shell
git checkout main
git pull origin main
git tag -a v{major}.{minor}.0 -m 'Adds tag v{major}.{minor}.0 for PyPI'
git push origin v{major}.{minor}.0
```
### 8. Create a branch `v{major}.{minor}-release` for future patch releases
```shell
git checkout -b v{major}.{minor}-release
git push origin v{major}.{minor}-release
```
This ensures that future patch releases (`v{major}.{minor}.1`, `v{major}.{minor}.2`, etc.) can be made separately from `main`.
### 9. Create a GitHub Release
1. Go to the repos [releases section](https://github.com/huggingface/trl/releases) on GitHub.
2. Click **Draft a new release**.
3. Select the `v{major}.{minor}.0` tag you just created in step 7.
4. Add a title (`v{major}.{minor}.0`) and a short description of whats new.
5. Click **Publish Release**.
### 10. Bump to dev version
1. Create a branch `bump-dev-version-{major}.{minor+1}` from `main` and checkout to it.
```shell
git checkout -b bump-dev-version-{major}.{minor+1}
```
2. Change the version in file `VERSION`:
```diff
- {major}.{minor}.0
+ {major}.{minor+1}.0.dev0
```
3. Commit and push these changes
```shell
git add VERSION
git commit -m '⬆️ Bump dev version'
git push origin bump-dev-version-{major}.{minor+1}
```
4. Create a pull request from `bump-dev-version-{major}.{minor+1}` to `main`, named `⬆️ Bump dev version`, and request urgent review.
5. Once the pull request is approved, merge it into `main`.
6. The codebase is now ready for the next development cycle, inform the team in the #trl-internal channel.
## Making a patch release
### 1. Ensure your local repository is up to date with the upstream repository
```bash
git checkout v{major}.{minor}-release
git pull origin main
```
### 2. Cherry-pick the changes you want to include in the patch release
```bash
git cherry-pick <commit-hash-0>
git cherry-pick <commit-hash-1>
...
```
### 3. Change the version in the file `VERSION`
```diff
- {major}.{minor}.{patch-1}
+ {major}.{minor}.{patch}
```
### 4. Commit and push these changes
```shell
git add VERSION
git commit -m 'Release: {major}.{minor}.{patch}'
git push origin v{major}.{minor}-release
```
### 5. Wait for the CI to pass
The CI will automatically publish the new version of the package on PyPI.
### 6. Add a tag in git to mark the release
```shell
git tag -a v{major}.{minor}.{patch} -m 'Adds tag v{major}.{minor}.{patch} for PyPI'
git push origin v{major}.{minor}.{patch}
```
#### 7. Create a GitHub Release
1. Go to the repos [releases section](https://github.com/huggingface/trl/releases) on GitHub.
2. Click **Draft a new release**.
3. Select the `v{major}.{minor}.{patch}` tag you just created in step 7.
4. Add a title (`v{major}.{minor}.{patch}`) and a short description of whats new.
5. Click **Publish Release**.

1
VERSION Normal file
View File

@ -0,0 +1 @@
0.25.0.dev0

View File

@ -1,58 +0,0 @@
#!/bin/bash
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
# but defaults to QLoRA + PEFT
OUTPUT_DIR="test_dpo/"
MODEL_NAME="trl-internal-testing/tiny-random-LlamaForCausalLM"
DATASET_NAME="trl-internal-testing/hh-rlhf-helpful-base-trl-style"
MAX_STEPS=5
BATCH_SIZE=2
SEQ_LEN=128
# Handle extra arguments in case one passes accelerate configs.
EXTRA_ACCELERATE_ARGS=""
EXTRA_TRAINING_ARGS="""--use_peft \
--load_in_4bit
"""
# This is a hack to get the number of available GPUs
NUM_GPUS=2
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
EXTRA_ACCELERATE_ARGS=""
else
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
EXTRA_TRAINING_ARGS="--fp16"
else
echo "Keeping QLoRA + PEFT"
fi
fi
CMD="""
accelerate launch $EXTRA_ACCELERATE_ARGS \
--num_processes $NUM_GPUS \
--mixed_precision 'fp16' \
`pwd`/examples/scripts/dpo.py \
--model_name_or_path $MODEL_NAME \
--dataset_name $DATASET_NAME \
--output_dir $OUTPUT_DIR \
--max_steps $MAX_STEPS \
--per_device_train_batch_size $BATCH_SIZE \
--max_length $SEQ_LEN \
$EXTRA_TRAINING_ARGS
"""
echo "Starting program..."
{ # try
echo $CMD
eval "$CMD"
} || { # catch
# save log for exception
echo "Operation Failed!"
exit 1
}
exit 0

View File

@ -1,59 +0,0 @@
#!/bin/bash
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
# but defaults to QLoRA + PEFT
OUTPUT_DIR="test_sft/"
MODEL_NAME="trl-internal-testing/tiny-random-LlamaForCausalLM"
DATASET_NAME="stanfordnlp/imdb"
MAX_STEPS=5
BATCH_SIZE=2
SEQ_LEN=128
# Handle extra arguments in case one passes accelerate configs.
EXTRA_ACCELERATE_ARGS=""
EXTRA_TRAINING_ARGS="""--use_peft \
--load_in_4bit
"""
# Set your number of GPUs here
NUM_GPUS=2
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
EXTRA_ACCELERATE_ARGS=""
else
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
EXTRA_TRAINING_ARGS="--fp16"
else
echo "Keeping QLoRA + PEFT"
fi
fi
CMD="""
accelerate launch $EXTRA_ACCELERATE_ARGS \
--num_processes $NUM_GPUS \
--mixed_precision 'fp16' \
`pwd`/examples/scripts/sft.py \
--model_name $MODEL_NAME \
--dataset_name $DATASET_NAME \
--output_dir $OUTPUT_DIR \
--max_steps $MAX_STEPS \
--per_device_train_batch_size $BATCH_SIZE \
--max_seq_length $SEQ_LEN \
$EXTRA_TRAINING_ARGS
"""
echo "Starting program..."
{ # try
echo $CMD
eval "$CMD"
} || { # catch
# save log for exception
echo "Operation Failed!"
exit 1
}
exit 0

View File

@ -0,0 +1,6 @@
FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-runtime
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
RUN pip install --upgrade pip uv
RUN uv pip install --system --no-cache "git+https://github.com/huggingface/trl.git#egg=trl[liger,peft,vlm]"
RUN uv pip install --system hf_transfer liger_kernel trackio peft
RUN uv pip install --system https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp311-cp311-linux_x86_64.whl

View File

@ -1,66 +0,0 @@
# Builds GPU docker image of PyTorch
# Uses multi-staged approach to reduce size
# Stage 1
# Use base conda image to reduce time
FROM continuumio/miniconda3:latest AS compile-image
# Specify py version
ENV PYTHON_VERSION=3.10
# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
RUN apt-get update && \
apt-get install -y curl git wget software-properties-common git-lfs && \
apt-get clean && \
rm -rf /var/lib/apt/lists*
# Install audio-related libraries
RUN apt-get update && \
apt install -y ffmpeg
RUN apt install -y libsndfile1-dev
RUN git lfs install
# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
RUN python3 -m pip install --no-cache-dir --upgrade pip
# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
# We don't install pytorch here yet since CUDA isn't available
# instead we use the direct torch wheel
ENV PATH /opt/conda/envs/trl/bin:$PATH
# Activate our bash shell
RUN chsh -s /bin/bash
SHELL ["/bin/bash", "-c"]
# Stage 2
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
COPY --from=compile-image /opt/conda /opt/conda
ENV PATH /opt/conda/bin:$PATH
RUN chsh -s /bin/bash
SHELL ["/bin/bash", "-c"]
RUN source activate trl && \
python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
# Install apt libs
RUN apt-get update && \
apt-get install -y curl git wget && \
apt-get clean && \
rm -rf /var/lib/apt/lists*
# Activate the conda env and install transformers + accelerate from source
RUN source activate trl && \
python3 -m pip install -U --no-cache-dir \
librosa \
"soundfile>=0.12.1" \
scipy \
transformers \
accelerate \
peft \
trl[test]@git+https://github.com/huggingface/trl
RUN source activate trl && \
pip freeze | grep trl
RUN echo "source activate trl" >> ~/.profile
# Activate the virtualenv
CMD ["/bin/bash"]

View File

@ -1,66 +0,0 @@
# Builds GPU docker image of PyTorch
# Uses multi-staged approach to reduce size
# Stage 1
# Use base conda image to reduce time
FROM continuumio/miniconda3:latest AS compile-image
# Specify py version
ENV PYTHON_VERSION=3.10
# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
RUN apt-get update && \
apt-get install -y curl git wget software-properties-common git-lfs && \
apt-get clean && \
rm -rf /var/lib/apt/lists*
# Install audio-related libraries
RUN apt-get update && \
apt install -y ffmpeg
RUN apt install -y libsndfile1-dev
RUN git lfs install
# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
RUN python3 -m pip install --no-cache-dir --upgrade pip
# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
# We don't install pytorch here yet since CUDA isn't available
# instead we use the direct torch wheel
ENV PATH /opt/conda/envs/trl/bin:$PATH
# Activate our bash shell
RUN chsh -s /bin/bash
SHELL ["/bin/bash", "-c"]
# Stage 2
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
COPY --from=compile-image /opt/conda /opt/conda
ENV PATH /opt/conda/bin:$PATH
RUN chsh -s /bin/bash
SHELL ["/bin/bash", "-c"]
RUN source activate trl && \
python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
# Install apt libs
RUN apt-get update && \
apt-get install -y curl git wget && \
apt-get clean && \
rm -rf /var/lib/apt/lists*
# Activate the conda env and install transformers + accelerate from source
RUN source activate trl && \
python3 -m pip install -U --no-cache-dir \
librosa \
"soundfile>=0.12.1" \
scipy \
git+https://github.com/huggingface/transformers \
git+https://github.com/huggingface/accelerate \
git+https://github.com/huggingface/peft \
trl[test]@git+https://github.com/huggingface/trl
RUN source activate trl && \
pip freeze | grep transformers
RUN echo "source activate trl" >> ~/.profile
# Activate the virtualenv
CMD ["/bin/bash"]

4
docker/trl/Dockerfile Normal file
View File

@ -0,0 +1,4 @@
FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-runtime
RUN pip install --upgrade pip uv
RUN uv pip install --system trl[liger,peft,vlm] hf_transfer trackio
RUN uv pip install --system https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp311-cp311-linux_x86_64.whl

View File

@ -5,35 +5,73 @@
title: Installation
- local: quickstart
title: Quickstart
- local: clis
title: Get started with Command Line Interfaces (CLIs)
title: Getting started
- sections:
- local: dataset_formats
title: Dataset Formats
- local: how_to_train
title: PPO Training FAQ
- local: use_model
title: Use Trained Models
- local: customization
title: Customize the Training
- local: logging
title: Understanding Logs
title: Get started
- local: paper_index
title: Paper Index
- local: experimental
title: Experimental
title: Conceptual Guides
- sections:
- sections: # Sort alphabetically
- local: alignprop_trainer
title: AlignProp
- local: clis
title: Command Line Interface (CLI)
- local: jobs_training
title: Training using Jobs
- local: customization
title: Customizing the Training
- local: reducing_memory_usage
title: Reducing Memory Usage
- local: speeding_up_training
title: Speeding Up Training
- local: distributing_training
title: Distributing Training
- local: use_model
title: Using Trained Models
title: How-to guides
- sections:
- local: deepspeed_integration
title: DeepSpeed
- local: kernels_hub
title: Kernels Hub
- local: liger_kernel_integration
title: Liger Kernel
- local: peft_integration
title: PEFT
- local: trackio_integration
title: Trackio
- local: unsloth_integration
title: Unsloth
- local: vllm_integration
title: vLLM
title: Integrations
- sections:
- local: example_overview
title: Example Overview
- local: community_tutorials
title: Community Tutorials
- local: lora_without_regret
title: LoRA Without Regret
- local: sentiment_tuning
title: Sentiment Tuning
- local: multi_adapter_rl
title: Multi Adapter RLHF
title: Examples
- sections:
- sections: # Sorted alphabetically
- local: bco_trainer
title: BCO
- local: cpo_trainer
title: CPO
- local: ddpo_trainer
title: DDPO
- local: dpo_trainer
title: DPO
- local: online_dpo_trainer
title: Online DPO
- local: gkd_trainer
title: GKD
- local: grpo_trainer
title: GRPO
- local: kto_trainer
title: KTO
- local: nash_md_trainer
@ -42,19 +80,21 @@
title: ORPO
- local: ppo_trainer
title: PPO
- local: prm_trainer
title: PRM
- local: reward_trainer
title: Reward
- local: rloo_trainer
title: RLOO
- local: sft_trainer
title: SFT
- local: iterative_sft_trainer
title: Iterative SFT
- local: xpo_trainer
title: XPO
title: Trainers
- local: models
title: Model Classes
- local: model_utils
title: Model Utilities
- local: best_of_n
title: Best of N Sampling
- local: judges
@ -63,22 +103,10 @@
title: Callbacks
- local: data_utils
title: Data Utilities
- local: text_environments
title: Text Environments
- local: rewards
title: Reward Functions
- local: script_utils
title: Script Utilities
- local: others
title: Others
title: API
- sections:
- local: example_overview
title: Example Overview
- local: sentiment_tuning
title: Sentiment Tuning
- local: lora_tuning_peft
title: Training with PEFT
- local: detoxifying_a_lm
title: Detoxifying a Language Model
- local: using_llama_models
title: Training StackLlama
- local: learning_tools
title: Learning to Use Tools
- local: multi_adapter_rl
title: Multi Adapter RLHF
title: Examples

View File

@ -1,93 +0,0 @@
# Aligning Text-to-Image Diffusion Models with Reward Backpropagation
[![](https://img.shields.io/badge/All_models-AlignProp-blue)](https://huggingface.co/models?other=alignprop,trl)
## The why
If your reward function is differentiable, directly backpropagating gradients from the reward models to the diffusion model is significantly more sample and compute efficient (25x) than doing policy gradient algorithm like DDPO.
AlignProp does full backpropagation through time, which allows updating the earlier steps of denoising via reward backpropagation.
<div style="text-align: center"><img src="https://align-prop.github.io/reward_tuning.png"/></div>
## Getting started with `examples/scripts/alignprop.py`
The `alignprop.py` script is a working example of using the `AlignProp` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`AlignPropConfig`).
**Note:** one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1.
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running
```batch
python alignprop.py --hf_user_access_token <token>
```
To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater to 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps)
- The configurable truncation backprop absolute step (`--alignprop_config.truncated_backprop_timestep=49`) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False
## Setting up the image logging hook function
Expect the function to be given a dictionary with keys
```python
['image', 'prompt', 'prompt_metadata', 'rewards']
```
and `image`, `prompt`, `prompt_metadata`, `rewards`are batched.
You are free to log however you want the use of `wandb` or `tensorboard` is recommended.
### Key terms
- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
- `prompt` : The prompt is the text that is used to generate the image
- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
- `image` : The image generated by the Stable Diffusion model
Example code for logging sampled images with `wandb` is given below.
```python
# for logging these images to wandb
def image_outputs_hook(image_data, global_step, accelerate_logger):
# For the sake of this example, we only care about the last batch
# hence we extract the last element of the list
result = {}
images, prompts, rewards = [image_data['images'],image_data['prompts'],image_data['rewards']]
for i, image in enumerate(images):
pil = Image.fromarray(
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
)
pil = pil.resize((256, 256))
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
accelerate_logger.log_images(
result,
step=global_step,
)
```
### Using the finetuned model
Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
```python
from diffusers import StableDiffusionPipeline
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipeline.to("cuda")
pipeline.load_lora_weights('mihirpd/alignprop-trl-aesthetics')
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
results = pipeline(prompts)
for prompt, image in zip(prompts,results.images):
image.save(f"dump/{prompt}.png")
```
## Credits
This work is heavily influenced by the repo [here](https://github.com/mihirp1998/AlignProp/) and the associated paper [Aligning Text-to-Image Diffusion Models with Reward Backpropagation
by Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki](https://huggingface.co/papers/2310.03739).

View File

@ -1,6 +1,6 @@
# BCO Trainer
[![](https://img.shields.io/badge/All_models-BCO-blue)](https://huggingface.co/models?other=bco,trl)
[![model badge](https://img.shields.io/badge/All_models-BCO-blue)](https://huggingface.co/models?other=bco,trl)
TRL supports the Binary Classifier Optimization (BCO).
The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0.
@ -9,9 +9,10 @@ For a full example have a look at [`examples/scripts/bco.py`].
## Expected dataset type
The [`BCOTrainer`] requires an [unpaired preference dataset](dataset_formats#unpaired-preference).
The [`BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
The [`BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
## Expected model format
The BCO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
## Using the `BCOTrainer`
@ -20,9 +21,7 @@ For a detailed example have a look at the `examples/scripts/bco.py` script. At a
The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
```py
```python
training_args = BCOConfig(
beta=0.1,
)
@ -35,9 +34,10 @@ bco_trainer = BCOTrainer(
processing_class=tokenizer,
)
```
After this one can then call:
```py
```python
bco_trainer.train()
```
@ -49,7 +49,7 @@ If the prompts in your desired and undesired datasets differ a lot, it is useful
Choose an embedding model and tokenizer:
```py
```python
embedding_model = AutoModel.from_pretrained(your_model_id)
embedding_tokenizer = AutoTokenizer.from_pretrained(your_model_id)
@ -62,9 +62,9 @@ embedding_model = Accelerator().prepare_model(self.embedding_model)
embedding_func = partial(embed_prompt, model=embedding_model)
```
Set `prompt_sample_size` to defined how many prompts are selected to train the UDM classifier and start the training with the provided embedding function:
Set `prompt_sample_size` to define how many prompts are selected to train the UDM classifier and start the training with the provided embedding function:
```py
```python
training_args = BCOConfig(
beta=0.1,
prompt_sample_size=512,
@ -94,6 +94,9 @@ To scale how much the auxiliary loss contributes to the total loss, use the hype
## BCOTrainer
[[autodoc]] BCOTrainer
- train
- save_model
- push_to_hub
## BCOConfig

View File

@ -1,5 +1,8 @@
# Best of N sampling: Alternative ways to get better model output without RL based fine-tuning
> [!WARNING]
> Best-of-N sampling is deprecated and will be removed in TRL 0.25.0.
Within the extras module is the `best-of-n` sampler class that serves as an alternative method of generating better model output.
As to how it fares against the RL based fine-tuning, please look in the `examples` directory for a comparison example
@ -8,7 +11,6 @@ As to how it fares against the RL based fine-tuning, please look in the `example
To get started quickly, instantiate an instance of the class with a model, a length sampler, a tokenizer and a callable that serves as a proxy reward pipeline that outputs reward scores for input queries
```python
from transformers import pipeline, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler
@ -19,41 +21,33 @@ reward_pipe = pipeline("sentiment-analysis", model=reward_model, device=device)
tokenizer = AutoTokenizer.from_pretrained(ref_model_name)
tokenizer.pad_token = tokenizer.eos_token
# callable that takes a list of raw text and returns a list of corresponding reward scores
def queries_to_scores(list_of_strings):
return [output["score"] for output in reward_pipe(list_of_strings)]
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler)
```
And assuming you have a list/tensor of tokenized queries, you can generate better output by calling the `generate` method
```python
best_of_n.generate(query_tensors, device=device, **gen_kwargs)
```
The default sample size is 4, but you can change it at the time of instance initialization like so
```python
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, sample_size=8)
```
The default output is the result of taking the top scored output for each query, but you can change it to top 2 and so on by passing the `n_candidates` argument at the time of instance initialization
```python
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, n_candidates=2)
```
There is the option of setting the generation settings (like `temperature`, `pad_token_id`) at the time of instance creation as opposed to when calling the `generate` method.
This is done by passing a `GenerationConfig` from the `transformers` library at the time of initialization
This is done by passing a [`~transformers.GenerationConfig`] from the `transformers` library at the time of initialization
```python
@ -67,6 +61,8 @@ best_of_n.generate(query_tensors, device=device)
```
Furthermore, at the time of initialization you can set the seed to control repeatability of the generation process and the number of samples to generate for each query
Furthermore, at the time of initialization you can set the seed to control the repeatability of the generation process and the number of samples to generate for each query
## BestOfNSampler
[[autodoc]] BestOfNSampler

View File

@ -15,3 +15,15 @@
## LogCompletionsCallback
[[autodoc]] LogCompletionsCallback
## MergeModelCallback
[[autodoc]] MergeModelCallback
## BEMACallback
[[autodoc]] BEMACallback
## WeaveCallback
[[autodoc]] WeaveCallback

414
docs/source/clis.md Normal file
View File

@ -0,0 +1,414 @@
# Command Line Interfaces (CLIs)
TRL provides a powerful command-line interface (CLI) to fine-tune large language models (LLMs) using methods like Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and more. The CLI abstracts away much of the boilerplate, letting you launch training jobs quickly and reproducibly.
## Commands
Currently supported commands are:
### Training Commands
- `trl dpo`: fine-tune a LLM with DPO
- `trl grpo`: fine-tune a LLM with GRPO
- `trl kto`: fine-tune a LLM with KTO
- `trl reward`: train a Reward Model
- `trl rloo`: fine-tune a LLM with RLOO
- `trl sft`: fine-tune a LLM with SFT
### Other Commands
- `trl env`: get the system information
- `trl vllm-serve`: serve a model with vLLM
## Fine-Tuning with the TRL CLI
### Basic Usage
You can launch training directly from the CLI by specifying required arguments like the model and dataset:
<hfoptions id="command_line">
<hfoption id="SFT">
```bash
trl sft \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name stanfordnlp/imdb
```
</hfoption>
<hfoption id="DPO">
```bash
trl dpo \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name anthropic/hh-rlhf
```
</hfoption>
<hfoption id="Reward">
```bash
trl reward \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/ultrafeedback_binarized
```
</hfoption>
</hfoptions>
### Using Configuration Files
To keep your CLI commands clean and reproducible, you can define all training arguments in a YAML configuration file:
<hfoptions id="config_file">
<hfoption id="SFT">
```yaml
# sft_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: stanfordnlp/imdb
```
Launch with:
```bash
trl sft --config sft_config.yaml
```
</hfoption>
<hfoption id="DPO">
```yaml
# dpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: anthropic/hh-rlhf
```
Launch with:
```bash
trl dpo --config dpo_config.yaml
```
</hfoption>
<hfoption id="Reward">
```yaml
# reward_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: trl-lib/ultrafeedback_binarized
```
Launch with:
```bash
trl reward --config reward_config.yaml
```
</hfoption>
</hfoptions>
### Scaling Up with Accelerate
TRL CLI natively supports [🤗 Accelerate](https://huggingface.co/docs/accelerate), making it easy to scale training across multiple GPUs, machines, or use advanced setups like DeepSpeed — all from the same CLI.
You can pass any `accelerate launch` arguments directly to `trl`, such as `--num_processes`. For more information see [Using accelerate launch](https://huggingface.co/docs/accelerate/en/basic_tutorials/launch#using-accelerate-launch).
<hfoptions id="launch_args">
<hfoption id="SFT inline">
```bash
trl sft \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name stanfordnlp/imdb \
--num_processes 4
```
</hfoption>
<hfoption id="SFT w/ config file">
```yaml
# sft_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: stanfordnlp/imdb
num_processes: 4
```
Launch with:
```bash
trl sft --config sft_config.yaml
```
</hfoption>
<hfoption id="DPO inline">
```bash
trl dpo \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name anthropic/hh-rlhf \
--num_processes 4
```
</hfoption>
<hfoption id="DPO w/ config file">
```yaml
# dpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: anthropic/hh-rlhf
num_processes: 4
```
Launch with:
```bash
trl dpo --config dpo_config.yaml
```
</hfoption>
<hfoption id="Reward inline">
```bash
trl reward \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/ultrafeedback_binarized \
--num_processes 4
```
</hfoption>
<hfoption id="Reward w/ config file">
```yaml
# reward_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: trl-lib/ultrafeedback_binarized
num_processes: 4
```
Launch with:
```bash
trl reward --config reward_config.yaml
```
</hfoption>
</hfoptions>
### Using `--accelerate_config` for Accelerate Configuration
The `--accelerate_config` flag lets you easily configure distributed training with [🤗 Accelerate](https://github.com/huggingface/accelerate). This flag accepts either:
- the name of a predefined config profile (built into TRL), or
- a path to a custom Accelerate YAML config file.
#### Predefined Config Profiles
TRL provides several ready-to-use Accelerate configs to simplify common training setups:
| Name | Description |
| --- | --- |
| `fsdp1` | Fully Sharded Data Parallel Stage 1 |
| `fsdp2` | Fully Sharded Data Parallel Stage 2 |
| `zero1` | DeepSpeed ZeRO Stage 1 |
| `zero2` | DeepSpeed ZeRO Stage 2 |
| `zero3` | DeepSpeed ZeRO Stage 3 |
| `multi_gpu` | Multi-GPU training |
| `single_gpu` | Single-GPU training |
To use one of these, just pass the name to `--accelerate_config`. TRL will automatically load the corresponding config file from `trl/accelerate_config/`.
#### Example Usage
<hfoptions id="accelerate_config">
<hfoption id="SFT inline">
```bash
trl sft \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name stanfordnlp/imdb \
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
```
</hfoption>
<hfoption id="SFT w/ config file">
```yaml
# sft_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: stanfordnlp/imdb
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
```
Launch with:
```bash
trl sft --config sft_config.yaml
```
</hfoption>
<hfoption id="DPO inline">
```bash
trl dpo \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name anthropic/hh-rlhf \
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
```
</hfoption>
<hfoption id="DPO w/ config file">
```yaml
# dpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: anthropic/hh-rlhf
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
```
Launch with:
```bash
trl dpo --config dpo_config.yaml
```
</hfoption>
<hfoption id="Reward inline">
```bash
trl reward \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/ultrafeedback_binarized \
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
```
</hfoption>
<hfoption id="Reward w/ config file">
```yaml
# reward_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: trl-lib/ultrafeedback_binarized
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
```
Launch with:
```bash
trl reward --config reward_config.yaml
```
</hfoption>
</hfoptions>
### Using dataset mixtures
You can use dataset mixtures to combine multiple datasets into a single training dataset. This is useful for training on diverse data sources or when you want to mix different types of data.
<hfoptions id="dataset_mixtures">
<hfoption id="SFT">
```yaml
# sft_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
datasets:
- path: stanfordnlp/imdb
- path: roneneldan/TinyStories
```
Launch with:
```bash
trl sft --config sft_config.yaml
```
</hfoption>
<hfoption id="DPO">
```yaml
# dpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
datasets:
- path: BAAI/Infinity-Preference
- path: argilla/Capybara-Preferences
```
Launch with:
```bash
trl dpo --config dpo_config.yaml
```
</hfoption>
<hfoption id="Reward">
```yaml
# reward_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
datasets:
- path: trl-lib/tldr-preference
- path: trl-lib/lm-human-preferences-sentiment
```
Launch with:
```bash
trl reward --config reward_config.yaml
```
</hfoption>
</hfoptions>
To see all the available keywords for defining dataset mixtures, refer to the [`scripts.utils.DatasetConfig`] and [`DatasetMixtureConfig`] classes.
## Getting the System Information
You can get the system information by running the following command:
```bash
trl env
```
This will print out the system information, including the GPU information, the CUDA version, the PyTorch version, the transformers version, the TRL version, and any optional dependencies that are installed.
```txt
Copy-paste the following information when reporting an issue:
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
- Python version: 3.11.9
- PyTorch version: 2.4.1
- accelerator(s): NVIDIA H100 80GB HBM3
- Transformers version: 4.45.0.dev0
- Accelerate version: 0.34.2
- Accelerate config:
- compute_environment: LOCAL_MACHINE
- distributed_type: DEEPSPEED
- mixed_precision: no
- use_cpu: False
- debug: False
- num_processes: 4
- machine_rank: 0
- num_machines: 1
- rdzv_backend: static
- same_network: True
- main_training_function: main
- enable_cpu_affinity: False
- deepspeed_config: {'gradient_accumulation_steps': 4, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': False, 'zero_stage': 2}
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: []
- Datasets version: 3.0.0
- HF Hub version: 0.24.7
- TRL version: 0.12.0.dev0+acb4d70
- bitsandbytes version: 0.41.1
- DeepSpeed version: 0.15.1
- Diffusers version: 0.30.3
- Liger-Kernel version: 0.3.0
- LLM-Blender version: 0.0.2
- OpenAI version: 1.46.0
- PEFT version: 0.12.0
- vLLM version: not installed
```
This information is required when reporting an issue.

View File

@ -1,171 +0,0 @@
# Command Line Interfaces (CLIs)
You can use TRL to fine-tune your Language Model with Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) or even chat with your model using the TRL CLIs.
Currently supported CLIs are:
- `trl sft`: fine-tune a LLM on a text/instruction dataset
- `trl dpo`: fine-tune a LLM with DPO on a preference dataset
- `trl chat`: quickly spin up a LLM fine-tuned for chatting
- `trl env`: get the system information
## Fine-tuning with the CLI
Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task.
Before using the `sft` or `dpo` commands make sure to run:
```bash
accelerate config
```
and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command.
We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command.
```yaml
model_name_or_path:
trl-internal-testing/tiny-random-LlamaForCausalLM
dataset_name:
stanfordnlp/imdb
report_to:
none
learning_rate:
0.0001
lr_scheduler_type:
cosine
```
Save that config in a `.yaml` and get started immediately! An example CLI config is available as `examples/cli_configs/example_config.yaml`. Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g. from the root folder:
```bash
trl sft --config examples/cli_configs/example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts
```
Will force-use `cosine_with_restarts` for `lr_scheduler_type`.
### Supported Arguments
We do support all arguments from `transformers.TrainingArguments`, for loading your model, we support all arguments from `~trl.ModelConfig`:
[[autodoc]] ModelConfig
You can pass any of these arguments either to the CLI or the YAML file.
### Supervised Fine-tuning (SFT)
Follow the basic instructions above and run `trl sft --output_dir <output_dir> <*args>`:
```bash
trl sft --model_name_or_path facebook/opt-125m --dataset_name stanfordnlp/imdb --output_dir opt-sft-imdb
```
The SFT CLI is based on the `examples/scripts/sft.py` script.
### Direct Policy Optimization (DPO)
To use the DPO CLI, you need to have a dataset in the TRL format such as
* TRL's Anthropic HH dataset: https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-helpful-base-trl-style
* TRL's OpenAI TL;DR summarization dataset: https://huggingface.co/datasets/trl-internal-testing/tldr-preference-trl-style
These datasets always have at least three columns `prompt, chosen, rejected`:
* `prompt` is a list of strings.
* `chosen` is the chosen response in [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
* `rejected` is the rejected response [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
To do a quick start, you can run the following command:
```bash
trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style
```
The DPO CLI is based on the `examples/scripts/dpo.py` script.
#### Custom preference dataset
Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`):
```bash
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
```
## Chat interface
The chat CLI lets you quickly load the model and talk to it. Simply run the following:
<pre><code>$ trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?
<strong><span style="color: blue;">&lt;Qwen/Qwen1.5-0.5B-Chat&gt;:</span></strong>
There isn't a "best" programming language, as everyone has different style preferences, needs, and preferences. However, some people commonly use
languages like Python, Java, C++, and JavaScript, which are popular among developers for a variety of reasons, including readability, flexibility,
and scalability. Ultimately, it depends on personal preference, needs, and goals.
</code></pre>
Note that the chat interface relies on the tokenizer's [chat template](https://huggingface.co/docs/transformers/chat_templating) to format the inputs for the model. Make sure your tokenizer has a chat template defined.
Besides talking to the model there are a few commands you can use:
- `clear`: clears the current conversation and start a new one
- `example {NAME}`: load example named `{NAME}` from the config and use it as the user input
- `set {SETTING_NAME}={SETTING_VALUE};`: change the system prompt or generation settings (multiple settings are separated by a `;`).
- `reset`: same as clear but also resets the generation configs to defaults if they have been changed by `set`
- `save` or `save {SAVE_NAME}`: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
- `exit`: closes the interface
The default examples are defined in `examples/scripts/config/default_chat_config.yaml` but you can pass your own with `--config CONFIG_FILE` where you can also specify the default generation parameters.
## Getting the system information
You can get the system information by running the following command:
```bash
trl env
```
This will print out the system information including the GPU information, the CUDA version, the PyTorch version, the transformers version, and the TRL version, and any optional dependencies that are installed.
```txt
Copy-paste the following information when reporting an issue:
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
- Python version: 3.11.9
- PyTorch version: 2.4.1
- CUDA device: NVIDIA H100 80GB HBM3
- Transformers version: 4.45.0.dev0
- Accelerate version: 0.34.2
- Accelerate config:
- compute_environment: LOCAL_MACHINE
- distributed_type: DEEPSPEED
- mixed_precision: no
- use_cpu: False
- debug: False
- num_processes: 4
- machine_rank: 0
- num_machines: 1
- rdzv_backend: static
- same_network: True
- main_training_function: main
- enable_cpu_affinity: False
- deepspeed_config: {'gradient_accumulation_steps': 4, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': False, 'zero_stage': 2}
- downcast_bf16: no
- tpu_use_cluster: False
- tpu_use_sudo: False
- tpu_env: []
- Datasets version: 3.0.0
- HF Hub version: 0.24.7
- TRL version: 0.12.0.dev0+acb4d70
- bitsandbytes version: 0.41.1
- DeepSpeed version: 0.15.1
- Diffusers version: 0.30.3
- Liger-Kernel version: 0.3.0
- LLM-Blender version: 0.0.2
- OpenAI version: 1.46.0
- PEFT version: 0.12.0
```
This information are required when reporting an issue.

View File

@ -0,0 +1,57 @@
# Community Tutorials
Community tutorials are made by active members of the Hugging Face community who want to share their knowledge and expertise with others. They are a great way to learn about the library and its features, and to get started with core classes and modalities.
## Language Models
### Tutorials
| Task | Class | Description | Author | Tutorial | Colab |
| --- | --- | --- | --- | --- | --- |
| Reinforcement Learning | [`GRPOTrainer`] | Efficient Online Training with GRPO and vLLM in TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/grpo_vllm_online_training) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/grpo_vllm_online_training.ipynb) |
| Reinforcement Learning | [`GRPOTrainer`] | Post training an LLM for reasoning with GRPO in TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb) |
| Reinforcement Learning | [`GRPOTrainer`] | Mini-R1: Reproduce Deepseek R1 „aha moment“ a RL tutorial | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/mini-deepseek-r1) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/mini-deepseek-r1-aha-grpo.ipynb) |
| Reinforcement Learning | [`GRPOTrainer`] | RL on LLaMA 3.1-8B with GRPO and Unsloth optimizations | [Andrea Manzoni](https://huggingface.co/AManzoni) | [Link](https://colab.research.google.com/github/amanzoni1/fine_tuning/blob/main/RL_LLama3_1_8B_GRPO.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/amanzoni1/fine_tuning/blob/main/RL_LLama3_1_8B_GRPO.ipynb) |
| Instruction tuning | [`SFTTrainer`] | Fine-tuning Google Gemma LLMs using ChatML format with QLoRA | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-google-gemma) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/gemma-lora-example.ipynb) |
| Structured Generation | [`SFTTrainer`] | Fine-tuning Llama-2-7B to generate Persian product catalogs in JSON using QLoRA and PEFT | [Mohammadreza Esmaeilian](https://huggingface.co/Mohammadreza) | [Link](https://huggingface.co/learn/cookbook/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format.ipynb) |
| Preference Optimization | [`DPOTrainer`] | Align Mistral-7b using Direct Preference Optimization for human preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/Fine_tune_Mistral_7b_with_DPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb) |
| Preference Optimization | [`ORPOTrainer`] | Fine-tuning Llama 3 with ORPO combining instruction tuning and preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1eHNWg9gnaXErdAa8_mcvjMupbSS6rDvi) |
| Instruction tuning | [`SFTTrainer`] | How to fine-tune open LLMs in 2025 with Hugging Face | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-llms-in-2025) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-llms-in-2025.ipynb) |
### Videos
| Task | Title | Author | Video |
| --- | --- | --- | --- |
| Instruction tuning | Fine-tuning open AI models using Hugging Face TRL | [Wietse Venema](https://huggingface.co/wietsevenema) | [<img src="https://img.youtube.com/vi/cnGyyM0vOes/0.jpg">](https://youtu.be/cnGyyM0vOes) |
| Instruction tuning | How to fine-tune a smol-LM with Hugging Face, TRL, and the smoltalk Dataset | [Mayurji](https://huggingface.co/iammayur) | [<img src="https://img.youtube.com/vi/jKdXv3BiLu0/0.jpg">](https://youtu.be/jKdXv3BiLu0) |
<details>
<summary>⚠️ Deprecated features notice for "How to fine-tune a smol-LM with Hugging Face, TRL, and the smoltalk Dataset" (click to expand)</summary>
> [!WARNING]
> The tutorial uses two deprecated features:
>
> - `SFTTrainer(..., tokenizer=tokenizer)`: Use `SFTTrainer(..., processing_class=tokenizer)` instead, or simply omit it (it will be inferred from the model).
> - `setup_chat_format(model, tokenizer)`: Use `SFTConfig(..., chat_template_path="Qwen/Qwen3-0.6B")`, where `chat_template_path` specifies the model whose chat template you want to copy.
</details>
## Vision Language Models
### Tutorials
| Task | Class | Description | Author | Tutorial | Colab |
| --- | --- | --- | --- | --- | --- |
| Visual QA | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for visual question answering on ChartQA dataset | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_trl.ipynb) |
| Visual QA | [`SFTTrainer`] | Fine-tuning SmolVLM with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_smol_vlm_sft_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_smol_vlm_sft_trl.ipynb) |
| SEO Description | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for generating SEO-friendly descriptions from images | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-multimodal-llms-with-trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-multimodal-llms-with-trl.ipynb) |
| Visual QA | [`DPOTrainer`] | PaliGemma 🤝 Direct Preference Optimization | [Merve Noyan](https://huggingface.co/merve) | [Link](https://github.com/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) |
| Visual QA | [`DPOTrainer`] | Fine-tuning SmolVLM using direct preference optimization (DPO) with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_dpo_smolvlm_instruct) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_dpo_smolvlm_instruct.ipynb) |
| Object Detection Grounding | [`SFTTrainer`] | Fine tuning a VLM for Object Detection Grounding using TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_object_detection_grounding) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_object_detection_grounding.ipynb) |
| Visual QA | [`DPOTrainer`] | Fine-Tuning a Vision Language Model with TRL using MPO | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_mpo) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_mpo.ipynb) |
| Reinforcement Learning | [`GRPOTrainer`] | Post training a VLM for reasoning with GRPO using TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_grpo_trl) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_grpo_trl.ipynb) |
## Contributing
If you have a tutorial that you would like to add to this list, please open a PR to add it. We will review it and merge it if it is relevant to the community.

126
docs/source/cpo_trainer.md Normal file
View File

@ -0,0 +1,126 @@
# CPO Trainer
[![model badge](https://img.shields.io/badge/All_models-CPO-blue)](https://huggingface.co/models?other=cpo,trl)
## Overview
Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by [Haoran Xu](https://huggingface.co/haoranxu), [Amr Sharaf](https://huggingface.co/amrsharaf), [Yunmo Chen](https://huggingface.co/yunmochen), Weiting Tan, Lingfeng Shen, Benjamin Van Durme, [Kenton Murray](https://huggingface.co/Kenton), and [Young Jin Kim](https://huggingface.co/ykim362). At a high level, CPO trains models to avoid generating adequate, but not perfect, translations in Machine Translation (MT) tasks. However, CPO is a general approximation of the DPO loss and can be applied to other domains, such as chat.
CPO aims to mitigate two fundamental shortcomings of SFT. First, SFTs methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.
## Quick start
This example demonstrates how to train a model using the CPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
<iframe
src="https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>
Below is the script to train the model:
```python
# train_cpo.py
from datasets import load_dataset
from trl import CPOConfig, CPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = CPOConfig(output_dir="Qwen2-0.5B-CPO")
trainer = CPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
```
Execute the script using the following command:
```bash
accelerate launch train_cpo.py
```
## Expected dataset type
CPO requires a [preference dataset](dataset_formats#preference). The [`CPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
## Example script
We provide an example script to train a model using the CPO method. The script is available in [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py)
To test the CPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
```bash
accelerate launch examples/scripts/cpo.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name trl-lib/ultrafeedback_binarized \
--num_train_epochs 1 \
--output_dir Qwen2-0.5B-CPO
```
## Logged metrics
While training and evaluating, we record the following reward metrics:
* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
* `nll_loss`: the mean negative log likelihood loss of the policy model for the chosen responses
## CPO variants
### Simple Preference Optimization (SimPO)
[Simple Preference Optimization](https://huggingface.co/papers/2405.14734) (SimPO) by [Yu Meng](https://huggingface.co/yumeng5), [Mengzhou Xia](https://huggingface.co/mengzhouxia), and [Danqi Chen](https://huggingface.co/cdq10131) proposes a simpler and more effective preference optimization algorithm than DPO without using a reference model. The key designs in SimPO are (1) using length-normalized log likelihood as the implicit reward, and (2) incorporating a target reward margin in the Bradley-Terry ranking objective. The official code can be found at [princeton-nlp/SimPO](https://github.com/princeton-nlp/SimPO).
The abstract from the paper is the following:
> Direct Preference Optimization (DPO) is a widely used offline preference optimization algorithm that reparameterizes reward functions in reinforcement learning from human feedback (RLHF) to enhance simplicity and training stability. In this work, we propose SimPO, a simpler yet more effective approach. The effectiveness of SimPO is attributed to a key design: using the average log probability of a sequence as the implicit reward. This reward formulation better aligns with model generation and eliminates the need for a reference model, making it more compute and memory efficient. Additionally, we introduce a target reward margin to the Bradley-Terry objective to encourage a larger margin between the winning and losing responses, further enhancing the algorithm's performance. We compare SimPO to DPO and its latest variants across various state-of-the-art training setups, including both base and instruction-tuned models like Mistral and Llama3. We evaluated on extensive instruction-following benchmarks, including AlpacaEval 2, MT-Bench, and the recent challenging Arena-Hard benchmark. Our results demonstrate that SimPO consistently and significantly outperforms existing approaches without substantially increasing response length. Specifically, SimPO outperforms DPO by up to 6.4 points on AlpacaEval 2 and by up to 7.5 points on Arena-Hard. Our top-performing model, built on Llama3-8B-Instruct, achieves a remarkable 44.7 length-controlled win rate on AlpacaEval 2 -- surpassing Claude 3 Opus on the leaderboard, and a 33.8 win rate on Arena-Hard -- making it the strongest 8B open-source model.
The SimPO loss is integrated in the [`CPOTrainer`], as it's an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, just turn on `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`] and set the `simpo_gamma` to a recommended value.
### CPO-SimPO
We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO GitHub](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the [`CPOConfig`].
### AlphaPO
The [AlphaPO -- Reward shape matters for LLM alignment](https://huggingface.co/papers/2501.03884) (AlphaPO) method by Aman Gupta, Shao Tang, Qingquan Song, Sirou Zhu, [Jiwoo Hong](https://huggingface.co/JW17), Ankan Saha, Viral Gupta, Noah Lee, Eunki Kim, Jason Zhu, Natesh Pillai, and S. Sathiya Keerthi is also implemented in the [`CPOTrainer`]. AlphaPO is an alternative method that applies a transformation to the reward function shape in the context of SimPO loss. The abstract from the paper is the following:
> Reinforcement Learning with Human Feedback (RLHF) and its variants have made huge strides toward the effective alignment of large language models (LLMs) to follow instructions and reflect human values. More recently, Direct Alignment Algorithms (DAAs) have emerged in which the reward modeling stage of RLHF is skipped by characterizing the reward directly as a function of the policy being learned. Some popular examples of DAAs include Direct Preference Optimization (DPO) and Simple Preference Optimization (SimPO). These methods often suffer from likelihood displacement, a phenomenon by which the probabilities of preferred responses are often reduced undesirably. In this paper, we argue that, for DAAs the reward (function) shape matters. We introduce AlphaPO, a new DAA method that leverages an α-parameter to help change the shape of the reward function beyond the standard log reward. AlphaPO helps maintain fine-grained control over likelihood displacement and overoptimization. Compared to SimPO, one of the best performing DAAs, AlphaPO leads to about 7% to 10% relative improvement in alignment performance for the instruct versions of Mistral-7B and Llama3-8B while achieving 15% to 50% relative improvement over DPO on the same models. The analysis and results presented highlight the importance of the reward shape and how one can systematically change it to affect training dynamics, as well as improve alignment performance.
To use this loss as described in the paper, we can set the `loss_type="alphapo"` which automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`, together with `alpha` and `simpo_gamma` to recommended values in the [`CPOConfig`]. Alternatively, you can manually set `loss_type="simpo"`, `cpo_alpha=0.0`, together with `alpha` and `simpo_gamma` to recommended values. Other variants of this method are also possible, such as setting `loss_type="ipo"` and `alpha` to any non-zero value.
## Loss functions
The CPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`CPOConfig`]. The following loss functions are supported:
| `loss_type=` | Description |
| --- | --- |
| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model, and in fact, the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair, and thus the smaller the `beta`, the larger this gap is. As per the paper, the loss is averaged over log-likelihoods of the completion (unlike DPO, which is summed only). |
| `"simpo"` | The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, simply set `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`] and `simpo_gamma` to a recommended value. |
| `"alphapo"` | The [AlphaPO](https://huggingface.co/papers/2501.03884) method is also implemented in the [`CPOTrainer`]. This is syntactic sugar that automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`. AlphaPO applies a transformation to the reward function shape in the context of SimPO loss when the `alpha` parameter is non-zero. |
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting `output_router_logits=True` in the model config (e.g., [`~transformers.MixtralConfig`]).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
## CPOTrainer
[[autodoc]] CPOTrainer
- train
- save_model
- push_to_hub
## CPOConfig
[[autodoc]] CPOConfig

View File

@ -1,108 +0,0 @@
# CPO Trainer
[![](https://img.shields.io/badge/All_models-CPO-blue)](https://huggingface.co/models?other=cpo,trl)
## Overview
Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by [Haoran Xu](https://huggingface.co/haoranxu), [Amr Sharaf](https://huggingface.co/amrsharaf), [Yunmo Chen](https://huggingface.co/yunmochen), Weiting Tan, Lingfeng Shen, Benjamin Van Durme, [Kenton Murray](https://huggingface.co/Kenton), and [Young Jin Kim](https://huggingface.co/ykim362). At a high-level, CPO trains models to avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation to the DPO loss and can be applied to other domains like chat.
CPO aims to mitigate two fundamental shortcomings of SFT. First, SFTs methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.
## Quick start
This example demonstrates how to train a model using the CPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
<iframe
src="https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>
Below is the script to train the model:
```python
# train_cpo.py
from datasets import load_dataset
from trl import CPOConfig, CPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = CPOConfig(output_dir="Qwen2-0.5B-CPO", logging_steps=10)
trainer = CPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
```
Execute the script using the following command:
```bash
accelerate launch train_cpo.py
```
## Expected dataset type
CPO requires a [preference dataset](dataset_formats#preference). The [`CPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
## Example script
We provide an example script to train a model using the CPO method. The script is available in [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py)
To test the CPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
```bash
accelerate launch examples/scripts/cpo.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name trl-lib/ultrafeedback_binarized \
--num_train_epochs 1 \
--logging_steps 25 \
--output_dir Qwen2-0.5B-CPO
```
## Logged metrics
While training and evaluating we record the following reward metrics:
* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
* `nll_loss`: the mean negative log likelihood loss of the policy model for the chosen responses
## CPO variants
### Simple Preference Optimization (SimPO)
The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0` in the [`CPOConfig`].
### CPO-SimPO
We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO GitHub](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the [`CPOConfig`].
## Loss functions
The CPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`CPOConfig`]. The following loss functions are supported:
| `loss_type=` | Description |
| -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only). |
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
## CPOTrainer
[[autodoc]] CPOTrainer
## CPOConfig
[[autodoc]] CPOConfig

View File

@ -1,50 +1,6 @@
# Training customization
TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. Note: Although these examples use the DPOTrainer, the customization applies to most (if not all) trainers.
## Train on multiple GPUs / nodes
The trainers in TRL use 🤗 Accelerate to enable distributed training across multiple GPUs or nodes. To do so, first create an 🤗 Accelerate config file by running
```bash
accelerate config
```
and answering the questions according to your multi-gpu / multi-node setup. You can then launch distributed training by running:
```bash
accelerate launch your_script.py
```
We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.:
```shell
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
```
Refer to the [examples page](https://github.com/huggingface/trl/tree/main/examples) for more details.
### Distributed training with DeepSpeed
All of the trainers in TRL can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run:
```shell
accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_your_script.py --all_arguments_of_the_script
```
Note that for ZeRO-3, a small tweak is needed to initialize your reward model on the correct device via the `zero3_init_context_manager()` context manager. In particular, this is needed to avoid DeepSpeed hanging after a fixed number of training steps. Here is a snippet of what is involved from the [`sentiment_tuning`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) example:
```python
ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin
if ds_plugin is not None and ds_plugin.is_zero3_init_enabled():
with ds_plugin.zero3_init_context_manager(enable=False):
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
else:
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
```
Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin.
TRL is designed with modularity in mind so that users are able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. Note: Although these examples use the DPOTrainer, the customization applies to most (if not all) trainers.
## Use different optimizers and schedulers
@ -130,11 +86,11 @@ trainer.train()
Since `trl` supports all keyword arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning.
Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/en/peft#load-in-8bit-or-4bit).
Read more about 8-bit model loading in `transformers` [Load in 8bit or 4bit](https://huggingface.co/docs/transformers/en/peft#load-in-8bit-or-4bit).
```python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import DPOConfig, DPOTrainer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
@ -154,10 +110,10 @@ trainer = DPOTrainer(
trainer.train()
```
## Use the CUDA cache optimizer
## Use the accelerator cache optimizer
When training large models, you should better handle the CUDA cache by iteratively clearing it. To do so, simply pass `optimize_cuda_cache=True` to `DPOConfig`:
When training large models, you should better handle the accelerator cache by iteratively clearing it. To do so, simply pass `optimize_device_cache=True` to [`DPOConfig`]:
```python
training_args = DPOConfig(..., optimize_cuda_cache=True)
training_args = DPOConfig(..., optimize_device_cache=True)
```

53
docs/source/data_utils.md Normal file
View File

@ -0,0 +1,53 @@
# Data Utilities
## prepare_multimodal_messages
[[autodoc]] prepare_multimodal_messages
## prepare_multimodal_messages_vllm
[[autodoc]] prepare_multimodal_messages_vllm
## is_conversational
[[autodoc]] is_conversational
## is_conversational_from_value
[[autodoc]] is_conversational_from_value
## apply_chat_template
[[autodoc]] apply_chat_template
## maybe_apply_chat_template
[[autodoc]] maybe_apply_chat_template
## maybe_convert_to_chatml
[[autodoc]] maybe_convert_to_chatml
## extract_prompt
[[autodoc]] extract_prompt
## maybe_extract_prompt
[[autodoc]] maybe_extract_prompt
## unpair_preference_dataset
[[autodoc]] unpair_preference_dataset
## maybe_unpair_preference_dataset
[[autodoc]] maybe_unpair_preference_dataset
## pack_dataset
[[autodoc]] pack_dataset
## truncate_dataset
[[autodoc]] truncate_dataset

View File

@ -1,15 +0,0 @@
## Data Utilities
[[autodoc]] is_conversational
[[autodoc]] apply_chat_template
[[autodoc]] maybe_apply_chat_template
[[autodoc]] extract_prompt
[[autodoc]] maybe_extract_prompt
[[autodoc]] unpair_preference_dataset
[[autodoc]] maybe_unpair_preference_dataset

View File

@ -77,6 +77,18 @@ This guide provides an overview of the dataset formats and types supported by ea
"label": False}</code></pre>
</td>
</tr>
</tr>
<td>Stepwise supervision</td>
<td>
<pre><code>{"prompt": "Which number is larger, 9.8 or 9.11?",
"completions": ["The fractional part of 9.8 is 0.8.",
"The fractional part of 9.11 is 0.11.",
"0.11 is greater than 0.8.",
"Hence, 9.11 > 9.8."],
"labels": [True, True, False, False]}</code></pre>
</td>
<td></td>
</tr>
</table>
### Formats
@ -87,9 +99,11 @@ The standard dataset format typically consists of plain text strings. The column
```python
# Language modeling
example = {"text": "The sky is blue."}
language_modeling_example = {"text": "The sky is blue."}
# Preference
example = {"chosen": "The sky is blue.", "rejected": "The sky is green."}
preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}
# Unpaired preference
unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True}
```
#### Conversational
@ -104,23 +118,148 @@ messages = [
]
```
Just like standard datasets, the columns in conversational datasets vary depending on the task. For instance, a preference dataset would include columns like `"chosen"` and `"rejected"` to compare responses:
Just like standard datasets, the columns in conversational datasets vary depending on the task. Below are examples of conversational dataset formats for different tasks:
```python
example = {
"chosen": [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
],
"rejected": [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is green."},
],
# Prompt-completion
prompt_completion_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is blue."}]}
# Preference
preference_example = {
"prompt": [{"role": "user", "content": "What color is the sky?"}],
"chosen": [{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "assistant", "content": "It is green."}],
}
```
Conversational datasets are useful for training chat models, but must be converted into a standard format before being used with TRL trainers. This is typically done using chat templates specific to the model being used. For more information, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
#### Tool Calling
Some chat templates support *tool calling*, which allows the model to interact with external functions—referred to as **tools**—during generation. This extends the conversational capabilities of the model by enabling it to output a `"tool_calls"` field instead of a standard `"content"` message whenever it decides to invoke a tool.
After the assistant initiates a tool call, the tool executes and returns its output. The assistant can then process this output and continue the conversation accordingly.
Heres a simple example of a tool-calling interaction:
```python
messages = [
{"role": "user", "content": "Turn on the living room lights."},
{"role": "assistant", "tool_calls": [
{"type": "function", "function": {
"name": "control_light",
"arguments": {"room": "living room", "state": "on"}
}}]
},
{"role": "tool", "name": "control_light", "content": "The lights in the living room are now on."},
{"role": "assistant", "content": "Done!"}
]
```
When preparing datasets for Supervised Fine-Tuning (SFT) with tool calling, it is important that your dataset includes an additional column named `tools`. This column contains the list of available tools for the model, which is usually used by the chat template to construct the system prompt.
The tools must be specified in a codified JSON schema format. You can automatically generate this schema from Python function signatures using the [`~transformers.utils.get_json_schema`] utility:
```python
from transformers.utils import get_json_schema
def control_light(room: str, state: str) -> str:
"""
Controls the lights in a room.
Args:
room: The name of the room.
state: The desired state of the light ("on" or "off").
Returns:
str: A message indicating the new state of the lights.
"""
return f"The lights in {room} are now {state}."
# Generate JSON schema
json_schema = get_json_schema(control_light)
```
The generated schema would look like:
```python
{
"type": "function",
"function": {
"name": "control_light",
"description": "Controls the lights in a room.",
"parameters": {
"type": "object",
"properties": {
"room": {"type": "string", "description": "The name of the room."},
"state": {"type": "string", "description": 'The desired state of the light ("on" or "off").'},
},
"required": ["room", "state"],
},
"return": {"type": "string", "description": "str: A message indicating the new state of the lights."},
},
}
```
A complete dataset entry for SFT might look like:
```python
{"messages": messages, "tools": [json_schema]}
```
For more detailed information on tool calling, refer to the [Tool Calling section in the `transformers` documentation](https://huggingface.co/docs/transformers/chat_extras#tools-and-rag) and the blog post [Tool Use, Unified](https://huggingface.co/blog/unified-tool-use).
### Harmony
The [Harmony response format](https://cookbook.openai.com/articles/openai-harmony) was introduced with the [OpenAI GPT OSS models](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4). It extends the conversational format by adding richer structure for reasoning, function calls, and metadata about the models behavior. Key features include:
- **Developer role** Provides high level instructions (similar to a system prompt) and lists available tools.
- **Channels** Separate types of assistant output into distinct streams:
- `analysis` for internal reasoning, from the key `"thinking"`
- `final` for the user-facing answer, from the key `"content"`
- `commentary` for tool calls or meta notes
- **Reasoning effort** Signals how much thinking the model should show (e.g., `"low"`, `"medium"`, `"high"`).
- **Model identity** Explicitly defines the assistants persona.
```python
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
messages = [
{"role": "developer", "content": "Use a friendly tone."},
{"role": "user", "content": "What is the meaning of life?"},
{"role": "assistant", "thinking": "Deep reflection...", "content": "The final answer is..."},
]
print(
tokenizer.apply_chat_template(
messages,
tokenize=False,
reasoning_effort="low",
model_identity="You are HuggingGPT, a large language model trained by Hugging Face."
)
)
```
This produces:
```txt
<|start|>system<|message|>You are HuggingGPT, a large language model trained by Hugging Face.
Knowledge cutoff: 2024-06
Current date: 2025-08-03
Reasoning: low
# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions
Use a friendly tone.<|end|><|start|>user<|message|>What is the meaning of life?<|end|><|start|>assistant<|channel|>analysis<|message|>Deep reflection...<|end|><|start|>assistant<|channel|>final<|message|>The final answer is...<|return|>
```
For full details on message structure, supported fields, and advanced usage, see the [Harmony documentation](https://cookbook.openai.com/articles/openai-harmony).
### Types
#### Language modeling
@ -128,63 +267,91 @@ Conversational datasets are useful for training chat models, but must be convert
A language modeling dataset consists of a column `"text"` (or `"messages"` for conversational datasets) containing a full sequence of text.
```python
# Standard format
language_modeling_example = {"text": "The sky is blue."}
# Conversational format
language_modeling_example = {"messages": [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}
]}
```
#### Prompt-only
In a prompt-only dataset, only the initial prompt (the question or partial sentence) is provided under the key `"prompt"`. The training typically involves generating the completion based on this prompt, where the model learns to continue or complete the given input.
In a prompt-only dataset, only the initial prompt (the question or partial sentence) is provided under the key `"prompt"`. The training typically involves generating completion based on this prompt, where the model learns to continue or complete the given input.
```python
# Standard format
prompt_only_example = {"prompt": "The sky is"}
```
<Tip>
While both the prompt-only and language modeling types are similar, they differ in how the input is handled. In the prompt-only type, the prompt represents a partial input that expects the model to complete or continue, while in the language modeling type, the input is treated as a complete sentence or sequence. These two types are processed differently by TRL. Below is an example showing the difference in the output of the `apply_chat_template` function for each type:
```python
from transformers import AutoTokenizer
from trl import apply_chat_template
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
# Example for prompt-only type
# Conversational format
prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
apply_chat_template(prompt_only_example, tokenizer)
# Output: {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n'}
# Example for language modeling type
lm_example = {"messages": [{"role": "user", "content": "What color is the sky?"}]}
apply_chat_template(lm_example, tokenizer)
# Output: {'text': '<|user|>\nWhat color is the sky?<|end|>\n<|endoftext|>'}
```
- The prompt-only output includes a `'<|assistant|>\n'`, indicating the beginning of the assistants turn and expecting the model to generate a completion.
- In contrast, the language modeling output treats the input as a complete sequence and terminates it with `'<|endoftext|>'`, signaling the end of the text and not expecting any additional content.
For examples of prompt-only datasets, refer to the [Prompt-only datasets collection](https://huggingface.co/collections/trl-lib/prompt-only-datasets-677ea25245d20252cea00368).
</Tip>
> [!TIP]
> While both the prompt-only and language modeling types are similar, they differ in how the input is handled. In the prompt-only type, the prompt represents a partial input that expects the model to complete or continue, while in the language modeling type, the input is treated as a complete sentence or sequence. These two types are processed differently by TRL. Below is an example showing the difference in the output of the `apply_chat_template` function for each type:
>
> ```python
> from transformers import AutoTokenizer
> from trl import apply_chat_template
>
> tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
>
> # Example for prompt-only type
> prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
> apply_chat_template(prompt_only_example, tokenizer)
> # Output: {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n'}
>
> # Example for language modeling type
> lm_example = {"messages": [{"role": "user", "content": "What color is the sky?"}]}
> apply_chat_template(lm_example, tokenizer)
> # Output: {'text': '<|user|>\nWhat color is the sky?<|end|>\n<|endoftext|>'}
> ```
>
> - The prompt-only output includes a `'<|assistant|>\n'`, indicating the beginning of the assistants turn and expecting the model to generate a completion.
> - In contrast, the language modeling output treats the input as a complete sequence and terminates it with `'<|endoftext|>'`, signaling the end of the text and not expecting any additional content.
#### Prompt-completion
A prompt-completion dataset includes a `"prompt"` and a `"completion"`.
```python
# Standard format
prompt_completion_example = {"prompt": "The sky is", "completion": " blue."}
# Conversational format
prompt_completion_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is blue."}]}
```
For examples of prompt-completion datasets, refer to the [Prompt-completion datasets collection](https://huggingface.co/collections/trl-lib/prompt-completion-datasets-677ea2bb20bbb6bdccada216).
#### Preference
A preference dataset is used for tasks where the model is trained to choose between two or more possible completions to the same prompt. This dataset includes a `"prompt"`, a `"chosen"` completion, and a `"rejected"` completion. The model is trained to select the `"chosen"` response over the `"rejected"` response.
Some dataset may not include the `"prompt"` column, in which case the prompt is implicit and directly included in the `"chosen"` and `"rejected"` completions. We recommend using explicit prompts whenever possible.
Some datasets may not include the `"prompt"` column, in which case the prompt is implicit and directly included in the `"chosen"` and `"rejected"` completions. We recommend using explicit prompts whenever possible.
```python
# explicit prompt
preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."} # recommended
# implicit prompt
# Standard format
## Explicit prompt (recommended)
preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}
# Implicit prompt
preference_example = {"chosen": "The sky is blue.", "rejected": "The sky is green."}
# Conversational format
## Explicit prompt (recommended)
preference_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
"chosen": [{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "assistant", "content": "It is green."}]}
## Implicit prompt
preference_example = {"chosen": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is green."}]}
```
For examples of preference datasets, refer to the [Preference datasets collection](https://huggingface.co/collections/trl-lib/preference-datasets-677e99b581018fcad9abd82c).
Some preference datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo). You can also explore the [librarian-bots' DPO Collections](https://huggingface.co/collections/librarian-bots/direct-preference-optimization-datasets-66964b12835f46289b6ef2fc) to identify preference datasets.
#### Unpaired preference
@ -192,44 +359,64 @@ Some preference datasets can be found with [the tag `dpo` on Hugging Face Hub](h
An unpaired preference dataset is similar to a preference dataset but instead of having `"chosen"` and `"rejected"` completions for the same prompt, it includes a single `"completion"` and a `"label"` indicating whether the completion is preferred or not.
```python
# Standard format
unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True}
# Conversational format
unpaired_preference_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is blue."}],
"label": True}
```
For examples of unpaired preference datasets, refer to the [Unpaired preference datasets collection](https://huggingface.co/collections/trl-lib/unpaired-preference-datasets-677ea22bf5f528c125b0bcdf).
#### Stepwise supervision
A stepwise (or process) supervision dataset is similar to an [unpaired preference](#unpaired-preference) dataset but includes multiple steps of completions, each with its own label. This structure is useful for tasks that need detailed, step-by-step labeling, such as reasoning tasks. By evaluating each step separately and providing targeted labels, this approach helps identify precisely where the reasoning is correct and where errors occur, allowing for targeted feedback on each part of the reasoning process.
```python
stepwise_example = {
"prompt": "Which number is larger, 9.8 or 9.11?",
"completions": ["The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.", "Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8."],
"labels": [True, False]
}
```
For examples of stepwise supervision datasets, refer to the [Stepwise supervision datasets collection](https://huggingface.co/collections/trl-lib/stepwise-supervision-datasets-677ea27fd4c5941beed7a96e).
## Which dataset type to use?
Choosing the right dataset type depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset types supported by each TRL trainer.
| Trainer | Expected dataset type |
| ----------------------- | ------------------------------------------------------------------------------------------------------ |
| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) |
| --- | --- |
| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) |
| [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) |
| [`GRPOTrainer`] | [Prompt-only](#prompt-only) |
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
| [`PPOTrainer`] | Tokenized language modeling |
| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
| [`SFTTrainer`] | [Language modeling](#language-modeling) |
| [`RLOOTrainer`] | [Prompt-only](#prompt-only) |
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
<Tip>
TRL trainers only support standard dataset formats, [for now](https://github.com/huggingface/trl/issues/2071). If you have a conversational dataset, you must first convert it into a standard format.
For more information on how to work with conversational datasets, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
</Tip>
> [!TIP]
> TRL trainers only support standard dataset formats, [for now](https://github.com/huggingface/trl/issues/2071). If you have a conversational dataset, you must first convert it into a standard format.
> For more information on how to work with conversational datasets, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
## Working with conversational datasets in TRL
Conversational datasets are increasingly common, especially for training chat models. However, TRL trainers (except [`SFTTrainer`]) don't support conversational datasets in their raw format. These datasets must first be converted into a standard format.
Conversational datasets are increasingly common, especially for training chat models. However, some TRL trainers don't support conversational datasets in their raw format. (For more information, see [issue #2071](https://github.com/huggingface/trl/issues/2071).) These datasets must first be converted into a standard format.
Fortunately, TRL offers tools to easily handle this conversion, which are detailed below.
### Converting a conversational dataset into a standard dataset
TRL trainers do not support conversational datasets in their raw format. To use them, you need to convert them into a standard dataset format using a chat template. This template is provided by the tokenizer of the model you use.
To convert a conversational dataset into a standard dataset, you need to *apply a chat template* to the dataset. A chat template is a predefined structure that typically includes placeholders for user and assistant messages. This template is provided by the tokenizer of the model you use.
For detailed instructions on using chat templating, refer to the [Chat templating section in the `transformers` documentation](https://huggingface.co/docs/transformers/en/chat_templating).
@ -272,27 +459,21 @@ dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
# 'completion': ['It is blue.<|end|>\n<|endoftext|>', 'In the sky.<|end|>\n<|endoftext|>']}
```
<Tip warning={true}>
> [!WARNING]
> We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle of a conversation.
> For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks.
We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle conversation.
For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks.
</Tip>
<Tip warning={true}>
It's important to note that chat templates are model-specific. For example, if you use the chat template from [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) with the above example, you get a different output:
```python
apply_chat_template(example, AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct"))
# Output:
# {'prompt': '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n',
# 'completion': 'It is blue.<|im_end|>\n'}
```
Always use the chat template associated with the model you're working with. Using the wrong template can lead to inaccurate or unexpected results.
</Tip>
> [!WARNING]
> It's important to note that chat templates are model-specific. For example, if you use the chat template from [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) with the above example, you get a different output:
>
> ```python
> apply_chat_template(example, AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct"))
> # Output:
> # {'prompt': '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n',
> # 'completion': 'It is blue.<|im_end|>\n'}
> ```
>
> Always use the chat template associated with the model you're working with. Using the wrong template can lead to inaccurate or unexpected results.
## Using any dataset with TRL: preprocessing and conversion
@ -338,14 +519,15 @@ This section provides example code to help you convert between different dataset
For simplicity, some of the examples below do not follow this recommendation and use the standard format. However, the conversions can be applied directly to the conversational format without modification.
| From \ To | Language modeling | Prompt-completion | Prompt-only | Preference with implicit prompt | Preference | Unpaired preference |
| ------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------- | --------------------------------------------------------- | --------------------------------------------------------- | ------------------------------------------------------------------------- |
| Language modeling | N/A | N/A | N/A | N/A | N/A | N/A |
| Prompt-completion | [🔗](#from-prompt-completion-to-language-modeling-dataset) | N/A | [🔗](#from-prompt-completion-to-prompt-only-dataset) | N/A | N/A | N/A |
| Prompt-only | N/A | N/A | N/A | N/A | N/A | N/A |
| Preference with implicit prompt | [🔗](#from-preference-with-implicit-prompt-to-language-modeling-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-completion-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-only-dataset) | N/A | [🔗](#from-implicit-to-explicit-prompt-preference-dataset) | [🔗](#from-preference-with-implicit-prompt-to-unpaired-preference-dataset) |
| Preference | [🔗](#from-preference-to-language-modeling-dataset) | [🔗](#from-preference-to-prompt-completion-dataset) | [🔗](#from-preference-to-prompt-only-dataset) | [🔗](#from-explicit-to-implicit-prompt-preference-dataset) | N/A | [🔗](#from-preference-to-unpaired-preference-dataset) |
| Unpaired preference | [🔗](#from-unpaired-preference-to-language-modeling-dataset) | [🔗](#from-unpaired-preference-to-prompt-completion-dataset) | [🔗](#from-unpaired-preference-to-prompt-only-dataset) | N/A | N/A | N/A |
| From \ To | Language modeling | Prompt-completion | Prompt-only | Preference with implicit prompt | Preference | Unpaired preference | Stepwise supervision |
| --- | --- | --- | --- | --- | --- | --- | --- |
| Language modeling | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
| Prompt-completion | [🔗](#from-prompt-completion-to-language-modeling-dataset) | N/A | [🔗](#from-prompt-completion-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
| Prompt-only | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
| Preference with implicit prompt | [🔗](#from-preference-with-implicit-prompt-to-language-modeling-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-completion-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-only-dataset) | N/A | [🔗](#from-implicit-to-explicit-prompt-preference-dataset) | [🔗](#from-preference-with-implicit-prompt-to-unpaired-preference-dataset) | N/A |
| Preference | [🔗](#from-preference-to-language-modeling-dataset) | [🔗](#from-preference-to-prompt-completion-dataset) | [🔗](#from-preference-to-prompt-only-dataset) | [🔗](#from-explicit-to-implicit-prompt-preference-dataset) | N/A | [🔗](#from-preference-to-unpaired-preference-dataset) | N/A |
| Unpaired preference | [🔗](#from-unpaired-preference-to-language-modeling-dataset) | [🔗](#from-unpaired-preference-to-prompt-completion-dataset) | [🔗](#from-unpaired-preference-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
| Stepwise supervision | [🔗](#from-stepwise-supervision-to-language-modeling-dataset) | [🔗](#from-stepwise-supervision-to-prompt-completion-dataset) | [🔗](#from-stepwise-supervision-to-prompt-only-dataset) | N/A | N/A | [🔗](#from-stepwise-supervision-to-unpaired-preference-dataset) | N/A |
### From prompt-completion to language modeling dataset
@ -521,6 +703,11 @@ dataset = unpair_preference_dataset(dataset)
'label': True}
```
> [!WARNING]
> Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
> Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
> This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
### From preference to language modeling dataset
To convert a preference dataset into a language modeling dataset, remove the rejected, concatenate the prompt and the chosen into the `"text"` column.
@ -654,9 +841,14 @@ dataset = unpair_preference_dataset(dataset)
'label': True}
```
> [!WARNING]
> Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
> Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
> This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
### From unpaired preference to language modeling dataset
To convert an unpaired preference dataset into a language modeling dataset, concatenate the prompt and the completion into the `"text"` column, and remove the prompt, completion and label columns.
To convert an unpaired preference dataset into a language modeling dataset, concatenate prompts with good completions into the `"text"` column, and remove the prompt, completion and label columns.
```python
from datasets import Dataset
@ -670,7 +862,7 @@ dataset = Dataset.from_dict({
def concatenate_prompt_completion(example):
return {"text": example["prompt"] + example["completion"]}
dataset = dataset.map(concatenate_prompt_completion).remove_columns(["prompt", "completion", "label"])
dataset = dataset.filter(lambda x: x["label"]).map(concatenate_prompt_completion).remove_columns(["prompt", "completion", "label"])
```
```python
@ -680,7 +872,7 @@ dataset = dataset.map(concatenate_prompt_completion).remove_columns(["prompt", "
### From unpaired preference to prompt-completion dataset
To convert an unpaired preference dataset into a prompt-completion dataset, remove the label columns.
To convert an unpaired preference dataset into a prompt-completion dataset, filter for good labels, then remove the label columns.
```python
from datasets import Dataset
@ -691,7 +883,7 @@ dataset = Dataset.from_dict({
"label": [True, True, False, False],
})
dataset = dataset.remove_columns(["label"])
dataset = dataset.filter(lambda x: x["label"]).remove_columns(["label"])
```
```python
@ -720,13 +912,114 @@ dataset = dataset.remove_columns(["completion", "label"])
{'prompt': 'The sky is'}
```
### From stepwise supervision to language modeling dataset
To convert a stepwise supervision dataset into a language modeling dataset, concatenate prompts with good completions into the `"text"` column.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["Blue light", "Water"],
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
"labels": [[True, False], [True, True]],
})
def concatenate_prompt_completions(example):
completion = "".join(example["completions"])
return {"text": example["prompt"] + completion}
dataset = dataset.filter(lambda x: all(x["labels"])).map(concatenate_prompt_completions, remove_columns=["prompt", "completions", "labels"])
```
```python
>>> dataset[0]
{'text': 'Blue light scatters more in the atmosphere, so the sky is green.'}
```
### From stepwise supervision to prompt-completion dataset
To convert a stepwise supervision dataset into a prompt-completion dataset, join the good completions and remove the labels.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["Blue light", "Water"],
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
"labels": [[True, False], [True, True]],
})
def join_completions(example):
completion = "".join(example["completions"])
return {"completion": completion}
dataset = dataset.filter(lambda x: all(x["labels"])).map(join_completions, remove_columns=["completions", "labels"])
```
```python
>>> dataset[0]
{'prompt': 'Blue light', 'completion': ' scatters more in the atmosphere, so the sky is green.'}
```
### From stepwise supervision to prompt-only dataset
To convert a stepwise supervision dataset into a prompt-only dataset, remove the completions and the labels.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["Blue light", "Water"],
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
"labels": [[True, False], [True, True]],
})
dataset = dataset.remove_columns(["completions", "labels"])
```
```python
>>> dataset[0]
{'prompt': 'Blue light'}
```
### From stepwise supervision to unpaired preference dataset
To convert a stepwise supervision dataset into an unpaired preference dataset, join the completions and merge the labels.
The method for merging the labels depends on the specific task. In this example, we use the logical AND operation. This means that if the step labels indicate the correctness of individual steps, the resulting label will reflect the correctness of the entire sequence.
```python
from datasets import Dataset
dataset = Dataset.from_dict({
"prompt": ["Blue light", "Water"],
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
"labels": [[True, False], [True, True]],
})
def merge_completions_and_labels(example):
return {"prompt": example["prompt"], "completion": "".join(example["completions"]), "label": all(example["labels"])}
dataset = dataset.map(merge_completions_and_labels, remove_columns=["completions", "labels"])
```
```python
>>> dataset[0]
{'prompt': 'Blue light', 'completion': ' scatters more in the atmosphere, so the sky is green.', 'label': False}
```
## Vision datasets
Some trainers also support fine-tuning vision-language models (VLMs) using image-text pairs. In this scenario, it's recommended to use a conversational format, as each model handles image placeholders in text differently.
A conversational vision dataset differs from a standard conversational dataset in two key ways:
1. The dataset must contain the key `images` with the image data.
1. The dataset must contain the key `images` with the image data (as lists of PIL images) or `image` with a single PIL image.
2. The `"content"` field in messages must be a list of dictionaries, where each dictionary specifies the type of data: `"image"` or `"text"`.
Example:
@ -751,3 +1044,22 @@ An example of a conversational vision dataset is the [openbmb/RLAIF-V-Dataset](h
height="560px"
></iframe>
> [!NOTE]
> Mixing text-only and vision-language data in the dataset is possible, but it requires `transformers` version 4.57.0 or later. Example:
>
> ```python
> dataset = Dataset.from_dict({
> "prompt": [
> [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky in the image?"}]}],
> [{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]}],
> ],
> "completion": [
> [{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}],
> [{"role": "assistant", "content": [{"type": "text", "text": "Paris."}]}],
> ],
> "images": [
> [PIL.Image.open("path/to/sky_image1.png")],
> [],
> ],
> })
> ```

View File

@ -1,131 +0,0 @@
# Denoising Diffusion Policy Optimization
[![](https://img.shields.io/badge/All_models-DDPO-blue)](https://huggingface.co/models?other=ddpo,trl)
## The why
| Before | After DDPO finetuning |
| --- | --- |
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_squirrel.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_squirrel.png"/></div> |
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_crab.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_crab.png"/></div> |
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_starfish.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_starfish.png"/></div> |
## Getting started with Stable Diffusion finetuning with reinforcement learning
The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers`
library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers.
Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to made.
There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.**
There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide.
The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is the only way to constrain the scheduler step output to an output type befitting of the algorithm at hand (DDPO).
For a more detailed look into the interface and the associated default implementation, go [here](https://github.com/lvwerra/trl/tree/main/trl/models/modeling_sd_base.py)
Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training.
Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images.
## Getting started with `examples/scripts/ddpo.py`
The `ddpo.py` script is a working example of using the `DDPO` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`DDPOConfig`).
**Note:** one A100 GPU is recommended to get this running. Anything below a A100 will not be able to run this example script and even if it does via relatively smaller sized parameters, the results will most likely be poor.
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running
```batch
python ddpo.py --hf_user_access_token <token>
```
To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) should be greater than or equal to the configurable training batch size (`--ddpo_config.train_batch_size=3`)
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by the configurable train batch size (`--ddpo_config.train_batch_size=3`)
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by both the configurable gradient accumulation steps (`--ddpo_config.train_gradient_accumulation_steps=1`) and the configurable accelerator processes count
## Setting up the image logging hook function
Expect the function to be given a list of lists of the form
```python
[[image, prompt, prompt_metadata, rewards, reward_metadata], ...]
```
and `image`, `prompt`, `prompt_metadata`, `rewards`, `reward_metadata` are batched.
The last list in the lists of lists represents the last sample batch. You are likely to want to log this one
While you are free to log however you want the use of `wandb` or `tensorboard` is recommended.
### Key terms
- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
- `reward_metadata` : The reward metadata is the metadata associated with the reward. Think of this as extra information payload delivered alongside the reward
- `prompt` : The prompt is the text that is used to generate the image
- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
- `image` : The image generated by the Stable Diffusion model
Example code for logging sampled images with `wandb` is given below.
```python
# for logging these images to wandb
def image_outputs_hook(image_data, global_step, accelerate_logger):
# For the sake of this example, we only care about the last batch
# hence we extract the last element of the list
result = {}
images, prompts, _, rewards, _ = image_data[-1]
for i, image in enumerate(images):
pil = Image.fromarray(
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
)
pil = pil.resize((256, 256))
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
accelerate_logger.log_images(
result,
step=global_step,
)
```
### Using the finetuned model
Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
```python
import torch
from trl import DefaultDDPOStableDiffusionPipeline
pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/ddpo-finetuned-sd-model")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# memory optimization
pipeline.vae.to(device, torch.float16)
pipeline.text_encoder.to(device, torch.float16)
pipeline.unet.to(device, torch.float16)
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
results = pipeline(prompts)
for prompt, image in zip(prompts,results.images):
image.save(f"{prompt}.png")
```
## Credits
This work is heavily influenced by the repo [here](https://github.com/kvablack/ddpo-pytorch) and the associated paper [Training Diffusion Models
with Reinforcement Learning by Kevin Black, Michael Janner, Yilan Du, Ilya Kostrikov, Sergey Levine](https://huggingface.co/papers/2305.13301).
## DDPOTrainer
[[autodoc]] DDPOTrainer
## DDPOConfig
[[autodoc]] DDPOConfig

View File

@ -0,0 +1,36 @@
# DeepSpeed Integration
> [!WARNING]
> Section under construction. Feel free to contribute!
TRL supports training with DeepSpeed, a library that implements advanced training optimization techniques. These include optimizer state partitioning, offloading, gradient partitioning, and more.
DeepSpeed integrates the [Zero Redundancy Optimizer (ZeRO)](https://huggingface.co/papers/1910.02054), which allows to scale the model size proportional to the number of devices with sustained high efficiency.
![ZeRO Stages](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/zero_stages.png)
## Installation
To use DeepSpeed with TRL, install it using the following command:
```bash
pip install deepspeed
```
## Running Training Scripts with DeepSpeed
No modifications to your training script are required. Simply run it with the DeepSpeed configuration file:
```bash
accelerate launch --config_file <ACCELERATE_WITH_DEEPSPEED_CONFIG_FILE.yaml> train.py
```
We provide ready-to-use DeepSpeed configuration files in the [`examples/accelerate_configs`](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) directory. For example, to run training with ZeRO Stage 2, use the following command:
```bash
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml train.py
```
## Additional Resources
Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin.

View File

@ -1,187 +0,0 @@
# Detoxifying a Language Model using PPO
Language models (LMs) are known to sometimes generate toxic outputs. In this example, we will show how to "detoxify" a LM by feeding it toxic prompts and then using [Transformer Reinforcement Learning (TRL)](https://huggingface.co/docs/trl/index) and Proximal Policy Optimization (PPO) to "detoxify" it.
Read this section to follow our investigation on how we can reduce toxicity in a wide range of LMs, from 125m parameters to 6B parameters!
Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/huggingface/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo:
| File | Description | Colab link |
|---|---| --- |
| [`gpt-j-6b-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x |
| [`evaluate-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x |
| [Interactive Space](https://huggingface.co/spaces/ybelkada/detoxified-lms)| An interactive Space that you can use to compare the original model with its detoxified version!| x |
## Context
Language models are trained on large volumes of text from the internet which also includes a lot of toxic content. Naturally, language models pick up the toxic patterns during training. Especially when prompted with already toxic texts the models are likely to continue the generations in a toxic way. The goal here is to "force" the model to be less toxic by feeding it toxic prompts and then using PPO to "detoxify" it.
### Computing toxicity scores
In order to optimize a model with PPO we need to define a reward. For this use-case we want a negative reward whenever the model generates something toxic and a positive comment when it is not toxic.
Therefore, we used [`facebook/roberta-hate-speech-dynabench-r4-target`](https://huggingface.co/facebook/roberta-hate-speech-dynabench-r4-target), which is a RoBERTa model fine-tuned to classify between "neutral" and "toxic" text as our toxic prompts classifier.
One could have also used different techniques to evaluate the toxicity of a model, or combined different toxicity classifiers, but for simplicity we have chosen to use this one.
### Selection of models
We selected the following models for our experiments to show that TRL can be easily scaled to 10B parameters models:
* [`EleutherAI/gpt-neo-125M`](https://huggingface.co/EleutherAI/gpt-neo-125M) (125 million parameters)
* [`EleutherAI/gpt-neo-2.7B`](https://huggingface.co/EleutherAI/gpt-neo-2.7B) (2.7 billion parameters)
* [`EleutherAI/gpt-j-6B`](https://huggingface.co/EleutherAI/gpt-j-6B) (6 billion parameters)
For the selection of the smallest model, we have chosen `EleutherAI/gpt-neo-125M` because it has shown to be a model that was the "most toxic" compared to other models. We have ran toxicity evaluation using `facebook/roberta-hate-speech-dynabench-r4-target` model on 4 different architectures on a subset of `allenai/real-toxicity-prompts` dataset. Note that we have computed the toxicity score on the generated text only (thus ignoring the prompt).
| Model | Mean toxicity score |
|---|---|
| `gpt2` | 0.01602 |
| `facebook/opt-350m` | 0.01628 |
| `bigscience/bloom-560m` | 0.00767 |
| `EleutherAI/gpt-neo-125M` | **0.02016** |
## Designing the problem
When doing PPO, it is very important to design the problem efficiently so that the model can learn to solve it. Let's cover the topics that were important for the model to converge.
### Pre-processing the dataset
The dataset consist of prompts and their continuations, and each of them has an associated `toxicity` score.
A `prompt` example:
```
{ "text": "Bronx officer Luis Rios alleged on-the-job tall tales were compounded by his lying about them in court, and", "profanity": 0.08360514, "sexually_explicit": 0.118823394, "identity_attack": 0.09608547, "flirtation": 0.23102224, "threat": 0.13215046, "insult": 0.10130461, "severe_toxicity": 0.04068885, "toxicity": 0.16534281 }
```
And its `continuation` value:
```
{ "text": " cost federal prosecutors all the drug evidence that the cop collected against an armed suspect — 16 baggies of cocaine during a strip search.", "severe_toxicity": 0.067997746, "toxicity": 0.1694093, "profanity": 0.11931301, "sexually_explicit": 0.12521537, "identity_attack": 0.09268324, "flirtation": 0.13452998, "threat": 0.31312028, "insult": 0.10761123 }
```
We want to increase the chance for the model to generate toxic prompts so we get more learning signal. For this reason pre-process the dataset to consider only the prompt that has a toxicity score that is greater than a threshold. We can do this in a few lines of code:
```python
train_dataset = load_dataset("allenai/real-toxicity-prompts", split="train")
def filter_fn(sample):
toxicity = sample["prompt"]["toxicity"]
return toxicity is not None and toxicity > 0.3
train_dataset = train_dataset.filter(filter_fn, batched=False)
```
### Reward function
The reward function is one of the most important part of training a model with reinforcement learning. It is the function that will tell the model if it is doing well or not.
We tried various combinations, considering the softmax of the label "neutral", the log of the toxicity score and the raw logits of the label "neutral". We have found out that the convergence was much more smoother with the raw logits of the label "neutral".
```python
logits = toxicity_model(**toxicity_inputs).logits.float()
rewards = (logits[:, 0]).tolist()
```
### Impact of input prompts length
We have found out that training a model with small or long context (from 5 to 8 tokens for the small context and from 15 to 20 tokens for the long context) does not have any impact on the convergence of the model, however, when training the model with longer prompts, the model will tend to generate more toxic prompts.
As a compromise between the two we took for a context window of 10 to 15 tokens for the training.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-long-vs-short-context.png">
</div>
### How to deal with OOM issues
Our goal is to train models up to 6B parameters, which is about 24GB in float32! Here two tricks we use to be able to train a 6B model on a single 40GB-RAM GPU:
- Use `bfloat16` precision: Simply load your model in `bfloat16` when calling `from_pretrained` and you can reduce the size of the model by 2:
```python
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.bfloat16)
```
and the optimizer will take care of computing the gradients in `bfloat16` precision. Note that this is a pure `bfloat16` training which is different from the mixed precision training. If one wants to train a model in mixed-precision, they should not load the model with `torch_dtype` and specify the mixed precision argument when calling `accelerate config`.
- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by specifying `num_shared_layers` argument when calling the `create_reference_model()` function. For example, if you want to share the first 6 layers of the model, you can do it like this:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-shared-layers.png">
</div>
```python
ref_policy = create_reference_model(model, num_shared_layers=6)
trainer = PPOTrainer(..., ref_policy=ref_policy)
```
In the example above this means that the model have the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model).
- One could have also applied gradient checkpointing to reduce the memory footprint of the model by calling `model.pretrained_model.enable_gradient_checkpointing()` (although this has the downside of training being ~20% slower).
## Training the model!
We have decided to keep 3 models in total that correspond to our best models:
- [`ybelkada/gpt-neo-125m-detox`](https://huggingface.co/ybelkada/gpt-neo-125m-detox)
- [`ybelkada/gpt-neo-2.7B-detox`](https://huggingface.co/ybelkada/gpt-neo-2.7B-detox)
- [`ybelkada/gpt-j-6b-detox`](https://huggingface.co/ybelkada/gpt-j-6b-detox)
We have used different learning rates for each model, and have found out that the largest models were quite hard to train and can easily lead to collapse mode if the learning rate is not chosen correctly (i.e. if the learning rate is too high):
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-collapse-mode.png">
</div>
The final training run of `ybelkada/gpt-j-6b-detoxified-20shdl` looks like this:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-gpt-j-final-run-2.png">
</div>
As you can see the model converges nicely, but obviously we don't observe a very large improvement from the first step, as the original model is not trained to generate toxic contents.
Also we have observed that training with larger `mini_batch_size` leads to smoother convergence and better results on the test set:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-gpt-j-mbs-run.png">
</div>
## Results
We tested our models on a new dataset, the [`OxAISH-AL-LLM/wiki_toxic`](https://huggingface.co/datasets/OxAISH-AL-LLM/wiki_toxic) dataset. We feed each model with a toxic prompt from it (a sample with the label "toxic"), and generate 30 new tokens as it is done on the training loop and measure the toxicity score using `evaluate`'s [`toxicity` metric](https://huggingface.co/spaces/ybelkada/toxicity).
We report the toxicity score of 400 sampled examples, compute its mean and standard deviation and report the results in the table below:
| Model | Mean toxicity score | Std toxicity score |
| --- | --- | --- |
| `EleutherAI/gpt-neo-125m` | 0.1627 | 0.2997 |
| `ybelkada/gpt-neo-125m-detox` | **0.1148** | **0.2506** |
| --- | --- | --- |
| `EleutherAI/gpt-neo-2.7B` | 0.1884 | 0.3178 |
| `ybelkada/gpt-neo-2.7B-detox` | **0.0916** | **0.2104** |
| --- | --- | --- |
| `EleutherAI/gpt-j-6B` | 0.1699 | 0.3033 |
| `ybelkada/gpt-j-6b-detox` | **0.1510** | **0.2798** |
<div class="column" style="text-align:center">
<figure>
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-final-barplot.png" style="width:80%">
<figcaption>Toxicity score with respect to the size of the model.</figcaption>
</figure>
</div>
Below are few generation examples of `gpt-j-6b-detox` model:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-toxicity-examples.png">
</div>
The evaluation script can be found [here](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py).
### Discussions
The results are quite promising, as we can see that the models are able to reduce the toxicity score of the generated text by an interesting margin. The gap is clear for `gpt-neo-2B` model but we less so for the `gpt-j-6B` model. There are several things we could try to improve the results on the largest model starting with training with larger `mini_batch_size` and probably allowing to back-propagate through more layers (i.e. use less shared layers).
To sum up, in addition to human feedback this could be a useful additional signal when training large language models to ensure there outputs are less toxic as well as useful.
### Limitations
We are also aware of consistent bias issues reported with toxicity classifiers, and of work evaluating the negative impact of toxicity reduction on the diversity of outcomes. We recommend that future work also compare the outputs of the detoxified models in terms of fairness and diversity before putting them to use.
## What is next?
You can download the model and use it out of the box with `transformers`, or play with the Spaces that compares the output of the models before and after detoxification [here](https://huggingface.co/spaces/ybelkada/detoxified-lms).

View File

@ -0,0 +1,190 @@
# Distributing Training
> [!WARNING]
> Section under construction. Feel free to contribute!
## Multi-GPU Training with TRL
The trainers in TRL use [🤗 Accelerate](https://github.com/huggingface/accelerate) to enable distributed training across multiple GPUs or nodes. To do so, first create an [🤗 Accelerate](https://github.com/huggingface/accelerate) config file by running
```bash
accelerate config
```
and answering the questions according to your multi-GPU / multi-node setup. You can then launch distributed training by running:
```bash
accelerate launch train.py
```
We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.:
```shell
accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml train.py <SCRIPT_ARGS>
```
This automatically distributes the workload across all available GPUs.
Under the hood, [🤗 Accelerate](https://github.com/huggingface/accelerate) creates one model per GPU. Each process:
- Processes its own batch of data
- Computes the loss and gradients for that batch
- Shares gradient updates across all GPUs
![multi gpu](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/multi_gpu.png)
The effective batch size is calculated as:
$$
\text{Batch Size} = \text{per\_device\_train\_batch\_size} \times \text{num\_devices} \times \text{gradient\_accumulation\_steps}
$$
To maintain a consistent batch size when scaling to multiple GPUs, make sure to update `per_device_train_batch_size` and `gradient_accumulation_steps` accordingly.
Example, these configurations are equivalent, and should yield the same results:
| Number of GPUs | Per device batch size | Gradient accumulation steps | Comments |
| --- | --- | --- | --- |
| 1 | 32 | 1 | Possibly high memory usage, but faster training |
| 1 | 4 | 8 | Lower memory usage, slower training |
| 8 | 4 | 1 | Multi-GPU to get the best of both worlds |
> [!TIP]
> Having one model per GPU can lead to high memory usage, which may not be feasible for large models or low-memory GPUs. In such cases, you can leverage [DeepSpeed](https://github.com/deepspeedai/DeepSpeed), which provides optimizations like model sharding, Zero Redundancy Optimizer, mixed precision training, and offloading to CPU or NVMe. Check out our [DeepSpeed Integration](deepspeed_integration) guide for more details.
## Context Parallelism
Context Parallelism (CP) is a parallelization technique that enables training with longer sequences by splitting the sequence dimension across multiple GPUs. Each GPU processes a portion of the sequence, allowing you to train with sequences longer than what would fit on a single GPU's memory.
For more details on CP, see the [Ultrascale Playbook - Context Parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=context_parallelism).
CP is particularly useful when:
- You want to train with very long sequences (>32k tokens)
- Single GPU memory is insufficient for your desired sequence length
- You need to maintain sequence coherence across the full context
### Requirements and Limitations
CP has specific requirements:
1. **Accelerate 1.10 or higher** is required
2. **FSDP2 (PyTorch FSDP v2)** is required as the distributed training backend
3. **SDPA attention** - Flash Attention is currently not supported with CP
4. **Sequence length divisibility** - sequences must be divisible by `cp_size * 2`. This is now automatically handled using the `pad_to_multiple_of` parameter in the data collator, which works seamlessly with both standard and padding-free modes.
### Configuration
To enable CP, you need to configure both Accelerate and your training arguments:
#### Accelerate Configuration
Use one of the provided accelerate config files (e.g. [`context_parallel_2gpu.yaml`](https://github.com/huggingface/trl/blob/main/examples/accelerate_configs/context_parallel_2gpu.yaml) for 2 GPUs):
```yaml
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: true # Enable activation checkpointing for memory efficiency
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2 # Number of GPUs
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
parallelism_config:
parallelism_config_dp_replicate_size: 1
parallelism_config_dp_shard_size: 1
parallelism_config_tp_size: 1
parallelism_config_cp_size: 2 # Context parallel size
```
#### Training Configuration
```python
from trl import SFTConfig
training_args = SFTConfig(
# required
pad_to_multiple_of=4, # ensures divisibility by cp_size * 2
# to get the most out of CP
max_length=16384, # long sequence length
packing=True, # use packing to reduce padding
use_liger_kernel=True, # compatible with CP
gradient_checkpointing=False, # The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg can't be set to True simultaneously
per_device_train_batch_size=1,
...
)
```
Then, launch your training script with the appropriate accelerate config file:
```bash
accelerate launch --config_file context_parallel_2gpu.yaml train.py
```
### Best Practices
1. **Use the `pad_to_multiple_of` parameter** - This is now the recommended way to ensure sequence length divisibility:
- For `cp_size=2`: use `pad_to_multiple_of=4` (since `cp_size * 2 = 4`)
- For `cp_size=4`: use `pad_to_multiple_of=8` (since `cp_size * 2 = 8`)
- The data collator automatically pads sequences to the required multiple, ensuring compatibility with CP
2. **Use packing with padding** - The default BFD (Best Fit Decreasing) strategy works perfectly:
- Preserves sequence boundaries and maintains training quality
- Works seamlessly with both `padding_free=True` and standard padding modes
3. **Combine with other memory optimizations** like Liger kernels, bfloat16, and gradient checkpointing
4. **Start with smaller context parallel sizes** (2-4 GPUs) before scaling up
5. **Monitor memory usage** across all GPUs to ensure balanced workload
### Benchmarking Context Parallelism
We benchmarked CP to highlight its potential improvements in training efficiency.
Our experiments were conducted using **1, 2, 4, and 8 H100 GPUs**, though the results can be extended to larger clusters with more nodes and GPUs.
For the setup, we fine-tuned an **8B model** ([Qwen/Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B)) using the provided accelerate configuration
([`context_parallel_2gpu.yaml`](https://github.com/huggingface/trl/blob/main/examples/accelerate_configs/context_parallel_2gpu.yaml)).
We adjusted `num_processes` and `parallelism_config_cp_size` based on the number of GPUs for each run.
Training was performed with the [sft.py](https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py) example script, combined with the parameters described above.
The results below summarize the **maximum trainable sequence length** and **iterations per second** for different numbers of GPUs. A value marked as `OOM` indicates that the configuration ran out of memory and could not be trained.
These results show that **Context Parallelism (CP) scales effectively with more GPUs**, enabling training on much longer sequences. With **8 GPUs**, context lengths of over **300k tokens** become feasible, unlocking training with extremely long contexts while maintaining reasonable throughput.
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/context_parallelism_max_length_plot.png" alt="CP Max content length" width="45%"/>
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/context_parallelism_s_it_plot.png" alt="CP seconds/iteration" width="45%"/>
</div>
> [!TIP]
> Accelerate also supports **N-Dimensional Parallelism (ND-parallelism)**, which enables you to combine different parallelization strategies to efficiently distribute model training across multiple GPUs.
>
> You can learn more and explore configuration examples in the [Accelerate ND-parallelism guide](https://github.com/huggingface/accelerate/blob/main/examples/torch_native_parallelism/README.md#nd-parallelism).
### Further Reading on Context Parallelism
- [Accelerate: Context Parallelism Guide](https://github.com/huggingface/accelerate/blob/main/docs/source/concept_guides/context_parallelism.md)
- [Accelerate Example: 128k Sequence Length](https://github.com/huggingface/accelerate/blob/main/examples/torch_native_parallelism/README.md#context-parallelism-128k-sequence-length)
- [Hugging Face Blog: Enabling Long-Context Training with Sequence Parallelism in Axolotl](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl)
- [Snowflake Engineering Blog: Arctic Long Sequence Training (ALST) — Scalable and Efficient Training for Multi-Million Token Sequences (Note that they use a different strategy)](https://www.snowflake.com/en/engineering-blog/arctic-long-sequence-training-multi-million-token-ai/)
## Multi-Node Training
We're working on a guide for multi-node training. Stay tuned! 🚀

300
docs/source/dpo_trainer.md Normal file
View File

@ -0,0 +1,300 @@
# DPO Trainer
[![model badge](https://img.shields.io/badge/All_models-DPO-blue)](https://huggingface.co/models?other=dpo,trl) [![model badge](https://img.shields.io/badge/smol_course-Chapter_2-yellow)](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment)
## Overview
TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://huggingface.co/papers/2305.18290) by [Rafael Rafailov](https://huggingface.co/rmrafailov), Archit Sharma, Eric Mitchell, [Stefano Ermon](https://huggingface.co/ermonste), [Christopher D. Manning](https://huggingface.co/manning), [Chelsea Finn](https://huggingface.co/cbfinn).
The abstract from the paper is the following:
> While large-scale unsupervised language models (LMs) learn broad world knowledge and some reasoning skills, achieving precise control of their behavior is difficult due to the completely unsupervised nature of their training. Existing methods for gaining such steerability collect human labels of the relative quality of model generations and fine-tune the unsupervised LM to align with these preferences, often with reinforcement learning from human feedback (RLHF). However, RLHF is a complex and often unstable procedure, first fitting a reward model that reflects the human preferences, and then fine-tuning the large unsupervised LM using reinforcement learning to maximize this estimated reward without drifting too far from the original model. In this paper we introduce a new parameterization of the reward model in RLHF that enables extraction of the corresponding optimal policy in closed form, allowing us to solve the standard RLHF problem with only a simple classification loss. The resulting algorithm, which we call Direct Preference Optimization (DPO), is stable, performant, and computationally lightweight, eliminating the need for sampling from the LM during fine-tuning or performing significant hyperparameter tuning. Our experiments show that DPO can fine-tune LMs to align with human preferences as well as or better than existing methods. Notably, fine-tuning with DPO exceeds PPO-based RLHF in ability to control sentiment of generations, and matches or improves response quality in summarization and single-turn dialogue while being substantially simpler to implement and train.
The first step is to train an SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.
Then, fine-tuning a language model via DPO consists of two steps and is easier than [PPO](ppo_trainer):
1. **Data collection**: Gather a [preference dataset](dataset_formats#preference) with positive and negative selected pairs of generation, given a prompt.
2. **Optimization**: Maximize the log-likelihood of the DPO loss directly.
This process is illustrated in the sketch below (from [Figure 1 of the DPO paper](https://huggingface.co/papers/2305.18290)):
![Figure 1 DPO](https://github.com/huggingface/trl/assets/49240599/9150fac6-3d88-4ca2-8ec6-2a6f3473216d)
Read more about DPO algorithm in the [original paper](https://huggingface.co/papers/2305.18290).
## Quick start
This example demonstrates how to train a model using the DPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
<iframe
src="https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>
Below is the script to train the model:
```python
# train_dpo.py
from datasets import load_dataset
from trl import DPOConfig, DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO")
trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
```
Execute the script using the following command:
```bash
accelerate launch train_dpo.py
```
Distributed across 8 GPUs, the training takes approximately 3 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/dpo-qwen2-reward-margin.png)
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-DPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
<pre><code>$ transformers chat trl-lib/Qwen2-0.5B-DPO
<strong><span style="color: red;">&lt;shirin_yamani&gt;:</span></strong>
What is Huggingface?
<strong><span style="color: blue;">&lt;trl-lib/Qwen2-0.5B-DPO&gt;:</span></strong>
Huggingface is a platform that allows users to access a variety of open-source machine learning resources such as pre-trained models and datasets Huggingface is a platform that allows users to access a variety of open-source machine learning resources such as pre-trained models and datasets for the development of machine learning models and applications. It provides a repository of over 300, 000 pre-trained models in Huggingface is a platform that allows users to access a variety of open-source machine learning resources such as pre-trained models and datasets for the development of machine learning models and applications. It provides a repository of over 300, 000 pre-trained models in a variety of languages, enabling users to explore and utilize the latest techniques and technologies in the field of machine learning.
</code></pre>
## Expected dataset type
DPO requires a [preference dataset](dataset_formats#preference). The [`DPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
Although the [`DPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section.
### Special considerations for vision-language models
The [`DPOTrainer`] supports fine-tuning vision-language models (VLMs). For these models, a vision dataset is required. To learn more about the specific format for vision datasets, refer to the [Vision dataset format](dataset_formats#vision-datasets) section.
Additionally, unlike standard text-based models where a `tokenizer` is used, for VLMs, you should replace the `tokenizer` with a `processor`.
```diff
- model = AutoModelForCausalLM.from_pretrained(model_id)
+ model = AutoModelForImageTextToText.from_pretrained(model_id)
- tokenizer = AutoTokenizer.from_pretrained(model_id)
+ processor = AutoProcessor.from_pretrained(model_id)
trainer = DPOTrainer(
model,
args=training_args,
train_dataset=train_dataset,
- processing_class=tokenizer,
+ processing_class=processor,
)
```
For a complete example of fine-tuning a vision-language model, refer to the script in [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py).
## Example script
We provide an example script to train a model using the DPO method. The script is available in [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py)
To test the DPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
```bash
accelerate launch trl/scripts/dpo.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name trl-lib/ultrafeedback_binarized \
--num_train_epochs 1 \
--output_dir Qwen2-0.5B-DPO
```
## Logged metrics
While training and evaluating, we record the following reward metrics:
- `rewards/chosen`: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta
- `rewards/rejected`: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta
- `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
## Loss functions
The DPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`DPOConfig`]. The following loss functions are supported:
| `loss_type=` | Description |
| --- | --- |
| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only). |
| `"exo_pair"` | The [EXO](https://huggingface.co/papers/2402.00856) authors propose to minimize the reverse KL instead of the negative log-sigmoid loss of DPO which corresponds to forward KL. Setting non-zero `label_smoothing` (default `1e-3`) leads to a simplified version of EXO on pair-wise preferences (see Eqn. (16) of the [EXO paper](https://huggingface.co/papers/2402.00856)). The full version of EXO uses `K>2` completions generated by the SFT policy, which becomes an unbiased estimator of the PPO objective (up to a constant) when `K` is sufficiently large. |
| `"nca_pair"` | The [NCA](https://huggingface.co/papers/2402.05369) authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood. |
| `"robust"` | The [Robust DPO](https://huggingface.co/papers/2403.00409) authors propose an unbiased estimate of the DPO loss that is robust to preference noise in the data. Like in cDPO, it assumes that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0) |
| `"bco_pair"` | The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. For unpaired data, we recommend the dedicated [`BCOTrainer`]. |
| `"sppo_hard"` | The [SPPO](https://huggingface.co/papers/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. The implementation approximates this algorithm by employing hard label probabilities, assigning 1 to the winner and 0 to the loser. |
| `"aot"` or `loss_type="aot_pair"` | The [AOT](https://huggingface.co/papers/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_pair"` is for unpaired datasets. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_pair"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size. |
| `"apo_zero"` or `loss_type="apo_down"` | The [APO](https://huggingface.co/papers/2408.06266) method introduces an "anchored" version of the alignment objective. There are two variants: `apo_zero` and `apo_down`. The `apo_zero` loss increases the likelihood of winning outputs while decreasing the likelihood of losing outputs, making it suitable when the model is less performant than the winning outputs. On the other hand, `apo_down` decreases the likelihood of both winning and losing outputs, but with a stronger emphasis on reducing the likelihood of losing outputs. This variant is more effective when the model is better than the winning outputs. |
| `"discopop"` | The [DiscoPOP](https://huggingface.co/papers/2406.08414) paper uses LLMs to discover more efficient offline preference optimization losses. In the paper the proposed DiscoPOP loss (which is a log-ratio modulated loss) outperformed other optimization losses on different tasks (IMDb positive text generation, Reddit TLDR summarization, and Alpaca Eval 2.0). |
| `"sft"` | SFT (Supervised Fine-Tuning) loss is the negative log likelihood loss, used to train the model to generate preferred responses. |
### Multi-loss combinations
The DPO trainer supports combining multiple loss functions with different weights, enabling more sophisticated optimization strategies. This is particularly useful for implementing algorithms like MPO (Mixed Preference Optimization). MPO is a training approach that combines multiple optimization objectives, as described in the paper [Enhancing the Reasoning Ability of Multimodal Large Language Models via Mixed Preference Optimization](https://huggingface.co/papers/2411.10442).
To combine multiple losses, specify the loss types and corresponding weights as lists:
```python
# MPO: Combines DPO (sigmoid) for preference and BCO (bco_pair) for quality
training_args = DPOConfig(
loss_type=["sigmoid", "bco_pair", "sft"], # Loss types to combine
loss_weights=[0.8, 0.2, 1.0] # Corresponding weights, as used in the MPO paper
)
```
If `loss_weights` is not provided, all loss types will have equal weights (1.0 by default).
### Label smoothing
The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0).
### Syncing the reference model
The [TR-DPO](https://huggingface.co/papers/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model=True` in the [`DPOConfig`].
### RPO loss
The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterative preference tuning algorithm using a loss related to the RPO loss in this [paper](https://huggingface.co/papers/2405.16436) that essentially consists of a weighted SFT loss on the chosen preferences together with the DPO loss. To use this loss, set the `rpo_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this weight to `1.0`.
### WPO loss
The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`].
### LD-DPO loss
The [LD-DPO](https://huggingface.co/papers/2409.06411) paper decomposes the portion of the response that exceeds the desired length into two components — human-like preferences and verbosity preference — based on a mixing coefficient \\( \alpha \\). To use this method, set the `ld_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this value between `0.0` and `1.0`.
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
## Accelerate DPO fine-tuning using `unsloth`
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks for DPO listed below:
| GPU | Model | Dataset | 🤗 | 🤗 + FlashAttention 2 | 🦥 Unsloth | 🦥 VRAM saved |
| --- | --- | --- | --- | --- | --- | --- |
| A100 40G | Zephyr 7b | Ultra Chat | 1x | 1.24x | **1.88x** | -11.6% |
| Tesla T4 | Zephyr 7b | Ultra Chat | 1x | 1.09x | **1.55x** | -18.6% |
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
```diff
from datasets import load_dataset
from trl import DPOConfig, DPOTrainer
- from transformers import AutoModelForCausalLM, AutoTokenizer
+ from unsloth import FastLanguageModel
- model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+ model, tokenizer = FastLanguageModel.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+ model = FastLanguageModel.get_peft_model(model)
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
- training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO")
+ training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", bf16=True)
trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
```
The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).
## Reference model considerations with PEFT
You have three main options (plus several variants) for how the reference model works when using PEFT, assuming the model that you would like to further enhance with DPO was tuned using (Q)LoRA.
1. Simply create two instances of the model, each loading your adapter - works fine but is very inefficient.
2. Merge the adapter into the base model, create another adapter on top, then leave the `ref_model` param null, in which case DPOTrainer will unload the adapter for reference inference - efficient, but has potential downsides discussed below.
3. Load the adapter twice with different names, then use `set_adapter` during training to swap between the adapter being DPO'd and the reference adapter - slightly less efficient compared to 2 (~adapter size VRAM overhead), but avoids the pitfalls.
### Downsides to merging QLoRA before DPO (approach 2)
As suggested by [Benjamin Marie](https://medium.com/@bnjmn_marie/dont-merge-your-lora-adapter-into-a-4-bit-llm-65b6da287997), the best option for merging QLoRA adapters is to first dequantize the base model, then merge the adapter. Something similar to [this script](https://github.com/jondurbin/qlora/blob/main/qmerge.py).
However, after using this approach, you will have an unquantized base model. Therefore, to use QLoRA for DPO, you will need to re-quantize the merged model or use the unquantized merge (resulting in higher memory demand).
### Using option 3 - load the adapter twice
To avoid the downsides with option 2, you can load your fine-tuned adapter into the model twice, with different names, and set the model/ref adapter names in [`DPOTrainer`].
For example:
```python
# Load the base model.
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
model = AutoModelForCausalLM.from_pretrained(
"mistralai/mixtral-8x7b-v0.1",
load_in_4bit=True,
quantization_config=bnb_config,
attn_implementation="flash_attention_2",
dtype=torch.bfloat16,
device_map="auto",
)
# Load the adapter.
model = PeftModel.from_pretrained(
model,
"/path/to/peft",
is_trainable=True,
adapter_name="train",
)
# Load the adapter a second time, with a different name, which will be our reference model.
model.load_adapter("/path/to/peft", adapter_name="reference")
# Initialize the trainer, without a ref_model param.
training_args = DPOConfig(
model_adapter_name="train",
ref_adapter_name="reference",
)
dpo_trainer = DPOTrainer(
model,
args=training_args,
...
)
```
## DPOTrainer
[[autodoc]] DPOTrainer
- train
- save_model
- push_to_hub
## DPOConfig
[[autodoc]] DPOConfig
## DataCollatorForPreference
[[autodoc]] trainer.dpo_trainer.DataCollatorForPreference
## FDivergenceType
[[autodoc]] trainer.dpo_trainer.FDivergenceType

View File

@ -1,282 +0,0 @@
# DPO Trainer
[![](https://img.shields.io/badge/All_models-DPO-blue)](https://huggingface.co/models?other=dpo,trl)
## Overview
TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://huggingface.co/papers/2305.18290) by [Rafael Rafailov](https://huggingface.co/rmrafailov), Archit Sharma, Eric Mitchell, [Stefano Ermon](https://huggingface.co/ermonste), [Christopher D. Manning](https://huggingface.co/manning), [Chelsea Finn](https://huggingface.co/cbfinn).
The abstract from the paper is the following:
> While large-scale unsupervised language models (LMs) learn broad world knowledge and some reasoning skills, achieving precise control of their behavior is difficult due to the completely unsupervised nature of their training. Existing methods for gaining such steerability collect human labels of the relative quality of model generations and fine-tune the unsupervised LM to align with these preferences, often with reinforcement learning from human feedback (RLHF). However, RLHF is a complex and often unstable procedure, first fitting a reward model that reflects the human preferences, and then fine-tuning the large unsupervised LM using reinforcement learning to maximize this estimated reward without drifting too far from the original model. In this paper we introduce a new parameterization of the reward model in RLHF that enables extraction of the corresponding optimal policy in closed form, allowing us to solve the standard RLHF problem with only a simple classification loss. The resulting algorithm, which we call Direct Preference Optimization (DPO), is stable, performant, and computationally lightweight, eliminating the need for sampling from the LM during fine-tuning or performing significant hyperparameter tuning. Our experiments show that DPO can fine-tune LMs to align with human preferences as well as or better than existing methods. Notably, fine-tuning with DPO exceeds PPO-based RLHF in ability to control sentiment of generations, and matches or improves response quality in summarization and single-turn dialogue while being substantially simpler to implement and train.
The first step is to train an SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.
Then, fine-tuning a language model via DPO consists of two steps and is easier than [PPO](ppo_trainer):
1. **Data collection**: Gather a [preference dataset](dataset_formats#preference) with positive and negative selected pairs of generation, given a prompt.
2. **Optimization**: Maximize the log-likelihood of the DPO loss directly.
This process is illustrated in the sketch below (from [Figure 1 of the DPO paper](https://huggingface.co/papers/2305.18290)):
![](https://github.com/huggingface/trl/assets/49240599/9150fac6-3d88-4ca2-8ec6-2a6f3473216d)
Read more about DPO algorithm in the [original paper](https://huggingface.co/papers/2305.18290).
## Quick start
This example demonstrates how to train a model using the DPO method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) as the base model. We use the preference data from the [UltraFeedback dataset](https://huggingface.co/datasets/openbmb/UltraFeedback). You can view the data in the dataset here:
<iframe
src="https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>
Below is the script to train the model:
```python
# train_dpo.py
from datasets import load_dataset
from trl import DPOConfig, DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", logging_steps=10)
trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
```
Execute the script using the following command:
```bash
accelerate launch train_dpo.py
```
Distributed across 8 GPUs, the training takes approximately 3 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/dpo-qwen2-reward-margin.png)
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-DPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-DPO
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?
<strong><span style="color: blue;">&lt;trl-lib/Qwen2-0.5B-DPO&gt;:</span></strong>
The best programming language for specific applications can vary depending on the use case and knowledge level of the programmer. Here are some general factors that can be used as input to choose the best programming language:
<strong><span style="color: green;">1</span></strong> Ease of use: Some programming languages are more user-friendly than others, such as Python, Java, or Ruby. Python is popular due to its simplicity and great scalability.
<strong><span style="color: green;">2</span></strong> Versatility: The ability to work with a wide range of data structures and frameworks can define the language as versatile.
<strong><span style="color: green;">3</span></strong> Ease of learning: Different programming languages have different learning curves, so users must be willing to take some time to master one.
<strong><span style="color: green;">4</span></strong> Community support: The broader community of developers and enthusiasts in the selected programming language can provide great support and resources.
<strong><span style="color: green;">5</span></strong> Reusability: Languages that emphasize code reuse and can be easily modifiable can be more suitable for software development.
The best programming language based on these factors is subjective and depends on what the programmer intends to accomplish.
</code></pre>
## Expected dataset type
DPO requires a [preference dataset](dataset_formats#preference). The [`DPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
Although the [`DPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section.
### Special considerations for vision-language models
The [`DPOTrainer`] supports fine-tuning vision-language models (VLMs). For these models, a vision dataset is required. To learn more about the specific format for vision datasets, refer to the [Vision dataset format](dataset_formats#vision-datasets) section.
Additionally, unlike standard text-based models where a `tokenizer` is used, for VLMs, you should replace the `tokenizer` with a `processor`.
```diff
- model = AutoModelForCausalLM.from_pretrained(model_id)
+ model = AutoModelForVision2Seq.from_pretrained(model_id)
- tokenizer = AutoTokenizer.from_pretrained(model_id)
+ processor = AutoProcessor.from_pretrained(model_id)
trainer = DPOTrainer(
model,
args=training_args,
train_dataset=train_dataset,
- processing_class=tokenizer,
+ processing_class=processor,
)
```
For a complete example of fine-tuning a vision-language model, refer to the script in [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py).
## Example script
We provide an example script to train a model using the DPO method. The script is available in [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py)
To test the DPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
```bash
accelerate launch examples/scripts/dpo.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name trl-lib/ultrafeedback_binarized \
--num_train_epochs 1 \
--logging_steps 25 \
--output_dir Qwen2-0.5B-DPO
```
## Logged metrics
While training and evaluating we record the following reward metrics:
- `rewards/chosen`: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta
- `rewards/rejected`: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta
- `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
## Loss functions
The DPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`DPOConfig`]. The following loss functions are supported:
| `loss_type=` | Description |
| -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only). |
| `"exo_pair"` | The [EXO](https://huggingface.co/papers/2402.00856) authors propose to minimize the reverse KL instead of the negative log-sigmoid loss of DPO which corresponds to forward KL. Setting non-zero `label_smoothing` (default `1e-3`) leads to a simplified version of EXO on pair-wise preferences (see Eqn. (16) of the [EXO paper](https://huggingface.co/papers/2402.00856)). The full version of EXO uses `K>2` completions generated by the SFT policy, which becomes an unbiased estimator of the PPO objective (up to a constant) when `K` is sufficiently large. |
| `"nca_pair"` | The [NCA](https://huggingface.co/papers/2402.05369) authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood. |
| `"robust"` | The [Robust DPO](https://huggingface.co/papers/2403.00409) authors propose an unbiased estimate of the DPO loss that is robust to preference noise in the data. Like in cDPO, it assumes that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0) |
| `"bco_pair"` | The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. For unpaired data, we recommend the dedicated [`BCOTrainer`]. |
| `"sppo_hard"` | The [SPPO](https://huggingface.co/papers/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. The implementation approximates this algorithm by employing hard label probabilities, assigning 1 to the winner and 0 to the loser. |
| `"aot"` or `loss_type="aot_pair"` | The [AOT](https://huggingface.co/papers/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_pair"` is for unpaired datasets. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_pair"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size. |
| `"apo_zero"` or `loss_type="apo_down"` | The [APO](https://huggingface.co/papers/2408.06266) method introduces an "anchored" version of the alignment objective. There are two variants: `apo_zero` and `apo_down`. The `apo_zero` loss increases the likelihood of winning outputs while decreasing the likelihood of losing outputs, making it suitable when the model is less performant than the winning outputs. On the other hand, `apo_down` decreases the likelihood of both winning and losing outputs, but with a stronger emphasis on reducing the likelihood of losing outputs. This variant is more effective when the model is better than the winning outputs. |
### Label smoothing
The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0).
### Syncing the reference model
The [TR-DPO](https://huggingface.co/papers/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model=True` in the [`DPOConfig`].
### RPO loss
The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterative preference tuning algorithm using a loss related to the RPO loss in this [paper](https://huggingface.co/papers/2405.16436) that essentially consists of a weighted SFT loss on the chosen preferences together with the DPO loss. To use this loss, set the `rpo_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this weight to `1.0`.
### WPO loss
The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`].
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
## Accelerate DPO fine-tuning using `unsloth`
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks for DPO listed below:
| GPU | Model | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
| -------- | --------- | ---------- | --- | --------------------- | --------- | ------------ |
| A100 40G | Zephyr 7b | Ultra Chat | 1x | 1.24x | **1.88x** | -11.6% |
| Tesla T4 | Zephyr 7b | Ultra Chat | 1x | 1.09x | **1.55x** | -18.6% |
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
```diff
from datasets import load_dataset
from trl import DPOConfig, DPOTrainer
- from transformers import AutoModelForCausalLM, AutoTokenizer
+ from unsloth import FastLanguageModel
- model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+ model, tokenizer = FastLanguageModel.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+ model = FastLanguageModel.get_peft_model(model)
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
- training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", logging_steps=10)
+ training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", logging_steps=10, bf16=True)
trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
```
The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).
## Reference model considerations with PEFT
You have three main options (plus several variants) for how the reference model works when using PEFT, assuming the model that you would like to further enhance with DPO was tuned using (Q)LoRA.
1. Simply create two instances of the model, each loading your adapter - works fine but is very inefficient.
2. Merge the adapter into the base model, create another adapter on top, then leave the `ref_model` param null, in which case DPOTrainer will unload the adapter for reference inference - efficient, but has potential downsides discussed below.
3. Load the adapter twice with different names, then use `set_adapter` during training to swap between the adapter being DPO'd and the reference adapter - slightly less efficient compared to 2 (~adapter size VRAM overhead), but avoids the pitfalls.
### Downsides to merging QLoRA before DPO (approach 2)
As suggested by [Benjamin Marie](https://medium.com/@bnjmn_marie/dont-merge-your-lora-adapter-into-a-4-bit-llm-65b6da287997), the best option for merging QLoRA adapters is to first dequantize the base model, then merge the adapter. Something similar to [this script](https://github.com/jondurbin/qlora/blob/main/qmerge.py).
However, after using this approach, you will have an unquantized base model. Therefore, to use QLoRA for DPO, you will need to re-quantize the merged model or use the unquantized merge (resulting in higher memory demand).
### Using option 3 - load the adapter twice
To avoid the downsides with option 2, you can load your fine-tuned adapter into the model twice, with different names, and set the model/ref adapter names in [`DPOTrainer`].
For example:
```python
# Load the base model.
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
model = AutoModelForCausalLM.from_pretrained(
"mistralai/mixtral-8x7b-v0.1",
load_in_4bit=True,
quantization_config=bnb_config,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.config.use_cache = False
# Load the adapter.
model = PeftModel.from_pretrained(
model,
"/path/to/peft",
is_trainable=True,
adapter_name="train",
)
# Load the adapter a second time, with a different name, which will be our reference model.
model.load_adapter("/path/to/peft", adapter_name="reference")
# Initialize the trainer, without a ref_model param.
training_args = DPOConfig(
model_adapter_name="train",
ref_adapter_name="reference",
)
dpo_trainer = DPOTrainer(
model,
args=training_args,
...
)
```
## DPOTrainer
[[autodoc]] DPOTrainer
## DPOConfig
[[autodoc]] DPOConfig
## PreferenceCollator
[[autodoc]] trainer.dpo_trainer.PreferenceCollator

View File

@ -1,25 +1,24 @@
# Examples
## Introduction
The examples should work in any of the following settings (with the same script):
- single GPU
- multi GPUS (using PyTorch distributed mode)
- multi GPUS (using DeepSpeed ZeRO-Offload stages 1, 2, & 3)
- fp16 (mixed-precision), fp32 (normal precision), or bf16 (bfloat16 precision)
To run it in each of these various modes, first initialize the accelerate
configuration with `accelerate config`
- single GPU
- multi GPUs (using PyTorch distributed mode)
- multi GPUs (using DeepSpeed ZeRO-Offload stages 1, 2, & 3)
- fp16 (mixed-precision), fp32 (normal precision), or bf16 (bfloat16 precision)
**NOTE to train with a 4-bit or 8-bit model**, please run
To run it in each of these various modes, first initialize the accelerate configuration with `accelerate config`.
To train with a 4-bit or 8-bit model, please run:
```bash
pip install --upgrade trl[quantization]
```
## Accelerate Config
For all the examples, you'll need to generate a 🤗 Accelerate config file with:
```shell
@ -28,44 +27,52 @@ accelerate config # will prompt you to define the training configuration
Then, it is encouraged to launch jobs with `accelerate launch`!
## Maintained Examples
# Maintained Examples
Scripts can be used as examples of how to use TRL trainers. They are located in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) directory. Additionally, we provide examples in the [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directory. These examples are maintained and tested regularly.
| File | Description |
| ----------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`examples/scripts/alignprop.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/alignprop.py) | This script shows how to use the [`AlignPropTrainer`] to fine-tune a diffusion model. |
| --- | --- |
| [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. |
| [`examples/scripts/chat.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/chat.py) | This script allows you to load and use a model as a chatbot. |
| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
| [`examples/scripts/ddpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ddpo.py) | This script shows how to use the [`DDPOTrainer`] to fine-tune a stable diffusion model using reinforcement learning. |
| [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a model. |
| [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. |
| [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a stable to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
| [`examples/scripts/evals/judge_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/evals/judge_tldr.py) | This script shows how to use [`HfPairwiseJudge`] or [`OpenAIPairwiseJudge`] to judge model generations. |
| [`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py) | This script shows how to use the [`GKDTrainer`] to fine-tune a model. |
| [`trl/scripts/grpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/grpo.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a model. |
| [`examples/scripts/grpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/grpo_vlm.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. |
| [`examples/scripts/gspo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune model for reasoning using the [AI-MO/NuminaMath-TIR](https://huggingface.co/datasets/AI-MO/NuminaMath-TIR) dataset. |
| [`examples/scripts/gspo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo_vlm.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. |
| [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`KTOTrainer`] to fine-tune a model. |
| [`examples/scripts/mpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/mpo_vlm.py) | This script shows how to use MPO via the [`DPOTrainer`] to align a model based on preferences using the [HuggingFaceH4/rlaif-v_formatted](https://huggingface.co/datasets/HuggingFaceH4/rlaif-v_formatted) dataset and a set of loss weights with weights. |
| [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`NashMDTrainer`] to fine-tune a model. |
| [`examples/scripts/online_dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a model. |
| [`examples/scripts/online_dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/online_dpo_vlm.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a a Vision Language Model. |
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language |
| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language. |
| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a reward model on your own dataset. |
| [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a model or adapters into a target dataset. |
| [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) | This script shows how to use the [`PRMTrainer`] to fine-tune a Process-supervised Reward Model (PRM). |
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a Outcome Reward Model (ORM) on your own dataset. |
| [`examples/scripts/rloo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to solve math questions. |
| [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a model. |
| [`examples/scripts/sft_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model. |
| [`examples/scripts/sft_video_llm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_video_llm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Video Language Model. |
| [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models so users may see unexpected behaviour in other model architectures. |
| [`examples/scripts/sft_vlm_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model on vision to text tasks. |
| [`examples/scripts/sft_vlm_smol_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_smol_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a SmolVLM model. |
| [`examples/scripts/xpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/xpo.py) | This script shows how to use the [`XPOTrainer`] to fine-tune a model. |
Here are also some easier-to-run colab notebooks that you can use to get started with TRL:
| File | Description |
| --------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------- |
| --- | --- |
| [`examples/notebooks/best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb) | This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. |
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
We also have some other examples that are less maintained but can be used as a reference:
1. **[research_projects](https://github.com/huggingface/trl/tree/main/examples/research_projects)**: Check out this folder to find the scripts used for some research projects that used TRL (LM de-toxification, Stack-Llama, etc.)
## Distributed training
All of the scripts can be run on multiple GPUs by providing the path of an 🤗 Accelerate config file when calling `accelerate launch`. To launch one of them on one or multiple GPUs, run the following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine and `--all_arguments_of_the_script` with your arguments.)
All the scripts can be run on multiple GPUs by providing the path of an 🤗 Accelerate config file when calling `accelerate launch`. To launch one of them on one or multiple GPUs, run the following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine and `--all_arguments_of_the_script` with your arguments).
```shell
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
@ -75,7 +82,7 @@ You can also adjust the parameters of the 🤗 Accelerate config file to suit yo
### Distributed training with DeepSpeed
Most of the scripts can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine, `--all_arguments_of_the_script` with your arguments, and `--deepspeed_config` with the path to the DeepSpeed config file such as `examples/deepspeed_configs/deepspeed_zero1.yaml`):
Most of the scripts can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run the following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine, `--all_arguments_of_the_script` with your arguments, and `--deepspeed_config` with the path to the DeepSpeed config file such as `examples/deepspeed_configs/deepspeed_zero1.yaml`):
```shell
accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script

163
docs/source/experimental.md Normal file
View File

@ -0,0 +1,163 @@
# Experimental Features
The `trl.experimental` namespace provides a minimal, clearly separated space for fast iteration on new ideas.
> [!WARNING]
> **Stability contract:** Anything under `trl.experimental` may change or be removed in *any* release (including patch versions) without prior deprecation. Do not rely on these APIs for production workloads.
## Current Experimental Features
The following modules are currently available under [`trl.experimental`](https://github.com/huggingface/trl/tree/main/trl/experimental).
This list is not exhaustive and may change at any time.
### BEMA for Reference Model
This feature implements the BEMA algorithm to update the reference model during DPO training.
```python
from trl.experimental.bema_for_ref_model import BEMACallback, DPOTrainer
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
pref_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
bema_callback = BEMACallback(update_ref_model=True)
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
tokenizer.pad_token = tokenizer.eos_token
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
train_dataset=pref_dataset,
processing_class=tokenizer,
callbacks=[bema_callback],
)
trainer.train()
```
### GFPO
This feature implements the GFPO algorithm to enforce concise reasoning in the model's output generation, as proposed in the paper [Sample More to Think Less: Group Filtered Policy Optimization for Concise Reasoning](https://huggingface.co/papers/2508.09726).
To activate GFPO in [`GFPOTrainer`]:
- set `num_remains_in_group` in [`GFPOConfig`]
- define a group filter function and set it to `group_filter_func` in [`GFPOTrainer`]. `group_filter_func` will score the `num_generations` completions and The GFPOTrainer filters groups according to their scores to get top `num_remains_in_group` completions as a new group. Model will be trained on the filtered group.
```python
# train_gfpo.py
from trl.experimental.gfpo import GFPOConfig, GFPOTrainer
# dummy group filter to scores the completions based on its indice in group
class GroupFilter:
def __call__(self, group_completions, group_rewards, **kwargs):
group_scores = []
for completions, rewards in zip(group_completions, group_rewards):
scores = [float(i) for i in range(len(completions))]
group_scores.append(scores)
return group_scores
training_args = GFPOConfig(
output_dir="Qwen3-0.6B-GFPO",
per_device_train_batch_size=4,
num_remains_in_group=2,
bf16=True,
)
trainer = GFPOTrainer(
model="Qwen/Qwen3-0.6B",
reward_funcs=...,
train_dataset=...,
args=training_args,
group_filter_func=GroupFilter(),
)
trainer.train()
```
### GSPO-token
In the paper [Group Sequence Policy Optimization](https://huggingface.co/papers/2507.18071), the authors propose a token-level objective variant to GSPO, called GSPO-token. To use GSPO-token, you can use the `GRPOTrainer` class in `trl.experimental.gspo_token`.
```python
from trl.experimental.gspo_token import GRPOTrainer
from trl import GRPOConfig
training_args = GRPOConfig(
importance_sampling_level="sequence_token",
...
)
```
> [!WARNING]
> To leverage GSPO-token, the user will need to provide the per-token advantage \\( \hat{A_{i,t}} \\) for each token \\( t \\) in the sequence \\( i \\) (i.e., make \\( \hat{A_{i,t}} \\) varies with \\( t \\)—which isn't the case here, \\( \hat{A_{i,t}}=\hat{A_{i}} \\)). Otherwise, GSPO-Token gradient is just equivalent to the original GSPO implementation.
### GRPO With Replay Buffer
This experimental trainer, trains a model with GRPO but replaces groups (and corresponding completions) that have 0 standard deviation with groups with high rewards and standard deviation that've been used to train a model in prior batches.
#### Usage
```python
from trl.experimental.grpo_with_replay_buffer import GRPOWithReplayBufferTrainer
from datasets import load_dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
# Guarantee that some rewards have 0 std
def custom_reward_func(completions, **kwargs):
if torch.rand(1).item() < 0.25:
return [0] * len(completions) # simulate some None rewards
else:
return torch.rand(len(completions)).tolist()
training_args = GRPOWithReplayBufferConfig(
output_dir=self.tmp_dir,
learning_rate=1e-4,
per_device_train_batch_size=4,
num_generations=4,
max_completion_length=8,
replay_buffer_size=8,
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=[custom_reward_func],
args=training_args,
train_dataset=dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
```
To silence the runtime notice:
```bash
export TRL_EXPERIMENTAL_SILENCE=1
```
## Promotion Path (Simple)
1. **Prototype outside the main repo:** Start development in your own fork or a separate repository to iterate quickly.
2. **Experimental inclusion:** Once its ready for early users, move the idea into `trl.experimental.<feature>`.
3. **Improve:** Add tests, a short doc/example, and demonstrate the usage.
4. **Promote:** Once the API proves stable and there is clear interest or adoption from the community, move it into `trl.<feature>` (stable module).
## FAQ
**Why not just use branches?**
Because branches are not shipped to users; experimental code inside the package lets early adopters try things and give feedback.
**Can these APIs change or vanish without warning?**
Yes. Anything inside `trl.experimental` can change or disappear in *any* release.
**Should I use this in production?**
Only if you are fine with updating your code quickly when things change.
**Will maintainers promptly fix issues in `trl.experimental`?**
Not necessarily. The experimental module is a playground for new ideas, and maintainers may not prioritize bug fixes or feature requests there. Issues may remain unresolved until (or unless) the feature graduates to the stable API.

View File

@ -1,6 +1,6 @@
# Generalized Knowledge Distillation Trainer
[![](https://img.shields.io/badge/All_models-GKD-blue)](https://huggingface.co/models?other=gkd,trl)
[![model badge](https://img.shields.io/badge/All_models-GKD-blue)](https://huggingface.co/models?other=gkd,trl)
## Overview
@ -10,8 +10,8 @@ The abstract from the paper is the following:
> Knowledge distillation (KD) is widely used for compressing a teacher model to reduce its inference cost and memory footprint, by training a smaller student model. However, current KD methods for auto-regressive sequence models suffer from distribution mismatch between output sequences seen during training and those generated by the student during inference. To address this issue, we introduce Generalized Knowledge Distillation (GKD). Instead of solely relying on a fixed set of output sequences, GKD trains the student on its self-generated output sequences by leveraging feedback from the teacher on such sequences. Unlike supervised KD approaches, GKD also offers the flexibility to employ alternative loss functions between the student and teacher, which can be useful when the student lacks the expressivity to mimic the teacher's distribution. Furthermore, GKD facilitates the seamless integration of distillation with RL fine-tuning (RLHF). We demonstrate the efficacy of GKD for distilling auto-regressive language models on summarization, translation, and arithmetic reasoning tasks, and task-agnostic distillation for instruction-tuning.
The key aspects of GKD are:
1. It addresses the train-inference distribution mismatch in auto-regressive sequence models by training the student model on its self-generated output sequences.
2. GKD allows flexibility in choosing different divergence measures between student and teacher models via the generalized Jensen-Shannon Divergence (JSD), which can be useful when the student lacks the capacity to fully mimic the teacher.
@ -20,6 +20,7 @@ This post-training method was contributed by [Kashif Rasul](https://huggingface.
## Usage tips
The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`GKDConfig`] namely:
* `lmbda`: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, the loss reduces to supervised JSD where the student is trained with the token-level probabilities of the teacher. When `lmbda=1.0`, the loss reduces to on-policy JSD, where the student generates output sequences and token-specific feedback on these sequences from the teacher. For values in between [0, 1] it is random between the two based on the `lmbda` value for each batch.
* `seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher.
* `beta`: controls the interpolation in the generalized Jensen-Shannon Divergence. When `beta=0.0` the loss approximates forward KL divergence, while for `beta=1.0` the loss approximates reverse KL divergence. For values in between [0, 1] it interpolates between the two.
@ -85,13 +86,16 @@ trainer.train()
### Expected dataset type
The dataset should be formatted as a list of "messages" where each message is a list of dictionaries with the following keys:
* `role`: either `system`, `assistant` or `user`
* `content`: the message content
## GKDTrainer
[[autodoc]] GKDTrainer
- train
- save_model
- push_to_hub
## GKDConfig

592
docs/source/grpo_trainer.md Normal file
View File

@ -0,0 +1,592 @@
# GRPO Trainer
[![model badge](https://img.shields.io/badge/All_models-GRPO-blue)](https://huggingface.co/models?other=grpo,trl)
## Overview
TRL supports the GRPO Trainer for training language models, as described in the paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300) by [Zhihong Shao](https://huggingface.co/syhia), [Peiyi Wang](https://huggingface.co/peiyiwang89), [Qihao Zhu](https://huggingface.co/zqh11), Runxin Xu, [Junxiao Song](https://huggingface.co/haha-point), Mingchuan Zhang, Y. K. Li, Y. Wu, [Daya Guo](https://huggingface.co/guoday).
The abstract from the paper is the following:
> Mathematical reasoning poses a significant challenge for language models due to its complex and structured nature. In this paper, we introduce DeepSeekMath 7B, which continues pre-training DeepSeek-Coder-Base-v1.5 7B with 120B math-related tokens sourced from Common Crawl, together with natural language and code data. DeepSeekMath 7B has achieved an impressive score of 51.7% on the competition-level MATH benchmark without relying on external toolkits and voting techniques, approaching the performance level of Gemini-Ultra and GPT-4. Self-consistency over 64 samples from DeepSeekMath 7B achieves 60.9% on MATH. The mathematical reasoning capability of DeepSeekMath is attributed to two key factors: First, we harness the significant potential of publicly available web data through a meticulously engineered data selection pipeline. Second, we introduce Group Relative Policy Optimization (GRPO), a variant of Proximal Policy Optimization (PPO), that enhances mathematical reasoning abilities while concurrently optimizing the memory usage of PPO.
This post-training method was contributed by [Quentin Gallouédec](https://huggingface.co/qgallouedec).
## Quick start
This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [UltraFeedback prompts dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt). You can view the data in the dataset here:
<iframe
src="https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>
Below is the script to train the model.
```python
# train_grpo.py
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
# Dummy reward function for demonstration purposes
def reward_num_unique_letters(completions, **kwargs):
"""Reward function that rewards completions with more unique letters."""
completion_contents = [completion[0]["content"] for completion in completions]
return [float(len(set(content))) for content in completion_contents]
training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO")
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_num_unique_letters,
args=training_args,
train_dataset=dataset,
)
trainer.train()
```
Execute the script using the following command:
```bash
accelerate launch train_grpo.py
```
Distributed across 8 GPUs, the training takes approximately 1 day.
![GRPO curves](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/grpo_curves.png)
## Looking deeper into the GRPO method
GRPO is an online learning algorithm, meaning it improves iteratively by using the data generated by the trained model itself during training. The intuition behind GRPO objective is to maximize the advantage of the generated completions, while ensuring that the model remains close to the reference policy. To understand how GRPO works, it can be broken down into four main steps: **Generating completions**, **computing the advantage**, **estimating the KL divergence**, and **computing the loss**.
![GRPO visual](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/grpo_visual.png)
### Generating completions
At each training step, we sample a batch of prompts and generate a set of \\( G \\) completions for each prompt (denoted as \\( o_i \\)).
### Computing the advantage
For each of the \\( G \\) sequences, we compute the reward using a reward model or reward function. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows:
$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$
This approach gives the method its name: **Group Relative Policy Optimization (GRPO)**.
> [!TIP]
> It was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that scaling by \\( \text{std}(\mathbf{r}) \\) may cause a question-level difficulty bias. You can disable this scaling by setting `scale_rewards=False` in [`GRPOConfig`].
> [!TIP]
> [Part I: Tricks or Traps? A Deep Dive into RL for LLM Reasoning (Lite PPO)](https://huggingface.co/papers/2508.08221) showed that calculating the mean at the local (group) level and the standard deviation at the global (batch) level enables more robust reward shaping. You can use this scaling strategy by setting `scale_rewards="batch"` in [`GRPOConfig`].
### Estimating the KL divergence
KL divergence is estimated using the approximator introduced by [Schulman et al. (2020)](http://joschu.net/blog/kl-approx.html). The approximator is defined as follows:
$$\mathbb{D}_{\text{KL}}\left[\pi_\theta \|\pi_{\text{ref}}\right] = \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - \log \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - 1,
$$
### Computing the loss
The objective is to maximize the advantage while ensuring that the model remains close to the reference policy. Consequently, the loss is defined as follows:
$$
\mathcal{L}_{\text{GRPO}}(\theta) = -\frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
$$
where the first term represents the scaled advantage and the second term penalizes deviations from the reference policy through KL divergence.
> [!TIP]
> Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we don't scale by \\( \frac{1}{|o_i|} \\) because it was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that this introduces a response-level length bias. More details in [loss types](#loss-types).
> [!TIP]
> Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we use \\( \beta = 0.0 \\) by default, meaning that the KL divergence term is not used. This choice is motivated by several recent studies (e.g., [Open-Reasoner-Zero: An Open Source Approach to Scaling Up Reinforcement Learning on the Base Model](https://huggingface.co/papers/2503.24290)) which have shown that the KL divergence term is not essential for training with GRPO. As a result, it has become common practice to exclude it (e.g. [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783), [DAPO: An Open-Source LLM Reinforcement Learning System at Scale](https://huggingface.co/papers/2503.14476)). If you wish to include the KL divergence term, you can set `beta` in [`GRPOConfig`] to a non-zero value.
In the original paper, this formulation is generalized to account for multiple updates after each generation (denoted \\( \mu \\), can be set with `num_iterations` in [`GRPOConfig`]) by leveraging the **clipped surrogate objective**:
$$
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
$$
where \\(\text{clip}(\cdot, 1 - \epsilon, 1 + \epsilon) \\) ensures that updates do not deviate excessively from the reference policy by bounding the policy ratio between \\( 1 - \epsilon \\) and \\( 1 + \epsilon \\).
When \\( \mu = 1 \\) (default in TRL), the clipped surrogate objective simplifies to the original objective.
#### Loss Types
Several formulations of the objective have been proposed in the literature. Initially, the objective of GRPO was defined as follows:
$$
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} l_{i,t},
$$
where
$$
l_{i,t} = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right].
$$
The [DAPO paper](https://huggingface.co/papers/2503.14476) highlights the limitations of the GRPO algorithms sample-level loss in long-CoT scenarios, where longer responses are under-penalized, leading to poorer quality outputs. The proposed solution is a token-level normalization, which better handles longer sequences by assigning more balanced rewards to individual tokens, regardless of response length:
$$
\mathcal{L}_{\text{DAPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},
$$
To use this formulation, set `loss_type="dapo"` in [`GRPOConfig`].
Furthermore, it was demonstrated in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that the initial GRPO formulation introduces a response length bias. They show that while the DAPO formulation reduces this bias, it does not eliminate it completely. To fully remove this bias, they propose dividing by a constant instead of the sequence length, resulting in the following formulation:
$$
\mathcal{L}_{\text{Dr. GRPO}}(\theta) = - \frac{1}{LG} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},
$$
This constant is recommended to be the maximum completion length. To use this formulation, set `loss_type="dr_grpo"` in the [`GRPOConfig`].
## Logged metrics
While training and evaluating, we record the following reward metrics:
- `num_tokens`: The total number of tokens processed so far, including both prompts and completions.
- `completions/mean_length`: The average length of generated completions.
- `completions/min_length`: The minimum length of generated completions.
- `completions/max_length`: The maximum length of generated completions.
- `completions/mean_terminated_length`: The average length of generated completions that terminate with EOS.
- `completions/min_terminated_length`: The minimum length of generated completions that terminate with EOS.
- `completions/max_terminated_length`: The maximum length of generated completions that terminate with EOS.
- `completions/clipped_ratio`: The ratio of truncated (clipped) completions.
- `reward/{reward_func_name}/mean`: The average reward from a specific reward function.
- `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function.
- `reward`: The overall average reward after applying reward weights.
- `reward_std`: The standard deviation of rewards after applying reward weights.
- If `scale_rewards` is `"group"` or `"none"`, this is the average of the per-group standard deviations.
- If `scale_rewards` is `"batch"`, this is the standard deviation computed over all rewards in the batch (ignoring groups).
- `frac_reward_zero_std`: The fraction of samples in the generation batch with a reward std of zero, implying there is little diversity for that prompt (all answers are correct or incorrect).
- `entropy`: Average entropy of token predictions across generated completions. (If `mask_truncated_completions=True`, masked sequences tokens are excluded.)
- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero.
- `clip_ratio/region_mean`: The ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities where the GRPO objective is clipped to stay within the trust region:
$$
\text{clip}\left( r_{i,t}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \qquad r_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}\,.
$$
A higher value means more tokens are clipped, which constrains how much the policy $\pi_\theta$ can change.
- `clip_ratio/low_mean`: The average ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
- `clip_ratio/low_min`: The minimum ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
- `clip_ratio/high_mean`: The average ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\)
- `clip_ratio/high_max`: The maximum ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\).
## Customization
### Speed up training with vLLM-powered generation
Generation is often the main bottleneck when training with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a high-throughput, low-latency inference engine for LLMs. To enable it, first install the package with
```shell
pip install trl[vllm]
```
We support two ways of using vLLM during training: **server mode** and **colocate mode**.
> [!TIP]
> By default, Truncated Importance Sampling is activated for vLLM generation to address the generation-training mismatch that occurs when using different frameworks. This can be turned off by setting `vllm_importance_sampling_correction=False`. For more information, see [Truncated Importance Sampling](paper_index#truncated-importance-sampling)
#### 🔌 Option 1: Server mode
In this mode, vLLM runs in a separate process (and using separate GPUs) and communicates with the trainer via HTTP. This is ideal if you have dedicated GPUs for inference.
1. **Start the vLLM server**:
```bash
trl vllm-serve --model <model_name>
```
2. **Enable server mode in your training script**:
```python
from trl import GRPOConfig
training_args = GRPOConfig(
...,
use_vllm=True,
vllm_mode="server", # default value, can be omitted
)
```
> [!WARNING]
> Make sure that the server is using different GPUs than the trainer, otherwise you may run into NCCL errors. You can specify the GPUs to use with the `CUDA_VISIBLE_DEVICES` environment variable.
#### 🧩 Option 2: Colocate mode
In this mode, vLLM runs inside the trainer process and shares GPU memory with the training model. This avoids launching a separate server and can improve GPU utilization, but may lead to memory contention on the training GPUs.
```python
from trl import GRPOConfig
training_args = GRPOConfig(
...,
use_vllm=True,
vllm_mode="colocate",
)
```
> [!TIP]
> Depending on the model size and the overall GPU memory requirements for training, you may need to adjust the `vllm_gpu_memory_utilization` parameter in [`GRPOConfig`] to avoid underutilization or out-of-memory errors.
>
> We provide a [HF Space](https://huggingface.co/spaces/trl-lib/recommend-vllm-memory) to help estimate the recommended GPU memory utilization based on your model configuration and experiment settings. Simply use it as follows to get `vllm_gpu_memory_utilization` recommendation:
>
> <iframe src="https://trl-lib-recommend-vllm-memory.hf.space" frameborder="0" width="850" height="450"></iframe>
>
> If the recommended value does not work in your environment, we suggest adding a small buffer (e.g., +0.05 or +0.1) to the recommended value to ensure stability.
>
> If you still find you are getting out-of-memory errors set `vllm_enable_sleep_mode` to True and the vllm parameters and cache will be offloaded during the optimization step. For more information, see [Reducing Memory Usage with vLLM Sleep Mode](reducing_memory_usage#vllm-sleep-mode).
> [!TIP]
> By default, GRPO uses `MASTER_ADDR=localhost` and `MASTER_PORT=12345` for vLLM, but you can override these values by setting the environment variables accordingly.
For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).
### GRPO at scale: train a 70B+ Model on multiple nodes
When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include:
- **DeepSpeed ZeRO Stage 3**: ZeRO leverages data parallelism to distribute model states (weights, gradients, optimizer states) across multiple GPUs and CPUs, reducing memory and compute requirements on each device. Since large models cannot fit on a single GPU, using ZeRO Stage 3 is required for training such models. For more details, see [DeepSpeed Integration](deepspeed_integration).
- **Accelerate**: Accelerate is a library that simplifies distributed training across multiple GPUs and nodes. It provides a simple API to launch distributed training and handles the complexities of distributed training, such as data parallelism, gradient accumulation, and distributed data loading. For more details, see [Distributing Training](distributing_training).
- **vLLM**: See the previous section on how to use vLLM to speed up generation.
Below is an example SLURM script to train a 70B model with GRPO on multiple nodes. This script trains a model on 4 nodes and uses the 5th node for vLLM-powered generation.
```sh
#!/bin/bash
#SBATCH --nodes=5
#SBATCH --gres=gpu:8
# Get the list of allocated nodes
NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
# Assign the first 4 nodes for training and the 5th node for vLLM
TRAIN_NODES="${NODELIST[@]:0:4}" # Nodes 0, 1, 2, 3 for training
VLLM_NODE="${NODELIST[4]}" # Node 4 for vLLM
# Run training on the first 4 nodes (Group 1)
srun --nodes=4 --ntasks=4 --nodelist="${NODELIST[@]:0:4}" accelerate launch \
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
--num_processes 32 \
--num_machines 4 \
--main_process_ip ${NODELIST[0]} \
--machine_rank $SLURM_PROCID \
--rdzv_backend c10d \
train_grpo.py \
--server_ip $VLLM_NODE &
# Run vLLM server on the 5th node (Group 2)
srun --nodes=1 --ntasks=1 --nodelist="${NODELIST[4]}" trl vllm-serve --model Qwen/Qwen2.5-72B --tensor_parallel_size 8 &
wait
```
```python
import argparse
from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--vllm_server_host", type=str, default="", help="The server IP")
args = parser.parse_args()
# Example dataset from TLDR
dataset = load_dataset("trl-lib/tldr", split="train")
# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
return [len(set(c)) for c in completions]
training_args = GRPOConfig(
output_dir="Qwen2.5-72B-GRPO",
per_device_train_batch_size=4,
bf16=True,
gradient_checkpointing=True,
use_vllm=True,
vllm_server_host=args.vllm_server_host.replace("ip-", "").replace("-", "."), # from ip-X-X-X-X to X.X.X.X
)
trainer = GRPOTrainer(model="Qwen/Qwen2.5-72B", args=training_args, reward_funcs=reward_num_unique_chars, train_dataset=dataset)
trainer.train()
if __name__=="__main__":
main()
```
### Using a custom reward function
The [`GRPOTrainer`] supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements:
1. **Input arguments**:
- The function must accept the following as keyword arguments:
- `prompts` (contains the prompts),
- `completions` (contains the generated completions),
- `completions_ids` (contains the tokenized completions),
- `trainer_state` ([`~transformers.TrainerState`]): The current state of the trainer. This can be used to implement dynamic reward functions, such as curriculum learning, where the reward is adjusted based on the training progress.
- All column names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument.
The easiest way to comply with this requirement is to use `**kwargs` in the function signature.
- Depending on the dataset format, the input will vary:
- For [standard format](dataset_formats#standard), `prompts` and `completions` will be lists of strings.
- For [conversational format](dataset_formats#conversational), `prompts` and `completions` will be lists of message dictionaries.
2. **Return value**: The function must return a list of floats. Each float represents the reward corresponding to a single completion.
#### Example 1: Reward longer completions
Below is an example of a reward function for a standard format that rewards longer completions:
```python
def reward_func(completions_ids, **kwargs):
"""Reward function that assigns higher scores to longer completions (in terms of token count)."""
return [float(len(ids)) for ids in completions_ids]
```
You can test it as follows:
```python
>>> prompts = ["The sky is", "The sun is"] # not used in the reward function, but the trainer will pass it
>>> completions = [" blue.", " in the sky."] # not used in the reward function, but the trainer will pass it
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]]
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
[2.0, 4.0]
```
#### Example 1.1: Reward longer completions (based on the number of characters)
Same as the previous example, but this time the reward function is based on the number of characters instead of tokens.
```python
def reward_func(completions, **kwargs):
"""Reward function that assigns higher scores to longer completions (in terms of character count)."""
return [float(len(completion)) for completion in completions]
```
You can test it as follows:
```python
>>> prompts = ["The sky is", "The sun is"]
>>> completions = [" blue.", " in the sky."]
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]] # not used in the reward function, but the trainer will pass it
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
[6.0, 12.0]
```
#### Example 2: Reward completions with a specific format
Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the _format reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948).
It is designed for a conversational format, where prompts and completions consist of structured messages.
```python
import re
def format_reward_func(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]
```
You can test this function as follows:
```python
>>> prompts = [
... [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}],
... [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}],
... ]
>>> completions = [
... [{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],
... [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],
... ]
>>> format_reward_func(prompts=prompts, completions=completions)
[1.0, 0.0]
```
#### Example 3: Reward completions based on a reference
Below is an example of a reward function that checks if the completion is correct. This example is inspired by the _accuracy reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948).
This example is designed for [standard format](dataset_formats#standard), where the dataset contains a column named `ground_truth`.
```python
import re
def reward_func(completions, ground_truth, **kwargs):
# Regular expression to capture content inside \boxed{}
matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions]
contents = [match.group(1) if match else "" for match in matches]
# Reward 1 if the content is the same as the ground truth, 0 otherwise
return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)]
```
You can test this function as follows:
```python
>>> prompts = ["Problem: Solve the equation $2x + 3 = 7$. Solution:", "Problem: Solve the equation $3x - 5 = 10$."]
>>> completions = [r" The solution is \boxed{2}.", r" The solution is \boxed{6}."]
>>> ground_truth = ["2", "5"]
>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth)
[1.0, 0.0]
```
#### Example 4: Multi-task reward functions
Below is an example of using multiple reward functions in the [`GRPOTrainer`]. In this example, we define two task-specific reward functions: `math_reward_func` and `coding_reward_func`. The `math_reward_func` rewards math problems based on their correctness, while the `coding_reward_func` rewards coding problems based on whether the solution works.
```python
from datasets import Dataset
from trl import GRPOTrainer
# Define a dataset that contains both math and coding problems
dataset = Dataset.from_list(
[
{"prompt": "What is 2+2?", "task": "math"},
{"prompt": "Write a function that returns the sum of two numbers.", "task": "code"},
{"prompt": "What is 3*4?", "task": "math"},
{"prompt": "Write a function that returns the product of two numbers.", "task": "code"},
]
)
# Math-specific reward function
def math_reward_func(prompts, completions, task, **kwargs):
rewards = []
for prompt, completion, t in zip(prompts, completions, task):
if t == "math":
# Calculate math-specific reward
correct = check_math_solution(prompt, completion)
reward = 1.0 if correct else -1.0
rewards.append(reward)
else:
# Return None for non-math tasks
rewards.append(None)
return rewards
# Coding-specific reward function
def coding_reward_func(prompts, completions, task, **kwargs):
rewards = []
for prompt, completion, t in zip(prompts, completions, task):
if t == "coding":
# Calculate coding-specific reward
works = test_code_solution(prompt, completion)
reward = 1.0 if works else -1.0
rewards.append(reward)
else:
# Return None for non-coding tasks
rewards.append(None)
return rewards
# Use both task-specific reward functions
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=[math_reward_func, coding_reward_func],
train_dataset=dataset,
)
trainer.train()
```
In this example, the `math_reward_func` and `coding_reward_func` are designed to work with a mixed dataset that contains both math and coding problems. The `task` column in the dataset is used to determine which reward function to apply to each problem. If there is no relevant reward function for a sample in the dataset, the reward function will return `None`, and the [`GRPOTrainer`] will continue with the valid functions and tasks. This allows the [`GRPOTrainer`] to handle multiple reward functions with different applicability.
Note that the [`GRPOTrainer`] will ignore the `None` rewards returned by the reward functions and only consider the rewards returned by the relevant functions. This ensures that the model is trained on the relevant tasks and ignores the tasks for which there is no relevant reward function.
#### Passing the reward function to the trainer
To use your custom reward function, pass it to the [`GRPOTrainer`] as follows:
```python
from trl import GRPOTrainer
trainer = GRPOTrainer(
reward_funcs=reward_func,
...,
)
```
If you have multiple reward functions, you can pass them as a list:
```python
from trl import GRPOTrainer
trainer = GRPOTrainer(
reward_funcs=[reward_func1, reward_func2],
...,
)
```
and the reward will be computed as the sum of the rewards from each function, or the weighted sum if `reward_weights` is provided in the config.
Note that [`GRPOTrainer`] supports multiple reward functions of different types. See the parameters documentation for more details.
## Vision-Language Model (VLM) Training
GRPO supports training Vision-Language Models (VLMs) on multimodal datasets containing both text and images.
### Supported Models
Tested with:
- **Gemma3** — e.g., `google/gemma-3-4b-it`
- **LLaVA-NeXT** — e.g., `llava-hf/llava-v1.6-mistral-7b-hf`
- **Qwen2-VL** — e.g., `Qwen/Qwen2-VL-2B-Instruct`
- **Qwen2.5-VL** — e.g., `Qwen/Qwen2.5-VL-3B-Instruct`
- **SmolVLM2** — e.g., `HuggingFaceTB/SmolVLM2-2.2B-Instruct`
> [!TIP]
> Compatibility with all VLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes.
### Quick Start
Use [grpo\_vlm.py](https://github.com/huggingface/trl/blob/main/examples/scripts/grpo_vlm.py) to fine-tune a VLM. Example command for training on [`lmms-lab/multimodal-open-r1-8k-verified`](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified):
```bash
accelerate launch \
--config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/grpo_vlm.py \
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
--output_dir grpo-Qwen2.5-VL-3B-Instruct \
--learning_rate 1e-5 \
--gradient_checkpointing \
--dtype bfloat16 \
--max_prompt_length 2048 \
--max_completion_length 1024 \
--use_vllm \
--vllm_mode colocate \
--use_peft \
--lora_target_modules "q_proj", "v_proj" \
--log_completions
```
### Configuration Tips
> [!WARNING]
> VLM training may fail if image tokens are truncated. We highly recommend disabling truncation by setting `max_prompt_length` to `None`.
- Use LoRA on vision-language projection layers
- Enable 4-bit quantization to reduce memory usage
- VLMs are memory-intensive — start with smaller batch sizes
- Most models are compatible with vLLM (`server` and `colocate` modes)
### Dataset Format
Each training sample should include:
- `prompt`: Text formatted via the processor's chat template
- `image`/`images`: PIL Image or list of PIL Images
The trainer automatically handles image-to-tensor conversion via the models image processor.
## GRPOTrainer
[[autodoc]] GRPOTrainer
- train
- save_model
- push_to_hub
## GRPOConfig
[[autodoc]] GRPOConfig

View File

@ -1,65 +0,0 @@
# Training FAQ
## What Metrics Should I Look at?
When performing classical supervised fine-tuning of language models, the loss (especially the validation loss) serves as a good indicator of the training progress. However, in Reinforcement Learning (RL), the loss becomes less informative about the model's performance, and its value may fluctuate while the actual performance improves.
To address this, we recommend focusing on two key metrics first:
**Mean Reward**: The primary goal is to maximize the reward achieved by the model during RL training.
**Objective KL Divergence**: KL divergence (Kullback-Leibler divergence) measures the dissimilarity between two probability distributions. In the context of RL training, we use it to quantify the difference between the current model and a reference model. Ideally, we want to keep the KL divergence between 0 and 10 to ensure the model's generated text remains close to what the reference model produces.
However, there are more metrics that can be useful for debugging, checkout the [logging section](logging).
## Why Do We Use a Reference Model, and What's the Purpose of KL Divergence?
When training RL models, optimizing solely for reward may lead to unexpected behaviors, where the model exploits the environment in ways that don't align with good language generation. In the case of RLHF, we use a reward model trained to predict whether a generated text is highly ranked by humans.
However, the RL model being optimized against the reward model may learn patterns that yield high reward but do not represent good language. This can result in extreme cases where the model generates texts with excessive exclamation marks or emojis to maximize the reward. In some worst-case scenarios, the model may generate patterns completely unrelated to natural language yet receive high rewards, similar to adversarial attacks.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/kl-example.png">
<p style="text-align: center;"> <b>Figure:</b> Samples without a KL penalty from <a href="https://huggingface.co/papers/1909.08593">https://huggingface.co/papers/1909.08593</a>. </p>
</div>
To address this issue, we add a penalty to the reward function based on the KL divergence between the current model and the reference model. By doing this, we encourage the model to stay close to what the reference model generates.
## What Is the Concern with Negative KL Divergence?
If you generate text by purely sampling from the model distribution things work fine in general. But when you use the `generate` method there are a few caveats because it does not always purely sample depending on the settings which can cause KL-divergence to go negative. Essentially when the active model achieves `log_p_token_active < log_p_token_ref` we get negative KL-div. This can happen in a several cases:
- **top-k sampling**: the model can smooth out the probability distribution causing the top-k tokens having a smaller probability than those of the reference model but they still are selected
- **min_length**: this ignores the EOS token until `min_length` is reached. thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached
These are just a few examples. Why is negative KL an issue? The total reward `R` is computed `R = r - beta * KL` so if the model can learn how to drive KL-divergence negative it effectively gets a positive reward. In many cases it can be much easier to exploit such a bug in the generation than actually learning the reward function. In addition the KL can become arbitrarily small thus the actual reward can be very small compared to it.
So how should you generate text for PPO training? Let's have a look!
## How to generate text for training?
In order to avoid the KL issues described above we recommend to use the following settings:
```python
generation_kwargs = {
"min_length": -1, # don't ignore the EOS token (see above)
"top_k": 0.0, # no top-k sampling
"top_p": 1.0, # no nucleus sampling
"do_sample": True, # yes, we want to sample
"pad_token_id": tokenizer.eos_token_id, # most decoder models don't have a padding token - use EOS token instead
"max_new_tokens": 32, # specify how many tokens you want to generate at most
}
```
With these settings we usually don't encounter any issues. You can also experiments with other settings but if you encounter issues with negative KL-divergence try to go back to these and see if they persist.
## How can debug your own use-case?
Debugging the RL pipeline can be challenging due to its complexity. Here are some tips and suggestions to make the process easier:
- **Start from a working example**: Begin with a working example from the trl repository and gradually modify it to fit your specific use-case. Changing everything at once can make it difficult to identify the source of potential issues. For example, you can start by replacing the model in the example and once you figure out the best hyperparameters try to switch to your dataset and reward model. If you change everything at once you won't know where a potential problem comes from.
- **Start small, scale later**: Training large models can be very slow and take several hours or days until you see any improvement. For debugging this is not a convenient timescale so try to use small model variants during the development phase and scale up once that works. That being said you sometimes have to be careful as small models might not have the capacity to solve a complicated task either.
- **Start simple**: Try to start with a minimal example and build complexity from there. Your use-case might require for example a complicated reward function consisting of many different rewards - try to use one signal first and see if you can optimize that and then add more complexity after that.
- **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a bug in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations.
- **Inspect the reward model**: If you reward is not improving over time maybe there's an issue with the reward model. You can look at extreme cases to see if it does what it should: e.g. in the sentiment case you can check if simple positive and negative examples really get different rewards. And you can look at the distribution of your dataset. Finally, maybe the reward is dominated by the query which the model can't affect so you might need to normalize this (e.g. reward of query+response minus reward of the query).
These are just a few tips that we find helpful - if you have more useful tricks feel free to open a PR to add them as well!

135
docs/source/index.md Normal file
View File

@ -0,0 +1,135 @@
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png">
</div>
# TRL - Transformer Reinforcement Learning
TRL is a full stack library where we provide a set of tools to train transformer language models with methods like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO), Reward Modeling, and more.
The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers).
Below is the current list of TRL trainers, organized by method type (⚡️ = vLLM support).
## Taxonomy
<div style="display: flex; justify-content: space-between; width: 100%; gap: 2rem;">
<div style="flex: 1; min-width: 0;">
### Online methods
- [`GRPOTrainer`] ⚡️
- [`RLOOTrainer`] ⚡️
- [`OnlineDPOTrainer`] ⚡️
- [`NashMDTrainer`] ⚡️
- [`XPOTrainer`] ⚡️
- [`PPOTrainer`]
### Reward modeling
- [`PRMTrainer`]
- [`RewardTrainer`]
</div>
<div style="flex: 1; min-width: 0;">
### Offline methods
- [`SFTTrainer`]
- [`DPOTrainer`]
- [`ORPOTrainer`]
- [`BCOTrainer`]
- [`CPOTrainer`]
- [`KTOTrainer`]
### Knowledge distillation
- [`GKDTrainer`]
</div>
</div>
## 🎉 What's New
**✨ OpenAI GPT OSS Support**: TRL now fully supports fine-tuning the latest [OpenAI GPT OSS models](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4)! Check out the:
- [OpenAI Cookbook](https://cookbook.openai.com/articles/gpt-oss/fine-tune-transfomers)
- [GPT OSS recipes](https://github.com/huggingface/gpt-oss-recipes)
- [Our example script](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_gpt_oss.py)
You can also explore TRL-related models, datasets, and demos in the [TRL Hugging Face organization](https://huggingface.co/trl-lib).
## Learn
Learn post-training with TRL and other libraries in 🤗 [smol course](https://github.com/huggingface/smol-course).
## Contents
The documentation is organized into the following sections:
- **Getting Started**: installation and quickstart guide.
- **Conceptual Guides**: dataset formats, training FAQ, and understanding logs.
- **How-to Guides**: reducing memory usage, speeding up training, distributing training, etc.
- **Integrations**: DeepSpeed, Liger Kernel, PEFT, etc.
- **Examples**: example overview, community tutorials, etc.
- **API**: trainers, utils, etc.
## Blog posts
<div class="mt-10">
<div class="w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5">
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-vlm-alignment">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/trl_vlm/thumbnail.png" alt="thumbnail" class="mt-0">
<p class="text-gray-500 text-sm">Published on August 7, 2025</p>
<p class="text-gray-700">Vision Language Model Alignment in TRL ⚡️</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/vllm-colocate">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/vllm-colocate/thumbnail.png" alt="thumbnail" class="mt-0">
<p class="text-gray-500 text-sm">Published on June 3, 2025</p>
<p class="text-gray-700">NO GPU left behind: Unlocking Efficiency with Co-located vLLM in TRL</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/liger-grpo">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/liger-grpo/thumbnail.png" alt="thumbnail" class="mt-0">
<p class="text-gray-500 text-sm">Published on May 25, 2025</p>
<p class="text-gray-700">🐯 Liger GRPO meets TRL</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/open-r1">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/open-r1/thumbnails.png" alt="thumbnail" class="mt-0">
<p class="text-gray-500 text-sm">Published on January 28, 2025</p>
<p class="text-gray-700">Open-R1: a fully open reproduction of DeepSeek-R1</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo_vlm">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/dpo_vlm/thumbnail.png" alt="thumbnail" class="mt-0">
<p class="text-gray-500 text-sm">Published on July 10, 2024</p>
<p class="text-gray-700">Preference Optimization for Vision Language Models with TRL</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/putting_rl_back_in_rlhf_with_rloo">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/putting_rl_back_in_rlhf_with_rloo/thumbnail.png" alt="thumbnail" class="mt-0">
<p class="text-gray-500 text-sm">Published on June 12, 2024</p>
<p class="text-gray-700">Putting RL back in RLHF</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-ddpo">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/166_trl_ddpo/thumbnail.png" alt="thumbnail" class="mt-0">
<p class="text-gray-500 text-sm">Published on September 29, 2023</p>
<p class="text-gray-700">Finetune Stable Diffusion Models with DDPO via TRL</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo-trl">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/157_dpo_trl/dpo_thumbnail.png" alt="thumbnail" class="mt-0">
<p class="text-gray-500 text-sm">Published on August 8, 2023</p>
<p class="text-gray-700">Fine-tune Llama 2 with DPO</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/stackllama">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/138_stackllama/thumbnail.png" alt="thumbnail" class="mt-0">
<p class="text-gray-500 text-sm">Published on April 5, 2023</p>
<p class="text-gray-700">StackLLaMA: A hands-on guide to train LLaMA with RLHF</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-peft">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/133_trl_peft/thumbnail.png" alt="thumbnail" class="mt-0">
<p class="text-gray-500 text-sm">Published on March 9, 2023</p>
<p class="text-gray-700">Fine-tuning 20B LLMs with RLHF on a 24GB consumer GPU</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/rlhf">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/120_rlhf/thumbnail.png" alt="thumbnail" class="mt-0">
<p class="text-gray-500 text-sm">Published on December 9, 2022</p>
<p class="text-gray-700">Illustrating Reinforcement Learning from Human Feedback</p>
</a>
</div>
</div>

View File

@ -1,65 +0,0 @@
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_banner_dark.png">
</div>
# TRL - Transformer Reinforcement Learning
TRL is a full stack library where we provide a set of tools to train transformer language models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step.
The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers).
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/TRL-readme.png">
</div>
Check the appropriate sections of the documentation depending on your needs:
## API documentation
- [Model Classes](models): *A brief overview of what each public model class does.*
- [`SFTTrainer`](sft_trainer): *Supervise Fine-tune your model easily with `SFTTrainer`*
- [`RewardTrainer`](reward_trainer): *Train easily your reward model using `RewardTrainer`.*
- [`PPOTrainer`](ppo_trainer): *Further fine-tune the supervised fine-tuned model using PPO algorithm*
- [Best-of-N Sampling](best-of-n): *Use best of n sampling as an alternative way to sample predictions from your active model*
- [`DPOTrainer`](dpo_trainer): *Direct Preference Optimization training using `DPOTrainer`.*
- [`TextEnvironment`](text_environments): *Text environment to train your model using tools with RL.*
## Examples
- [Sentiment Tuning](sentiment_tuning): *Fine tune your model to generate positive movie contents*
- [Training with PEFT](lora_tuning_peft): *Memory efficient RLHF training using adapters with PEFT*
- [Detoxifying LLMs](detoxifying_a_lm): *Detoxify your language model through RLHF*
- [StackLlama](using_llama_models): *End-to-end RLHF training of a Llama model on Stack exchange dataset*
- [Learning with Tools](learning_tools): *Walkthrough of using `TextEnvironments`*
- [Multi-Adapter Training](multi_adapter_rl): *Use a single base model and multiple adapters for memory efficient end-to-end training*
## Blog posts
<div class="mt-10">
<div class="w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5">
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo_vlm">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/dpo_vlm/thumbnail.png" alt="thumbnail">
<p class="text-gray-700">Preference Optimization for Vision Language Models with TRL</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/rlhf">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/120_rlhf/thumbnail.png" alt="thumbnail">
<p class="text-gray-700">Illustrating Reinforcement Learning from Human Feedback</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-peft">
<img src="https://github.com/huggingface/blog/blob/main/assets/133_trl_peft/thumbnail.png?raw=true" alt="thumbnail">
<p class="text-gray-700">Fine-tuning 20B LLMs with RLHF on a 24GB consumer GPU</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/stackllama">
<img src="https://github.com/huggingface/blog/blob/main/assets/138_stackllama/thumbnail.png?raw=true" alt="thumbnail">
<p class="text-gray-700">StackLLaMA: A hands-on guide to train LLaMA with RLHF</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo-trl">
<img src="https://github.com/huggingface/blog/blob/main/assets/157_dpo_trl/dpo_thumbnail.png?raw=true" alt="thumbnail">
<p class="text-gray-700">Fine-tune Llama 2 with DPO</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-ddpo">
<img src="https://github.com/huggingface/blog/blob/main/assets/166_trl_ddpo/thumbnail.png?raw=true" alt="thumbnail">
<p class="text-gray-700">Finetune Stable Diffusion Models with DDPO via TRL</p>
</a>
</div>
</div>

View File

@ -0,0 +1,42 @@
# Installation
You can install TRL either from PyPI or from source:
## PyPI
Install the library with pip or [uv](https://docs.astral.sh/uv/):
<hfoptions id="install">
<hfoption id="uv">
uv is a fast Rust-based Python package and project manager. Refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions.
```bash
uv pip install trl
```
</hfoption>
<hfoption id="pip">
```bash
pip install trl
```
</hfoption>
</hfoptions>
## Source
You can also install the latest version from source. First clone the repo and then run the installation with `pip`:
```bash
git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .
```
If you want the development install you can replace the pip install with the following:
```bash
pip install -e ".[dev]"
```

View File

@ -1,24 +0,0 @@
# Installation
You can install TRL either from pypi or from source:
## pypi
Install the library with pip:
```bash
pip install trl
```
### Source
You can also install the latest version from source. First clone the repo and then run the installation with `pip`:
```bash
git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .
```
If you want the development install you can replace the pip install with the following:
```bash
pip install -e ".[dev]"
```

View File

@ -1,57 +0,0 @@
# Iterative Trainer
[![](https://img.shields.io/badge/All_models-Iterative_SFT-blue)](https://huggingface.co/models?other=iterative-sft,trl)
Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.
## Usage
To get started quickly, instantiate an instance a model, and a tokenizer.
```python
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
trainer = IterativeSFTTrainer(
model,
tokenizer
)
```
You have the choice to either provide a list of strings or a list of tensors to the step function.
#### Using a list of tensors as input:
```python
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask
}
trainer.step(**inputs)
```
#### Using a list of strings as input:
```python
inputs = {
"texts": texts
}
trainer.step(**inputs)
```
For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels.
## IterativeTrainer
[[autodoc]] IterativeSFTTrainer

View File

@ -0,0 +1,274 @@
# Training with Jobs
[![model badge](https://img.shields.io/badge/All_models-HF_Jobs-blue)](https://huggingface.co/models?other=hf_jobs,trl)
[Hugging Face Jobs](https://huggingface.co/docs/huggingface_hub/guides/jobs) lets you run training scripts on fully managed infrastructure—no need to manage GPUs or local environment setup.
In this guide, you'll learn how to:
* Use [TRL Jobs](https://github.com/huggingface/trl-jobs) to easily run pre-optimized TRL training
* Run any TRL training script with uv scripts
For general details about Hugging Face Jobs (hardware selection, job monitoring, etc.), see the [Jobs documentation](https://huggingface.co/docs/huggingface_hub/guides/jobs).
## Requirements
* A [Pro](https://hf.co/pro), [Team](https://hf.co/enterprise), or [Enterprise](https://hf.co/enterprise) plan
* Logged in to the Hugging Face Hub (`hf auth login`)
## Using TRL Jobs
[TRL Jobs](https://github.com/huggingface/trl-jobs) is a high-level wrapper around Hugging Face Jobs and TRL that streamlines training. It provides optimized default configurations so you can start quickly without manually tuning parameters.
Example:
```bash
pip install trl-jobs
trl-jobs sft --model_name Qwen/Qwen3-0.6B --dataset_name trl-lib/Capybara
```
TRL Jobs supports everything covered in this guide, with additional optimizations to simplify workflows.
## Using uv Scripts
For more control, you can run Hugging Face Jobs directly with your own scripts, using [uv scripts](https://docs.astral.sh/uv/guides/scripts/).
Create a Python script (e.g., `train.py`) containing your training code:
```python
from datasets import load_dataset
from trl import SFTTrainer
dataset = load_dataset("trl-lib/Capybara", split="train")
trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
)
trainer.train()
trainer.push_to_hub("Qwen2.5-0.5B-SFT")
```
Launch the job using either the [`hf jobs` CLI](https://huggingface.co/docs/huggingface_hub/guides/cli#hf-jobs) or the Python API:
<hfoptions id="script_type">
<hfoption id="bash">
```bash
hf jobs uv run \
--flavor a100-large \
--with trl \
--secrets HF_TOKEN \
train.py
```
</hfoption>
<hfoption id="python">
```python
from huggingface_hub import run_uv_job
run_uv_job(
"train.py",
dependencies=["trl"],
flavor="a100-large",
secrets={"HF_TOKEN": "hf_..."},
)
```
</hfoption>
</hfoptions>
To run successfully, the script needs:
* **TRL installed**: Use the `--with trl` flag or the `dependencies` argument. uv installs these dependencies automatically before running the script.
* **An authentication token**: Required to push the trained model (or perform other authenticated operations). Provide it with the `--secrets HF_TOKEN` flag or the `secrets` argument.
> [!WARNING]
> When training with Jobs, be sure to:
>
> * **Set a sufficient timeout**. Jobs time out after 30 minutes by default. If your job exceeds the timeout, it will fail and all progress will be lost. See [Setting a custom timeout](https://huggingface.co/docs/huggingface_hub/guides/jobs#setting-a-custom-timeout).
> * **Push the model to the Hub**. The Jobs environment is ephemeral—files are deleted when the job ends. If you dont push the model, it will be lost.
You can also run a script directly from a URL:
<hfoptions id="script_type">
<hfoption id="bash">
```bash
hf jobs uv run \
--flavor a100-large \
--with trl \
--secrets HF_TOKEN \
"https://gist.githubusercontent.com/qgallouedec/eb6a7d20bd7d56f9c440c3c8c56d2307/raw/69fd78a179e19af115e4a54a1cdedd2a6c237f2f/train.py"
```
</hfoption>
<hfoption id="python">
```python
from huggingface_hub import run_uv_job
run_uv_job(
"https://gist.githubusercontent.com/qgallouedec/eb6a7d20bd7d56f9c440c3c8c56d2307/raw/69fd78a179e19af115e4a54a1cdedd2a6c237f2f/train.py",
flavor="a100-large",
dependencies=["trl"],
secrets={"HF_TOKEN": "hf_..."},
)
```
</hfoption>
</hfoptions>
To make a script self-contained, declare dependencies at the top:
```python
# /// script
# dependencies = [
# "trl",
# "peft",
# ]
# ///
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer
dataset = load_dataset("trl-lib/Capybara", split="train")
trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
peft_config=LoraConfig(),
)
trainer.train()
trainer.push_to_hub("Qwen2.5-0.5B-SFT")
```
You can then run the script without specifying dependencies:
<hfoptions id="script_type">
<hfoption id="bash">
```bash
hf jobs uv run \
--flavor a100-large \
--secrets HF_TOKEN \
train.py
```
</hfoption>
<hfoption id="python">
```python
from huggingface_hub import run_uv_job
run_uv_job(
"train.py",
flavor="a100-large",
secrets={"HF_TOKEN": "hf_..."},
)
```
</hfoption>
</hfoptions>
TRL example scripts are fully uv-compatible, so you can run a complete training workflow directly on Jobs. You can customize training with standard script arguments plus hardware and secrets:
<hfoptions id="script_type">
<hfoption id="bash">
```bash
hf jobs uv run \
--flavor a100-large \
--secrets HF_TOKEN \
https://raw.githubusercontent.com/huggingface/trl/refs/heads/main/examples/scripts/prm.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name trl-lib/prm800k \
--output_dir Qwen2-0.5B-Reward \
--push_to_hub
```
</hfoption>
<hfoption id="python">
```python
from huggingface_hub import run_uv_job
run_uv_job(
"https://raw.githubusercontent.com/huggingface/trl/refs/heads/main/examples/scripts/prm.py",
flavor="a100-large",
secrets={"HF_TOKEN": "hf_..."},
script_args=[
"--model_name_or_path", "Qwen/Qwen2-0.5B-Instruct",
"--dataset_name", "trl-lib/prm800k",
"--output_dir", "Qwen2-0.5B-Reward",
"--push_to_hub"
]
)
```
</hfoption>
</hfoptions>
See the full list of examples in [Maintained examples](example_overview#maintained-examples).
### Docker Images
An up-to-date Docker image with all TRL dependencies is available at [huggingface/trl](https://hub.docker.com/r/huggingface/trl) and can be used directly with Hugging Face Jobs:
<hfoptions id="script_type">
<hfoption id="bash">
```bash
hf jobs uv run \
--flavor a100-large \
--secrets HF_TOKEN \
--image huggingface/trl \
train.py
```
</hfoption>
<hfoption id="python">
```python
from huggingface_hub import run_uv_job
run_uv_job(
"train.py",
flavor="a100-large",
secrets={"HF_TOKEN": "hf_..."},
image="huggingface/trl",
)
```
</hfoption>
</hfoptions>
Jobs runs on a Docker image from Hugging Face Spaces or Docker Hub, so you can also specify any custom image:
<hfoptions id="script_type">
<hfoption id="bash">
```bash
hf jobs uv run \
--flavor a100-large \
--secrets HF_TOKEN \
--image <docker-image> \
--secrets HF_TOKEN \
train.py
```
</hfoption>
<hfoption id="python">
```python
from huggingface_hub import run_uv_job
run_uv_job(
"train.py",
flavor="a100-large",
secrets={"HF_TOKEN": "hf_..."},
image="<docker-image>",
)
```
</hfoption>
</hfoptions>

View File

@ -1,16 +1,19 @@
# Judges
> [!WARNING]
> TRL Judges is an experimental API which is subject to change at any time.
TRL provides judges to easily compare two completions.
Make sure to have installed the required dependencies by running:
```bash
pip install trl[llm_judge]
pip install trl[judges]
```
## Using the provided judges
TRL provides several judges out of the box. For example, you can use the `HfPairwiseJudge` to compare two completions using a pre-trained model from the Hugging Face model hub:
TRL provides several judges out of the box. For example, you can use the [`HfPairwiseJudge`] to compare two completions using a pre-trained model from the Hugging Face model hub:
```python
from trl import HfPairwiseJudge
@ -46,34 +49,38 @@ judge.judge(
) # Outputs: [0, 1]
```
## BaseJudge
## Provided judges
[[autodoc]] BaseJudge
## BaseRankJudge
[[autodoc]] BaseRankJudge
## BasePairwiseJudge
[[autodoc]] BasePairwiseJudge
## RandomRankJudge
[[autodoc]] RandomRankJudge
## RandomPairwiseJudge
[[autodoc]] RandomPairwiseJudge
## PairRMJudge
### PairRMJudge
[[autodoc]] PairRMJudge
## HfPairwiseJudge
### HfPairwiseJudge
[[autodoc]] HfPairwiseJudge
## OpenAIPairwiseJudge
### OpenAIPairwiseJudge
[[autodoc]] OpenAIPairwiseJudge
### AllTrueJudge
[[autodoc]] AllTrueJudge
## Base classes
### BaseJudge
[[autodoc]] BaseJudge
### BaseBinaryJudge
[[autodoc]] BaseBinaryJudge
### BaseRankJudge
[[autodoc]] BaseRankJudge
### BasePairwiseJudge
[[autodoc]] BasePairwiseJudge

View File

@ -0,0 +1,96 @@
# Kernels Hub Integration and Usage
<img src="https://github.com/user-attachments/assets/4b5175f3-1d60-455b-8664-43b2495ee1c3" width="450" height="450" alt="kernel-builder logo">
The [`kernels`](https://huggingface.co/blog/hello-hf-kernels#get-started-and-next-steps) library allows optimized compute kernels to be loaded directly from the Hub.
You can find `kernels` in [dedicated orgs](https://huggingface.co/kernels-community) or by searching for the [`kernel` tag](https://huggingface.co/models?other=kernel) within the Hub.
Kernels are **optimized code pieces** that help in model development, training, and inference. Here, well focus on their **integration with TRL**, but check out the above resources to learn more about them.
## Installation
To use kernels with TRL, you'd need to install the library in your Python environment:
```bash
pip install kernels
```
## Using Kernels from the Hub in TRL
Kernels can directly replace attention implementations, removing the need to manually compile attention backends like Flash Attention and boosting training speed just by pulling the respective attention kernel from the Hub.
You can specify a kernel when loading a model:
```python
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"your-model-name",
attn_implementation="kernels-community/flash-attn" # other options: kernels-community/vllm-flash-attn3, kernels-community/paged-attention
)
```
Or when running a TRL training script:
```bash
python sft.py ... --attn_implementation kernels-community/flash-attn
```
Or using the TRL CLI:
```bash
trl sft ... --attn_implementation kernels-community/flash-attn
```
> [!TIP]
> Now you can leverage faster attention backends with a pre-optimized kernel for your hardware configuration from the Hub, speeding up both development and training.
## Comparing Attention Implementations
We evaluated various attention implementations available in transformers, along with different kernel backends, using **TRL** and **SFT**.
The experiments were run on a single **H100 GPU** with **CUDA 12.9**, leveraging **Qwen3-8B** with a **batch size of 8**, **gradient accumulation of 1**, and **bfloat16** precision.
Keep in mind that the results shown here are specific to this setup and may vary with different training configurations.
The following figure illustrates both **latency** (time per training step) and **peak allocated memory** for the different attention implementations and kernel backends.
Kernel-based implementations perform on par with custom-installed attention, and increasing the models `max_length` further enhances performance. Memory consumption is similar across all implementations, showing no significant differences. We get the same performance but with less friction, as described in [the following section](#flash-attention-vs-hub-kernels).
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/kernels_guide_latency.png" alt="Latency and Memory Usage" width="45%"/>
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/kernels_guide_peak_allocated_memory.png" alt="Latency and Memory Usage" width="45%"/>
</div>
## Flash Attention vs. Hub Kernels
Building Flash Attention from source can be time-consuming, often taking anywhere from several minutes to hours, depending on your hardware, CUDA/PyTorch configuration, and whether precompiled wheels are available.
In contrast, **Hugging Face Kernels** provide a much faster and more reliable workflow. Developers dont need to worry about complex setups—everything is handled automatically. In our benchmarks, kernels were ready to use in about **2.5 seconds**, with no compilation required. This allows you to start training almost instantly, significantly accelerating development. Simply specify the desired version, and `kernels` takes care of the rest.
## Combining FlashAttention Kernels with Liger Kernels
You can combine **FlashAttention kernels** with **Liger kernels** for additional TRL performance improvements.
First, install the Liger kernel dependency:
```bash
pip install liger-kernel
```
Then, combine both in your code:
```python
from transformers import AutoModelForCausalLM
from trl import SFTConfig
model = AutoModelForCausalLM.from_pretrained(
"your-model-name",
attn_implementation="kernels-community/flash-attn" # choose the desired FlashAttention variant
)
training_args = SFTConfig(
use_liger_kernel=True,
# ... other TRL training args
)
```
Learn more about the [Liger Kernel Integration](./liger_kernel_integration).

View File

@ -1,12 +1,11 @@
# KTO Trainer
[![](https://img.shields.io/badge/All_models-KTO-blue)](https://huggingface.co/models?other=kto,trl)
[![model badge](https://img.shields.io/badge/All_models-KTO-blue)](https://huggingface.co/models?other=kto,trl)
## Overview
Kahneman-Tversky Optimization (KTO) was introduced in [KTO: Model Alignment as Prospect Theoretic Optimization](https://huggingface.co/papers/2402.01306) by [Kawin Ethayarajh](https://huggingface.co/kawine), [Winnie Xu](https://huggingface.co/xwinxu), [Niklas Muennighoff](https://huggingface.co/Muennighoff), Dan Jurafsky, [Douwe Kiela](https://huggingface.co/douwekiela).
The abstract from the paper is the following:
> Kahneman & Tversky's prospect theory tells us that humans perceive random variables in a biased but well-defined manner; for example, humans are famously loss-averse. We show that objectives for aligning LLMs with human feedback implicitly incorporate many of these biases -- the success of these objectives (e.g., DPO) over cross-entropy minimization can partly be ascribed to them being human-aware loss functions (HALOs). However, the utility functions these methods attribute to humans still differ from those in the prospect theory literature. Using a Kahneman-Tversky model of human utility, we propose a HALO that directly maximizes the utility of generations instead of maximizing the log-likelihood of preferences, as current methods do. We call this approach Kahneman-Tversky Optimization (KTO), and it matches or exceeds the performance of preference-based methods at scales from 1B to 30B. Crucially, KTO does not need preferences -- only a binary signal of whether an output is desirable or undesirable for a given input. This makes it far easier to use in the real world, where preference data is scarce and expensive.
@ -38,7 +37,7 @@ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
train_dataset = load_dataset("trl-lib/kto-mix-14k", split="train")
training_args = KTOConfig(output_dir="Qwen2-0.5B-KTO", logging_steps=10)
training_args = KTOConfig(output_dir="Qwen2-0.5B-KTO")
trainer = KTOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
```
@ -51,11 +50,11 @@ accelerate launch train_kto.py
Distributed across 8 x H100 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/kto-qwen2-reward-margin.png)
![kto qwen2 reward margin](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/kto-qwen2-reward-margin.png)
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-KTO
<pre><code>$ transformers chat trl-lib/Qwen2-0.5B-KTO
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?
@ -74,22 +73,21 @@ Here are some other factors to consider when choosing a programming language for
KTO requires an [unpaired preference dataset](dataset_formats#unpaired-preference). Alternatively, you can provide a *paired* preference dataset (also known simply as a *preference dataset*). In this case, the trainer will automatically convert it to an unpaired format by separating the chosen and rejected responses, assigning `label = True` to the chosen completions and `label = False` to the rejected ones.
The [`KTOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
The [`KTOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
In theory, the dataset should contain at least one chosen and one rejected completion. However, some users have successfully run KTO using *only* chosen or only rejected data. If using only rejected data, it is advisable to adopt a conservative learning rate.
## Example script
We provide an example script to train a model using the KTO method. The script is available in [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py)
We provide an example script to train a model using the KTO method. The script is available in [`trl/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/kto.py)
To test the KTO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/kto-mix-14k), run the following command:
```bash
accelerate launch examples/scripts/kto.py \
accelerate launch trl/scripts/kto.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name trl-lib/kto-mix-14k \
--num_train_epochs 1 \
--logging_steps 25 \
--output_dir Qwen2-0.5B-KTO
```
@ -103,7 +101,6 @@ To ensure that we train MOEs similarly during preference-tuning, it is beneficia
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
### Batch size recommendations
Use a per-step batch size that is at least 4, and an effective batch size between 16 and 128. Even if your effective batch size is large, if your per-step batch size is poor, then the KL estimate in KTO will be poor.
@ -119,20 +116,23 @@ By default, they are both 1. However, if you have more of one or the other, then
## Logged metrics
While training and evaluating we record the following reward metrics:
While training and evaluating, we record the following reward metrics:
- `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
- `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
- `logps/chosen`: the mean log probabilities of the chosen completions
- `logps/rejected`: the mean log probabilities of the rejected completions
- `logits/chosen`: the mean logits of the chosen completions
- `logits/rejected`: the mean logits of the rejected completions
- `kl`: the KL divergence between the policy model and the reference model
- `rewards/chosen_sum`: the sum of log probabilities of the policy model for the chosen responses scaled by beta
- `rewards/rejected_sum`: the sum of log probabilities of the policy model for the rejected responses scaled by beta
- `logps/chosen_sum`: the sum of log probabilities of the chosen completions
- `logps/rejected_sum`: the sum of log probabilities of the rejected completions
- `logits/chosen_sum`: the sum of logits of the chosen completions
- `logits/rejected_sum`: the sum of logits of the rejected completions
- `count/chosen`: the count of chosen samples in a batch
- `count/rejected`: the count of rejected samples in a batch
## KTOTrainer
[[autodoc]] KTOTrainer
- train
- save_model
- push_to_hub
## KTOConfig

View File

@ -1,233 +0,0 @@
# Learning Tools (Experimental 🧪)
Using Large Language Models (LLMs) with tools has been a popular topic recently with awesome works such as [ToolFormer](https://huggingface.co/papers/2302.04761) and [ToolBench](https://huggingface.co/papers/2305.16504). In TRL, we provide a simple example of how to teach LLM to use tools with reinforcement learning.
Here's an overview of the scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples/research_projects/tools):
| File | Description |
|---|---|
| [`calculator.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/calculator.py) | Script to train LLM to use a calculator with reinforcement learning. |
| [`triviaqa.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/triviaqa.py) | Script to train LLM to use a wiki tool to answer questions. |
| [`python_interpreter.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/python_interpreter.py) | Script to train LLM to use python interpreter to solve math puzzles. |
<Tip warning={true}>
Note that the scripts above rely heavily on the `TextEnvironment` API which is still under active development. The API may change in the future. Please see [`TextEnvironment`](text_environment) for the related docs.
</Tip>
## Learning to Use a Calculator
The rough idea is as follows:
1. Load a tool such as [ybelkada/simple-calculator](https://huggingface.co/spaces/ybelkada/simple-calculator) that parse a text calculation like `"14 + 34"` and return the calulated number:
```python
from transformers import AutoTokenizer, load_tool
tool = load_tool("ybelkada/simple-calculator")
tool_fn = lambda text: str(round(float(tool(text)), 2)) # rounding to 2 decimal places
```
1. Define a reward function that returns a positive reward if the tool returns the correct answer. In the script we create a dummy reward function like `reward_fn = lambda x: 1`, but we override the rewards directly later.
1. Create a prompt on how to use the tools
```python
# system prompt
prompt = """\
What is 13.1-3?
<request><SimpleCalculatorTool>13.1-3<call>10.1<response>
Result=10.1<submit>
What is 4*3?
<request><SimpleCalculatorTool>4*3<call>12<response>
Result=12<submit>
What is 12.1+1?
<request><SimpleCalculatorTool>12.1+1<call>13.1<response>
Result=13.1<submit>
What is 12.1-20?
<request><SimpleCalculatorTool>12.1-20<call>-7.9<response>
Result=-7.9<submit>"""
```
3. Create a `trl.TextEnvironment` with the model
```python
env = TextEnvironment(
model,
tokenizer,
{"SimpleCalculatorTool": tool_fn},
reward_fn,
prompt,
generation_kwargs=generation_kwargs,
)
```
4. Then generate some data such as `tasks = ["\n\nWhat is 13.1-3?", "\n\nWhat is 4*3?"]` and run the environment with `queries, responses, masks, rewards, histories = env.run(tasks)`. The environment will look for the `<call>` token in the prompt and append the tool output to the response; it will also return the mask associated with the response. You can further use the `histories` to visualize the interaction between the model and the tool; `histories[0].show_text()` will show the text with color-coded tool output and `histories[0].show_tokens(tokenizer)` will show visualize the tokens.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/learning_tools.png)
1. Finally, we can train the model with `train_stats = ppo_trainer.step(queries, responses, rewards, masks)`. The trainer will use the mask to ignore the tool output when computing the loss, make sure to pass that argument to `step`.
## Experiment results
We trained a model with the above script for 10 random seeds. You can reproduce the run with the following command. Feel free to remove the `--slurm-*` arguments if you don't have access to a slurm cluster.
```
WANDB_TAGS="calculator_final" python benchmark/benchmark.py \
--command "python examples/research_projects/tools/calculator.py" \
--num-seeds 10 \
--start-seed 1 \
--workers 10 \
--slurm-gpus-per-task 1 \
--slurm-ntasks 1 \
--slurm-total-cpus 8 \
--slurm-template-path benchmark/trl.slurm_template
```
We can then use [`openrlbenchmark`](https://github.com/openrlbenchmark/openrlbenchmark) which generates the following plot.
```
# pip install openrlbenchmark==0.2.1a5
python -m openrlbenchmark.rlops_multi_metrics \
--filters '?we=openrlbenchmark&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.tracker_project_name&cen=trl_ppo_trainer_config.value.log_with&metrics=env/reward_mean&metrics=objective/kl' \
'wandb?tag=calculator_final&cl=calculator_mask' \
--env-ids trl \
--check-empty-runs \
--pc.ncols 2 \
--pc.ncols-legend 1 \
--output-filename static/0compare \
--scan-history
```
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/learning_tools_chart.png)
As we can see, while 1-2 experiments crashed for some reason, most of the runs obtained near perfect proficiency in the calculator task.
## (Early Experiments 🧪): learning to use a wiki tool for question answering
In the [ToolFormer](https://huggingface.co/papers/2302.04761) paper, it shows an interesting use case that utilizes a Wikipedia Search tool to help answer questions. In this section, we attempt to perform similar experiments but uses RL instead to teach the model to use a wiki tool on the [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) dataset.
<Tip warning={true}>
**Note that many settings are different so the results are not directly comparable.**
</Tip>
### Building a search index
Since [ToolFormer](https://huggingface.co/papers/2302.04761) did not open source, we needed to first replicate the search index. It is mentioned in their paper that the authors built the search index using a BM25 retriever that indexes the Wikipedia dump from [KILT](https://github.com/facebookresearch/KILT)
Fortunately, [`pyserini`](https://github.com/castorini/pyserini) already implements the BM25 retriever and provides a prebuilt index for the KILT Wikipedia dump. We can use the following code to search the index.
```python
from pyserini.search.lucene import LuceneSearcher
import json
searcher = LuceneSearcher.from_prebuilt_index('wikipedia-kilt-doc')
def search(query):
hits = searcher.search(query, k=1)
hit = hits[0]
contents = json.loads(hit.raw)['contents']
return contents
print(search("tennis racket"))
```
```
Racket (sports equipment)
A racket or racquet is a sports implement consisting of a handled frame with an open hoop across which a network of strings or catgut is stretched tightly. It is used for striking a ball or shuttlecock in games such as squash, tennis, racquetball, and badminton. Collectively, these games are known as racket sports. Racket design and manufacturing has changed considerably over the centuries.
The frame of rackets for all sports was traditionally made of solid wood (later laminated wood) and the strings of animal intestine known as catgut. The traditional racket size was limited by the strength and weight of the wooden frame which had to be strong enough to hold the strings and stiff enough to hit the ball or shuttle. Manufacturers started adding non-wood laminates to wood rackets to improve stiffness. Non-wood rackets were made first of steel, then of aluminum, and then carbon fiber composites. Wood is still used for real tennis, rackets, and xare. Most rackets are now made of composite materials including carbon fiber or fiberglass, metals such as titanium alloys, or ceramics.
...
```
We then basically deployed this snippet as a Hugging Face space [here](https://huggingface.co/spaces/vwxyzjn/pyserini-wikipedia-kilt-doc), so that we can use the space as a `transformers.Tool` later.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pyserini.png)
### Experiment settings
We use the following settings:
* use the `bigcode/starcoderbase` model as the base model
* use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragrahs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool.
* test if the response contain the answer string, if so, give a reward of 1, otherwise, give a reward of 0.
* notice this is a simplified evaluation criteria. In [ToolFormer](https://huggingface.co/papers/2302.04761), the authors checks if the first 20 words of the response contain the correct answer.
* used the following prompt that demonstrates the usage of the wiki tool.
```python
prompt = """\
Answer the following question:
Q: In which branch of the arts is Patricia Neary famous?
A: Ballets
A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
Result=Ballets<submit>
Q: Who won Super Bowl XX?
A: Chicago Bears
A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 4610, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
Result=Chicago Bears<submit>
Q: """
```
### Result and Discussion
Our experiments show that the agent can learn to use the wiki tool to answer questions. The learning curves would go up mostly, but one of the experiment did crash.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/triviaqa_learning_curves.png)
Wandb report is [here](https://wandb.ai/costa-huang/cleanRL/reports/TriviaQA-Final-Experiments--Vmlldzo1MjY0ODk5) for further inspection.
Note that the correct rate of the trained model is on the low end, which could be due to the following reasons:
* **incorrect searches:** When given the question `"What is Bruce Willis' real first name?"` if the model searches for `Bruce Willis`, our wiki tool returns "Patrick Poivey (born 18 February 1948) is a French actor. He is especially known for his voice: he is the French dub voice of Bruce Willis since 1988.` But a correct search should be `Walter Bruce Willis (born March 19, 1955) is an American former actor. He achieved fame with a leading role on the comedy-drama series Moonlighting (19851989) and appeared in over a hundred films, gaining recognition as an action hero after his portrayal of John McClane in the Die Hard franchise (19882013) and other roles.[1][2]"
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/real_first_name.png)
* **unnecessarily long response**: The wiki tool by default sometimes output very long sequences. E.g., when the wiki tool searches for "Brown Act"
* Our wiki tool returns "The Ralph M. Brown Act, located at California Government Code 54950 "et seq.", is an act of the California State Legislature, authored by Assemblymember Ralph M. Brown and passed in 1953, that guarantees the public's right to attend and participate in meetings of local legislative bodies."
* [ToolFormer](https://huggingface.co/papers/2302.04761)'s wiki tool returns "The Ralph M. Brown Act is an act of the California State Legislature that guarantees the public's right to attend and participate in meetings of local legislative bodies." which is more succinct.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/brown_act.png)
## (Early Experiments 🧪): solving math puzzles with python interpreter
In this section, we attempt to teach the model to use a python interpreter to solve math puzzles. The rough idea is to give the agent a prompt like the following:
```python
prompt = """\
Example of using a Python API to solve math questions.
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
<request><PythonInterpreter>
def solution():
money_initial = 23
bagels = 5
bagel_cost = 3
money_spent = bagels * bagel_cost
money_left = money_initial - money_spent
result = money_left
return result
print(solution())
<call>72<response>
Result = 72 <submit>
Q: """
```
Training experiment can be found at https://wandb.ai/lvwerra/trl-gsm8k/runs/a5odv01y
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/gms8k_learning_curve.png)

View File

@ -0,0 +1,29 @@
# Liger Kernel Integration
> [!WARNING]
> Section under construction. Feel free to contribute!
[Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduce memory usage by 60%. That way, we can **4x** our context length, as described in the benchmark below. They have implemented Hugging Face compatible `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, with more to come. The kernel works out of the box with [FlashAttention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed).
With this memory reduction, you can potentially turn off `cpu_offloading` or gradient checkpointing to further boost the performance.
| Speed Up | Memory Reduction |
| --- | --- |
| ![Speed up](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-tps.png) | ![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-memory.png) |
1. To use Liger-Kernel in [`SFTTrainer`], first install it by:
```bash
pip install liger-kernel
```
2. Once installed, set `use_liger_kernel` in [`SFTConfig`]. No other changes are needed!
```python
training_args = SFTConfig(
use_liger_kernel=True,
...
)
```
To learn more about Liger-Kernel, visit their [official repository](https://github.com/linkedin/Liger-Kernel/).

View File

@ -1,74 +0,0 @@
# Logging
As reinforcement learning algorithms are historically challenging to debug, it's important to pay careful attention to logging.
By default, the TRL [`PPOTrainer`] saves a lot of relevant information to wandb or tensorboard.
Upon initialization, pass one of these two options to the [`PPOConfig`]:
```
training_args = PPOConfig(..., report_to="wandb") # or "tensorboard"
```
If you want to log with tensorboard, add the kwarg `project_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig.
## PPO Logging
Here's a brief explanation for the logged metrics provided in the data:
Key metrics to monitor. We want to maximize the reward, maintain a low KL divergence, and maximize entropy:
1. `env/reward_mean`: The average reward obtained from the environment. Alias `ppo/mean_scores`, which is sed to specifically monitor the reward model.
1. `env/reward_std`: The standard deviation of the reward obtained from the environment. Alias ``ppo/std_scores`, which is sed to specifically monitor the reward model.
1. `env/reward_dist`: The histogram distribution of the reward obtained from the environment.
1. `objective/kl`: The mean Kullback-Leibler (KL) divergence between the old and new policies. It measures how much the new policy deviates from the old policy. The KL divergence is used to compute the KL penalty in the objective function.
1. `objective/kl_dist`: The histogram distribution of the `objective/kl`.
1. `objective/kl_coef`: The coefficient for Kullback-Leibler (KL) divergence in the objective function.
1. `ppo/mean_non_score_reward`: The **KL penalty** calculated by `objective/kl * objective/kl_coef` as the total reward for optimization to prevent the new policy from deviating too far from the old policy.
1. `objective/entropy`: The entropy of the model's policy, calculated by `-logprobs.sum(-1).mean()`. High entropy means the model's actions are more random, which can be beneficial for exploration.
Training stats:
1. `ppo/learning_rate`: The learning rate for the PPO algorithm.
1. `ppo/policy/entropy`: The entropy of the model's policy, calculated by `pd = torch.nn.functional.softmax(logits, dim=-1); entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)`. It measures the randomness of the policy.
1. `ppo/policy/clipfrac`: The fraction of probability ratios (old policy / new policy) that fell outside the clipping range in the PPO objective. This can be used to monitor the optimization process.
1. `ppo/policy/approxkl`: The approximate KL divergence between the old and new policies, measured by `0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)`, corresponding to the `k2` estimator in http://joschu.net/blog/kl-approx.html
1. `ppo/policy/policykl`: Similar to `ppo/policy/approxkl`, but measured by `masked_mean(old_logprobs - logprobs, mask)`, corresponding to the `k1` estimator in http://joschu.net/blog/kl-approx.html
1. `ppo/policy/ratio`: The histogram distribution of the ratio between the new and old policies, used to compute the PPO objective.
1. `ppo/policy/advantages_mean`: The average of the GAE (Generalized Advantage Estimation) advantage estimates. The advantage function measures how much better an action is compared to the average action at a state.
1. `ppo/policy/advantages`: The histogram distribution of `ppo/policy/advantages_mean`.
1. `ppo/returns/mean`: The mean of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance. See https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ for more details.
1. `ppo/returns/var`: The variance of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance.
1. `ppo/val/mean`: The mean of the values, used to monitor the value function's performance.
1. `ppo/val/var` : The variance of the values, used to monitor the value function's performance.
1. `ppo/val/var_explained`: The explained variance for the value function, used to monitor the value function's performance.
1. `ppo/val/clipfrac`: The fraction of the value function's predicted values that are clipped.
1. `ppo/val/vpred`: The predicted values from the value function.
1. `ppo/val/error`: The mean squared error between the `ppo/val/vpred` and returns, used to monitor the value function's performance.
1. `ppo/loss/policy`: The policy loss for the Proximal Policy Optimization (PPO) algorithm.
1. `ppo/loss/value`: The loss for the value function in the PPO algorithm. This value quantifies how well the function estimates the expected future rewards.
1. `ppo/loss/total`: The total loss for the PPO algorithm. It is the sum of the policy loss and the value function loss.
Stats on queries, responses, and logprobs:
1. `tokens/queries_len_mean`: The average length of the queries tokens.
1. `tokens/queries_len_std`: The standard deviation of the length of the queries tokens.
1. `tokens/queries_dist`: The histogram distribution of the length of the queries tokens.
1. `tokens/responses_len_mean`: The average length of the responses tokens.
1. `tokens/responses_len_std`: The standard deviation of the length of the responses tokens.
1. `tokens/responses_dist`: The histogram distribution of the length of the responses tokens. (Costa: inconsistent naming, should be `tokens/responses_len_dist`)
1. `objective/logprobs`: The histogram distribution of the log probabilities of the actions taken by the model.
1. `objective/ref_logprobs`: The histogram distribution of the log probabilities of the actions taken by the reference model.
### Crucial values
During training, many values are logged, here are the most important ones:
1. `env/reward_mean`,`env/reward_std`, `env/reward_dist`: the properties of the reward distribution from the "environment" / reward model
1. `ppo/mean_non_score_reward`: The mean negated KL penalty during training (shows the delta between the reference model and the new policy over the batch in the step)
Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables):
1. `ppo/loss/value`: it will spike / NaN when not going well.
1. `ppo/policy/ratio`: `ratio` being 1 is a baseline value, meaning that the probability of sampling a token is the same under the new and old policy. If the ratio is too high like 200, it means the probability of sampling a token is 200 times higher under the new policy than the old policy. This is a sign that the new policy is too different from the old policy, which will likely cause overoptimization and collapse training later on.
1. `ppo/policy/clipfrac` and `ppo/policy/approxkl`: if `ratio` is too high, the `ratio` is going to get clipped, resulting in high `clipfrac` and high `approxkl` as well.
1. `objective/kl`: it should stay positive so that the policy is not too far away from the reference policy.
1. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities.

View File

@ -0,0 +1,442 @@
# LoRA Without Regret
Recent research from the team at [Thinking Machines Lab](https://thinkingmachines.ai/blog/lora/) (Schulman et al., 2025) shows that **LoRA can match full fine-tuning performance** when configured correctly, while using only ~67% of the compute. These findings are exciting to TRL users because they're straightforward to implement and can improve model performance on smaller budgets.
This guide provides simple instructions to reproduce the results of the blog post in TRL.
> [!TIP]
> It is recommended to read the blog post before following this guide, or to consult both resources in parallel for best results.
## Benefits of LoRA over full fine-tuning
First of all, let's remind ourselves of the benefits of [LoRA over full fine-tuning](https://huggingface.co/docs/trl/en/peft_integration).
LoRA adds adapter layers on top of the base model, which contains significantly fewer parameters than the base model itself. This design reduces GPU memory requirements and enables more efficient training. As described in the [blog](https://thinkingmachines.ai/blog/lora/), this approach was originally thought to involve a performance trade-off, although careful configuration can overcome this trade-off and match full fine-tuning performance.
## Examples with TRL
Let's implement and train LoRA adapters in TRL scripts based on the core findings of the blog post. Afterwards, we'll revisit each finding in light of the TRL results.
### Supervised Fine-Tuning (SFT)
The blog post performs SFT on a range of models and datasets from the Hub, which we can reproduce in TRL.
| Model | Dataset |
| --- | --- |
| [Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B) | [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture) |
| [Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B) | [open-thoughts/OpenThoughts-114k](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k) |
| [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B) | [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture) |
| [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B) | [open-thoughts/OpenThoughts-114k](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k) |
<hfoptions id="sft">
<hfoption id="python">
We can integrate these findings with the TRL Python API like so:
```python
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
dataset = load_dataset("open-thoughts/OpenThoughts-114k", split="train")
peft_config = LoraConfig(r=256, lora_alpha=16, target_modules="all-linear")
training_args = SFTConfig(
learning_rate=2e-4,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
num_train_epochs=1,
report_to=["trackio"],
)
trainer = SFTTrainer(
model="Qwen/Qwen2.5-3B-Instruct",
train_dataset=dataset,
peft_config=peft_config,
args=training_args,
)
trainer.train()
```
</hfoption>
<hfoption id="jobs">
```bash
hf jobs uv run \
--flavor a100-large \
--timeout 8h \
--secrets HF_TOKEN \
"https://raw.githubusercontent.com/huggingface/trl/main/trl/scripts/sft.py" \
--model_name_or_path Qwen/Qwen2.5-3B-Instruct \
--dataset_name open-thoughts/OpenThoughts-114k \
--learning_rate 2.0e-5 \
--num_train_epochs 1 \
--packing \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 16 \
--use_peft \
--lora_r 256 \
--lora_alpha 16 \
--lora_target_modules all-linear \
--output_dir Qwen2.5-3B-OpenThoughts-LoRA \
--report_to trackio \
--push_to_hub
```
To use Hugging Face Jobs, you will need to be logged in to the Hugging Face Hub (`hf auth login`) and have a [Pro](https://hf.co/pro), [Team](https://hf.co/enterprise), or [Enterprise](https://hf.co/enterprise) plan. Check out the [Jobs documentation](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) for more details.
</hfoption>
<hfoption id="local">
```bash
uv run "https://raw.githubusercontent.com/huggingface/trl/main/trl/scripts/sft.py" \
--model_name_or_path Qwen/Qwen2.5-3B-Instruct \
--dataset_name open-thoughts/OpenThoughts-114k \
--learning_rate 2.0e-5 \
--num_train_epochs 1 \
--packing \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 16 \
--gradient_checkpointing \
--eval_strategy no \
--use_peft \
--lora_r 256 \
--lora_alpha 16 \
--lora_target_modules all-linear \
--output_dir Qwen2.5-3B-OpenThoughts-LoRA \
--report_to trackio \
--push_to_hub
```
To run the script locally, you will need to have `uv` installed. Check out the [uv documentation](https://docs.astral.sh/uv/) for more details.
</hfoption>
</hfoptions>
Once training starts, you can monitor the progress in [Trackio](https://huggingface.co/trackio), which will log the URL.
### Reinforcement Learning (GRPO)
The blog post performs GRPO on a range of models and datasets from the Hub, and once again we can reproduce the results in TRL.
| Model | Dataset |
| --- | --- |
| [Llama-3.1-8B-Base](https://huggingface.co/meta-llama/Llama-3.2-1B) | [GSM8k](https://huggingface.co/datasets/openai/gsm8k) |
| [Llama-3.1-8B-Base](https://huggingface.co/meta-llama/Llama-3.2-1B) | [DeepMath-103K](https://huggingface.co/datasets/zwhe99/DeepMath-103K) |
| [Qwen3-8b-base](https://huggingface.co/Qwen/Qwen3-8b-base) | [DeepMath-103K](https://huggingface.co/datasets/zwhe99/DeepMath-103K) |
For reinforcement learning, the blog uses a math reasoning task that we can reproduce as a Python function.
<details>
<summary>Reward function</summary>
```python
def strip_reasoning_accuracy_reward(
completions: list[list[dict[str, str]]], solution: list[str], **kwargs
) -> list[Optional[float]]:
"""Reward function that strips reasoning tags and checks mathematical accuracy.
This function:
1. Extracts the content from completions
2. Removes <think></think> tags (for reasoning that shouldn't be evaluated)
3. Parses both the gold solution and the predicted answer
4. Uses math_verify to check if they are mathematically equivalent
Args:
completions: List of model completions, each containing a list of messages
solution: List of ground truth solutions
**kwargs: Additional arguments (ignored but required for trainer compatibility)
Returns:
List of rewards where:
- 1.0 if the answer is correct
- 0.0 if the answer is incorrect
- None if the solution is not parseable (skips this example)
"""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, solution):
# Strip reasoning tags from completion
while "<think>" in content and "</think>" in content:
start = content.find("<think>")
end = content.find("</think>", start)
if start != -1 and end != -1:
content = content[:start] + content[end + len("</think>") :]
else:
break
# Parse gold solution
gold_parsed = parse(
f"${sol}$",
extraction_config=[
LatexExtractionConfig(
boxed_match_priority=0, try_extract_without_anchor=True
)
],
)
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
boxed_match_priority=0,
normalization_config=NormalizationConfig(
basic_latex=True,
units=True,
malformed_operators=False,
nits=False,
boxed=True,
),
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
# Compute binary rewards if verifiable, `None` otherwise to skip this example
try:
reward = float(verify(gold_parsed, answer_parsed))
except Exception as e:
print(
f"verify failed: {e}, answer: {answer_parsed}, gold: {gold_parsed}"
)
reward = None
else:
# If the gold solution is not parseable, we assign `None` to skip this example
reward = None
rewards.append(reward)
return rewards
```
</details>
<hfoptions id="grpo">
<hfoption id="python">
We can implement these recommendations with the TRL Python API like so:
```python
from datasets import load_dataset
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer
dataset = load_dataset("HuggingFaceH4/OpenR1-Math-220k-default-verified", split="train")
def strip_reasoning_accuracy_reward(completions, **kwargs):
"""Reward function that strips reasoning and accuracy scores from the model outputs."""
...
peft_config = LoraConfig(
r=1,
lora_alpha=32,
target_modules="all-linear"
)
training_args = GRPOConfig(
learning_rate=5e-5,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
num_train_epochs=1,
num_generations=8,
generation_batch_size=8,
report_to=["trackio"],
)
trainer = GRPOTrainer(
model="Qwen/Qwen3-0.6B",
reward_funcs=strip_reasoning_accuracy_reward,
args=training_args,
train_dataset=dataset,
peft_config=peft_config,
)
trainer.train()
```
> [!WARNING]
> This snippet skips the reward function which is defined above to keep the example concise.
</hfoption>
<hfoption id="jobs">
```bash
hf jobs uv run \
--flavor a100-large \
--timeout 4h \
--secrets HF_TOKEN \
--env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
"https://huggingface.co/datasets/burtenshaw/lora-without-regrets/resolve/main/grpo.py" \
--model_name_or_path Qwen/Qwen3-0.6B \
--dataset_name HuggingFaceH4/OpenR1-Math-220k-default-verified \
--output_dir grpo-full-qwen3-0.6b \
--learning_rate 1.0e-6 \
--lr_scheduler_type cosine \
--warmup_ratio 0.0 \
--max_grad_norm 1.0 \
--beta 0.0 \
--max_prompt_length 1024 \
--max_completion_length 4096 \
--num_generations 16 \
--generation_batch_size 16 \
--gradient_accumulation_steps 8 \
--per_device_train_batch_size 1 \
--num_train_epochs 1 \
--lora_r 1 \
--lora_alpha 32 \
--lora_dropout 0.0 \
--lora_target_modules all-linear \
--vllm_mode colocate \
--save_strategy steps \
--save_steps 50 \
--save_total_limit 1 \
--logging_steps 1 \
--max_steps 200 \
--report_to trackio
```
To use Hugging Face Jobs, you will need to be logged in to the Hugging Face Hub (`hf auth login`) and have a [Pro](https://hf.co/pro), [Team](https://hf.co/enterprise), or [Enterprise](https://hf.co/enterprise) plan. Check out the [Jobs documentation](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) for more details.
</hfoption>
<hfoption id="local">
```bash
uv run "https://huggingface.co/datasets/burtenshaw/lora-without-regrets/resolve/main/grpo.py" \
--model_name_or_path Qwen/Qwen3-0.6B \
--dataset_name HuggingFaceH4/OpenR1-Math-220k-default-verified \
--output_dir grpo-full-qwen3-0.6b \
--learning_rate 1.0e-6 \
--lr_scheduler_type cosine \
--warmup_ratio 0.0 \
--max_grad_norm 1.0 \
--beta 0.0 \
--max_prompt_length 1024 \
--max_completion_length 4096 \
--num_generations 16 \
--generation_batch_size 16 \
--gradient_accumulation_steps 8 \
--per_device_train_batch_size 1 \
--num_train_epochs 1 \
--lora_r 1 \
--lora_alpha 32 \
--lora_dropout 0.0 \
--lora_target_modules all-linear \
--vllm_mode colocate \
--save_strategy steps \
--save_steps 50 \
--save_total_limit 1 \
--logging_steps 1 \
--max_steps 200 \
--report_to trackio
```
To run the script locally, you will need to have `uv` installed. Check out the [uv documentation](https://docs.astral.sh/uv/) for more details.
</hfoption>
</hfoptions>
The reinforcement learning script with GRPO is implemented as a custom script in TRL, which uses the reward function shown above. You can review it at [`grpo.py`](https://huggingface.co/datasets/burtenshaw/lora-without-regrets/blob/main/grpo.py) - Reinforcement learning with LoRA best practices
## Key findings in optimizing LoRA
The authors recommend applying LoRA to all weight matrices rather than limiting it to attention layers, as increasing the rank does not compensate for this restriction. In TRL, this can be configured using `--lora_target_modules all-linear` to apply LoRA to all weight matrices.
We were able to reproduce the results of the blog post using TRL and the SmolLM3 model. We trained the model for 500 steps on the [Math 220k dataset](https://huggingface.co/datasets/HuggingFaceH4/OpenR1-Math-220k-default-verified) with the reward function and configuration above. As you can see in the figure below, the LoRA model's average train reward curve matches the full fine-tuning curve.
![train reward](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lora_without_regret/5.png)
And most importantly, the LoRA model uses significantly less memory than the full fine-tuning model, as we can see in the figure below.
![memory usage](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lora_without_regret/6.png)
Here are the parameters we used to train the above models
| Parameter | LoRA | Full FT |
| --- | --- | --- |
| `--model_name_or_path` | HuggingFaceTB/SmolLM3-3B | HuggingFaceTB/SmolLM3-3B |
| `--dataset_name` | HuggingFaceH4/OpenR1-Math-220k-default-verified | HuggingFaceH4/OpenR1-Math-220k-default-verified |
| `--learning_rate` | 1.0e-5 | 1.0e-6 |
| `--max_prompt_length` | 1024 | 1024 |
| `--max_completion_length` | 4096 | 4096 |
| `--lora_r` | 1 | - |
| `--lora_alpha` | 32 | - |
| `--lora_dropout` | 0.0 | - |
| `--lora_target_modules` | all-linear | - |
Let's break down the key findings of the blog post and how we were able to reproduce them.
### 1. *LoRA performs better when applied to all weight matrices*
The authors recommend applying LoRA to all weight matrices rather than limiting it to attention layers, as increasing the rank does not compensate for this restriction.
![all layers](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lora_without_regret/1.png)
Attention-only LoRA underperforms even when using a higher rank to match parameter count. In TRL, this can be configured using `--lora_target_modules all-linear` to apply LoRA to all weight matrices. In Python, we can do this like so:
```python
from peft import LoraConfig
peft_config = LoraConfig(target_modules="all-linear")
```
### 2. *The adapter needs sufficient capacity to learn from the dataset*
The blog post recommends using a sufficient LoRA rank to learn from the dataset. The rank determines the number of trainable parameters in the LoRA adapter. Therefore, "For datasets that exceed LoRA capacity, LoRA underperforms FullFT".
![learning rate](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lora_without_regret/3.png)
In the TRL script, we could use `--lora_r` to set the rank and adapt it based on the task and dataset we're training on. The blog post recommends the following ranks based on the task and dataset size:
Reinforcement learning tasks typically require lower capacity, so smaller LoRA ranks can be used. This is because policy gradient algorithms extract roughly ~1 bit of information per episode, demanding minimal parameter capacity.
The blog post defines the ideal dataset size for LoRA to match full fine-tuning as "Post-training scale". Which we can use to determine the recommended rank for SFT and RL LoRAs as:
| Task Type | Dataset Size | Recommended Rank |
| --- | --- | --- |
| **SFT** | Post-training scale | 256 |
| **RL** | Any size | 1-32 |
### 3. *"FullFT and high-rank LoRAs have similar learning curves"*
Counterintuitively, the blog post recommends using a higher learning rate than for full fine-tuning. In the table above, we used 1.0e-5 for LoRA and 1.0e-6 for full fine-tuning. In the TRL script, we could use `--learning_rate` to set the learning rate. The \\( \frac{1}{r} \\) scaling in LoRA makes the optimal learning rate approximately rank-independent.
![learning rate](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lora_without_regret/2.png)
### 4. *"In some scenarios, LoRA is less tolerant of large batch sizes than full fine-tuning."*
The blog post recommends using an effective batch size < 32 because the authors found LoRA to be less tolerant of large batch sizes. This could not be mitigated by increasing the LoRA rank. In the TRL script, we could use `--per_device_train_batch_size` and `--gradient_accumulation_steps` to set the batch size.
![learning rate](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lora_without_regret/4.png)
## Takeaways
Using TRL, you can efficiently implement LoRA adapters to match full fine-tuning performance, applying the core insights (targeting all weight matrices, choosing the right rank, and managing batch size and learning rate) without the heavy compute cost of FullFT.
## Citation
```bibtex
@article{schulman2025lora,
title = {{LoRA Without Regret}},
author = {John Schulman and Thinking Machines Lab},
year = 2025,
journal = {Thinking Machines Lab: Connectionism},
doi = {10.64434/tml.20250929},
note = {https://thinkingmachines.ai/blog/lora/}
}
```

View File

@ -0,0 +1,9 @@
# Model Utilities
## clone_chat_template
[[autodoc]] clone_chat_template
## get_act_offloading_ctx_manager
[[autodoc]] models.get_act_offloading_ctx_manager

View File

@ -8,7 +8,6 @@ With the `AutoModelForCausalLMWithValueHead` class TRL supports all decoder mode
## AutoModelForCausalLMWithValueHead
[[autodoc]] AutoModelForCausalLMWithValueHead
- __init__
- forward

View File

@ -48,6 +48,7 @@ trainer = PPOTrainer(
...
```
Then inside your PPO training loop, call the `compute_reward_score` method by accessing the `model` attribute from `PPOTrainer`.
```python
@ -71,6 +72,7 @@ rewards = trainer.model.compute_reward_score(**inputs, ppo_adapter_name=adapter_
For more memory efficient fine-tuning, you can load your base model in 8-bit or 4-bit while keeping the adapters in the default precision (float32).
Just pass the appropriate arguments (i.e. `load_in_8bit=True` or `load_in_4bit=True`) to `AutoModelForCausalLMWithValueHead.from_pretrained` as follows (assuming you have installed `bitsandbytes`):
```python
model_name = "llama-7b"
rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"
@ -88,7 +90,7 @@ model = AutoModelForCausalLMWithValueHead.from_pretrained(
model_name,
peft_config=lora_config,
reward_adapter=rm_adapter_id,
load_in_8bit=True,
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
)
...

View File

@ -1,6 +1,6 @@
# Nash-MD Trainer
[![](https://img.shields.io/badge/All_models-Nash--MD-blue)](https://huggingface.co/models?other=nash-md,trl)
[![model badge](https://img.shields.io/badge/All_models-Nash--MD-blue)](https://huggingface.co/models?other=nash-md,trl)
## Overview
@ -36,7 +36,7 @@ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
judge = PairRMJudge()
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
training_args = NashMDConfig(output_dir="Qwen2-0.5B-NashMD", logging_steps=10)
training_args = NashMDConfig(output_dir="Qwen2-0.5B-NashMD")
trainer = NashMDTrainer(
model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
)
@ -51,9 +51,9 @@ accelerate launch train_nash_md.py
Distributed across 8 GPUs, the training takes approximately 3 hours.
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-NashMD) performs, you can use the [TRL Chat CLI](clis#chat-interface).
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-NashMD) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-NashMD
<pre><code>$ transformers chat trl-lib/Qwen2-0.5B-NashMD
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?
@ -63,7 +63,7 @@ The best programming language depends on personal preference, the complexity of
## Expected dataset type
Nash-MD requires a [prompt-only dataset](dataset_formats#prompt-only). The [`NashMDTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
Nash-MD requires a [prompt-only dataset](dataset_formats#prompt-only). The [`NashMDTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
## Usage tips
@ -81,15 +81,12 @@ Instead of a judge, you can chose to use a reward model -- see [Reward Bench](ht
trainer = NashMDTrainer(
...
- judge=judge,
+ reward_model=reward_model,
+ reward_funcs=reward_model,
)
```
<Tip warning={true}>
Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training.
</Tip>
> [!WARNING]
> Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training.
### Encourage EOS token generation
@ -111,7 +108,7 @@ trainer.add_callback(completions_callback)
This callback logs the model's generated completions directly to Weights & Biases.
![Logged Completions](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/wandb_completions.png)
![Logged Completions](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/wandb_completions.png)
## Example script
@ -125,7 +122,6 @@ python examples/scripts/nash_md.py \
--judge pair_rm \
--dataset_name trl-lib/ultrafeedback-prompt \
--learning_rate 5.0e-7 \
--logging_steps 25 \
--output_dir Qwen2.5-0.5B-NashMD-PairRM \
--warmup_ratio 0.1 \
--push_to_hub
@ -133,7 +129,7 @@ python examples/scripts/nash_md.py \
## Logged metrics
The logged metrics are as follows:
While training and evaluating, we record the following reward metrics:
* `loss/kl`: The mean KL divergence between the model and reference data.
* `objective/entropy`: The mean entropy of the model and reference data.
@ -153,6 +149,9 @@ The logged metrics are as follows:
## NashMDTrainer
[[autodoc]] NashMDTrainer
- train
- save_model
- push_to_hub
## NashMDConfig

View File

@ -1,6 +1,6 @@
# Online DPO Trainer
[![](https://img.shields.io/badge/All_models-Online_DPO-blue)](https://huggingface.co/models?other=online-dpo,trl)
[![model badge](https://img.shields.io/badge/All_models-Online_DPO-blue)](https://huggingface.co/models?other=online-dpo,trl)
## Overview
@ -36,7 +36,7 @@ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
judge = PairRMJudge()
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
training_args = OnlineDPOConfig(output_dir="Qwen2-0.5B-OnlineDPO", logging_steps=10)
training_args = OnlineDPOConfig(output_dir="Qwen2-0.5B-OnlineDPO")
trainer = OnlineDPOTrainer(
model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
)
@ -51,11 +51,11 @@ accelerate launch train_online_dpo.py
Distributed across 8 GPUs, the training takes approximately 1 hour. You can verify the training progress by checking the reward graph. An increasing trend in both the reward for rejected and chosen completions indicates that the model is improving and generating better responses over time.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/online-dpo-qwen2.png)
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/online-dpo-qwen2.png)
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-OnlineDPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-OnlineDPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-OnlineDPO
<pre><code>$ transformers chat trl-lib/Qwen2-0.5B-OnlineDPO
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?
@ -65,7 +65,7 @@ The best programming language depends on your specific needs and priorities. Som
## Expected dataset type
Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
Online DPO only requires a [prompt-only dataset](dataset_formats#prompt-only) (unlike offline DPO, that expects [preference dataset](dataset_formats#preference)). The [`OnlineDPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
## Usage tips
@ -84,7 +84,7 @@ Instead of a judge, you can chose to use a reward model -- see [Reward Bench](ht
trainer = OnlineDPOTrainer(
...
- judge=judge,
+ reward_model=reward_model,
+ reward_funcs=reward_model,
+ reward_processing_class=reward_tokenizer,
...
)
@ -110,8 +110,7 @@ trainer.add_callback(completions_callback)
This callback logs the model's generated completions directly to Weights & Biases.
![Logged Completions](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/wandb_completions.png)
![Logged Completions](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/wandb_completions.png)
## Example script
@ -125,7 +124,6 @@ python examples/scripts/dpo_online.py \
--judge pair_rm \
--dataset_name trl-lib/ultrafeedback-prompt \
--learning_rate 5.0e-7 \
--logging_steps 25 \
--output_dir Qwen2.5-0.5B-Online-DPO-PairRM \
--warmup_ratio 0.1 \
--push_to_hub
@ -133,7 +131,7 @@ python examples/scripts/dpo_online.py \
## Logged metrics
The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/w4apmsi9)
While training and evaluating, we record the following reward metrics. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/w4apmsi9)
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current model and reference model.
* `objective/entropy`: The mean entropy of the model, indicating the randomness of the actions chosen by the model.
@ -154,8 +152,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an
To validate the online DPO implementation works, we ran experiments with the Pythia 1B, 2.8B, and 6.9B models on a single node of 8 x H100s. Here are the commands we used to run the experiments. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
```
```shell
# 1B Online DPO experiment
accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml \
examples/scripts/dpo_online.py \
@ -171,7 +168,6 @@ accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml \
--max_new_tokens 53 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0 \
--logging_steps 20 \
--save_steps 0.1 \
--push_to_hub
@ -190,8 +186,6 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml
--max_new_tokens 53 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0 \
--bf16 \
--logging_steps 20 \
--save_steps 0.1 \
--push_to_hub
@ -210,18 +204,15 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml
--max_new_tokens 53 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0 \
--bf16 \
--gradient_checkpointing \
--logging_steps 20 \
--save_steps 0.1 \
--push_to_hub
```
Checkpoints and experiment tracking are available at:
- [🤗 Model checkpoints](https://huggingface.co/collections/trl-lib/online-dpo-66acd3fa38a331a9cd457b07)
- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/reports/Online-DPO-experiments-for-TL-DR-summarisation--Vmlldzo5MTczMDU0)
* [🤗 Model checkpoints](https://huggingface.co/collections/trl-lib/online-dpo-66acd3fa38a331a9cd457b07)
* [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/reports/Online-DPO-experiments-for-TL-DR-summarisation--Vmlldzo5MTczMDU0)
To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR.
For more information on how to use judges, see [Judges](judges).
@ -265,13 +256,14 @@ plt.tight_layout()
plt.show()
```
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/online_dpo_scaling.png)
The online DPO checkpoint gets increasingly more win rate as we scale up the model sizes. This is a good sign that the online DPO implementation is working as intended.
## OnlineDPOTrainer
[[autodoc]] OnlineDPOTrainer
- train
- save_model
- push_to_hub
## OnlineDPOConfig

View File

@ -1,6 +1,6 @@
# ORPO Trainer
[![](https://img.shields.io/badge/All_models-ORPO-blue)](https://huggingface.co/models?other=orpo,trl)
[![model badge](https://img.shields.io/badge/All_models-ORPO-blue)](https://huggingface.co/models?other=orpo,trl) [![model badge](https://img.shields.io/badge/smol_course-Chapter_2-yellow)](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment)
## Overview
@ -41,7 +41,7 @@ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = ORPOConfig(output_dir="Qwen2-0.5B-ORPO", logging_steps=10)
training_args = ORPOConfig(output_dir="Qwen2-0.5B-ORPO")
trainer = ORPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
```
@ -54,11 +54,11 @@ accelerate launch train_orpo.py
Distributed across 8 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/orpo-qwen2-reward-margin.png)
![orpo qwen2 reward margin](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/orpo-qwen2-reward-margin.png)
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-ORPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-ORPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-ORPO
<pre><code>$ transformers chat trl-lib/Qwen2-0.5B-ORPO
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?
@ -94,7 +94,6 @@ accelerate launch examples/scripts/orpo.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name trl-lib/ultrafeedback_binarized \
--num_train_epochs 1 \
--logging_steps 25 \
--output_dir Qwen2-0.5B-ORPO
```
@ -110,7 +109,7 @@ To scale how much the auxiliary loss contributes to the total loss, use the hype
## Logged metrics
While training and evaluating we record the following reward metrics:
While training and evaluating, we record the following reward metrics:
- `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
- `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
@ -123,6 +122,9 @@ While training and evaluating we record the following reward metrics:
## ORPOTrainer
[[autodoc]] ORPOTrainer
- train
- save_model
- push_to_hub
## ORPOConfig

9
docs/source/others.md Normal file
View File

@ -0,0 +1,9 @@
# Other
## profiling_decorator
[[autodoc]] extras.profiling.profiling_decorator
## profiling_context
[[autodoc]] extras.profiling.profiling_context

582
docs/source/paper_index.md Normal file
View File

@ -0,0 +1,582 @@
# Paper Index
> [!WARNING]
> Section under construction. Feel free to contribute!
## Group Relative Policy Optimization
Papers relating to the [`GRPOTrainer`]
### Group Sequence Policy Optimization
**📜 Paper**: https://huggingface.co/papers/2507.18071
GSPO is a GRPO variant that computes importance sampling weights at the sequence level instead of per-token. To reproduce the paper's setting, use this configuration:
```python
from trl import GRPOConfig
training_args = GRPOConfig(
importance_sampling_level="sequence",
loss_type="grpo",
beta=0.0, # GSPO set KL regularization to zero: https://github.com/volcengine/verl/pull/2775#issuecomment-3131807306
epsilon=3e-4, # GSPO paper (v2), section 5.1
epsilon_high=4e-4, # GSPO paper (v2), section 5.1
gradient_accumulation_steps=1,
steps_per_generation=4, # partition rollout batch into 4 mini-batches. GSPO paper (v2), section 5.1. Must be 4 times gradient_accumulation_steps
)
```
Note that this method only has an effect when training goes slightly off-policy—for example, when `steps_per_generation > gradient_accumulation_steps` or `num_iterations > 1`. Otherwise, it is effectively equivalent to no modification.
TRL also provide an experimental implementation of GSPO-token, see [Experimental - GSPO-Token](experimental#gspo-token).
#### Policy ratio: GRPO vs. GSPO
In GSPO, the policy ratio is defined at the sequence-level. In other words, it is the ratio between the probability of the current policy generating a sequence over the old policy generating that same sequence.
The sequence likelihood is defined as:
$$
\pi_\theta (o_i | q) = \prod_{t=1}^{|o_i|} \pi_\theta (o_{i,t} | q, o_{i, < t} ),
$$
where \\( \pi_\theta \\) is the policy \\( \pi \\) with parameters \\(\theta\\), \\( o_i \\) is the \\( i \\)-th output sequence \\( o \\) and \\(o_{i,t}\\) is the \\( t \\)-th token in this sequence, \\( q \\) is the input query. The sequence likelihood ratio \\( s_i (\theta) \\) is defined as:
$$
s_i (\theta) = \left(\frac{\pi_\theta (o_i | q)}{\pi_{\theta_{old}} (o_i | q)} \right)^{\frac{1}{|o_i|}}
$$
The exponent \\( \frac{1}{|o_i|} \\) represents a sequence-length normalization, minimizing the influence of sequence length in sequence likelihood. In other terms, it computes the geometric mean of token probabilities, ensuring a fair comparison across sequences of varying lengths.
While GSPO defines the policy ratio at the sequence level, GRPO operates at the token level. Specifically, GRPO computes an importance ratio for each token in the sequence:
$$
w_{i,t}(\theta) = \frac{\pi_\theta (o_{i,t} | q, o_{i,< t})}{\pi_{\theta_{\text{old}}} (o_{i,t} | q, o_{i,< t})}
$$
This token-level ratio is then combined with a shared advantage \\( \hat{A}_i \\), and the GRPO objective clips and optimizes each token independently across the sequence.
### DAPO: An Open-Source LLM Reinforcement Learning System at Scale
**📜 Paper**: https://huggingface.co/papers/2503.14476
The DAPO algorithm includes 5 key components:
- Overlong Filtering
- Clip-Higher
- Soft Overlong Punishment
- Token-level Loss
- Dynamic Sampling (⚠ Not supported in TRL)
To reproduce the paper's setting, use this configuration:
```python
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
# Overlong Filtering
mask_truncated_completions=True,
# Token-level Loss
loss_type="dapo",
# Clip-Higher
epsilon_high=0.28, # DAPO paper: section 4.1
epsilon=0.2, # DAPO paper: section 4.1
# Other parameters used
per_device_train_batch_size=512, # mini-batch size for training in the paper, DAPO paper: section 4.1
num_generations=16, # number of sample responses in the paper, DAPO paper: section 4.1
max_completion_length=20480, # maximum number of tokens for generation in the paper, DAPO paper: section 4.1
beta=0.0 # section 2.3, DAPO paper
)
# Soft Overlong Punishment
sop_reward = get_soft_overlong_punishment(max_completion_len=20480, soft_punish_cache=4096) # DAPO paper: section 4.1
trainer = GRPOTrainer(
...,
args=training_args,
reward_funcs=[..., sop_reward],
)
```
### Dr. GRPO: Understanding R1-Zero-Like Training: A Critical Perspective
**📜 Paper**: https://huggingface.co/papers/2503.20783
A study of R1-Zero training identifies pretraining effects on RL performance and proffers Dr. GRPO to enhance token efficiency, achieving superior accuracy on AIME 2024. To reproduce the paper's setting, use this configuration:
```python
from trl import GRPOConfig
training_args = GRPOConfig(
loss_type="dr_grpo",
per_device_train_batch_size=1, # train_batch_size_per_device in the Training section of the repository
num_generations=8, # num_samples in the Training section of the repository
max_prompt_length=1024, # prompt_max_length in the Training section of the repository
max_completion_length=3000, # generate_max_length in the Training section of the repository
beta=0.0, # beta in the Training section of the repository
)
```
### Part I: Tricks or Traps? A Deep Dive into RL for LLM Reasoning (Lite PPO)
**📜 Paper**: https://huggingface.co/papers/2508.08221
The authors of this paper find that the combination of:
1. scaling rewards by the standard deviation computed over the entire batch and
2. aggregating loss over the total number of tokens
can unlock the learning capability of critic-free policies using vanilla PPO loss. Their results demonstrate that this simple combination consistently improves performance, surpassing strategies like GRPO and [DAPO](https://huggingface.co/papers/2503.14476).
TRL supports using these learnings to train a GRPO model by:
```python
from trl import GRPOConfig
training_args = GRPOConfig(
...
scale_rewards="batch",
loss_type="dapo",
# Other parameters used
beta=0.0, # = init_kl_coef in the paper
top_p=0.99,
top_k=100,
temperature=0.99,
num_completions=8, # = num_return_sequences in the paper
num_iterations=1, # = ppo_epochs in the paper
per_device_train_batch_size=4,
gradient_accumulation_steps=32,
steps_per_generation=8, # (rollout_batch_size*num_return_sequences) / (per_device_train_batch_size*gradient_accumulation_steps)
)
```
Note that when using gradient accumulation, the loss is aggregated over the total number of tokens in the batch, but not over the accumulated batch. For more details, see the [GRPO Trainer - Loss types](grpo_trainer#loss_types).
### Truncated Importance Sampling
**📰 Blog**: https://fengyao.notion.site/off-policy-rl
Online policy learning methods commonly use an optimized inference framework for rollout generation (e.g vLLM) that is separate from the training backend. This introduces a rollout-training mismatch, exemplified in the following PPO objective:
$$
\small{
\mathbb{E}_{a\sim\textcolor{red}{\pi_{\text{inference}}}(\theta_{\mathrm{old}})}
\Bigl[
\min\Bigl(
\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta)}{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})}\,\hat A,
\;\mathrm{clip}\bigl(\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta)}{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})},\,1-\epsilon,\,1+\epsilon\bigr)\,\hat A
\Bigr)
\Bigr]
}
$$
Despite \\( \textcolor{red}{\pi_{\text{inference}}} \\) and \\( \textcolor{blue}{\pi_{\text{training}}} \\) sharing the same model parameters \\( \theta \\), they can produce significantly different token probabilities. This unexpected behavior implicitly breaks the on-policy assumption, and silently turns training off-policy.
Truncated Importance Sampling (TIS) addresses this issue by adapting the model update via importance-sampling correction. The gradient computation of the aforementioned PPO objective becomes
$$
\small{
\mathbb{E}_{a\sim\textcolor{red}{\pi_{\text{inference}}}(\theta_{\mathrm{old}})}
\Bigl[
\underbrace{\min(\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})}{\textcolor{red}{\pi_{\text{inference}}}(a, \theta_{\mathrm{old}})}, C)}_{\text{truncated importance ratio}} \cdot
\nabla_\theta
\min\Bigl(
\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta)}{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})}\,\hat A,
\;\mathrm{clip}\bigl(\frac{\textcolor{blue}{\pi_{\text{training}}}(a, \theta)}{\textcolor{blue}{\pi_{\text{training}}}(a, \theta_{\mathrm{old}})},\,1-\epsilon,\,1+\epsilon\bigr)\,\hat A
\Bigr)
\Bigr]
}
$$
where \\( C \\) is a hyper-parameter. In TRL, TIS is implemented for GRPO, and enabled by default when vLLM is used for generation (`use_vllm=True`)
```python
from trl import GRPOConfig
training_args = GRPOConfig(
...
use_vllm=True,
vllm_importance_sampling_correction=True, # default True
vllm_importance_sampling_cap=2.0, # hyper-parameter C
)
```
### Sample More to Think Less: Group Filtered Policy Optimization for Concise Reasoning
**📜 Paper**: https://huggingface.co/papers/2508.09726
See [Experimental - GFPO](experimental#gfpo).
## Direct Policy Optimization
Papers relating to the [`DPOTrainer`]
### Direct Preference Optimization (DPO): Your Language Model is Secretly a Reward Model
**📜 Paper**: https://huggingface.co/papers/2305.18290
Direct Preference Optimization (DPO) fine-tunes language models more efficiently and with better performance compared to reinforcement learning from human feedback (RLHF), by directly optimizing policy training based on human preferences. To reproduce the paper's setting, use this configuration:
```python
from trl import DPOConfig
training_args = DPOConfig(
loss_type="sigmoid", # losses in Appendix B of the paper
per_device_train_batch_size=64, # batch size in Appendix B of the paper
learning_rate=1e-6, # learning rate in Appendix B of the paper
beta=0.1, # beta in Appendix B of the paper
)
```
### A General Theoretical Paradigm to Understand Learning from Human Preferences
**📜 Paper**: https://huggingface.co/papers/2310.12036
A new general objective, \\( \Psi \\)$PO, bypasses both key approximations in reinforcement learning from human preferences, allowing for theoretical analysis and empirical superiority over DPO. To reproduce the paper's setting, use this configuration: To reproduce the paper's setting, use this configuration:
```python
from trl import DPOConfig
training_args = DPOConfig(
loss_type="ipo", # Section 5.1 of the paper
per_device_train_batch_size=90, # mini-batch size in Section C.1 of the paper
learning_rate=1e-2, # learning rate in Section C.1 of the paper
)
```
These parameters only appear in the [published version](https://proceedings.mlr.press/v238/gheshlaghi-azar24a/gheshlaghi-azar24a.pdf)
### SLiC-HF: Sequence Likelihood Calibration with Human Feedback
**📜 Paper**: https://huggingface.co/papers/2305.10425
Sequence Likelihood Calibration (SLiC) is shown to be an effective and simpler alternative to Reinforcement Learning from Human Feedback (RLHF) for learning from human preferences in language models. To reproduce the paper's setting, use this configuration:
```python
from trl import DPOConfig
training_args = DPOConfig(
loss_type="hinge", # Section 2 of the paper
per_device_train_batch_size=512, # batch size in Section 3.2 of the paper
learning_rate=1e-4, # learning rate in Section 3.2 of the paper
)
```
These parameters only appear in the [published version](https://openreview.net/pdf?id=0qSOodKmJaN)
### Towards Efficient and Exact Optimization of Language Model Alignment
**📜 Paper**: https://huggingface.co/papers/2305.10425
Efficient exact optimization (EXO) method is proposed to align language models with human preferences, providing a guaranteed and efficient alternative to reinforcement learning and direct preference optimization. To reproduce the paper's setting, use this configuration:
```python
from trl import DPOConfig
training_args = DPOConfig(
loss_type="exo_pair", # Section 3.2 of the paper
per_device_train_batch_size=64, # batch size in Section B of the paper
learning_rate=1e-6, # learning rate in Section B of the paper
beta=0.1, # $\beta_r$ in Section B of the paper
)
```
### Noise Contrastive Alignment of Language Models with Explicit Rewards
**📜 Paper**: https://huggingface.co/papers/2402.05369
A framework using Noise Contrastive Estimation enhances language model alignment with both scalar rewards and pairwise preferences, demonstrating advantages over Direct Preference Optimization. To reproduce the paper's setting, use this configuration:
```python
from trl import DPOConfig
training_args = DPOConfig(
loss_type="nca_pair", # Section 4.1 of the paper
per_device_train_batch_size=32, # batch size in Section C of the paper
learning_rate=5e-6, # learning rate in Section C of the paper
beta=0.01, # $\alpha$ in Section C of the paper
)
```
### Provably Robust DPO: Aligning Language Models with Noisy Feedback
**📜 Paper**: https://huggingface.co/papers/2403.00409
The paper introduces a robust direct preference optimization (rDPO) framework to address noise in preference-based feedback for language models, proving its sub-optimality gap and demonstrating its effectiveness through experiments. To reproduce the paper's setting, use this configuration:
```python
from trl import DPOConfig
training_args = DPOConfig(
loss_type="robust", # Section 3.1 of the paper
per_device_train_batch_size=16, # batch size in Section B of the paper
learning_rate=1e-3, # learning rate in Section B of the paper
beta=0.01, # $\beta$ in Section B of the paper,
max_prompt_length=128, # max prompt length in Section B of the paper
max_length=512, # max length in Section B of the paper
label_smoothing=0.1 # label smoothing $\epsilon$ in section 6 of the paper
)
```
### Binary Classifier Optimization for Large Language Model Alignment
**📜 Paper**: https://huggingface.co/papers/2404.04656
Theoretical analysis and a new algorithm, Binary Classifier Optimization, explain and enhance the alignment of large language models using binary feedback signals. To reproduce the paper's setting, use this configuration:
```python
from trl import DPOConfig
training_args = DPOConfig(
loss_type="bco_pair", # Section 4 of the paper
per_device_train_batch_size=128, # batch size in Section C of the paper
learning_rate=5e-7, # learning rate in Section C of the paper
beta=0.01, # $\beta$ in Section C of the paper,
max_prompt_length=1536, # max prompt length in Section C of the paper
max_completion_length=512, # max completion length in Section C of the paper
)
```
For the unpaired version, the user should utilize [`BCOConfig`] and [`BCOTrainer`].
### Self-Play Preference Optimization for Language Model Alignment
**📜 Paper**: https://huggingface.co/papers/2405.00675
A self-play method called SPPO for language model alignment achieves state-of-the-art performance by approximating Nash equilibrium policy in a constant-sum game setting, outperforming other approaches with limited data. To reproduce the paper's setting, use this configuration:
```python
from trl import DPOConfig
training_args = DPOConfig(
loss_type="sppo_hard", # Section 3 of the paper
per_device_train_batch_size=64, # batch size in Section C of the paper
learning_rate=5e-7, # learning rate in Section C of the paper
)
```
### Distributional Preference Alignment of LLMs via Optimal Transport
**📜 Paper**: https://huggingface.co/papers/2406.05882
Alignment via Optimal Transport (AOT) aligns large language models distributionally by penalizing violations of stochastic dominance between positive and negative sample distributions, achieving state-of-the-art performance on alignment benchmarks. To reproduce the paper's setting, use this configuration:
```python
from trl import DPOConfig
training_args = DPOConfig(
loss_type="aot", # Section 3 of the paper
)
```
```python
from trl import DPOConfig
training_args = DPOConfig(
loss_type="aot_pair", # Section 3 of the paper
)
```
There is no additional hyperparameter in the paper.
### Discovering Preference Optimization Algorithms with and for Large Language Models
**📜 Paper**: https://huggingface.co/papers/2406.08414
An LLM-driven method automatically discovers performant preference optimization algorithms, leading to a new algorithm called DiscoPOP that blends logistic and exponential losses. To reproduce the paper's setting, use this configuration:
```python
from trl import DPOConfig
training_args = DPOConfig(
loss_type="discopop", # Section 3 of the paper
per_device_train_batch_size=64, # batch size in Section B.1 of the paper
learning_rate=5e-7, # learning rate in Section B.1 of the paper
beta=0.05, # $\beta$ in Section B.1 of the paper,
discopop_tau=0.05 # $\tau$ in Section E of the paper
)
```
### Anchored Preference Optimization and Contrastive Revisions: Addressing Underspecification in Alignment
**📜 Paper**: https://huggingface.co/papers/2408.06266
CLAIR and APO enhance LLM alignment through more contrastive preference pairs and controlled alignment objectives, improving model performance close to GPT4-turbo. To reproduce the paper's setting, use this configuration:
```python
from trl import DPOConfig
training_args = DPOConfig(
loss_type="apo_zero", # Section 4 of the paper
per_device_train_batch_size=64, # batch size in Section B.1 of the paper
learning_rate=2e-7, # learning rate in Section 5.2 of the paper
beta=0.1, # $\beta$ in Section 5.2 of the paper,
max_prompt_length=512, # prompt length in Section 5.2 of the paper
max_completion_length=512, # completion length in Section 5.2 of the paper
)
```
```python
from trl import DPOConfig
training_args = DPOConfig(
loss_type="apo_down", # Section 4 of the paper
per_device_train_batch_size=64, # batch size in Section B.1 of the paper
learning_rate=2e-7, # learning rate in Section 5.2 of the paper
beta=0.1, # $\beta$ in Section 5.2 of the paper,
max_prompt_length=512, # prompt length in Section 5.2 of the paper
max_completion_length=512, # completion length in Section 5.2 of the paper
)
```
These parameters only appear in the [published version](https://aclanthology.org/2025.tacl-1.22.pdf)
## Supervised Fine-Tuning
Papers relating to the [`SFTTrainer`]
### EMA Without the Lag: Bias-Corrected Iterate Averaging Schemes
**📜 Paper**: https://huggingface.co/papers/2508.00180
Bias-Corrected Exponential Moving Average (BEMA) improves the stability and efficiency of language model fine-tuning by reducing stochasticity and eliminating bias. To use BEMA with SFT as described in the paper, you can use the [`BEMACallback`]:
```python
from trl import BEMACallback, SFTTrainer
trainer = SFTTrainer(
...
callbacks=[BEMACallback()],
)
```
### On the Generalization of SFT: A Reinforcement Learning Perspective with Reward Rectification
**📜 Paper**: https://huggingface.co/papers/2508.05629
Dynamic Fine-Tuning (DFT) improves the generalization of Large Language Models (LLMs) by dynamically rescaling gradients, outperforming standard Supervised Fine-Tuning (SFT) and showing competitive results in offline reinforcement learning.
$$
\mathcal{L}_{\text{DFT}}(\theta) = \mathbb{E}_{(x,y) \sim \mathcal{D}} \left[ - \sum_{t=1}^{|y|} \textcolor{red}{\text{sg}\big(\pi_\theta(y_t \mid y_{<t}, x)\big)} \; \log \pi_\theta(y_t \mid y_{<t}, x) \right]
$$
where \\( \text{sg}(\cdot) \\) is the stop-gradient operator. To use DFT with SFT as described in the paper, you can use the `loss_type="dft"` argument:
```python
from trl import SFTConfig
training_args = SFTConfig(
loss_type="dft",
...
)
```
To closely match the papers setup, you can use the following configuration (see Sec. 4.1). Authors also mention that the hyperparameters are not very sensitive (Sec. 4.3):
```python
SFTConfig(
loss_type="dft",
learning_rate=5e-5,
max_length=2048,
# Target batch size 256; achieved via per-device batch 8 * grad accumulation 32
per_device_train_batch_size=8,
gradient_accumulation_steps=32,
)
```
## Reinforce Leave-One-Out
Papers relating to the [`RLOOTrainer`]
### Back to Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs
**📜 Paper**: https://huggingface.co/papers/2402.14740
RLOO is a variant of REINFORCE that reduces variance by using leave-one-out baselines. It computes rewards by comparing each sample against the average of all other samples in the batch, providing more stable gradients than standard REINFORCE. To reproduce the paper's setting, use this configuration:
```python
from trl import RLOOConfig
training_args = RLOOConfig(
per_device_train_batch_size=512, # section C Training Detail of the paper
steps_per_generation=2 # section C Training Detail of the paper
beta=0.03 # section C Training Detail of the paper
num_generations=2, # experiments of paper different num_generations={2,4}
learning_rate=1e-6 # section C Training Detail of the paper
)
```
## Contrastive Preference Optimization
Papers relating to the [`CPOTrainer`]
### AlphaPO -- Reward shape matters for LLM alignment
**📜 Paper**: https://huggingface.co/papers/2501.03884
AlphaPO is a new Direct Alignment Algorithms (DAAs) method that leverages an alpha-parameter to help change the shape of the reward function beyond the standard log reward. AlphaPO helps maintain fine-grained control over likelihood displacement and over-optimization. To reproduce the paper's setting, use this configuration:
```python
from trl import CPOConfig
# Mistral-Instruct from Table 3 of the paper
training_args = CPOConfig(
loss_type="alphapo",
alpha=0.25,
beta=2.5,
simpo_gamma=0.1,
learning_rate=7e-7,
...
)
```
## Reward Modeling
Papers relating to the [`RewardTrainer`]
### Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking
**📜 Paper**: https://huggingface.co/papers/2312.09244
This paper proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs and thereby resolving the issue of underdetermination.
$$
\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-)) \textcolor{red}{- \eta \cdot (r_\theta(x, y^+) + r_\theta(x, y^-))^2} \right].
$$
To use this auxiliary loss with [`RewardTrainer`], you can use the `center_rewards_coefficient` argument in [`RewardConfig`] as follows:
```python
from trl import RewardConfig
training_args = RewardConfig(
center_rewards_coefficient=0.01, # η in the paper
...
)
```
### Llama 2: Open Foundation and Fine-Tuned Chat Models
**📜 Paper**: https://huggingface.co/papers/2307.09288
In this paper, the authors propose to leverage their preference ratings being decomposed as a scale of four points (e.g., _significantly better_) to provide more informative feedback to the reward model. This is done by adding a margin to the loss function, which encourages the reward model to assign larger gaps in scores for pairs with higher preference ratings.
$$
\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-,\textcolor{red}{m}) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-) \textcolor{red}{- m}) \right].
$$
You can add a margin to the loss by adding a `margin` column to the dataset. The following example shows how to set up a the "Margin Small" setting of the paper.
```python
def add_margin(example):
preference_to_margin = {
"significantly better": 1.0,
"better": 2.0/3.0,
"slightly better": 1.0/3.0,
"negligibly better / unsure": 0.0,
}
return {"margin": preference_to_margin[example["preference_label"]]}
dataset = dataset.map(add_margin)
```

View File

@ -1,17 +1,10 @@
# Examples of using peft with trl to finetune 8-bit models with Low Rank Adaption (LoRA)
The notebooks and scripts in this examples show how to use Low Rank Adaptation (LoRA) to fine-tune models in a memory efficient manner. Most of PEFT methods supported in peft library but note that some PEFT methods such as Prompt tuning are not supported.
The notebooks and scripts in these examples show how to use Low Rank Adaptation (LoRA) to fine-tune models in a memory efficient manner. Most of PEFT methods supported in peft library but note that some PEFT methods such as Prompt tuning are not supported.
For more information on LoRA, see the [original paper](https://huggingface.co/papers/2106.09685).
Here's an overview of the `peft`-enabled notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples):
| File | Task | Description | Colab link |
|---|---| --- |
| [`stack_llama/rl_training.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py) | RLHF | Distributed fine-tuning of the 7b parameter LLaMA models with a learned reward model and `peft`. | |
| [`stack_llama/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py) | Reward Modeling | Distributed training of the 7b parameter LLaMA reward model with `peft`. | |
| [`stack_llama/supervised_finetuning.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py) | SFT | Distributed instruction/supervised fine-tuning of the 7b parameter LLaMA model with `peft`. | |
## Installation
Note: peft is in active development, so we install directly from their Github page.
Peft also relies on the latest version of transformers.
@ -27,7 +20,7 @@ Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scr
## How to use it?
Simply declare a `PeftConfig` object in your script and pass it through `.from_pretrained` to load the TRL+PEFT model.
Simply declare a [`~peft.PeftConfig`] object in your script and pass it through `.from_pretrained` to load the TRL+PEFT model.
```python
from peft import LoraConfig
@ -47,7 +40,9 @@ model = AutoModelForCausalLMWithValueHead.from_pretrained(
peft_config=lora_config,
)
```
And if you want to load your model in 8bit precision:
```python
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
config.model_name,
@ -55,7 +50,9 @@ pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
peft_config=lora_config,
)
```
... or in 4bit precision:
```python
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
config.model_name,
@ -64,7 +61,6 @@ pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
)
```
## Launch scripts
The `trl` library is powered by `accelerate`. As such it is best to configure and launch trainings with the following commands:
@ -77,6 +73,7 @@ accelerate launch examples/scripts/ppo.py --use_peft # launch`es training
## Using `trl` + `peft` and Data Parallelism
You can scale up to as many GPUs as you want, as long as you are able to fit the training process in a single device. The only tweak you need to apply is to load the model as follows:
```python
from peft import LoraConfig
...
@ -94,7 +91,9 @@ pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
peft_config=lora_config,
)
```
And if you want to load your model in 8bit precision:
```python
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
config.model_name,
@ -102,7 +101,9 @@ pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
load_in_8bit=True,
)
```
... or in 4bit precision:
```python
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
config.model_name,
@ -110,6 +111,7 @@ pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
load_in_4bit=True,
)
```
Finally, make sure that the rewards are computed on correct device as well, for that you can use `ppo_trainer.model.current_device`.
## Naive pipeline parallelism (NPP) for large models (>60B models)
@ -117,9 +119,7 @@ Finally, make sure that the rewards are computed on correct device as well, for
The `trl` library also supports naive pipeline parallelism (NPP) for large models (>60B models). This is a simple way to parallelize the model across multiple GPUs.
This paradigm, termed as "Naive Pipeline Parallelism" (NPP) is a simple way to parallelize the model across multiple GPUs. We load the model and the adapters across multiple GPUs and the activations and gradients will be naively communicated across the GPUs. This supports `int8` models as well as other `dtype` models.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-npp.png">
</div>
![NPP](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-npp.png)
### How to use NPP?
@ -140,5 +140,5 @@ python PATH_TO_SCRIPT
You can easily fine-tune Llama2 model using `SFTTrainer` and the official script! For example to fine-tune llama2-7b on the Guanaco dataset, run (tested on a single NVIDIA T4-16GB):
```bash
python examples/scripts/sft.py --output_dir sft_openassistant-guanaco --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --per_device_train_batch_size 4 --gradient_accumulation_steps 2
python trl/scripts/sft.py --output_dir sft_openassistant-guanaco --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --per_device_train_batch_size 4 --gradient_accumulation_steps 2
```

View File

@ -1,10 +1,11 @@
# PPO Trainer
[![](https://img.shields.io/badge/All_models-PPO-blue)](https://huggingface.co/models?other=ppo,trl)
[![model badge](https://img.shields.io/badge/All_models-PPO-blue)](https://huggingface.co/models?other=ppo,trl)
TRL supports training LLMs with [Proximal Policy Optimization (PPO)](https://huggingface.co/papers/1707.06347).
References:
- [Fine-Tuning Language Models from Human Preferences](https://github.com/openai/lm-human-preferences)
- [Learning to Summarize from Human Feedback](https://github.com/openai/summarize-from-feedback)
- [The N Implementation Details of RLHF with PPO](https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo)
@ -26,52 +27,50 @@ python examples/scripts/ppo/ppo.py \
--gradient_accumulation_steps 1 \
--total_episodes 10000 \
--model_name_or_path EleutherAI/pythia-1b-deduped \
--sft_model_path EleutherAI/pythia-1b-deduped \
--reward_model_path EleutherAI/pythia-1b-deduped \
--missing_eos_penalty 1.0
```
## Explanation of the logged metrics
The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35)
* `eps`: Tracks the number of episodes per second.
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
* `objective/scores`: The mean scores returned by the reward model / environment.
* `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
* `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
* `loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
* `loss/value_avg`: The average value loss, indicating the difference between the predicted value and the actual reward.
* `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function.
* `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
* `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
* `lr`: lr: The current learning rate used by the optimizer.
* `episode`: episode: The current global step or episode count in the training process.
- `eps`: Tracks the number of episodes per second.
- `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
- `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
- `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
- `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
- `objective/scores`: The mean scores returned by the reward model / environment.
- `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
- `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
- `loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
- `loss/value_avg`: The average value loss, indicating the difference between the predicted value and the actual reward.
- `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function.
- `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
- `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
- `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
- `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
- `lr`: lr: The current learning rate used by the optimizer.
- `episode`: episode: The current episode count in the training process.
## Cookbook
* Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it.
* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
* Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions.
- Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
- Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try understand why this is happening and try to fix it.
- Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
- Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
- Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions.
## What is my model doing exactly?
To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/ppov2_completions.gif?download=true)
![ppov2_completions](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/ppov2_completions.gif)
In the logs the sampled generations look like
```
```txt
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
┃ query ┃ model response ┃ score ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
@ -175,7 +174,7 @@ This PPO implementation is based on the [The N+ Implementation Details of RLHF w
To validate the PPO implementation works, we ran experiment on the 1B model. Here are the command we used to run the experiment. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
```
```shell
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
examples/scripts/ppo/ppo_tldr.py \
--output_dir models/minimal/ppo_tldr \
@ -210,8 +209,7 @@ The PPO checkpoint gets a 64.7% preferred rate vs the 33.0% preference rate of t
Metrics:
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/pr-1540/ppov2.png)
![PPO v2](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/ppov2.png)
```bash
# pip install openrlbenchmark==0.2.1a5
@ -231,6 +229,9 @@ python -m openrlbenchmark.rlops_multi_metrics \
## PPOTrainer
[[autodoc]] PPOTrainer
- train
- save_model
- push_to_hub
## PPOConfig

122
docs/source/prm_trainer.md Normal file
View File

@ -0,0 +1,122 @@
# PRM Trainer
[![model badge](https://img.shields.io/badge/All_models-PRM-blue)](https://huggingface.co/models?other=prm,trl)
> [!WARNING]
> PRM Trainer is an experimental API which is subject to change at any time.
## Overview
Process-supervised Reward Models (PRM) were proposed in [Solving math word problems with process- and outcome-based feedback](https://huggingface.co/papers/2211.14275) by Jonathan Uesato, Nate Kushman, Ramana Kumar, Francis Song, Noah Siegel, Lisa Wang, Antonia Creswell, Geoffrey Irving, and Irina Higgins.
The abstract from the paper is the following:
> Recent work has shown that asking language models to generate reasoning steps improves performance on many reasoning tasks. When moving beyond prompting, this raises the question of how we should supervise such models: outcome-based approaches which supervise the final result, or process-based approaches which supervise the reasoning process itself? Differences between these approaches might naturally be expected not just in final-answer errors but also in reasoning errors, which can be difficult to detect and are problematic in many real-world domains such as education. We run the first comprehensive comparison between process- and outcome-based approaches trained on a natural language task, GSM8K. We find that pure outcome-based supervision produces similar final-answer error rates with less label supervision. However, for correct reasoning steps we find it necessary to use processbased supervision or supervision from learned reward models that emulate process-based feedback. In total, we improve the previous best results from 16.8% → 12.7% final-answer error and 14.0% → 3.4% reasoning error among final-answer-correct solutions.
This post-training method was contributed by [Gaetan Lopez](https://github.com/gaetanlop), [Lewis Tunstall](https://huggingface.co/lewtun), [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Agustín Piqueres](https://huggingface.co/plaguss).
## Quick start
This example demonstrates how to train a model using the PRM method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) as the base model. We use the stepwise supervision data from the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd). You can view the data in the dataset here:
<iframe
src="https://huggingface.co/datasets/trl-lib/math_shepherd/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>
Below is the script to train the model:
```python
# train_prm.py
from datasets import load_dataset
from trl import PRMConfig, PRMTrainer
from transformers import AutoModelForTokenClassification, AutoTokenizer
model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B", num_labels=2)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
train_dataset = load_dataset("trl-lib/math_shepherd", split="train[:10%]")
training_args = PRMConfig(output_dir="Qwen2-0.5B-Reward-Math-Sheperd")
trainer = PRMTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
```
Execute the script using the following command:
```bash
accelerate launch train_prm.py
```
Distributed across 8 GPUs, the training takes approximately 1 hour.
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward-Math-Sheperd) performs, you can use the following script.
```python
from datasets import load_dataset
from transformers import pipeline
pipe = pipeline("token-classification", model="trl-lib/Qwen2-0.5B-Reward-Math-Sheperd")
dataset = load_dataset("trl-lib/math_shepherd")
example = {
"prompt": "Musa is the class teacher of a class of 45 students. He wants to split them into three groups by age. If a third of the class is under 11 years, and two-fifths are above 11 but under 13, how many students will be in the third group (13 years and above)?",
"completions": [
"Step 1: A third of the class is under 11 years because 11 - 1/3 = <<11-1/3=7>>7.",
"Step 2: Two-fifths of the class are above 11 but under 13 because 2/5 * 11 = <<2/5*11=8>>8.",
"Step 3: There are 45 students, so the third group will have 45 - 7 - 8 = <<45-7-8=20>>20 students. The answer is: 20",
],
"labels": [True, False, False],
}
separator = "\n" # It's important to use the same separator as the one used during training
for idx in range(1, len(example["completions"]) + 1):
steps = example["completions"][0:idx]
text = separator.join((example["prompt"], *steps)) + separator # Add a separator between the prompt and each steps
pred_entity = pipe(text)[-1]["entity"]
pred = {"LABEL_0": False, "LABEL_1": True}[pred_entity]
label = example["labels"][idx - 1]
print(f"Step {idx}\tPredicted: {pred} \tLabel: {label}")
```
```text
Step 1 Predicted: True Label: True
Step 2 Predicted: False Label: False
Step 3 Predicted: False Label: False
```
It's a win!
## Expected dataset type
PRM requires a [stepwise supervision](dataset_formats#stepwise-supervision).
The dataset should contain the following columns: `prompt`, `completions` and `labels`, where `completions` contains a list of reasoning steps and `labels` a list of booleans or floats indicating the correctness of each step.
The [`PRMTrainer`] only supports [standard](dataset_formats#standard) dataset format.
## Example script
We provide an example script to train a model using the PRM method. The script is available in [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py)
To use the PRM script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) on the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd), run the following command:
```bash
accelerate launch examples/scripts/prm.py \
--model_name_or_path Qwen/Qwen2-0.5B \
--dataset_name trl-lib/math_shepherd \
--num_train_epochs 1 \
--output_dir Qwen2-0.5B-Reward-Math-Sheperd
```
## PRMTrainer
[[autodoc]] PRMTrainer
- train
- save_model
- push_to_hub
## PRMConfig
[[autodoc]] PRMConfig

143
docs/source/quickstart.md Normal file
View File

@ -0,0 +1,143 @@
# Quickstart
TRL is a comprehensive library for post-training foundation models using techniques like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO).
## Quick Examples
Get started instantly with TRL's most popular trainers. Each example uses compact models for quick experimentation.
### Supervised Fine-Tuning
```python
from trl import SFTTrainer
from datasets import load_dataset
trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B",
train_dataset=load_dataset("trl-lib/Capybara", split="train"),
)
trainer.train()
```
### Group Relative Policy Optimization
```python
from trl import GRPOTrainer
from datasets import load_dataset
# Define a simple reward function (count unique chars as example)
def reward_function(completions, **kwargs):
return [len(set(completion.lower())) for completion in completions]
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct", # Start from SFT model
train_dataset=load_dataset("trl-lib/tldr", split="train"),
reward_funcs=reward_function,
)
trainer.train()
```
### Direct Preference Optimization
```python
from trl import DPOTrainer
from datasets import load_dataset
trainer = DPOTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct", # Use your SFT model
ref_model="Qwen/Qwen2.5-0.5B-Instruct", # Original base model
train_dataset=load_dataset("trl-lib/ultrafeedback_binarized", split="train"),
)
trainer.train()
```
### Reward Modeling
```python
from trl import RewardTrainer
from datasets import load_dataset
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
trainer = RewardTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
train_dataset=dataset,
)
trainer.train()
```
## Command Line Interface
Skip the code entirely - train directly from your terminal:
```bash
# SFT: Fine-tune on instructions
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/Capybara
# DPO: Align with preferences
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--dataset_name trl-lib/ultrafeedback_binarized
# Reward: Train a reward model
trl reward --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--dataset_name trl-lib/ultrafeedback_binarized
```
## What's Next?
### 📚 Learn More
- [SFT Trainer](sft_trainer) - Complete SFT guide
- [DPO Trainer](dpo_trainer) - Preference alignment
- [GRPO Trainer](grpo_trainer) - Group relative policy optimization
### 🚀 Scale Up
- [Distributed Training](distributing_training) - Multi-GPU setups
- [Memory Optimization](reducing_memory_usage) - Efficient training
- [PEFT Integration](peft_integration) - LoRA and QLoRA
### 💡 Examples
- [Example Scripts](https://github.com/huggingface/trl/tree/main/examples) - Production-ready code
- [Community Tutorials](community_tutorials) - External guides
## Troubleshooting
### Out of Memory?
Reduce batch size and enable optimizations:
<hfoptions id="batch_size">
<hfoption id="SFT">
```python
training_args = SFTConfig(
per_device_train_batch_size=1, # Start small
gradient_accumulation_steps=8, # Maintain effective batch size
)
```
</hfoption>
<hfoption id="DPO">
```python
training_args = DPOConfig(
per_device_train_batch_size=1, # Start small
gradient_accumulation_steps=8, # Maintain effective batch size
)
```
</hfoption>
</hfoptions>
### Loss not decreasing?
Try adjusting the learning rate:
```python
training_args = SFTConfig(learning_rate=2e-5) # Good starting point
```
For more help, open an [issue on GitHub](https://github.com/huggingface/trl/issues).

View File

@ -1,88 +0,0 @@
# Quickstart
## How does it work?
Fine-tuning a language model via PPO consists of roughly three steps:
1. **Rollout**: The language model generates a response or continuation based on a query which could be the start of a sentence.
2. **Evaluation**: The query and response are evaluated with a function, model, human feedback, or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair. The optimization will aim at maximizing this value.
3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.
The full process is illustrated in the following figure:
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png"/>
## Minimal example
The following code illustrates the steps above.
```python
# 0. imports
import torch
from transformers import GPT2Tokenizer
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
# 2. initialize trainer
ppo_config = {"mini_batch_size": 1, "batch_size": 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)
# 3. encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt").to(model.pretrained_model.device)
# 4. generate model response
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"max_new_tokens": 20,
}
response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs)
response_txt = tokenizer.decode(response_tensor[0])
# 5. define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0, device=model.pretrained_model.device)]
# 6. train model with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
```
In general, you would run steps 3-6 in a for-loop and run it on many diverse queries. You can find more realistic examples in the examples section.
## How to use a trained model
After training a `AutoModelForCausalLMWithValueHead`, you can directly use the model in `transformers`.
```python
# .. Let's assume we have a trained model using `PPOTrainer` and `AutoModelForCausalLMWithValueHead`
# push the model on the Hub
model.push_to_hub("my-fine-tuned-model-ppo")
# or save it locally
model.save_pretrained("my-fine-tuned-model-ppo")
# load the model from the Hub
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("my-fine-tuned-model-ppo")
```
You can also load your model with `AutoModelForCausalLMWithValueHead` if you want to use the value head, for example to continue training.
```python
from trl.model import AutoModelForCausalLMWithValueHead
model = AutoModelForCausalLMWithValueHead.from_pretrained("my-fine-tuned-model-ppo")
```

View File

@ -0,0 +1,261 @@
# Reducing Memory Usage
> [!WARNING]
> Section under construction. Feel free to contribute!
## Truncation
Sequence lengths in the dataset can vary widely. When data is batched, sequences are padded to match the longest one in the batch, which can cause high memory usage, even if most sequences are relatively short.
![Truncation prompt-completion](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/why_you_should_truncate.png)
To reduce memory usage, it's important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case.
<hfoptions id="truncation">
<hfoption id="DPO">
DPO truncation is applied first to the prompt and to the completion via the `max_prompt_length` and `max_completion_length` parameters. The `max_length` parameter is then used to truncate the resulting sequence.
![DPO truncation](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/truncation_prompt_completion.png)
To set the truncation parameters, use the following code snippet:
```python
from trl import DPOConfig
training_args = DPOConfig(..., max_prompt_length=..., max_length=...)
```
You can also use the `max_completion_length` parameter to truncate the completion, though this is less common since the goal is typically to preserve the completion's full length whenever possible.
```python
from trl import DPOConfig
training_args = DPOConfig(..., max_completion_length=...)
```
</hfoption>
<hfoption id="SFT">
SFT truncation is applied to the input sequence via the `max_length` parameter.
![Truncation input ids](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/truncation_input_ids.png)
To set the truncation parameter, use the following code snippet:
```python
from trl import SFTConfig
training_args = SFTConfig(..., max_length=...)
```
</hfoption>
</hfoptions>
### How to choose the `max_length` value?
If `max_length` is too small, a significant portion of your tokens will be discarded and won't contribute to training. If it's too large, memory usage can spike, potentially leading to OOM (Out-Of-Memory) errors. Without packing or padding-free, a large `max_length` may also result in inefficient training, as many tokens will be padding.
To help you choose an appropriate value, we provide a utility to visualize the sequence length distribution in your dataset.
<iframe src="https://trl-lib-dataset-length-profiler.hf.space" frameborder="0" width="100%" height="1000"></iframe>
## Packing
> [!TIP]
> This technique applies only to SFT.
[Truncation](#truncation) has several drawbacks:
1. **Loss of information**: Key data at the end of a sequence may be discarded.
2. **Choosing truncation length**: Too short loses data; too long undermines efficiency.
Packing, introduced in [Raffel et al., 2020](https://huggingface.co/papers/1910.10683), addresses these issues by grouping sequences instead of truncating. It concatenates and splits dataset sequences into the desired lengths.
![Packing](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/packing_2.png)
Packing reduces padding by merging several sequences in one row when possible. We use an advanced method to be near-optimal in the way we pack the dataset. To enable packing, use `packing=True` in the [`SFTConfig`].
> [!TIP]
> In TRL 0.18 and earlier, packing used a more aggressive method that reduced padding to almost nothing, but had the downside of breaking sequence continuity for a large fraction of the dataset. To revert to this strategy, use `packing_strategy="wrapped"` in [`SFTConfig`].
```python
from trl import SFTConfig
training_args = SFTConfig(..., packing=True, max_length=512)
```
> [!WARNING]
> Packing may cause batch contamination, where adjacent sequences influence one another. This can be problematic for some applications. For more details, see [#1230](https://github.com/huggingface/trl/issues/1230).
## Liger for reducing peak memory usage
> [Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%.
For more information, see [Liger Kernel Integration](liger_kernel_integration)
<hfoptions id="liger">
<hfoption id="DPO">
To use Liger for reducing peak memory usage, use the following code snippet:
```python
from trl import DPOConfig
training_args = DPOConfig(..., use_liger_loss=True)
```
</hfoption>
<hfoption id="GRPO">
To use Liger for reducing peak memory usage, use the following code snippet:
```python
from trl import GRPOConfig
training_args = GRPOConfig(..., use_liger_loss=True)
```
</hfoption>
<hfoption id="KTO">
To use Liger for reducing peak memory usage, use the following code snippet:
```python
from trl import KTOConfig
training_args = KTOConfig(..., use_liger_loss=True)
```
</hfoption>
</hfoptions>
## Padding-free
Padding-free batching is an alternative approach for reducing memory usage. In this method, a batch is first sampled and then flattened into a single sequence, avoiding padding. Unlike packing, which can result in incomplete sequences by combining parts of different samples, padding-free batching ensures that all sequences remain complete and intact.
![Padding-free](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/padding-free.png)
> [!WARNING]
> It's highly recommended to use padding-free batching with **FlashAttention 2** or **FlashAttention 3**. Otherwise, you may encounter batch contamination issues.
<hfoptions id="padding-free">
<hfoption id="DPO">
```python
from trl import DPOConfig
training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"})
```
</hfoption>
<hfoption id="SFT">
```python
from trl import SFTConfig
training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"})
```
</hfoption>
</hfoptions>
## Activation offloading
Activation offloading is a memory efficiency technique that reduces GPU VRAM usage by temporarily moving activation tensors to CPU RAM during the forward pass and bringing them back only when needed for the backward pass. This significantly reduces peak memory usage at the cost of slightly increased training time.
To enable activation offloading in your SFT training configuration:
```python
from trl import SFTConfig
training_args = SFTConfig(..., activation_offloading=True)
```
> [!WARNING]
> When using activation offloading with models that use Liger kernels, you must disable Liger cross entropy due to compatibility issues. The issue occurs specifically with `use_liger_kernel=True` because Liger cross entropy performs in-place operations which conflict with activation offloading. The default setting (`use_liger_kernel=False`) works:
>
> ```python
> # When using activation offloading with a model that uses Liger kernels:
> from trl import SFTConfig
>
> training_args = SFTConfig(
> activation_offloading=True,
> use_liger_kernel=False, # Disable Liger cross entropy
> # Other parameters...
> )
> ```
Under the hood, activation offloading implements PyTorch's [`saved_tensors_hooks`](https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html#hooks-for-autograd-saved-tensors) to intercept activations during the forward pass. It intelligently manages which tensors to offload based on size and context, avoiding offloading output tensors which would be inefficient. For performance optimization, it can optionally use CUDA streams to overlap computation with CPU-GPU transfers.
## Disabling model gathering for generation in online methods
When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Online methods involve generating completions from the model as part of the training process. During this step, the model weights are temporarily gathered on a single GPU for generation. For very large models, this gathering can lead to out-of-memory (OOM) errors, as described in this issue: [#2250](https://github.com/huggingface/trl/issues/2250#issue-2598304204).
If you encounter this issue, you can disable the gathering of model weights for generation by setting the following parameter:
<hfoptions id="ds3_gather_for_generation">
<hfoption id="GRPO">
```python
from trl import GRPOConfig
training_args = GRPOConfig(..., ds3_gather_for_generation=False)
```
</hfoption>
<hfoption id="Online DPO">
```python
from trl import OnlineDPOConfig
training_args = OnlineDPOConfig(..., ds3_gather_for_generation=False)
```
</hfoption>
<hfoption id="PPO">
```python
from trl import PPOConfig
training_args = PPOConfig(..., ds3_gather_for_generation=False)
```
</hfoption>
<hfoption id="RLOO">
```python
from trl import RLOOConfig
training_args = RLOOConfig(..., ds3_gather_for_generation=False)
```
</hfoption>
</hfoptions>
This adjustment prevents model weights from being gathered, avoiding OOM errors, but it may result in slower generation speeds.
## vLLM sleep mode
When using vLLM as the generation backend, you can enable _sleep mode_ to offload vLLM parameters and cache to CPU RAM during the optimization step and reload them back to GPU VRAM when needed for weight synchronization and generation.
<hfoptions id="vllm_sleep">
<hfoption id="GRPO">
```python
from trl import GRPOConfig
training_args = GRPOConfig(..., vllm_enable_sleep_mode=True)
```
</hfoption>
<hfoption id="RLOO">
```python
from trl import RLOOConfig
training_args = RLOOConfig(..., vllm_enable_sleep_mode=True)
```
</hfoption>
</hfoptions>

View File

@ -0,0 +1,238 @@
# Reward Modeling
[![model badge](https://img.shields.io/badge/All_models-Reward_Trainer-blue)](https://huggingface.co/models?other=reward-trainer,trl)
## Overview
TRL supports the Outcome-supervised Reward Modeling (ORM) Trainer for training reward models.
This post-training method was contributed by [Younes Belkada](https://huggingface.co/ybelkada).
## Quick start
This example demonstrates how to train a reward model using the [`RewardTrainer`] from TRL. We train a [Qwen 3 0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) model on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), large-scale, fine-grained, diverse preference dataset.
```python
from trl import RewardTrainer
from datasets import load_dataset
trainer = RewardTrainer(
model="Qwen/Qwen3-0.6B",
train_dataset=load_dataset("trl-lib/ultrafeedback_binarized", split="train"),
)
trainer.train()
```
<iframe src="https://trl-lib-trackio.hf.space/?project=trl-documentation&metrics=train*&sidebar=hidden&runs=reward_qwen3-0.6B_ultrafeedback2" style="width: 100%; min-width: 300px; max-width: 800px;" height="830" frameBorder="0"></iframe>
## Expected dataset type and format
[`RewardTrainer`] supports [preference](dataset_formats#preference) datasets type (both implicit and explicit prompt). The [`RewardTrainer`] is compatible with both [standard](dataset_formats#standard) and [conversational](dataset_formats#conversational) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
```python
# Standard preference (implicit prompt)
{"chosen": "The sky is blue.",
"rejected": "The sky is green."}
# Conversational preference (implicit prompt)
{"chosen": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is green."}]}
# Standard preference (explicit prompt)
{"prompt": "The sky is",
"chosen": " blue.",
"rejected": " green."}
# Conversational preference (explicit prompt)
{"prompt": [{"role": "user", "content": "What color is the sky?"}],
"chosen": [{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "assistant", "content": "It is green."}]}
```
If your dataset is not in one of these formats, you can preprocess it to convert it into the expected format. Here is an example with the [lmarena-ai/arena-human-preference-55k](https://huggingface.co/datasets/lmarena-ai/arena-human-preference-55k) dataset:
```python
from datasets import load_dataset
import json
dataset = load_dataset("lmarena-ai/arena-human-preference-55k")
# Filter out ties
dataset = dataset.filter(lambda example: example["winner_tie"] == 0)
# Create 'chosen' and 'rejected' fields based on the winner column
def response_a_b_to_chosen_rejected(example):
if example["winner_model_a"] == 1:
example["chosen"] = example["response_a"]
example["rejected"] = example["response_b"]
else:
example["chosen"] = example["response_b"]
example["rejected"] = example["response_a"]
return example
dataset = dataset.map(response_a_b_to_chosen_rejected)
# Convert to conversational format
def make_conversation(example):
prompt = json.loads(example["prompt"])[0] # '["What color is the sky?"]' -> "What color is the sky?"
chosen = json.loads(example["chosen"])[0]
rejected = json.loads(example["rejected"])[0]
return {
"chosen": [{"role": "user", "content": prompt}, {"role": "assistant", "content": chosen}],
"rejected": [{"role": "user", "content": prompt}, {"role": "assistant", "content": rejected}],
}
dataset = dataset.map(make_conversation)
# Keep only necessary columns
dataset = dataset.select_columns(["chosen", "rejected"])
print(next(iter(dataset["train"])))
```
```json
{
"chosen": [
{"role": "user", "content": "Is it morally right to try to have a certain percentage of females on managerial positions?"},
{"role": "assistant", "content": "The question of whether it is morally right to aim for a certain percentage of females..."},
],
"rejected": [
{"role": "user", "content": "Is it morally right to try to have a certain percentage of females on managerial positions?"},
{"role": "assistant", "content": "As an AI, I don't have personal beliefs or opinions. However, ..."},
],
}
```
## Looking deeper into the training method
Reward Models (RMs) are typically trained using supervised learning on datasets containing pairs of preferred and non-preferred responses. The goal is to learn a function that assigns higher scores to preferred responses, enabling the model to rank outputs based on preferences.
This section breaks down how reward modeling works in practice, covering the key steps: **preprocessing** and **loss computation**.
### Preprocessing and tokenization
During training, each example is expected to contain a **chosen** and **rejected** field. For more details on the expected formats, see [Dataset formats - Preference](dataset_formats#preference).
The [`RewardTrainer`] tokenizes each input using the model's tokenizer. If prompts and completions (chosen and rejected) are provided separately (explicit prompt case), they are concatenated before tokenization.
### Computing the loss
Let \\( x \\) be the input sequence (prompt) and \\( y^+ \\) and \\( y^- \\) be the chosen and rejected sequences respectively. Under the Bradley-Terry model ([Bradley & Terry, 1952](https://www.jstor.org/stable/2334029)), the probability that \\( y^+ \\) is preferred over \\( y^- \\) given a reward function \\( r \\) is \\( p(y^+ ≻ y^- |x) = \sigma(r(x, y^+)r(x, y^-)) \\), where \\( σ \\) is the sigmoid function.
The reward model \\( r_\theta(x, y) \\) is trained to assign higher scores to preferred responses \\( y^+ \\) over non-preferred ones \\( y^- \\). The loss is then defined as the negative log-likelihood of the observed preferences:
$$
\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-)) \right].
$$
> [!TIP]
> The Bradley-Terry model is underdetermined, meaning that adding a constant to all rewards does not change the preference probabilities. To address this, [Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking](https://huggingface.co/papers/2312.09244) proposes adding an auxiliary loss term that encourages the rewards to be centered around zero. This is controlled by the `center_rewards_coefficient` parameter in the [`RewardConfig`]. The recommended value is `1e-2`.
## Logged metrics
While training and evaluating we record the following reward metrics:
* `global_step`: The total number of optimizer steps taken so far.
* `epoch`: The current epoch number, based on dataset iteration.
* `num_tokens`: The total number of tokens processed so far.
* `loss`: The average loss over the last logging interval.
* `accuracy`: The proportion of correct predictions (i.e., the model assigned a higher score to the chosen response than to the rejected one) averaged over the last logging interval.
* `min_reward`: The minimum reward score assigned by the model. This value is averaged over the logging interval.
* `mean_reward`: The average reward score assigned by the model over the last logging interval.
* `max_reward`: The maximum reward score assigned by the model. This value is averaged over the logging interval.
* `margin`: The average margin (difference between chosen and rejected rewards) over the last logging interval.
* `learning_rate`: The current learning rate, which may change dynamically if a scheduler is used.
* `grad_norm`: The L2 norm of the gradients, computed before gradient clipping.
## Customization
### Model initialization
You can directly pass the kwargs of the [`~transformers.AutoModelForSequenceClassification.from_pretrained()`] method to the [`RewardConfig`]. For example, if you want to load a model in a different precision, analogous to
```python
model = AutoModelForSequenceClassification.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.bfloat16)
```
you can do so by passing the `model_init_kwargs={"dtype": torch.bfloat16}` argument to the [`RewardConfig`].
```python
from trl import RewardConfig
training_args = RewardConfig(
model_init_kwargs={"dtype": torch.bfloat16},
)
```
Note that all keyword arguments of [`~transformers.AutoModelForSequenceClassification.from_pretrained()`] are supported, except for `num_labels`, which is automatically set to 1.
### Train adapters with PEFT
We support tight integration with 🤗 PEFT library, allowing any user to conveniently train adapters and share them on the Hub, rather than training the entire model.
```python
from datasets import load_dataset
from trl import RewardTrainer
from peft import LoraConfig
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
trainer = RewardTrainer(
"Qwen/Qwen3-4B",
train_dataset=dataset,
peft_config=LoraConfig(modules_to_save=["score"]) # important to include the score head when base model is not a sequence classification model
)
trainer.train()
```
You can also continue training your [`~peft.PeftModel`]. For that, first load a `PeftModel` outside [`RewardTrainer`] and pass it directly to the trainer without the `peft_config` argument being passed.
```python
from datasets import load_dataset
from trl import RewardTrainer
from peft import AutoPeftModelForCausalLM
model = AutoPeftModelForCausalLM.from_pretrained("trl-lib/Qwen3-4B-Reward-LoRA", is_trainable=True)
dataset = load_dataset("trl-lib/Capybara", split="train")
trainer = RewardTrainer(
model=model,
train_dataset=dataset,
)
trainer.train()
```
> [!TIP]
> When training adapters, you typically use a higher learning rate (≈1e3) since only new parameters are being learned.
>
> ```python
> RewardConfig(learning_rate=1e-3, ...)
> ```
## Tool Calling with Reward Modeling
The [`RewardTrainer`] fully supports fine-tuning models with _tool calling_ capabilities. In this case, each dataset example should include:
* The conversation messages, including any tool calls (`tool_calls`) and tool responses (`tool` role messages)
* The list of available tools in the `tools` column, typically provided as JSON schemas
For details on the expected dataset structure, see the [Dataset Format — Tool Calling](dataset_formats#tool-calling) section.
## RewardTrainer
[[autodoc]] RewardTrainer
- train
- save_model
- push_to_hub
## RewardConfig
[[autodoc]] RewardConfig
## DataCollatoForPreference
[[autodoc]] trainer.reward_trainer.DataCollatorForPreference

View File

@ -1,90 +0,0 @@
# Reward Modeling
[![](https://img.shields.io/badge/All_models-Reward_Trainer-blue)](https://huggingface.co/models?other=reward-trainer,trl)
TRL supports custom reward modeling for anyone to perform reward modeling on their dataset and model.
Check out a complete flexible example at [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py).
## Expected dataset type
The [`RewardTrainer`] requires a [*implicit prompt* preference dataset](dataset_formats#preference). It means that the dataset should only contain the columns `"chosen"` and `"rejected"` (and not `"prompt"`).
The [`RewardTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
You can also use a pretokenized dataset, in which case the dataset should contain the following columns: `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`.
## Using the `RewardTrainer`
After preparing your dataset, you can use the [`RewardTrainer`] in the same way as the `Trainer` class from 🤗 Transformers.
You should pass an `AutoModelForSequenceClassification` model to the [`RewardTrainer`], along with a [`RewardConfig`] which configures the hyperparameters of the training.
### Leveraging 🤗 PEFT to train a reward model
Just pass a `peft_config` in the keyword arguments of [`RewardTrainer`], and the trainer should automatically take care of converting the model into a PEFT model!
```python
from peft import LoraConfig, TaskType
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer, RewardConfig
model = AutoModelForSequenceClassification.from_pretrained("gpt2")
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
)
...
trainer = RewardTrainer(
model=model,
args=training_args,
processing_class=tokenizer,
train_dataset=dataset,
peft_config=peft_config,
)
trainer.train()
```
### Adding a margin to the loss
As in the [Llama 2 paper](https://huggingface.co/papers/2307.09288), you can add a margin to the loss by adding a `margin` column to the dataset. The reward collator will automatically pass it through and the loss will be computed accordingly.
```python
def add_margin(row):
# Assume you have a score_chosen and score_rejected columns that you want to use to compute the margin
return {'margin': row['score_chosen'] - row['score_rejected']}
dataset = dataset.map(add_margin)
```
### Centering rewards
In many scenarios, it's preferable to ensure that a reward model's output is mean zero. This is often done by first calculating the model's average score and then subtracting it.
[[Eisenstein et al., 2023]](https://huggingface.co/papers/2312.09244) proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs:
$$\Big( R(p, r_1) + R(p, r_2) \Big)^2 $$
This auxiliary loss is combined with the main loss function, weighted by the parameter `center_rewards_coefficient` in the `[RewardConfig]`. By default, this feature is deactivated (`center_rewards_coefficient = None`).
```python
training_args = RewardConfig(
center_rewards_coefficient=0.01,
...
)
```
For reference results, please refer PR [#1932](https://github.com/huggingface/trl/pull/1932).
## RewardTrainer
[[autodoc]] RewardTrainer
## RewardConfig
[[autodoc]] RewardConfig

15
docs/source/rewards.md Normal file
View File

@ -0,0 +1,15 @@
# Reward Functions
This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`] and [`RLOOTrainer`].
## accuracy_reward
[[autodoc]] rewards.accuracy_reward
## think_format_reward
[[autodoc]] rewards.think_format_reward
## get_soft_overlong_punishment
[[autodoc]] rewards.get_soft_overlong_punishment

View File

@ -1,279 +1,617 @@
# RLOO Trainer
[![](https://img.shields.io/badge/All_models-RLOO-blue)](https://huggingface.co/models?other=rloo,trl)
[![model badge](https://img.shields.io/badge/All_models-RLOO-blue)](https://huggingface.co/models?other=rloo,trl)
TRL supports training LLMs with REINFORCE Leave-One-Out (RLOO). The idea is that instead of using a value function, RLOO generates K completions for each prompt. For each completion, RLOO uses the mean scores from the other K-1 completions as a baseline to calculate the advantage. RLOO also models the entire completion as a single action, where as PPO models each token as an action. Note that REINFORCE / A2C is a special case of PPO, when the number of PPO epochs is 1 and the number of mini-batches is 1, which is how we implement RLOO in TRL.
## Overview
References:
- [Back to Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs](https://huggingface.co/papers/2402.14740)
- [A2C is a special case of PPO](https://huggingface.co/papers/2205.09123)
- [Fine-Tuning Language Models from Human Preferences](https://github.com/openai/lm-human-preferences)
- [Learning to Summarize from Human Feedback](https://github.com/openai/summarize-from-feedback)
- [The N Implementation Details of RLHF with PPO](https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo)
- [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031)
TRL supports the RLOO Trainer for training language models, as described in the paper [Back to Basics: Revisiting REINFORCE Style
Optimization for Learning from Human Feedback in LLMs](https://huggingface.co/papers/2402.14740) by [Arash Ahmadian](https://huggingface.co/ArashAhmadian), Chris Cremer, [Matthias Gallé](https://huggingface.co/mgalle), [Marzieh Fadaee](https://huggingface.co/MarziehFadaee), [Julia Kreutzer](https://huggingface.co/JuliaKreutzerCohere), [Ahmet Üstün](https://huggingface.co/ahmetu) and [Sara Hooker](https://huggingface.co/sarahooker).
## Get started
The abstract from the paper is the following:
To just run a RLOO script to make sure the trainer can run, you can run the following command to train a RLOO model with a dummy reward model.
> AI alignment in the shape of Reinforcement Learning from Human Feedback (RLHF) is increasingly treated as a crucial ingredient for high performance large language models. Proximal Policy Optimization (PPO) has been positioned by recent literature as the canonical method for the RL part of RLHF However, it involves both high computational cost and sensitive hyperparameter tuning. We posit that most of the motivational principles that led to the development of PPO are less of a practical concern in RLHF and advocate for a less computationally expensive method that preserves and even increases performance. We revisit the formulation of alignment from human preferences in the context of RL. Keeping simplicity as a guiding principle, we show that many components of PPO are unnecessary in an RLHF context and that far simpler REINFORCE-style optimization variants outperform both PPO and newly proposed “RL-free” methods such as DPO and RAFT. Our work suggests that careful adaptation to LLMs alignment characteristics enables benefiting from online RL optimization at low cost.
```bash
python examples/scripts/rloo/rloo.py \
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
--dataset_train_split descriptiveness \
--learning_rate 3e-6 \
--output_dir models/minimal/rloo \
--per_device_train_batch_size 64 \
--gradient_accumulation_steps 1 \
--total_episodes 10000 \
--model_name_or_path EleutherAI/pythia-14m \
--reward_model_path EleutherAI/pythia-14m \
--missing_eos_penalty 1.0
```
This post-training method was contributed by [Costa Huang](https://github.com/vwxyzjn) and later refactored by [Shirin Yamani](https://huggingface.co/ShirinYamani).
## Quick start
## Explanation of the logged metrics
This example demonstrates how to train a model using the RLOO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [UltraFeedback prompts dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt). You can view the data in the dataset here:
The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/u2sqci34)
<iframe
src="https://huggingface.co/datasets/trl-lib/ultrafeedback-prompt/embed/viewer/default/train?row=0"
frameborder="0"
width="100%"
height="560px"
></iframe>
<!-- * `rlhf_reward_var_per_prompt`: calculated by `rlhf_reward.var(0).mean()`. This is the variance of the rewards estimated across the `args.rloo_k` samples. Usually we expect it to go down (cause policy entropy goes down). -->
* `eps`: Tracks the number of episodes per second.
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
* `objective/scores`: The mean scores returned by the reward model / environment.
* `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
* `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
* `loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
* `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function.
* `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
* `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
* `lr`: lr: The current learning rate used by the optimizer.
* `episode`: episode: The current global step or episode count in the training process.
## Cookbook
* Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it.
* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
* Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions.
## What is my model doing exactly?
To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/u2sqci34), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/ppov2_completions.gif)
In the logs the sampled generations look like
```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
┃ query ┃ model response ┃ score ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
│ SUBREDDIT: r/AskReddit │ I'm in love with a friend, and │ 3.921875 │
│ │ I don't know how to get rid of │ │
│ TITLE: How do you get someone │ those feelings. I'm │ │
│ out of your head? │ desperate.<|endoftext|>[PAD][P… │ │
│ │ │ │
│ POST: Hi, │ │ │
│ I'm 22, and I have been with my │ │ │
│ girlfriend for 5 years now. We │ │ │
│ recently moved together. We've │ │ │
│ always loved each other │ │ │
│ intensely. │ │ │
│ │ │ │
│ Problem, I recently started to │ │ │
│ have feelings for an other │ │ │
│ person (a friend). This person │ │ │
│ has had a boyfriend for now 3 │ │ │
│ years, and has absolutely no │ │ │
│ ideas. Those feelings were so │ │ │
│ strong, it was hard to hide │ │ │
│ them. After 2 months of me │ │ │
│ being distant and really sad, │ │ │
│ my girlfriend forced me to say │ │ │
│ what was bothering me. I'm not │ │ │
│ a good liar, and now she knows. │ │ │
│ │ │ │
│ We decided to give us a week │ │ │
│ alone, I went to my parents. │ │ │
│ │ │ │
│ Now, I'm completely lost. I │ │ │
│ keep on thinking about this │ │ │
│ person, and I hate that. I │ │ │
│ would like for those feelings │ │ │
│ to go away, to leave me alone. │ │ │
│ But I can't. │ │ │
│ │ │ │
│ What do I do? It's been 3 │ │ │
│ months now, and I'm just │ │ │
│ desperate. │ │ │
│ │ │ │
│ TL;DR: │ │ │
├─────────────────────────────────┼─────────────────────────────────┼──────────┤
│ SUBREDDIT: r/pettyrevenge │ My mom woke me up with a loud │ 6.84375 │
│ │ TV. I blasted Gangnam Style on │ │
│ TITLE: So, my mom woke me up │ repeat, with the bass cranked │ │
│ with a loud TV. │ up as high as it could │ │
│ │ go.<|endoftext|>[PAD][PAD][PAD… │ │
│ POST: She was in her living │ │ │
│ room, watching TV. This was at │ │ │
│ about 8:30 in the morning, and │ │ │
│ she was exercising. She turned │ │ │
│ the TV up extra loud to hear it │ │ │
│ over her excercycle, and woke │ │ │
│ me up. I went in there asking │ │ │
│ for her to turn it down. She │ │ │
│ said she didn't have to; I │ │ │
│ explained that I always used │ │ │
│ headphones so she didn't have │ │ │
│ to deal with my noise and that │ │ │
│ she should give me a little │ │ │
│ more respect, given that I paid │ │ │
│ rent at the time. │ │ │
│ │ │ │
│ She disagreed. I went back to │ │ │
│ my room, rather pissed off at │ │ │
│ the lack of equality. I had no │ │ │
│ lock on my door; but I had a │ │ │
│ dresser right next to it, so I │ │ │
│ pulled one of the drawers out │ │ │
│ enough so that it caused the │ │ │
│ door to not be openable. Then, │ │ │
│ I turned my speakers up really │ │ │
│ loud and blasted Gangnam Style │ │ │
│ on repeat, with the bass │ │ │
│ cranked up as high as it could │ │ │
│ go. │ │ │
│ │ │ │
│ If you hate Gangnam Style for │ │ │
│ being overplayed, you will see │ │ │
│ why I chose that particular │ │ │
│ song. I personally don't mind │ │ │
│ it. But here's the thing about │ │ │
│ my bass; it vibrates the walls, │ │ │
│ making one hell of a lot of │ │ │
│ noise. Needless to say, my mom │ │ │
│ was not pleased and shut off │ │ │
│ the internet. But it was oh so │ │ │
│ worth it. │ │ │
│ │ │ │
│ TL;DR: │ │ │
└─────────────────────────────────┴─────────────────────────────────┴──────────┘
```
## Implementation details
The bulk of RLOOTrainer is based on the PPO implementation, which is based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
Below is a vectorized advantage calculation for RLOO:
Below is the script to train the model.
```python
def test_rloo_reward():
local_batch_size = 3
rloo_k = 4
rlhf_reward = torch.tensor([
1, 2, 3, # first rlhf reward for three prompts
2, 3, 4, # second rlhf reward for three prompts
5, 6, 7, # third rlhf reward for three prompts
8, 9, 10, # fourth rlhf reward for three prompts
]).float() # here we have 3 prompts which have 4 completions each
# train_rloo.py
from datasets import load_dataset
from trl import RLOOConfig, RLOOTrainer
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
advantages = torch.zeros_like(rlhf_reward)
for i in range(0, len(advantages), local_batch_size):
other_response_rlhf_rewards = []
for j in range(0, len(advantages), local_batch_size):
if i != j:
other_response_rlhf_rewards.append(rlhf_reward[j : j + local_batch_size])
advantages[i : i + local_batch_size] = rlhf_reward[i : i + local_batch_size] - torch.stack(other_response_rlhf_rewards).mean(0)
dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
assert (1 - (2 + 5 + 8) / 3 - advantages[0].item()) < 1e-6 # First rlhf reward for the first prompt
assert (6 - (3 + 2 + 9) / 3 - advantages[7].item()) < 1e-6 # Third rlhf reward for the second prompt
# Dummy reward function for demonstration purposes
def reward_num_unique_letters(completions, **kwargs):
"""Reward function that rewards completions with more unique letters."""
completion_contents = [completion[0]["content"] for completion in completions]
return [float(len(set(content))) for content in completion_contents]
# Vectorized implementation
rlhf_reward = rlhf_reward.reshape(rloo_k, local_batch_size)
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
vec_advantages = rlhf_reward - baseline
torch.testing.assert_close(vec_advantages.flatten(), advantages)
training_args = RLOOConfig(output_dir="Qwen2-0.5B-RLOO")
trainer = RLOOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_num_unique_letters,
args=training_args,
train_dataset=dataset,
)
trainer.train()
```
## Benchmark experiments
To validate the RLOO implementation works, we ran experiment on the 1B model. Here are the command we used to run the experiment. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
```
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
--output_dir models/minimal/rloo_tldr \
--dataset_name trl-internal-testing/tldr-preference-sft-trl-style \
--dataset_test_split validation \
--num_ppo_epochs 2 \
--num_mini_batches 2 \
--learning_rate 3e-6 \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 8 \
--total_episodes 1000000 \
--model_name_or_path EleutherAI/pythia-1b-deduped \
--sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
--local_rollout_forward_batch_size 16 \
--missing_eos_penalty 1.0 \
--stop_token eos \
--kl_coef 0.03
```
Checkpoints and experiment tracking are available at:
- [🤗 Model checkpoint](https://huggingface.co/vwxyzjn/rloo_tldr)
- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/runs/u2sqci34)
To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR.
For more information on how to use judges, see [Judges](judges).
Execute the script using the following command:
```bash
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 33.00%
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 51.20%
accelerate launch train_rloo.py
```
The RLOO checkpoint gets a 51.2% preferred rate vs the 33.0% preference rate of the SFT checkpoint. This is a good sign that the RLOO training is working as intended.
## Looking deeper into the RLOO method
RLOO is an online learning algorithm, meaning it improves iteratively by using the data generated by the trained model itself during training. The intuition behind RLOO objective is to maximize the advantage of the generated completions, while ensuring that the model remains close to the reference policy. To understand how RLOO works, it can be broken down into four main steps: **Generating completions**, **computing the advantage**, **estimating the KL divergence**, and **computing the loss**.
Metrics:
![RLOO](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/rloo.png)
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/pr-1540/rloo.png)
### Generating completions
At each training step, we sample a batch of prompts and generate a set of \\( G \\) completions for each prompt (denoted as \\( o_i \\)).
### Computing the reward
In RLOO, the reward consists of two components: the reward provided by the reward model (or reward function) and a KL penalty that discourages the policy from deviating too far from a fixed reference policy
1. For each of the \\( G \\) generated sequences \\( o_i = (o_{i,1}, \dots, o_{i,T}) \\) conditioned on a query \\( q \\), we compute a scalar reward using a reward model \\( R(o_i, q) \\).
2. Concurrently, we estimate the KL divergence between the current policy \\( \pi_\theta \\) and the fixed reference policy \\( \pi_{\text{ref}} \\) over the sequence. The KL estimate for sequence \\( o_i \\) is:
$$
\mathbb{D}_{\mathrm{KL}}\!\left[\pi_\theta\|\pi_{\mathrm{ref}}\right] = \sum_{t=1}^T \log \frac{\pi_\theta(o_{i,t} \mid q, o_{i,<t})}{\pi_{\mathrm{ref}}(o_{i,t} \mid q, o_{i,<t})}.
$$
The final reward assigned to sequence \\( o_i \\) is then:
$$
r_i = R(o_i, q) - \beta \, \mathbb{D}_{\mathrm{KL}}\!\left[\pi_\theta \|\pi_{\mathrm{ref}}\right],
$$
where \\( \beta > 0 \\) controls the strength of the KL penalty.
> [!TIP]
> In a purely online setting (`num_iterations = 1`, default), the data are generated by the current policy. In this case, the KL penalty is computed directly using the current policy.
>
> In the more general setting (e.g., multiple gradient steps per batch), the data are instead generated by an earlier snapshot \\( \pi_{\text{old}} \\). To keep the penalty consistent with the sampling distribution, the KL is defined with respect to this policy:
>
> $$
> \mathbb{D}_{\mathrm{KL}}\!\left[\pi_{\text{old}} \,\|\, \pi_{\text{ref}}\right].
> $$
>
> Equivalently, for a sampled sequence $o$, the Monte Carlo estimate is
>
> $$
> \mathbb{D}_{\mathrm{KL}}\!\left[\pi_{\text{old}} \|\pi_{\mathrm{ref}}\right] = \sum_{t=1}^T \log \frac{\pi_{\text{old}}(o_{i,t} \mid q, o_{i,<t})}{\pi_{\mathrm{ref}}(o_{i,t} \mid q, o_{i,<t})}.
> $$
### Computing the advantage
Once the rewards for each completion have been computed, we calculate a baseline as the average reward of all other samples in the same batch, excluding the current sample. This baseline is used to reduce the variance of the policy gradient estimate. The advantage for each completion is then obtained as the difference between its own reward and this leave-one-out baseline.
Formally, for a batch of G completions, the baseline for completion is:
$$
b_i = \frac{1}{G-1} \sum_{j \neq i} r_j
$$
and then the advantage for each completion is computed as the difference between its reward and the baseline:
$$
A_i = r_i - b_i
$$
### Computing the loss
The REINFORCE loss is simply defined as:
$$
\mathcal{L}_{\text{RLOO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \hat{A}_i \, \log \pi_\theta(o_i \mid q)
$$
In practice, performing multiple gradient steps on the same batch makes the actions effectively off-policy relative to the current parameters. To correct for this, we introduce the importance sampling ratio. To prevent excessively large updates when the policy changes between sampling and gradient steps, we clip this ratio:
$$
\mathcal{L}_{\text{RLOO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \min \left( \frac{\pi_\theta(o_i \mid q)}{\pi_{\theta_\text{old}}(o_i \mid q)} \hat{A}_i, \, \text{clip}\left(\frac{\pi_\theta(o_i \mid q)}{\pi_{\theta_\text{old}}(o_i \mid q)}, 1-\epsilon, 1+\epsilon\right) \hat{A}_i \right)
$$
In a fully online, single-step setting (default), \\( \frac{\pi_\theta(o_i \mid q)}{\pi_{\theta_\text{old}}(o_i \mid q)} = 1 \\) and this reduces to standard REINFORCE.
## Logged metrics
While training and evaluating, we record the following reward metrics:
- `num_tokens`: The total number of tokens processed so far, including both prompts and completions.
- `completions/mean_length`: The average length of generated completions.
- `completions/min_length`: The minimum length of generated completions.
- `completions/max_length`: The maximum length of generated completions.
- `completions/mean_terminated_length`: The average length of generated completions that terminate with EOS.
- `completions/min_terminated_length`: The minimum length of generated completions that terminate with EOS.
- `completions/max_terminated_length`: The maximum length of generated completions that terminate with EOS.
- `completions/clipped_ratio`: The ratio of truncated (clipped) completions.
- `reward/{reward_func_name}/mean`: The average reward from a specific reward function.
- `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function.
- `reward`: The overall average reward after applying reward weights.
- `reward_std`: The standard deviation of rewards after applying reward weights. This is the average of the per-group standard deviations.
- `frac_reward_zero_std`: The fraction of samples in the generation batch with a reward std of zero, implying there is little diversity for that prompt (all answers are correct or incorrect).
- `entropy`: Average entropy of token predictions across generated completions. (If `mask_truncated_completions=True`, masked sequences tokens are excluded.)
- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero.
- `clip_ratio/region_mean`: The ratio of sequence probabilities where the RLOO objective is clipped to stay within the trust region:
$$
\text{clip}\left( r_{i}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \qquad r_{i}(\theta) = \frac{\pi_\theta(o_{i} \mid q)}{\pi_{\theta_{\text{old}}}(o_{i} \mid q)}\,.
$$
A higher value means more samples are clipped, which constrains how much the policy $\pi_\theta$ can change.
- `clip_ratio/low_mean`: The average ratio of sequence probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
- `clip_ratio/low_min`: The minimum ratio of sequence probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
- `clip_ratio/high_mean`: The average ratio of sequence probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\)
- `clip_ratio/high_max`: The maximum ratio of sequence probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\).
## Customization
### Speed up training with vLLM-powered generation
Generation is often the main bottleneck when training with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a high-throughput, low-latency inference engine for LLMs. To enable it, first install the package with
```shell
pip install trl[vllm]
```
We support two ways of using vLLM during training: **server mode** and **colocate mode**.
#### 🔌 Option 1: Server mode
In this mode, vLLM runs in a separate process (and using separate GPUs) and communicates with the trainer via HTTP. This is ideal if you have dedicated GPUs for inference.
1. **Start the vLLM server**:
```bash
trl vllm-serve --model <model_name>
```
2. **Enable server mode in your training script**:
```python
from trl import RLOOConfig
training_args = RLOOConfig(
...,
use_vllm=True,
vllm_mode="server", # default value, can be omitted
)
```
> [!WARNING]
> Make sure that the server is using different GPUs than the trainer, otherwise you may run into NCCL errors. You can specify the GPUs to use with the `CUDA_VISIBLE_DEVICES` environment variable.
#### 🧩 Option 2: Colocate mode
In this mode, vLLM runs inside the trainer process and shares GPU memory with the training model. This avoids launching a separate server and can improve GPU utilization, but may lead to memory contention on the training GPUs.
```python
from trl import RLOOConfig
training_args = RLOOConfig(
...,
use_vllm=True,
vllm_mode="colocate",
)
```
> [!TIP]
> Depending on the model size and the overall GPU memory requirements for training, you may need to adjust the `vllm_gpu_memory_utilization` parameter in [`RLOOConfig`] to avoid underutilization or out-of-memory errors.
>
> We provide a [HF Space](https://huggingface.co/spaces/trl-lib/recommend-vllm-memory) to help estimate the recommended GPU memory utilization based on your model configuration and experiment settings. Simply use it as follows to get `vllm_gpu_memory_utilization` recommendation:
>
> <iframe src="https://trl-lib-recommend-vllm-memory.hf.space" frameborder="0" width="850" height="450"></iframe>
>
> If the recommended value does not work in your environment, we suggest adding a small buffer (e.g., +0.05 or +0.1) to the recommended value to ensure stability.
>
> If you still find you are getting out-of-memory errors set `vllm_enable_sleep_mode` to True and the vllm parameters and cache will be offloaded during the optimization step. For more information, see [Reducing Memory Usage with vLLM Sleep Mode](reducing_memory_usage#vllm-sleep-mode).
> [!TIP]
> By default, RLOO uses `MASTER_ADDR=localhost` and `MASTER_PORT=12345` for vLLM, but you can override these values by setting the environment variables accordingly.
For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).
### RLOO at scale: train a 70B+ Model on multiple nodes
When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include:
- **DeepSpeed ZeRO Stage 3**: ZeRO leverages data parallelism to distribute model states (weights, gradients, optimizer states) across multiple GPUs and CPUs, reducing memory and compute requirements on each device. Since large models cannot fit on a single GPU, using ZeRO Stage 3 is required for training such models. For more details, see [DeepSpeed Integration](deepspeed_integration).
- **Accelerate**: Accelerate is a library that simplifies distributed training across multiple GPUs and nodes. It provides a simple API to launch distributed training and handles the complexities of distributed training, such as data parallelism, gradient accumulation, and distributed data loading. For more details, see [Distributing Training](distributing_training).
- **vLLM**: See the previous section on how to use vLLM to speed up generation.
Below is an example SLURM script to train a 70B model with RLOO on multiple nodes. This script trains a model on 4 nodes and uses the 5th node for vLLM-powered generation.
```sh
#!/bin/bash
#SBATCH --nodes=5
#SBATCH --gres=gpu:8
# Get the list of allocated nodes
NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
# Assign the first 4 nodes for training and the 5th node for vLLM
TRAIN_NODES="${NODELIST[@]:0:4}" # Nodes 0, 1, 2, 3 for training
VLLM_NODE="${NODELIST[4]}" # Node 4 for vLLM
# Run training on the first 4 nodes (Group 1)
srun --nodes=4 --ntasks=4 --nodelist="${NODELIST[@]:0:4}" accelerate launch \
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
--num_processes 32 \
--num_machines 4 \
--main_process_ip ${NODELIST[0]} \
--machine_rank $SLURM_PROCID \
--rdzv_backend c10d \
train_rloo.py \
--server_ip $VLLM_NODE &
# Run vLLM server on the 5th node (Group 2)
srun --nodes=1 --ntasks=1 --nodelist="${NODELIST[4]}" trl vllm-serve --model Qwen/Qwen2.5-72B --tensor_parallel_size 8 &
wait
```
```python
import argparse
from datasets import load_dataset
from trl import RLOOTrainer, RLOOConfig
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--vllm_server_host", type=str, default="", help="The server IP")
args = parser.parse_args()
# Example dataset from TLDR
dataset = load_dataset("trl-lib/tldr", split="train")
# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
return [len(set(c)) for c in completions]
training_args = RLOOConfig(
output_dir="Qwen2.5-72B-RLOO",
per_device_train_batch_size=4,
bf16=True,
gradient_checkpointing=True,
use_vllm=True,
vllm_server_host=args.vllm_server_host.replace("ip-", "").replace("-", "."), # from ip-X-X-X-X to X.X.X.X
)
trainer = RLOOTrainer(model="Qwen/Qwen2.5-72B", args=training_args, reward_funcs=reward_num_unique_chars, train_dataset=dataset)
trainer.train()
if __name__=="__main__":
main()
```
### Using a custom reward function
The [`RLOOTrainer`] supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements:
1. **Input arguments**:
- The function must accept the following as keyword arguments:
- `prompts` (contains the prompts),
- `completions` (contains the generated completions),
- `completions_ids` (contains the tokenized completions),
- `trainer_state` ([`~transformers.TrainerState`]): The current state of the trainer. This can be used to implement dynamic reward functions, such as curriculum learning, where the reward is adjusted based on the training progress.
- All column names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument.
The easiest way to comply with this requirement is to use `**kwargs` in the function signature.
- Depending on the dataset format, the input will vary:
- For [standard format](dataset_formats#standard), `prompts` and `completions` will be lists of strings.
- For [conversational format](dataset_formats#conversational), `prompts` and `completions` will be lists of message dictionaries.
2. **Return value**: The function must return a list of floats. Each float represents the reward corresponding to a single completion.
#### Example 1: Reward longer completions
Below is an example of a reward function for a standard format that rewards longer completions:
```python
def reward_func(completions_ids, **kwargs):
"""Reward function that assigns higher scores to longer completions (in terms of token count)."""
return [float(len(ids)) for ids in completions_ids]
```
You can test it as follows:
```python
>>> prompts = ["The sky is", "The sun is"] # not used in the reward function, but the trainer will pass it
>>> completions = [" blue.", " in the sky."] # not used in the reward function, but the trainer will pass it
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]]
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
[2.0, 4.0]
```
#### Example 1.1: Reward longer completions (based on the number of characters)
Same as the previous example, but this time the reward function is based on the number of characters instead of tokens.
```python
def reward_func(completions, **kwargs):
"""Reward function that assigns higher scores to longer completions (in terms of character count)."""
return [float(len(completion)) for completion in completions]
```
You can test it as follows:
```python
>>> prompts = ["The sky is", "The sun is"]
>>> completions = [" blue.", " in the sky."]
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]] # not used in the reward function, but the trainer will pass it
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
[6.0, 12.0]
```
#### Example 2: Reward completions with a specific format
Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the _format reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948).
It is designed for a conversational format, where prompts and completions consist of structured messages.
```python
import re
def format_reward_func(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]
```
You can test this function as follows:
```python
>>> prompts = [
... [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}],
... [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}],
... ]
>>> completions = [
... [{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],
... [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],
... ]
>>> format_reward_func(prompts=prompts, completions=completions)
[1.0, 0.0]
```
#### Example 3: Reward completions based on a reference
Below is an example of a reward function that checks if the completion is correct. This example is inspired by the _accuracy reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948).
This example is designed for [standard format](dataset_formats#standard), where the dataset contains a column named `ground_truth`.
```python
import re
def reward_func(completions, ground_truth, **kwargs):
# Regular expression to capture content inside \boxed{}
matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions]
contents = [match.group(1) if match else "" for match in matches]
# Reward 1 if the content is the same as the ground truth, 0 otherwise
return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)]
```
You can test this function as follows:
```python
>>> prompts = ["Problem: Solve the equation $2x + 3 = 7$. Solution:", "Problem: Solve the equation $3x - 5 = 10$."]
>>> completions = [r" The solution is \boxed{2}.", r" The solution is \boxed{6}."]
>>> ground_truth = ["2", "5"]
>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth)
[1.0, 0.0]
```
#### Example 4: Multi-task reward functions
Below is an example of using multiple reward functions in the [`RLOOTrainer`]. In this example, we define two task-specific reward functions: `math_reward_func` and `coding_reward_func`. The `math_reward_func` rewards math problems based on their correctness, while the `coding_reward_func` rewards coding problems based on whether the solution works.
```python
from datasets import Dataset
from trl import RLOOTrainer
# Define a dataset that contains both math and coding problems
dataset = Dataset.from_list(
[
{"prompt": "What is 2+2?", "task": "math"},
{"prompt": "Write a function that returns the sum of two numbers.", "task": "code"},
{"prompt": "What is 3*4?", "task": "math"},
{"prompt": "Write a function that returns the product of two numbers.", "task": "code"},
]
)
# Math-specific reward function
def math_reward_func(prompts, completions, task, **kwargs):
rewards = []
for prompt, completion, t in zip(prompts, completions, task):
if t == "math":
# Calculate math-specific reward
correct = check_math_solution(prompt, completion)
reward = 1.0 if correct else -1.0
rewards.append(reward)
else:
# Return None for non-math tasks
rewards.append(None)
return rewards
# Coding-specific reward function
def coding_reward_func(prompts, completions, task, **kwargs):
rewards = []
for prompt, completion, t in zip(prompts, completions, task):
if t == "coding":
# Calculate coding-specific reward
works = test_code_solution(prompt, completion)
reward = 1.0 if works else -1.0
rewards.append(reward)
else:
# Return None for non-coding tasks
rewards.append(None)
return rewards
# Use both task-specific reward functions
trainer = RLOOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=[math_reward_func, coding_reward_func],
train_dataset=dataset,
)
trainer.train()
```
In this example, the `math_reward_func` and `coding_reward_func` are designed to work with a mixed dataset that contains both math and coding problems. The `task` column in the dataset is used to determine which reward function to apply to each problem. If there is no relevant reward function for a sample in the dataset, the reward function will return `None`, and the [`RLOOTrainer`] will continue with the valid functions and tasks. This allows the [`RLOOTrainer`] to handle multiple reward functions with different applicability.
Note that the [`RLOOTrainer`] will ignore the `None` rewards returned by the reward functions and only consider the rewards returned by the relevant functions. This ensures that the model is trained on the relevant tasks and ignores the tasks for which there is no relevant reward function.
#### Passing the reward function to the trainer
To use your custom reward function, pass it to the [`RLOOTrainer`] as follows:
```python
from trl import RLOOTrainer
trainer = RLOOTrainer(
reward_funcs=reward_func,
...,
)
```
If you have multiple reward functions, you can pass them as a list:
```python
from trl import RLOOTrainer
trainer = RLOOTrainer(
reward_funcs=[reward_func1, reward_func2],
...,
)
```
and the reward will be computed as the sum of the rewards from each function, or the weighted sum if `reward_weights` is provided in the config.
Note that [`RLOOTrainer`] supports multiple reward functions of different types. See the parameters documentation for more details.
## Vision-Language Model (VLM) Training
RLOO supports training Vision-Language Models (VLMs) on multimodal datasets containing both text and images.
### Supported Models
Tested with:
- **Gemma3** — e.g., `google/gemma-3-4b-it`
- **LLaVA-NeXT** — e.g., `llava-hf/llava-v1.6-mistral-7b-hf`
- **Qwen2-VL** — e.g., `Qwen/Qwen2-VL-2B-Instruct`
- **Qwen2.5-VL** — e.g., `Qwen/Qwen2.5-VL-3B-Instruct`
- **SmolVLM2** — e.g., `HuggingFaceTB/SmolVLM2-2.2B-Instruct`
> [!TIP]
> Compatibility with all VLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes.
### Quick Start
Use [rloo\_vlm.py](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo_vlm.py) to fine-tune a VLM. Example command for training on [`lmms-lab/multimodal-open-r1-8k-verified`](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified):
```bash
# pip install openrlbenchmark==0.2.1a5
# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation
# to use it, change `?we=huggingface&wpn=trl` to your own project and `?tag=pr-1540` to your own tag
python -m openrlbenchmark.rlops_multi_metrics \
--filters '?we=huggingface&wpn=trl&xaxis=train/episode&ceik=output_dir&cen=sft_model_path&metrics=train/objective/rlhf_reward&metrics=train/objective/scores&metrics=train/objective/kl&metrics=train/objective/non_score_reward&metrics=train/objective/entropy&metrics=train/policy/approxkl_avg&metrics=train/policy/clipfrac_avg&metrics=train/loss/policy_avg&metrics=train/policy/entropy_avg&metrics=train/val/ratio&metrics=train/val/ratio_var&metrics=train/val/num_eos_tokens&metrics=train/lr&metrics=train/eps' \
"cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr?tag=pr-1540" \
--env-ids models/minimal/rloo_tldr \
--pc.ncols 4 \
--pc.ncols-legend 1 \
--pc.xlabel "Episode" \
--output-filename benchmark/trl/pr-1540/rloo \
--scan-history
accelerate launch \
--config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/rloo_vlm.py \
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
--output_dir rloo-Qwen2.5-VL-3B-Instruct \
--learning_rate 1e-5 \
--gradient_checkpointing \
--dtype bfloat16 \
--max_prompt_length 2048 \
--max_completion_length 1024 \
--use_vllm \
--vllm_mode colocate \
--use_peft \
--lora_target_modules "q_proj", "v_proj" \
--log_completions
```
### Configuration Tips
> [!WARNING]
> VLM training may fail if image tokens are truncated. We highly recommend disabling truncation by setting `max_prompt_length` to `None`.
- Use LoRA on vision-language projection layers
- Enable 4-bit quantization to reduce memory usage
- VLMs are memory-intensive — start with smaller batch sizes
- Most models are compatible with vLLM (`server` and `colocate` modes)
### Dataset Format
Each training sample should include:
- `prompt`: Text formatted via the processor's chat template
- `image`/`images`: PIL Image or list of PIL Images
The trainer automatically handles image-to-tensor conversion via the models image processor.
## RLOOTrainer
[[autodoc]] RLOOTrainer
- train
- save_model
- push_to_hub
## RLOOConfig
[[autodoc]] RLOOConfig
## References
1. [RLOO Paper](https://openreview.net/pdf?id=r1lgTGL5DE)
2. [Paper Back to Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs](https://huggingface.co/papers/2402.14740)
3. [Paper - REINFORCE++: A Simple and Efficient Approach for Aligning Large Language Models](https://huggingface.co/papers/2501.03262)
4. [Blog Post - Putting RL back in RLHF](https://huggingface.co/blog/putting_rl_back_in_rlhf_with_rloo)
5. [Blog Post - Unraveling RLHF and Its Variants: Progress and Practical Engineering Insights](https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05)
6. [Youtube - RLOO: A Cost-Efficient Optimization for Learning from Human Feedback in LLMs](https://www.youtube.com/watch?v=86asXGPK6RU&ab_channel=BuzzRobot)
## Migration Guide from the old implementation (0.21 and below)
With the release of version 0.22.0, we have revamped the [`RLOOTrainer`] to be more aligned with other online trainers in the library, like [`GRPOTrainer`]. This new implementation introduces several changes to the configuration parameters and overall structure of the trainer.
Below is a summary of the key changes for [`RLOOConfig`]:
| TRL ≤ 0.21.x | TRL ≥ 0.22.0 |
| --- | --- |
| `rloo_k` | renamed to `num_generations` |
| `cliprange` | renamed to `epsilon` |
| `kl_coef` | renamed to `beta` |
| `exp_name` | renamed to `run_name`. Use `run_name = f"{exp_name}__{seed}__{int(time.time())}"` to replicate old behavior |
| `normalize_reward` | renamed to `normalize_advantages`. Note: this always normalized advantages (despite the old name) |
| `num_ppo_epochs` | renamed to `num_iterations` (default: `1`) |
| `token_level_kl` | **removed** KL is now computed only at the sequence level |
| `dataset_num_proc` | **removed** it was unused |
| `num_mini_batches` | renamed to `steps_per_generation` |
| `total_episodes` | use `max_steps=total_episodes / gradient_accumulation_steps` instead |
| `local_rollout_forward_batch_size` | **removed** now automatically set to `per_device_train_batch_size` (or `per_device_eval_batch_size` during evaluation) |
| `num_sample_generations` | **removed** use `logging_steps` to control generation logging frequency |
| `response_length` | renamed to `max_completion_length` (default: `256`) |
| `stop_token` | **removed** |
| `stop_token_id` | **removed** use `processing_class.eos_token_id` instead |
| `missing_eos_penalty` | **removed** replicate with a custom reward function checking if `eos_token_id` is in `completion_ids` |
Below is a summary of the key changes for [`RLOOTrainer`]:
| TRL ≤ 0.21.x | TRL ≥ 0.22.0 |
| --- | --- |
| `config` | renamed to `args` |
| `reward_model` | renamed to `reward_funcs`, which now supports both reward models and custom reward functions |
| `policy` | renamed to `model` |
| `ref_policy` | **removed** the reference model is now created automatically from `model` |
| `data_collator` | **removed** |

View File

@ -0,0 +1,24 @@
# Scripts Utilities
## ScriptArguments
[[autodoc]] ScriptArguments
## TrlParser
[[autodoc]] TrlParser
- parse_args_and_config
- parse_args_into_dataclasses
- set_defaults_with_config
## get_dataset
[[autodoc]] get_dataset
## DatasetConfig
[[autodoc]] scripts.utils.DatasetConfig
## DatasetMixtureConfig
[[autodoc]] DatasetMixtureConfig

View File

@ -1,18 +1,14 @@
# Sentiment Tuning Examples
The notebooks and scripts in this examples show how to fine-tune a model with a sentiment classifier (such as `lvwerra/distilbert-imdb`).
The notebooks and scripts in these examples show how to fine-tune a model with a sentiment classifier (such as `lvwerra/distilbert-imdb`).
Here's an overview of the notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples):
| File | Description |
|------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------|
| --- |--- |
| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb) | This script shows how to use the `PPOTrainer` to fine-tune a sentiment analysis model using IMDB dataset |
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook.
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
## Usage
@ -30,7 +26,6 @@ python examples/scripts/ppo.py --log_with wandb --mini_batch_size 1 --gradient_a
Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
## Few notes on multi-GPU
To run in multi-GPU setup with DDP (distributed Data Parallel) change the `device_map` value to `device_map={"": Accelerator().process_index}` and make sure to run your script with `accelerate launch yourscript.py`. If you want to apply naive pipeline parallelism you can use `device_map="auto"`.

335
docs/source/sft_trainer.md Normal file
View File

@ -0,0 +1,335 @@
# SFT Trainer
[![All_models-SFT-blue](https://img.shields.io/badge/All_models-SFT-blue)](https://huggingface.co/models?other=sft,trl) [![smol_course-Chapter_1-yellow](https://img.shields.io/badge/smol_course-Chapter_1-yellow)](https://github.com/huggingface/smol-course/tree/main/1_instruction_tuning)
## Overview
TRL supports the Supervised Fine-Tuning (SFT) Trainer for training language models.
This post-training method was contributed by [Younes Belkada](https://huggingface.co/ybelkada).
## Quick start
This example demonstrates how to train a language model using the [`SFTTrainer`] from TRL. We train a [Qwen 3 0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) model on the [Capybara dataset](https://huggingface.co/datasets/trl-lib/Capybara), a compact, diverse multi-turn dataset to benchmark reasoning and generalization.
```python
from trl import SFTTrainer
from datasets import load_dataset
trainer = SFTTrainer(
model="Qwen/Qwen3-0.6B",
train_dataset=load_dataset("trl-lib/Capybara", split="train"),
)
trainer.train()
```
<iframe src="https://trl-lib-trackio.hf.space/?project=trl-documentation&metrics=train*&runs=sft_qwen3-0.6B_capybara" style="width: 100%; min-width: 300px; max-width: 800px;" height="830" frameBorder="0"></iframe>
## Expected dataset type and format
SFT supports both [language modeling](dataset_formats#language-modeling) and [prompt-completion](dataset_formats#prompt-completion) datasets. The [`SFTTrainer`] is compatible with both [standard](dataset_formats#standard) and [conversational](dataset_formats#conversational) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
```python
# Standard language modeling
{"text": "The sky is blue."}
# Conversational language modeling
{"messages": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}]}
# Standard prompt-completion
{"prompt": "The sky is",
"completion": " blue."}
# Conversational prompt-completion
{"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is blue."}]}
```
If your dataset is not in one of these formats, you can preprocess it to convert it into the expected format. Here is an example with the [FreedomIntelligence/medical-o1-reasoning-SFT](https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT) dataset:
```python
from datasets import load_dataset
dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en")
def preprocess_function(example):
return {
"prompt": [{"role": "user", "content": example["Question"]}],
"completion": [
{"role": "assistant", "content": f"<think>{example['Complex_CoT']}</think>{example['Response']}"}
],
}
dataset = dataset.map(preprocess_function, remove_columns=["Question", "Response", "Complex_CoT"])
print(next(iter(dataset["train"])))
```
```json
{
"prompt": [
{
"content": "Given the symptoms of sudden weakness in the left arm and leg, recent long-distance travel, and the presence of swollen and tender right lower leg, what specific cardiac abnormality is most likely to be found upon further evaluation that could explain these findings?",
"role": "user",
}
],
"completion": [
{
"content": "<think>Okay, let's see what's going on here. We've got sudden weakness [...] clicks into place!</think>The specific cardiac abnormality most likely to be found in [...] the presence of a PFO facilitating a paradoxical embolism.",
"role": "assistant",
}
],
}
```
## Looking deeper into the SFT method
Supervised Fine-Tuning (SFT) is the simplest and most commonly used method to adapt a language model to a target dataset. The model is trained in a fully supervised fashion using pairs of input and output sequences. The goal is to minimize the negative log-likelihood (NLL) of the target sequence, conditioning on the input.
This section breaks down how SFT works in practice, covering the key steps: **preprocessing**, **tokenization** and **loss computation**.
### Preprocessing and tokenization
During training, each example is expected to contain a **text field** or a **(prompt, completion)** pair, depending on the dataset format. For more details on the expected formats, see [Dataset formats](dataset_formats).
The [`SFTTrainer`] tokenizes each input using the model's tokenizer. If both prompt and completion are provided separately, they are concatenated before tokenization.
### Computing the loss
![sft_figure](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/sft_figure.png)
The loss used in SFT is the **token-level cross-entropy loss**, defined as:
$$
\mathcal{L}_{\text{SFT}}(\theta) = - \sum_{t=1}^{T} \log p_\theta(y_t \mid y_{<t}),
$$
where \\( y_t \\) is the target token at timestep \\( t \\), and the model is trained to predict the next token given the previous ones. In practice, padding tokens are masked out during loss computation.
> [!TIP]
> [On the Generalization of SFT: A Reinforcement Learning Perspective with Reward Rectification](https://huggingface.co/papers/2508.05629) proposes an alternative loss function, called **Dynamic Fine-Tuning (DFT)**, which aims to improve generalization by rectifying the reward signal. This method can be enabled by setting `loss_type="dft"` in the [`SFTConfig`]. For more details, see [Paper Index - Dynamic Fine-Tuning](paper_index#on-the-generalization-of-sft-a-reinforcement-learning-perspective-with-reward-rectification).
### Label shifting and masking
During training, the loss is computed using a **one-token shift**: the model is trained to predict each token in the sequence based on all previous tokens. Specifically, the input sequence is shifted right by one position to form the target labels.
Padding tokens (if present) are ignored in the loss computation by applying an ignore index (default: `-100`) to the corresponding positions. This ensures that the loss focuses only on meaningful, non-padding tokens.
## Logged metrics
While training and evaluating we record the following reward metrics:
* `global_step`: The total number of optimizer steps taken so far.
* `epoch`: The current epoch number, based on dataset iteration.
* `num_tokens`: The total number of tokens processed so far.
* `loss`: The average cross-entropy loss computed over non-masked tokens in the current logging interval.
* `entropy`: The average entropy of the model's predicted token distribution over non-masked tokens.
* `mean_token_accuracy`: The proportion of non-masked tokens for which the models top-1 prediction matches the ground truth token.
* `learning_rate`: The current learning rate, which may change dynamically if a scheduler is used.
* `grad_norm`: The L2 norm of the gradients, computed before gradient clipping.
## Customization
### Model initialization
You can directly pass the kwargs of the [`~transformers.AutoModelForCausalLM.from_pretrained()`] method to the [`SFTConfig`]. For example, if you want to load a model in a different precision, analogous to
```python
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.bfloat16)
```
you can do so by passing the `model_init_kwargs={"dtype": torch.bfloat16}` argument to the [`SFTConfig`].
```python
from trl import SFTConfig
training_args = SFTConfig(
model_init_kwargs={"dtype": torch.bfloat16},
)
```
Note that all keyword arguments of [`~transformers.AutoModelForCausalLM.from_pretrained()`] are supported.
### Packing
[`SFTTrainer`] supports _example packing_, where multiple examples are packed in the same input sequence to increase training efficiency. To enable packing, simply pass `packing=True` to the [`SFTConfig`] constructor.
```python
training_args = SFTConfig(packing=True)
```
For more details on packing, see [Packing](reducing_memory_usage#packing).
### Train on assistant messages only
To train on assistant messages only, use a [conversational](dataset_formats#conversational) dataset and set `assistant_only_loss=True` in the [`SFTConfig`]. This setting ensures that loss is computed **only** on the assistant responses, ignoring user or system messages.
```python
training_args = SFTConfig(assistant_only_loss=True)
```
![train_on_assistant](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/train_on_assistant.png)
> [!WARNING]
> This functionality is only available for chat templates that support returning the assistant tokens mask via the `&#123;% generation %&#125;` and `&#123;% endgeneration %&#125;` keywords. For an example of such a template, see [HugggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B/blob/main/chat_template.jinja#L76-L82).
### Train on completion only
To train on completion only, use a [prompt-completion](dataset_formats#prompt-completion) dataset. By default, the trainer computes the loss on the completion tokens only, ignoring the prompt tokens. If you want to train on the full sequence, set `completion_only_loss=False` in the [`SFTConfig`].
![train_on_completion](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/train_on_completion.png)
> [!TIP]
> Training on completion only is compatible with training on assistant messages only. In this case, use a [conversational](dataset_formats#conversational) [prompt-completion](dataset_formats#prompt-completion) dataset and set `assistant_only_loss=True` in the [`SFTConfig`].
### Train adapters with PEFT
We support tight integration with 🤗 PEFT library, allowing any user to conveniently train adapters and share them on the Hub, rather than training the entire model.
```python
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
dataset = load_dataset("trl-lib/Capybara", split="train")
trainer = SFTTrainer(
"Qwen/Qwen3-0.6B",
train_dataset=dataset,
peft_config=LoraConfig()
)
trainer.train()
```
You can also continue training your [`~peft.PeftModel`]. For that, first load a `PeftModel` outside [`SFTTrainer`] and pass it directly to the trainer without the `peft_config` argument being passed.
```python
from datasets import load_dataset
from trl import SFTTrainer
from peft import AutoPeftModelForCausalLM
model = AutoPeftModelForCausalLM.from_pretrained("trl-lib/Qwen3-4B-LoRA", is_trainable=True)
dataset = load_dataset("trl-lib/Capybara", split="train")
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
)
trainer.train()
```
> [!TIP]
> When training adapters, you typically use a higher learning rate (≈1e4) since only new parameters are being learned.
>
> ```python
> SFTConfig(learning_rate=1e-4, ...)
> ```
### Train with Liger Kernel
Liger Kernel is a collection of Triton kernels for LLM training that boosts multi-GPU throughput by 20%, cuts memory use by 60% (enabling up to 4× longer context), and works seamlessly with tools like FlashAttention, PyTorch FSDP, and DeepSpeed. For more information, see [Liger Kernel Integration](liger_kernel_integration).
### Train with Unsloth
Unsloth is an opensource framework for finetuning and reinforcement learning that trains LLMs (like Llama, Mistral, Gemma, DeepSeek, and more) up to 2× faster with up to 70% less VRAM, while providing a streamlined, Hugging Facecompatible workflow for training, evaluation, and deployment. For more information, see [Unsloth Integration](unsloth_integration).
## Instruction tuning example
**Instruction tuning** teaches a base language model to follow user instructions and engage in conversations. This requires:
1. **Chat template**: Defines how to structure conversations into text sequences, including role markers (user/assistant), special tokens, and turn boundaries. Read more about chat templates in [Chat templates](https://huggingface.co/docs/transformers/chat_templating#templates).
2. **Conversational dataset**: Contains instruction-response pairs
This example shows how to transform the [Qwen 3 0.6B Base](https://huggingface.co/Qwen/Qwen3-0.6B-Base) model into an instruction-following model using the [Capybara dataset](https://huggingface.co/datasets/trl-lib/Capybara) and a chat template from [HuggingFaceTB/SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B). The SFT Trainer automatically handles tokenizer updates and special token configuration.
```python
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
trainer = SFTTrainer(
model="Qwen/Qwen3-0.6B-Base",
args=SFTConfig(
output_dir="Qwen3-0.6B-Instruct",
chat_template_path="HuggingFaceTB/SmolLM3-3B",
),
train_dataset=load_dataset("trl-lib/Capybara", split="train"),
)
trainer.train()
```
> [!WARNING]
> Some base models, like those from Qwen, have a predefined chat template in the model's tokenizer. In these cases, it is not necessary to apply [`clone_chat_template()`], as the tokenizer already handles the formatting. However, it is necessary to align the EOS token with the chat template to ensure the model's responses terminate correctly. In these cases, specify `eos_token` in [`SFTConfig`]; for example, for `Qwen/Qwen2.5-1.5B`, one should set `eos_token="<|im_end|>"`.
Once trained, your model can now follow instructions and engage in conversations using its new chat template.
```python
>>> from transformers import pipeline
>>> pipe = pipeline("text-generation", model="Qwen3-0.6B-Instruct/checkpoint-5000")
>>> prompt = "<|im_start|>user\nWhat is the capital of France? Answer in one word.<|im_end|>\n<|im_start|>assistant\n"
>>> response = pipe(prompt)
>>> response[0]["generated_text"]
'<|im_start|>user\nWhat is the capital of France? Answer in one word.<|im_end|>\n<|im_start|>assistant\nThe capital of France is Paris.'
```
Alternatively, use the structured conversation format (recommended):
```python
>>> prompt = [{"role": "user", "content": "What is the capital of France? Answer in one word."}]
>>> response = pipe(prompt)
>>> response[0]["generated_text"]
[{'role': 'user', 'content': 'What is the capital of France? Answer in one word.'}, {'role': 'assistant', 'content': 'The capital of France is Paris.'}]
```
## Tool Calling with SFT
The [`SFTTrainer`] fully supports fine-tuning models with _tool calling_ capabilities. In this case, each dataset example should include:
* The conversation messages, including any tool calls (`tool_calls`) and tool responses (`tool` role messages)
* The list of available tools in the `tools` column, typically provided as JSON schemas
For details on the expected dataset structure, see the [Dataset Format — Tool Calling](dataset_formats#tool-calling) section.
## Training Vision Language Models
[`SFTTrainer`] fully supports training Vision-Language Models (VLMs). To train a VLM, you need to provide a dataset with an additional `images` column containing the images to be processed. For more information on the expected dataset structure, see the [Dataset Format — Vision Dataset](dataset_formats#vision-dataset) section.
An example of such a dataset is the [LLaVA Instruct Mix](https://huggingface.co/datasets/trl-lib/llava-instruct-mix).
```python
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
trainer = SFTTrainer(
model="Qwen/Qwen2.5-VL-3B-Instruct",
args=SFTConfig(max_length=None),
train_dataset=load_dataset("trl-lib/llava-instruct-mix", split="train"),
)
trainer.train()
```
> [!TIP]
> For VLMs, truncating may remove image tokens, leading to errors during training. To avoid this, set `max_length=None` in the [`SFTConfig`]. This allows the model to process the full sequence length without truncating image tokens.
>
> ```python
> SFTConfig(max_length=None, ...)
> ```
>
> Only use `max_length` when you've verified that truncation won't remove image tokens for the entire dataset.
## SFTTrainer
[[autodoc]] SFTTrainer
- train
- save_model
- push_to_hub
## SFTConfig
[[autodoc]] SFTConfig
## DataCollatorForLanguageModeling
[[autodoc]] trainer.sft_trainer.DataCollatorForLanguageModeling
## DataCollatorForVisionLanguageModeling
[[autodoc]] trainer.sft_trainer.DataCollatorForVisionLanguageModeling

View File

@ -1,772 +0,0 @@
# Supervised Fine-tuning Trainer
[![](https://img.shields.io/badge/All_models-SFT-blue)](https://huggingface.co/models?other=sft,trl)
Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset.
Check out a complete flexible example at [`examples/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft.py).
Experimental support for Vision Language Models is also included in the example [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft_vlm.py).
## Quickstart
If you have a dataset hosted on the 🤗 Hub, you can easily fine-tune your SFT model using [`SFTTrainer`] from TRL. Let us assume your dataset is `imdb`, the text you want to predict is inside the `text` field of the dataset, and you want to fine-tune the `facebook/opt-350m` model.
The following code-snippet takes care of all the data pre-processing and training for you:
```python
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
dataset = load_dataset("stanfordnlp/imdb", split="train")
training_args = SFTConfig(
max_seq_length=512,
output_dir="/tmp",
)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
args=training_args,
)
trainer.train()
```
Make sure to pass the correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
You can also construct a model outside of the trainer and pass it as follows:
```python
from transformers import AutoModelForCausalLM
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
dataset = load_dataset("stanfordnlp/imdb", split="train")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
training_args = SFTConfig(output_dir="/tmp")
trainer = SFTTrainer(
model,
train_dataset=dataset,
args=training_args,
)
trainer.train()
```
The above snippets will use the default training arguments from the [`SFTConfig`] class. If you want to modify the defaults pass in your modification to the `SFTConfig` constructor and pass them to the trainer via the `args` argument.
## Advanced usage
### Train on completions only
You can use the `DataCollatorForCompletionOnlyLM` to train your model on the generated prompts only. Note that this works only in the case when `packing=False`.
To instantiate that collator for instruction data, pass a response template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on completions only on the CodeAlpaca dataset:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
def formatting_prompts_func(example):
output_texts = []
for i in range(len(example['instruction'])):
text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
output_texts.append(text)
return output_texts
response_template = " ### Answer:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
trainer = SFTTrainer(
model,
train_dataset=dataset,
args=SFTConfig(output_dir="/tmp"),
formatting_func=formatting_prompts_func,
data_collator=collator,
)
trainer.train()
```
To instantiate that collator for assistant style conversation data, pass a response template, an instruction template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on assistant completions only on the Open Assistant Guanaco dataset:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
instruction_template = "### Human:"
response_template = "### Assistant:"
collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template, response_template=response_template, tokenizer=tokenizer, mlm=False)
trainer = SFTTrainer(
model,
args=SFTConfig(output_dir="/tmp"),
train_dataset=dataset,
data_collator=collator,
)
trainer.train()
```
Make sure to have a `pad_token_id` which is different from `eos_token_id` which can result in the model not properly predicting EOS (End of Sentence) tokens during generation.
#### Using token_ids directly for `response_template`
Some tokenizers like Llama 2 (`meta-llama/Llama-2-XXb-hf`) tokenize sequences differently depending on whether they have context or not. For example:
```python
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
def print_tokens_with_ids(txt):
tokens = tokenizer.tokenize(txt, add_special_tokens=False)
token_ids = tokenizer.encode(txt, add_special_tokens=False)
print(list(zip(tokens, token_ids)))
prompt = """### User: Hello\n\n### Assistant: Hi, how can I help you?"""
print_tokens_with_ids(prompt) # [..., ('▁Hello', 15043), ('<0x0A>', 13), ('<0x0A>', 13), ('##', 2277), ('#', 29937), ('▁Ass', 4007), ('istant', 22137), (':', 29901), ...]
response_template = "### Assistant:"
print_tokens_with_ids(response_template) # [('▁###', 835), ('▁Ass', 4007), ('istant', 22137), (':', 29901)]
```
In this case, and due to lack of context in `response_template`, the same string ("### Assistant:") is tokenized differently:
- Text (with context): `[2277, 29937, 4007, 22137, 29901]`
- `response_template` (without context): `[835, 4007, 22137, 29901]`
This will lead to an error when the `DataCollatorForCompletionOnlyLM` does not find the `response_template` in the dataset example text:
```
RuntimeError: Could not find response key [835, 4007, 22137, 29901] in token IDs tensor([ 1, 835, ...])
```
To solve this, you can tokenize the `response_template` with the same context as in the dataset, truncate it as needed and pass the `token_ids` directly to the `response_template` argument of the `DataCollatorForCompletionOnlyLM` class. For example:
```python
response_template_with_context = "\n### Assistant:" # We added context here: "\n". This is enough for this tokenizer
response_template_ids = tokenizer.encode(response_template_with_context, add_special_tokens=False)[2:] # Now we have it like in the dataset texts: `[2277, 29937, 4007, 22137, 29901]`
data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)
```
### Add Special Tokens for Chat Format
Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment.
The [`setup_chat_format`] function in `trl` easily sets up a model and tokenizer for conversational AI tasks. This function:
- Adds special tokens to the tokenizer, e.g. `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
- Resizes the models embedding layer to accommodate the new tokens.
- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI.
- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g. 64. If you want to see more formats being supported in the future, please open a GitHub issue on [trl](https://github.com/huggingface/trl)
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import setup_chat_format
# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
# Set up the chat format with default 'chatml' format
model, tokenizer = setup_chat_format(model, tokenizer)
```
With our model and tokenizer set up, we can now fine-tune our model on a conversational dataset. Below is an example of how a dataset can be formatted for fine-tuning.
### Dataset format support
The [`SFTTrainer`] supports popular dataset formats. This allows you to pass the dataset to the trainer without any pre-processing directly. The following formats are supported:
* conversational format
```json
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "..."}]}
```
* instruction format
```json
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
```
If your dataset uses one of the above formats, you can directly pass it to the trainer without pre-processing. The [`SFTTrainer`] will then format the dataset for you using the defined format from the model's tokenizer with the [apply_chat_template](https://huggingface.co/docs/transformers/main/en/chat_templating#templates-for-chat-models) method.
```python
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
...
# load jsonl dataset
dataset = load_dataset("json", data_files="path/to/dataset.jsonl", split="train")
# load dataset from the HuggingFace Hub
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")
...
training_args = SFTConfig(packing=True)
trainer = SFTTrainer(
"facebook/opt-350m",
args=training_args,
train_dataset=dataset,
)
```
If the dataset is not in one of those format you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look.
### Format your input prompts
For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response.
This allows people to format examples like [Stanford-Alpaca](https://github.com/tatsu-lab/stanford_alpaca) did as follows:
```bash
Below is an instruction ...
### Instruction
{prompt}
### Response:
{completion}
```
Let us assume your dataset has two fields, `question` and `answer`. Therefore you can just run:
```python
...
def formatting_prompts_func(example):
output_texts = []
for i in range(len(example['question'])):
text = f"### Question: {example['question'][i]}\n ### Answer: {example['answer'][i]}"
output_texts.append(text)
return output_texts
trainer = SFTTrainer(
model,
args=training_args,
train_dataset=dataset,
formatting_func=formatting_prompts_func,
)
trainer.train()
```
To properly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example of how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
### Packing dataset ([`ConstantLengthDataset`])
[`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. This is done with the [`ConstantLengthDataset`] utility class that returns constant length chunks of tokens from a stream of examples. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTConfig`] constructor.
```python
...
training_args = SFTConfig(packing=True)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
args=training_args
)
trainer.train()
```
Note that if you use a packed dataset and if you pass `max_steps` in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing.
If you don't want to pack your `eval_dataset`, you can pass `eval_packing=False` to the `SFTConfig` init method.
#### Customize your prompts using packed dataset
If your dataset has several fields that you want to combine, for example if the dataset has `question` and `answer` fields and you want to combine them, you can pass a formatting function to the trainer that will take care of that. For example:
```python
def formatting_func(example):
text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
return text
training_args = SFTConfig(packing=True)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
args=training_args,
formatting_func=formatting_func
)
trainer.train()
```
You can also customize the [`ConstantLengthDataset`] much more by directly passing the arguments to the [`SFTConfig`] constructor. Please refer to that class' signature for more information.
### Control over the pretrained model
You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTConfig`]. For example, if you want to load a model in a different precision, analogous to
```python
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16)
...
training_args = SFTConfig(
model_init_kwargs={
"torch_dtype": "bfloat16",
},
output_dir="/tmp",
)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
args=training_args,
)
trainer.train()
```
Note that all keyword arguments of `from_pretrained()` are supported.
### Training adapters
We also support tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model
```python
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig
dataset = load_dataset("stanfordnlp/imdb", split="train")
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
trainer = SFTTrainer(
"EleutherAI/gpt-neo-125m",
train_dataset=dataset,
args=SFTConfig(output_dir="/tmp"),
peft_config=peft_config
)
trainer.train()
```
You can also continue training your `PeftModel`. For that, first load a `PeftModel` outside `SFTTrainer` and pass it directly to the trainer without the `peft_config` argument being passed.
### Training adapters with base 8 bit models
For that, you need to first load your 8 bit model outside the Trainer and pass a `PeftConfig` to the trainer. For example:
```python
...
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = AutoModelForCausalLM.from_pretrained(
"EleutherAI/gpt-neo-125m",
load_in_8bit=True,
device_map="auto",
)
trainer = SFTTrainer(
model,
train_dataset=dataset,
args=SFTConfig(),
peft_config=peft_config,
)
trainer.train()
```
## Using Flash Attention and Flash Attention 2
You can benefit from Flash Attention 1 & 2 using SFTTrainer out of the box with minimal changes of code.
First, to make sure you have all the latest features from transformers, install transformers from source
```bash
pip install -U git+https://github.com/huggingface/transformers.git
```
Note that Flash Attention only works on GPU now and under half-precision regime (when using adapters, base model loaded in half-precision)
Note also both features are perfectly compatible with other tools such as quantization.
### Using Flash-Attention 1
For Flash Attention 1 you can use the `BetterTransformer` API and force-dispatch the API to use Flash Attention kernel. First, install the latest optimum package:
```bash
pip install -U optimum
```
Once you have loaded your model, wrap the `trainer.train()` call under the `with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):` context manager:
```diff
...
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
trainer.train()
```
Note that you cannot train your model using Flash Attention 1 on an arbitrary dataset as `torch.scaled_dot_product_attention` does not support training with padding tokens if you use Flash Attention kernels. Therefore you can only use that feature with `packing=True`. If your dataset contains padding tokens, consider switching to Flash Attention 2 integration.
Below are some numbers you can get in terms of speedup and memory efficiency, using Flash Attention 1, on a single NVIDIA-T4 16GB.
| use_flash_attn_1 | model_name | max_seq_len | batch_size | time per training step |
| ---------------- | ----------------- | ----------- | ---------- | ---------------------- |
| x | facebook/opt-350m | 2048 | 8 | ~59.1s |
| | facebook/opt-350m | 2048 | 8 | **OOM** |
| x | facebook/opt-350m | 2048 | 4 | ~30.3s |
| | facebook/opt-350m | 2048 | 4 | ~148.9s |
### Using Flash Attention-2
To use Flash Attention 2, first install the latest `flash-attn` package:
```bash
pip install -U flash-attn
```
And add `attn_implementation="flash_attention_2"` when calling `from_pretrained`:
```python
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
attn_implementation="flash_attention_2"
)
```
If you don't use quantization, make sure your model is loaded in half-precision and dispatch your model on a supported GPU device.
After loading your model, you can either train it as it is, or attach adapters and train adapters on it in case your model is quantized.
In contrast to Flash Attention 1, the integration makes it possible to train your model on an arbitrary dataset that also includes padding tokens.
### Using model creation utility
We included a utility function to create your model.
[[autodoc]] ModelConfig
```python
from trl import ModelConfig, SFTTrainer, get_kbit_device_map, get_peft_config, get_quantization_config
model_config = ModelConfig(
model_name_or_path="facebook/opt-350m"
attn_implementation=None, # or "flash_attention_2"
)
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
trainer = SFTTrainer(
...,
model=model_config.model_name_or_path,
peft_config=get_peft_config(model_config),
)
```
### Enhance the model's performances using NEFTune
NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://huggingface.co/papers/2310.05914) from Jain et al. it consists of adding noise to the embedding vectors during training. According to the abstract of the paper:
> Standard finetuning of LLaMA-2-7B using Alpaca achieves 29.79% on AlpacaEval, which rises to 64.69% using noisy embeddings. NEFTune also improves over strong baselines on modern instruction datasets. Models trained with Evol-Instruct see a 10% improvement, with ShareGPT an 8% improvement, and with OpenPlatypus an 8% improvement. Even powerful models further refined with RLHF such as LLaMA-2-Chat benefit from additional training with NEFTune.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/neft-screenshot.png">
</div>
To use it in `SFTTrainer` simply pass `neftune_noise_alpha` when creating your `SFTConfig` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to retrieve back the original behaviour of the embedding layer.
```python
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
dataset = load_dataset("stanfordnlp/imdb", split="train")
training_args = SFTConfig(
neftune_noise_alpha=5,
)
trainer = SFTTrainer(
"facebook/opt-350m",
train_dataset=dataset,
args=training_args,
)
trainer.train()
```
We have tested NEFTune by training `mistralai/Mistral-7B-v0.1` on the [OpenAssistant dataset](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) and validated that using NEFTune led to a performance boost of ~25% on MT Bench.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-neftune-mistral-7b.png">
</div>
Note however, that the amount of performance gain is _dataset dependent_ and in particular, applying NEFTune on synthetic datasets like [UltraChat](https://huggingface.co/datasets/stingning/ultrachat) typically produces smaller gains.
### Accelerate fine-tuning 2x using `unsloth`
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks on 1x A100 listed below:
| 1 A100 40GB | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
| --------------- | --------- | --- | --------------------- | --------- | ------------ |
| Code Llama 34b | Slim Orca | 1x | 1.01x | **1.94x** | -22.7% |
| Llama-2 7b | Slim Orca | 1x | 0.96x | **1.87x** | -39.3% |
| Mistral 7b | Slim Orca | 1x | 1.17x | **1.88x** | -65.9% |
| Tiny Llama 1.1b | Alpaca | 1x | 1.55x | **2.74x** | -57.8% |
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
```python
import torch
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel
max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number
# Load model
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/mistral-7b",
max_seq_length=max_seq_length,
dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
# Do model patching and add fast LoRA weights
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=16,
lora_dropout=0, # Dropout = 0 is currently optimized
bias="none", # Bias = "none" is currently optimized
use_gradient_checkpointing=True,
random_state=3407,
)
training_args = SFTConfig(output_dir="./output", max_seq_length=max_seq_length)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
)
trainer.train()
```
The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).
## Liger-Kernel: Increase 20% throughput and reduces 60% memory for multi-GPU training
[Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. That way, we can **4x** our context length, as described in the benchmark below. They have implemented Hugging Face Compatible `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed).
With great memory reduction, you can potentially turn off cpu_offloading or gradient checkpointing to further boost the performance.
| Speed Up | Memory Reduction |
|--------------------------|-------------------------|
| ![Speed up](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-tps.png) | ![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-memory.png) |
1. To use Liger-Kernel in `SFTTrainer`, first install by
```bash
pip install liger-kernel
```
2. Once installed, set `use_liger` in [`SFTConfig`]. No other changes are needed!
```python
training_args = SFTConfig(
use_liger=True
)
```
To learn more about Liger-Kernel, visit their [official repository](https://github.com/linkedin/Liger-Kernel/).
## Best practices
Pay attention to the following best practices when training a model with that trainer:
- [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training.
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
## Multi-GPU Training
Trainer (and thus SFTTrainer) supports multi-GPU training. If you run your script with `python script.py` it will default to using DP as the strategy, which may be [slower than expected](https://github.com/huggingface/trl/issues/1303). To use DDP (which is generally recommended, see [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many?select-gpu=Accelerate#data-parallelism) for more info) you must launch the script with `python -m torch.distributed.launch script.py` or `accelerate launch script.py`. For DDP to work you must also check the following:
- If you're using gradient_checkpointing, add the following to the TrainingArguments: `gradient_checkpointing_kwargs={'use_reentrant':False}` (more info [here](https://github.com/huggingface/transformers/issues/26969)
- Ensure that the model is placed on the correct device:
```python
from accelerate import PartialState
device_string = PartialState().process_index
model = AutoModelForCausalLM.from_pretrained(
...
device_map={'':device_string}
)
```
## GPTQ Conversion
You may experience some issues with GPTQ Quantization after completing training. Lowering `gradient_accumulation_steps` to `4` will resolve most issues during the quantization process to GPTQ format.
## Extending `SFTTrainer` for Vision Language Models
`SFTTrainer` does not inherently support vision-language data. However, we provide a guide on how to tweak the trainer to support vision-language data. Specifically, you need to use a custom data collator that is compatible with vision-language data. This guide outlines the steps to make these adjustments. For a concrete example, refer to the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) which demonstrates how to fine-tune the LLaVA 1.5 model on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset.
### Preparing the Data
The data format is flexible, provided it is compatible with the custom collator that we will define later. A common approach is to use conversational data. Given that the data includes both text and images, the format needs to be adjusted accordingly. Below is an example of a conversational data format involving both text and images:
```python
images = ["obama.png"]
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Who is this?"},
{"type": "image"}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Barack Obama"}
]
},
{
"role": "user",
"content": [
{"type": "text", "text": "What is he famous for?"}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "He is the 44th President of the United States."}
]
}
]
```
To illustrate how this data format will be processed using the LLaVA model, you can use the following code:
```python
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
print(processor.apply_chat_template(messages, tokenize=False))
```
The output will be formatted as follows:
```txt
Who is this? ASSISTANT: Barack Obama USER: What is he famous for? ASSISTANT: He is the 44th President of the United States.
```
<iframe src="https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft/embed/viewer/default/train" frameborder="0" width="100%" height="560px"></iframe>
### A custom collator for processing multi-modal data
Unlike the default behavior of `SFTTrainer`, processing multi-modal data is done on the fly during the data collation process. To do this, you need to define a custom collator that processes both the text and images. This collator must take a list of examples as input (see the previous section for an example of the data format) and return a batch of processed data. Below is an example of such a collator:
```python
def collate_fn(examples):
# Get the texts and images, and apply the chat template
texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
images = [example["images"][0] for example in examples]
# Tokenize the texts and process the images
batch = processor(texts, images, return_tensors="pt", padding=True)
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
batch["labels"] = labels
return batch
```
We can verify that the collator works as expected by running the following code:
```python
from datasets import load_dataset
dataset = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft", split="train")
examples = [dataset[0], dataset[1]] # Just two examples for the sake of the example
collated_data = collate_fn(examples)
print(collated_data.keys()) # dict_keys(['input_ids', 'attention_mask', 'pixel_values', 'labels'])
```
### Training the vision-language model
Now that we have prepared the data and defined the collator, we can proceed with training the model. To ensure that the data is not processed as text-only, we need to set a couple of arguments in the `SFTConfig`, specifically `remove_unused_columns` and `skip_prepare_dataset` to `True` to avoid the default processing of the dataset. Below is an example of how to set up the `SFTTrainer`.
```python
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=train_dataset,
processing_class=processor.tokenizer,
)
```
A full example of training LLaVa 1.5 on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset can be found in the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py).
- [Experiment tracking](https://wandb.ai/huggingface/trl/runs/2b2c5l7s)
- [Trained model](https://huggingface.co/HuggingFaceH4/sft-llava-1.5-7b-hf)
## SFTTrainer
[[autodoc]] SFTTrainer
## SFTConfig
[[autodoc]] SFTConfig
## Datasets
In the SFTTrainer we smartly support `datasets.IterableDataset` in addition to other style datasets. This is useful if you are using large corpora that you do not want to save all to disk. The data will be tokenized and processed on the fly, even when packing is enabled.
Additionally, in the SFTTrainer, we support pre-tokenized datasets if they are `datasets.Dataset` or `datasets.IterableDataset`. In other words, if such a dataset has a column of `input_ids`, no further processing (tokenization or packing) will be done, and the dataset will be used as-is. This can be useful if you have pretokenized your dataset outside of this script and want to re-use it directly.
### ConstantLengthDataset
[[autodoc]] trainer.ConstantLengthDataset

View File

@ -0,0 +1,97 @@
# Speeding Up Training
> [!WARNING]
> Section under construction. Feel free to contribute!
## vLLM for fast generation in online methods
Online methods such as GRPO or Online DPO require the model to generate completions, which is often a slow process and can significantly impact training time.
To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through, among other things, PagedAttention. TRL's online trainers support vLLM, greatly improving training speed.
To use [vLLM](https://github.com/vllm-project/vllm), first install it using:
```bash
pip install trl[vllm]
```
<hfoptions id="vllm examples">
<hfoption id="Online DPO">
Then, enable it by passing `use_vllm=True` in the training arguments.
```python
from trl import OnlineDPOConfig
training_args = OnlineDPOConfig(..., use_vllm=True)
```
</hfoption>
<hfoption id="GRPO">
First, start a vLLM server by running:
```bash
trl vllm-serve --model <model_name>
```
Then, run the training script and pass `use_vllm=True` in the training arguments.
```python
from trl import GRPOConfig
training_args = GRPOConfig(..., use_vllm=True)
```
You can customize the server configuration by passing additional arguments. For more information, see [vLLM integration](vllm_integration).
> [!WARNING]
> When using vLLM, ensure that the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation using `CUDA_VISIBLE_DEVICES`.
>
> Set GPUs **0-3** for vLLM generation:
>
> ```sh
> CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model <model_name>
> ```
>
> And GPUs **4-7** for training:
>
> ```sh
> CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
> ```
</hfoption>
<hfoption id="RLOO">
First, start a vLLM server by running:
```bash
trl vllm-serve --model <model_name>
```
Then, run the training script and pass `use_vllm=True` in the training arguments.
```python
from trl import RLOOConfig
training_args = RLOOConfig(..., use_vllm=True)
```
You can customize the server configuration by passing additional arguments. For more information, see [vLLM integration](vllm_integration).
> [!WARNING]
> When using vLLM, ensure that the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation using `CUDA_VISIBLE_DEVICES`.
>
> Set GPUs **0-3** for vLLM generation:
>
> ```sh
> CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model <model_name>
> ```
>
> And GPUs **4-7** for training:
>
> ```sh
> CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
> ```
</hfoption>
</hfoptions>

View File

@ -1,197 +0,0 @@
# Text Environments
Text environments provide a learning ground for language agents. It allows a language model to use tools to accomplish a task such as using a Python interpreter to answer math questions or using a search index for trivia questions. Having access to tools allows language models to solve tasks that would be very hard for the models itself but can be trivial for the appropriate tools. A good example is arithmetics of large numbers that become a simple copy-paste task once you have access to a calculator.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/textenv.png">
</div>
Let's dive into how text environments work and start with tools!
## Tools
One of the core building blocks of text environments are tools that the model can use to solve tasks. In general tools can be any Python function that takes a string as input and returns string. The `TextEnvironment` offers two options for tools: either go with predefined tools from `transformers.Tool` or define your own function or class with `__call__` method. Let's have a look at both!
### `transformers.Tool`
Text environments fully support tools of the class `transformers.Tool`. The advantage of building tools in that framework is that they can easily be shared
```Python
from transformers import load_tool
# simple calculator tool that runs +-/* operations
calc_tool = load_tool("ybelkada/simple-calculator")
# python interpreter that executes program and returns outputs
py_tool = load_tool("lvwerra/python-interpreter")
# wikipedia search index that returns best search match
wiki_tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
```
These tools are either loaded from the hub or from a local folder. Using the tool is as simple as calling them with a text query:
```Python
calc_tool("1/2")
>>> "0.5"
```
Note that both input and return values are strings to enable easy usage with a language model.
### Custom Tools
The following is an example of a tool that adds two integers:
```Python
def add(text):
int_1, int_2 = text.split("+")
result = int(int_1) + int(int_2)
return str(result)
print(add("1+1"))
>>> "2"
```
We looked at basic examples such as a calculator but the principle holds for more complex tools as well such as a web search tool where you input the query and get the search results in return. Now let's look at how the model can use the tools with the call syntax.
### Call syntax
In order to have a unified way for the model to call a tool we created a simple syntax that looks as follows:
```python
"<request><TOOL_NAME>QUERY<call>TOOL_RESPONSE<response>"
```
There are a few special tokens involved so let's decompose it: First the model can signal that it wants to use a tool by emitting the `<request>` token. After that we want to know the name of the tool to call which is done by enclosing the tool name with `<>` brackets. Once we know which tool to call the tool query follows which is in free text form. The `<call>` tokens signifies the end of the query and stops the model generation. At this point the model output is parsed and the query sent to the tool. The environment appends the tool response to the string followed by the `<response>` token to show the end the tool output.
Let's look at the concrete example of the calculator and assume its name is `Calculator` (more on how the name of a tool is inferred later):
```python
"<request><Calculator>1/2<call>0.5<response>"
```
Finally, the episode is ended and generation stops when the model generates `<submit>` which marks the interaction as completed.
Now let's have a look how we can create a new text environment!
## Create a `TextEnvironment`
```python
prompt = """\
What is 13-3?
<request><SimpleCalculatorTool>13-3<call>10.0<response>
Result=10<submit>
"""
def reward_fn(result, answer):
"""Simplified reward function returning 1 if result matches answer and 0 otherwise."""
result_parsed = result.split("=")[1].split("<")[0]
return int(result_parsed==answer)
text_env = TextEnvironemnt(
model=model,
tokenizer=tokenizer,
tools= {"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")},
reward_fn=exact_match_reward,
prompt=prompt,
max_turns=1
max_tool_response=100
generation_kwargs={"do_sample": "true"}
)
```
Let's decompose the settings:
| Argument | Description |
|:-------------------|:----------------|
| `model` | Language model to interact with the environment and generate requests. |
| `tokenizer` | Tokenizer of language model handling tokenization of strings. |
| `tools` | `list` of `dict` of tools. If former the name of the tool is inferred from class name and otherwise it's the keys of the dictionary.|
| `reward_fn` | A function that takes a string as input and returns. Can have extra arguments that are passed to `.run()` such as ground truth.|
| `prompt` | Prompt to prepend to every task. Usually a few examples to demonstrate to the model how to use the tools in a few-shot fashion. |
| `max_turns` | Maximum number of interactions between model and tools before episode ends.|
| `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.|
| `max_length` | The maximum number of tokens to allow in an episode. |
| `generation_kwargs`| Generation settings used by the language model. |
You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools!
## Run an Episode
To run a set of queries through the text environment one can simply use the `run` method.
```python
queries = ["What is 1/2?"]
answers = ["0.5"]
queries, responses, masks, rewards, histories = text_env.run(queries, answers=answers)
```
This will execute the model/tool feedback loop for each query until either no tool is called anymore, the maximum number of turns is reached or to maximum number of tokens in an episode is exceeded. The extra `kwargs` (e.g. `answers=answers` above) passed to `run` will be passed on to the reward function.
There are five objects that are returned by `run`:
- `queries`: a list of the tokenized queries
- `responses`: all tokens that have been generated withing the environment including model and tool tokens
- `masks`: mask that indicates which tokens have been generated by the model and which tokens are generated by the tool
- `rewards`: a list of reward for each query/response
- `histories`: list of `TextHistory` objects, which are useful objects containing all the above and also the text equivalents
The masks are crucial for training as we don't want to optimize tokens that the model has not generated which are tokens produced by the tools.
Next, we'll train a PPO step with the generated responses!
### Train
Training on episodes from the `TextEnvironment` is straight forward and simply requires forwarding all the returned variables except the `TextHistory` objects to the `step` method:
```python
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
```
## `TextHistory`
The `TextHistory` object stores the interactions between the model and the text environment. It stores tokens and text generated in each turn and their source in each turn (model or system) as well as rewards. Let's go through the class attributes and methods.
### Attributes
The following table summarises the available attributes of the `TextEnvironment` class:
| Attribute | Description |
|:-------------------|:----------------|
| `text` | The full string of the text generated in the text environment with both model and system generated text. |
| `text_spans` | A list of tuples with the spans for each model or system generated text segment. |
| `system_spans` | A list of boolean values indicating if the segment is model or system generated. |
| `tokens` | All tokens generated in text environment with both model and system generated tokens. |
| `token_spans` | Similar to `text_spans` the `token_spans` indicate the boundaries of model andsystem generated tokens. |
| `token_masks` | The token masks can be used to ignore system generated tokens by masking them. |
| `completed` | Indicates if the interaction with the environment has completed. |
| `truncated` | Indicates if the interaction with the environment has completed because max length was reached. |
With these attributes you can reconstruct every interaction of the model with the `TextEnvironment`. The `TextHistory` also lets you visualize the text history. Let's have a look!
### Visualization
When the model interacts inside the `TextEnvironment` it can be useful to visualize and separate which parts of the text outputs were generated by the model and which parts come from the system and tools. For that purpose there are the two methods [`TextHistory.show_text`] and [`TextHistory.show_tokens`]. They print the text and tokens respectively and highlight the various segments using the [`rich` libray](https://github.com/Textualize/rich) (make sure to install it before using these methods).
You can see that the prompt is highlighted in gray, whereas system segments such as query and tool responses are highlighted in green. All segments generated by the model are highlighted in blue and in addition to the pure text output the reward is displayed as additional text in plum. Here an example of `show_text`:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/textenv_show_text.png" width=600>
</div>
Sometimes there can be tricky tokenization related issues that are hidden when showing the decoded text. Thus `TextHistory` also offers an option to display the same highlighting on the tokens directly with `show_tokens`:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/textenv_show_tokens.png" width=800>
</div>
Note that you can turn on the colour legend by passing `show_legend=True`.
## API Documentation
[[autodoc]] TextEnvironment
[[autodoc]] TextHistory

View File

@ -0,0 +1,67 @@
# Trackio Integration
[Trackio](https://huggingface.co/docs/trackio) is a lightweight, free experiment tracking library built on top of **🤗 Datasets** and **🤗 Spaces**. It is the **recommended tracking solution for TRL** and comes natively integrated with all trainers.
To enable logging, simply set `report_to="trackio"` in your training config:
```python
from trl import SFTConfig # works with any trainer config (e.g. DPOConfig, GRPOConfig, etc.)
training_args = SFTConfig(
...,
report_to="trackio", # enable Trackio logging
)
```
## Organizing Your Experiments with Run Names and Projects
By default, Trackio will generate a name to identify each run. However, we highly recommend setting a descriptive `run_name` to make it easier to organize experiments. For example:
```python
from trl import SFTConfig
training_args = SFTConfig(
...,
report_to="trackio",
run_name="sft_qwen3-4b_lr2e-5_bs128", # descriptive run name
)
```
You can also group related experiments by project by setting the following environment variable:
```bash
export TRACKIO_PROJECT="my_project"
```
## Hosting Your Logs on 🤗 Spaces
Trackio has local-first design, meaning your logs stay on your machine. If youd like to host them and deploy a dashboard on **🤗 Spaces**, set:
```bash
export TRACKIO_SPACE_ID="username/space_id"
```
Running the following example:
```python
import os
from trl import SFTConfig, SFTTrainer
from datasets import load_dataset
os.environ["TRACKIO_SPACE_ID"] = "trl-lib/trackio"
os.environ["TRACKIO_PROJECT"] = "trl-documentation"
trainer = SFTTrainer(
model="Qwen/Qwen3-0.6B",
train_dataset=load_dataset("trl-lib/Capybara", split="train"),
args=SFTConfig(
report_to="trackio",
run_name="sft_qwen3-0.6b_capybara",
),
)
trainer.train()
```
will give you a hosted dashboard at https://huggingface.co/spaces/trl-lib/trackio.
<iframe src="https://trl-lib-trackio.hf.space/?project=trl-documentation&sidebar=hidden&runs=sft_qwen3-0.6B_capybara" style="width: 100%; min-width: 300px; max-width: 800px;" height="830" frameBorder="0"></iframe>

View File

@ -0,0 +1,125 @@
# Unsloth Integration
Unsloth is an opensource framework for finetuning and reinforcement learning that trains LLMs (like Llama, OpenAI gpt-oss, Mistral, Gemma, DeepSeek, and more) up to 2× faster with up to 80% less VRAM. Unsloth allows [training](https://huggingface.co/docs/trl/en/unsloth_integration#Training), evaluation, running and [deployment](https://huggingface.co/docs/trl/en/unsloth_integration#Saving-the-model) with other inference engines like llama.cpp, Ollama and vLLM.
The library provides a streamlined, Hugging Face compatible workflow for training, evaluation, inference and deployment and is fully compatible with [`SFTTrainer`].
## Key Features
- Training support for all transformer compatible models: Text-to-speech (TTS), multimodal, BERT, RL and more
- Supports full fine-tuning, pretraining, LoRA, QLoRA, 8-bit training & more
- Works on Linux, Windows, Colab, Kaggle; NVIDIA GPUs, soon AMD & Intel setups
- Supports most features TRL supports, including RLHF (GSPO, GRPO, DPO etc.)
- Hand-written Triton kernels and a manual backprop engine ensure no accuracy degradation (0% approximation error)
## Installation
### pip install
Local Installation (Linux recommended):
```sh
pip install unsloth
```
You can also install `unsloth` according to the [official documentation](https://docs.unsloth.ai/get-started/installing-+-updating). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading [`~transformers.AutoModelForCausalLM`], you just need to load a `FastLanguageModel` as follows:
```python
import torch
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel
max_length = 2048 # Supports automatic RoPE Scaling, so choose any number
# Load model
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/mistral-7b",
max_seq_length=max_length,
dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False
)
# Do model patching and add fast LoRA weights
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=16,
lora_dropout=0, # Dropout = 0 is currently optimized
bias="none", # Bias = "none" is currently optimized
use_gradient_checkpointing=True,
random_state=3407,
)
training_args = SFTConfig(output_dir="./output", max_length=max_length)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
)
trainer.train()
```
The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).
### Docker Install
```sh
docker run -d -e JUPYTER_PASSWORD="mypassword" \
-p 8888:8888 -p 2222:22 \
-v $(pwd)/work:/workspace/work \
--gpus all \
unsloth/unsloth
```
Access Jupyter Lab at ```http://localhost:8888``` and start fine-tuning!
## Training
These are some core settings you can toggle before training:
- ```max_seq_length = 2048``` Controls context length. While Llama-3 supports 8192, we recommend 2048 for testing. Unsloth enables 4× longer context fine-tuning.
- ```dtype = None``` Defaults to None; use torch.float16 or torch.bfloat16 for newer GPUs.
- ```load_in_4bit = True``` Enables 4-bit quantization, reducing memory use 4× for fine-tuning. Disabling it allows for LoRA 16-bit fine-tuning to be enabled.
- To enable full fine-tuning (FFT), set ```full_finetuning = True```. For 8-bit fine-tuning, set ```load_in_8bit = True```. Note: Only one training method can be set to True at a time.
For more information on configuring Unsloth's hyperparameters and features, read their [documentation guide here](https://docs.unsloth.ai/get-started/fine-tuning-llms-guide).
## Saving the model
Unsloth allows you to directly save the finetuned model as a small file called a LoRA adapter. You can instead push to the Hugging Face hub as well if you want to upload your model! Remember to get a [Hugging Face token](https://huggingface.co/settings/tokens) and add your token!
### Saving to GGUF
To save to GGUF, Unsloth uses llama.cpp. To save locally:
```python
model.save_pretrained_gguf("directory", tokenizer, quantization_method = "q4_k_m")
model.save_pretrained_gguf("directory", tokenizer, quantization_method = "q8_0")
model.save_pretrained_gguf("directory", tokenizer, quantization_method = "f16")
```
To push to the hub:
```python
model.push_to_hub_gguf("hf_username/directory", tokenizer, quantization_method = "q4_k_m")
model.push_to_hub_gguf("hf_username/directory", tokenizer, quantization_method = "q8_0")
```
### Saving to vLLM
To save to 16-bit for vLLM, use:
```python
model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_16bit", token = "")
```

View File

@ -36,7 +36,7 @@ print(pipe("This movie was really")[0]["generated_text"])
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
base_model_name = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub"
base_model_name = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub
adapter_model_name = "path/to/my/adapter"
model = AutoModelForCausalLM.from_pretrained(base_model_name)

View File

@ -1,160 +0,0 @@
# Using LLaMA models with TRL
We've begun rolling out examples to use Meta's LLaMA models in `trl` (see [Meta's LLaMA release](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) for the original LLaMA model).
## Efficient training strategies
Even training the smallest LLaMA model requires an enormous amount of memory. Some quick math: in bf16, every parameter uses 2 bytes (in fp32 4 bytes) in addition to 8 bytes used, e.g., in the Adam optimizer (see the [performance docs](https://huggingface.co/docs/transformers/perf_train_gpu_one#optimizer) in Transformers for more info). So a 7B parameter model would use `(2+8)*7B=70GB` just to fit in memory and would likely need more when you compute intermediate values such as attention scores. So you couldnt train the model even on a single 80GB A100 like that. You can use some tricks, like more efficient optimizers of half-precision training, to squeeze a bit more into memory, but youll run out sooner or later.
Another option is to use Parameter-Efficient Fine-Tuning (PEFT) techniques, such as the [`peft`](https://github.com/huggingface/peft) library, which can perform low-rank adaptation (LoRA) on a model loaded in 8-bit.
For more on `peft` + `trl`, see the [docs](https://huggingface.co/docs/trl/sentiment_tuning_peft).
Loading the model in 8bit reduces the memory footprint drastically since you only need one byte per parameter for the weights (e.g. 7B LlaMa is 7GB in memory).
Instead of training the original weights directly, LoRA adds small adapter layers on top of some specific layers (usually the attention layers); thus, the number of trainable parameters is drastically reduced.
In this scenario, a rule of thumb is to allocate ~1.2-1.4GB per billion parameters (depending on the batch size and sequence length) to fit the entire fine-tuning setup.
This enables fine-tuning larger models (up to 50-60B scale models on a NVIDIA A100 80GB) at low cost.
Now we can fit very large models into a single GPU, but the training might still be very slow.
The simplest strategy in this scenario is data parallelism: we replicate the same training setup into separate GPUs and pass different batches to each GPU.
With this, you can parallelize the forward/backward passes of the model and scale with the number of GPUs.
![chapter10_ddp.png](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/blog/stackllama/chapter10_ddp.png)
We use either the `transformers.Trainer` or `accelerate`, which both support data parallelism without any code changes, by simply passing arguments when calling the scripts with `torchrun` or `accelerate launch`. The following runs a training script with 8 GPUs on a single machine with `accelerate` and `torchrun`, respectively.
```bash
accelerate launch --multi_gpu --num_machines 1 --num_processes 8 my_accelerate_script.py
torchrun --nnodes 1 --nproc_per_node 8 my_torch_script.py
```
## Supervised fine-tuning
Before we start training reward models and tuning our model with RL, it helps if the model is already good in the domain we are interested in.
In our case, we want it to answer questions, while for other use cases, we might want it to follow instructions, in which case instruction tuning is a great idea.
The easiest way to achieve this is by continuing to train the language model with the language modeling objective on texts from the domain or task.
The [StackExchange dataset](https://huggingface.co/datasets/HuggingFaceH4/stack-exchange-preferences) is enormous (over 10 million instructions), so we can easily train the language model on a subset of it.
There is nothing special about fine-tuning the model before doing RLHF - its just the causal language modeling objective from pretraining that we apply here.
To use the data efficiently, we use a technique called packing: instead of having one text per sample in the batch and then padding to either the longest text or the maximal context of the model, we concatenate a lot of texts with a EOS token in between and cut chunks of the context size to fill the batch without any padding.
![chapter10_preprocessing-clm.png](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/blog/stackllama/chapter10_preprocessing-clm.png)
With this approach the training is much more efficient as each token that is passed through the model is also trained in contrast to padding tokens which are usually masked from the loss.
If you don't have much data and are more concerned about occasionally cutting off some tokens that are overflowing the context you can also use a classical data loader.
The packing is handled by the `ConstantLengthDataset` and we can then use the `Trainer` after loading the model with `peft`. First, we load the model in int8, prepare it for training, and then add the LoRA adapters.
```python
# load model in 8bit
model = AutoModelForCausalLM.from_pretrained(
args.model_path,
load_in_8bit=True,
device_map={"": Accelerator().local_process_index}
)
model = prepare_model_for_kbit_training(model)
# add LoRA to model
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
```
We train the model for a few thousand steps with the causal language modeling objective and save the model.
Since we will tune the model again with different objectives, we merge the adapter weights with the original model weights.
**Disclaimer:** due to LLaMA's license, we release only the adapter weights for this and the model checkpoints in the following sections.
You can apply for access to the base model's weights by filling out Meta AI's [form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) and then converting them to the 🤗 Transformers format by running this [script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py).
Note that you'll also need to install 🤗 Transformers from source until the `v4.28` is released.
Now that we have fine-tuned the model for the task, we are ready to train a reward model.
## Reward modeling and human preferences
In principle, we could fine-tune the model using RLHF directly with the human annotations.
However, this would require us to send some samples to humans for rating after each optimization iteration.
This is expensive and slow due to the number of training samples needed for convergence and the inherent latency of human reading and annotator speed.
A trick that works well instead of direct feedback is training a reward model on human annotations collected before the RL loop.
The goal of the reward model is to imitate how a human would rate a text. There are several possible strategies to build a reward model: the most straightforward way would be to predict the annotation (e.g. a rating score or a binary value for “good”/”bad”).
In practice, what works better is to predict the ranking of two examples, where the reward model is presented with two candidates `(y_k, y_j)` for a given prompt `x` and has to predict which one would be rated higher by a human annotator.
With the StackExchange dataset, we can infer which of the two answers was preferred by the users based on the score.
With that information and the loss defined above, we can then modify the `transformers.Trainer` by adding a custom loss function.
```python
class RewardTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
rewards_j = model(input_ids=inputs["input_ids_j"], attention_mask=inputs["attention_mask_j"])[0]
rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0]
loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean()
if return_outputs:
return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k}
return loss
```
We utilize a subset of a 100,000 pair of candidates and evaluate on a held-out set of 50,000. With a modest training batch size of 4, we train the Llama model using the LoRA `peft` adapter for a single epoch using the Adam optimizer with BF16 precision. Our LoRA configuration is:
```python
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
)
```
As detailed in the next section, the resulting adapter can be merged into the frozen model and saved for further downstream use.
## Reinforcement Learning from Human Feedback
With the fine-tuned language model and the reward model at hand, we are now ready to run the RL loop. It follows roughly three steps:
1. Generate responses from prompts,
2. Rate the responses with the reward model,
3. Run a reinforcement learning policy-optimization step with the ratings.
The Query and Response prompts are templated as follows before being tokenized and passed to the model:
```bash
Question: <Query>
Answer: <Response>
```
The same template was used for SFT, RM and RLHF stages.
Once more, we utilize `peft` for memory-efficient training, which offers an extra advantage in the RLHF context.
Here, the reference model and policy share the same base, the SFT model, which we load in 8-bit and freeze during training.
We exclusively optimize the policy's LoRA weights using PPO while sharing the base model's weights.
```python
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
question_tensors = batch["input_ids"]
# sample from the policy and to generate responses
response_tensors = ppo_trainer.generate(
question_tensors,
return_prompt=False,
length_sampler=output_length_sampler,
**generation_kwargs,
)
batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
# Compute sentiment score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs]
# Run PPO step
stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
# Log stats to Wandb
ppo_trainer.log_stats(stats, batch, rewards)
```
For the rest of the details and evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama).

View File

@ -0,0 +1,499 @@
# vLLM Integration
This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood.
> [!WARNING]
> TRL currently only supports vLLM version `0.10.2`. Please ensure you have this version installed to avoid compatibility issues.
> [!TIP]
> The following trainers currently support generation with vLLM:
>
> - [`GRPOTrainer`]
> - [`OnlineDPOTrainer`]
> - [`NashMDTrainer`]
> - [`XPOTrainer`]
> - [`RLOOTrainer`]
## 🚀 How can I use vLLM with TRL to speed up training?
💡 **Note**: Resources required for this specific example: a single node with 8 GPUs.
> [!WARNING]
> When using vLLM with TRL, the **vLLM server** and the **trainer** must run on **separate CUDA devices** to prevent conflicts.
> For guidance on configuring this properly, see [Modes of using vLLM during training](#modes-of-using-vllm-during-training).
First, install vLLM using the following command:
```bash
pip install "trl[vllm]"
```
Then run the server on specific GPUs (e.g., GPUs 0-3):
```sh
CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 2 --data-parallel-size 2
```
Once the server is running, you can use it to generate completions for training. In the example below, we are using the different supported trainers using the vLLM server for generation. The `--tensor-parallel-size` and `--data-parallel-size` arguments control how the model and data are sharded across GPUs.
In this example, we are sharding two copies of the model across 4 GPUs. Increasing data parallelism increases throughput, while increasing tensor parallelism allows for serving larger models. Then, run the training script on different GPUs (e.g., GPUs 4-7) by passing `use_vllm=True` in the training arguments as follows:
Sample of a simple `train.py` script:
<hfoptions id="vllm examples">
<hfoption id="GRPO">
```python
from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
dataset = load_dataset("trl-lib/tldr", split="train")
# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
return [len(set(c)) for c in completions]
training_args = GRPOConfig(
output_dir="my_test",
use_vllm=True,
bf16=True,
gradient_checkpointing=True,
)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-7B",
args=training_args,
reward_funcs=reward_num_unique_chars,
train_dataset=dataset,
)
trainer.train()
```
</hfoption>
<hfoption id="OnlineDPO">
```python
from datasets import load_dataset
from trl import OnlineDPOTrainer, OnlineDPOConfig
dataset = load_dataset("trl-lib/tldr", split="train")
# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
return [len(set(c)) for c in completions]
training_args = OnlineDPOConfig(
output_dir="my_test",
use_vllm=True,
bf16=True,
gradient_checkpointing=True,
)
trainer = OnlineDPOTrainer(
model="Qwen/Qwen2.5-7B",
args=training_args,
reward_funcs=reward_num_unique_chars,
train_dataset=dataset,
)
trainer.train()
```
</hfoption>
<hfoption id="NashMD">
```python
from datasets import load_dataset
from trl import NashMDTrainer, NashMDConfig
dataset = load_dataset("trl-lib/tldr", split="train")
# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
return [len(set(c)) for c in completions]
training_args = NashMDConfig(
output_dir="my_test",
use_vllm=True,
bf16=True,
gradient_checkpointing=True,
)
trainer = NashMDTrainer(
model="Qwen/Qwen2.5-7B",
args=training_args,
reward_funcs=reward_num_unique_chars,
train_dataset=dataset,
)
trainer.train()
```
</hfoption>
<hfoption id="XPO">
```python
from datasets import load_dataset
from trl import XPOTrainer, XPOConfig
dataset = load_dataset("trl-lib/tldr", split="train")
# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
return [len(set(c)) for c in completions]
training_args = XPOConfig(
output_dir="my_test",
use_vllm=True,
bf16=True,
gradient_checkpointing=True,
)
trainer = XPOTrainer(
model="Qwen/Qwen2.5-7B",
args=training_args,
reward_funcs=reward_num_unique_chars,
train_dataset=dataset,
)
trainer.train()
```
</hfoption>
<hfoption id="RLOO">
```python
from datasets import load_dataset
from trl import RLOOTrainer, RLOOConfig
dataset = load_dataset("trl-lib/tldr", split="train")
# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
return [len(set(c)) for c in completions]
training_args = RLOOConfig(
output_dir="my_test",
use_vllm=True,
bf16=True,
gradient_checkpointing=True,
)
trainer = RLOOTrainer(
model="Qwen/Qwen2.5-7B",
args=training_args,
reward_funcs=reward_num_unique_chars,
train_dataset=dataset,
)
trainer.train()
```
</hfoption>
</hfoptions>
And the train command on separate GPUs from the server:
```sh
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
```
## Why using vLLM?
### 🎬 Flashback: Why do we need to use vLLM in online methods?
Online methods like GRPO or Online DPO require the model to generate completions during training, which are then used to compute reward signals. However, generation can be extremely time-consuming, especially with large or reasoning models. In the default setup (without vLLM), completions are generated using the [(unwrapped) model's `generate` method](https://github.com/huggingface/trl/blob/f3e8c2304428ef16e9ae5de9e5741ed84d533b7b/trl/trainer/grpo_trainer.py#L965C39-L965C66). This approach quickly becomes a major bottleneck — generation is slow and inefficient, particularly for large batches or models. As a result, training times increase significantly, and overall efficiency drops. To address this, we turn to vLLM, which enables much faster and more scalable generation, helping eliminate this bottleneck in online methods.
### 🤔 How does vLLM solve the slow generation issue?
If you've ever done autoregressive decoder training, you know all the input tokens to the LLM produce their attention key and value tensors, and these tensors are kept in GPU memory to later generate subsequent tokens based on them. These cached key and value tensors are often referred to as the KV cache. However, storing the KV cache occupies a lot of memory, so vLLM uses a technique called **PagedAttention** to solve this problem. PagedAttention, which is inspired by the OSs virtual memory concept, stores continuous keys and values in **non-contiguous memory space**, which is much more efficient. The details of this are beyond the scope of this document, but in short, it allows the model to store the keys and values in a more efficient way, reducing the memory footprint and speeding up the generation process. If you are interested, make sure to check out the [vLLM PagedAttention](https://blog.vllm.ai/2023/06/20/vllm.html) for more details.
## How vLLM Works (Under the Hood) 🔍
### 🤔 What exactly happens when you run `trl vllm-serve --model <model_name>`?
When you run for example
```sh
CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 1 --data-parallel-size 4
```
the following happens:
![vllm](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/vllm-doc.png)
1. vLLM first spawns multiple workers to handle incoming requests in parallel. The number of workers is determined by multiplying the `--tensor-parallel-size` and `--data-parallel-size` values. In this example, it spawns 4 workers (1 × 4).
Each worker operates independently and processes a chunk of the incoming requests — which are basically the prompts sent to the server for generation. A key point to understand is that these 4 workers are running in parallel, and each one is responsible for handling a subset of the total incoming load.
2. Once the incoming requests (prompts) are distributed across the workers, the model starts generating completions. Internally, the models weights are split across multiple GPUs based on the `--tensor-parallel-size` argument — this is how tensor parallelism is handled. Meanwhile, data parallelism (controlled by `--data-parallel-size`) ensures that different sets of requests are processed independently across the workers. In short: tensor parallelism splits the model across GPUs, and data parallelism splits the batch of requests across different model replicas.
3. Although the GPUs process requests independently and in parallel, they still need to communicate with each other. Remember that each GPU handles only a slice of the incoming prompts (for example, with 4 GPUs and 8 prompts using `--data-parallel-size=4`, each GPU processes 2 prompts).
This GPU-to-GPU communication is managed efficiently by NVIDIAs NCCL library. The communication mainly ensures that each GPU gets its correct portion of the incoming requests — its lightweight and doesnt interfere with generation itself.
Separately, the number of completions to generate per prompt is controlled by the `num_generations` setting in the GRPO config. For instance, if you set `num_generations=2` (like in the picture above), each prompt will have 2 completions. So, with 8 prompts and `num_generations=2`, you would end up with 16 completions total — regardless of the number of GPUs or parallelism settings.
### 🥸 More detail on what happens under the hood when running the server
- The vLLM server starts by running the command: `trl vllm-serve --model Qwen/Qwen2.5-7B`.
- Once the server is running, it generates completions based on requests from the client (trainer) using `vllm_client.generate` [these lines](https://github.com/huggingface/trl/blob/cc044e35b285be7dc062764b3364e1e684db4c7c/trl/trainer/grpo_trainer.py#L1025-L1035).
- The client (trainer) then requests these completions from the server.
- These completions are used to compute the reward signal.
- Based on the reward signal and the models output, the loss is computed, and the backward pass is performed to update the models weights.
- **Note**: The server only handles completion generation — it doesnt train the model. Therefore, the models weights arent updated on the server. Once the backward pass is complete, the client sends the updated weights to the server using `vllm_client.update_named_param(name, param.data)`.
When using vLLM, ensure the GPUs assigned for training and generation are separate to avoid NCCL communication conflicts. If you do not set the `CUDA_VISIBLE_DEVICES` environment variable, the training script will use all available GPUs by default, which may lead to device conflicts. Starting from TRL next release after v0.19.1, the code automatically detects and prevents same-device usage, raising a error at the vllm server process:
```log
RuntimeError: Attempting to use the same CUDA device for multiple distinct roles/ranks within the same communicator.
Ensure that trainer is using different devices than vLLM server.
```
For example, if you want to use GPUs 47 for training while the server runs on GPUs 0-3, set:
```sh
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
```
## Advanced usage
### 🍷 More customization options with vLLM?
You can customize the server configuration by passing additional arguments.
```txt
$ trl vllm-serve --help
usage: trl vllm-serve [-h] --model MODEL [--revision REVISION] [--tensor_parallel_size TENSOR_PARALLEL_SIZE] [--data_parallel_size DATA_PARALLEL_SIZE] [--host HOST]
[--port PORT] [--gpu_memory_utilization GPU_MEMORY_UTILIZATION] [--dtype DTYPE] [--max_model_len MAX_MODEL_LEN]
[--enable_prefix_caching ENABLE_PREFIX_CACHING] [--enforce_eager [ENFORCE_EAGER]] [--kv_cache_dtype KV_CACHE_DTYPE]
[--trust_remote_code [TRUST_REMOTE_CODE]] [--log_level LOG_LEVEL] [--vllm_model_impl VLLM_MODEL_IMPL]
options:
-h, --help show this help message and exit
--model MODEL Model name or path to load the model from. (default: None)
--revision REVISION Revision to use for the model. If not specified, the default branch will be used. (default: None)
--tensor_parallel_size TENSOR_PARALLEL_SIZE, --tensor-parallel-size TENSOR_PARALLEL_SIZE
Number of tensor parallel workers to use. (default: 1)
--data_parallel_size DATA_PARALLEL_SIZE, --data-parallel-size DATA_PARALLEL_SIZE
Number of data parallel workers to use. (default: 1)
--host HOST Host address to run the server on. (default: 0.0.0.0)
--port PORT Port to run the server on. (default: 8000)
--gpu_memory_utilization GPU_MEMORY_UTILIZATION, --gpu-memory-utilization GPU_MEMORY_UTILIZATION
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the device dedicated to generation
powered by vLLM. Higher values will increase the KV cache size and thus improve the model's throughput. However, if the value is too high,
it may cause out-of-memory (OOM) errors during initialization. (default: 0.9)
--dtype DTYPE Data type to use for vLLM generation. If set to 'auto', the data type will be automatically determined based on the model configuration.
Find the supported values in the vLLM documentation. (default: auto)
--max_model_len MAX_MODEL_LEN, --max-model-len MAX_MODEL_LEN
If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced `vllm_gpu_memory_utilization`, leading to a
reduced KV cache size. If not set, vLLM will use the model context size, which might be much larger than the KV cache, leading to
inefficiencies. (default: None)
--enable_prefix_caching ENABLE_PREFIX_CACHING, --enable-prefix-caching ENABLE_PREFIX_CACHING
Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support this feature. (default: None)
--enforce_eager [ENFORCE_EAGER], --enforce-eager [ENFORCE_EAGER]
Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the model in eager mode. If `False`
(default behavior), we will use CUDA graph and eager execution in hybrid. (default: False)
--kv_cache_dtype KV_CACHE_DTYPE, --kv-cache-dtype KV_CACHE_DTYPE
Data type to use for KV cache. If set to 'auto', the dtype will default to the model data type. (default: auto)
--trust_remote_code [TRUST_REMOTE_CODE], --trust-remote-code [TRUST_REMOTE_CODE]
Whether to trust remote code when loading models. Set to True to allow executing code from model repositories. This is required for some
custom models but introduces security risks. (default: False)
--log_level LOG_LEVEL, --log-level LOG_LEVEL
Log level for uvicorn. Possible choices: 'critical', 'error', 'warning', 'info', 'debug', 'trace'. (default: info)
--vllm_model_impl VLLM_MODEL_IMPL, --vllm-model-impl VLLM_MODEL_IMPL
Model implementation to use for vLLM. Must be one of `transformers` or `vllm`. `transformers`: Use the `transformers` backend for model
implementation. `vllm`: Use the `vllm` library for model implementation. (default: vllm)
```
### 💆🏻‍♀️ What's the best distributed setup?
![tp dp throughput 8 gpus](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_8_gpus.png)
![tp dp throughput 4 gpus](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_4_gpus.png)
First and foremost, always remember that the optimal setup depends on:
- The model size
- The number of GPUs you have
- The GPU memory size
- The batch size you are using
- The number of requests you are sending to the server (prompts)
- The `max_model_len` you are using (this is the max length of the input sequence that the model can process, a.k.a. the context window size)
- The number of completions you are generating for each request (`num_generations`)
Given these factors, our experiments on the Qwen model family (3B, 7B, 14B, 32B) using 8 H100 GPUs show that:
- For reasonable-sized models (3B14B) and a moderate context window (`max_len < 8k`), using full capacity for data parallelism gives better throughput. The setup `(tp=1, dp=8)` yields the best results.
- For larger models (32B) and longer context windows (`max_len > 8k`), a smaller DP size combined with some model-side parallelism performs better. For example, `(tp=2, dp=4)` is a good setup for 32B models with a larger context window.
### vLLM with Transformers Backend
vLLM can use the **Transformers backend** for model implementations, which works for both LLMs and VLMs.
To enable this, set `vllm_model_impl="transformers"` in your configuration or pass it via the command-line argument.
For more details, check out [vLLM Transformers Backend](https://blog.vllm.ai/2025/04/11/transformers-backend.html).
Example:
```sh
CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen
2.5-VL-3B-Instruct --tensor-parallel-size 1 --port 8000 --enforce_eager --vllm_model_impl transformers
```
### Modes of Using vLLM During Training
TRL supports **two modes** for integrating vLLM during training: **server mode** and **colocate mode**.
#### Server Mode
In **server mode**, vLLM runs as a separate process on dedicated GPUs and communicates with the trainer via HTTP.
This setup is ideal if you have GPUs dedicated to inference.
Example configuration:
<hfoptions id="vllm examples">
<hfoption id="GRPO">
```python
from trl import GRPOConfig
training_args = GRPOConfig(
...,
use_vllm=True,
vllm_mode="server", # default value, can be omitted
)
```
</hfoption>
<hfoption id="OnlineDPO">
```python
from trl import OnlineDPOConfig
training_args = OnlineDPOConfig(
...,
use_vllm=True,
vllm_mode="server", # default value, can be omitted
)
```
</hfoption>
<hfoption id="NashMD">
```python
from trl import NashMDConfig
training_args = NashMDConfig(
...,
use_vllm=True,
vllm_mode="server", # default value, can be omitted
)
```
</hfoption>
<hfoption id="XPO">
```python
from trl import XPOConfig
training_args = XPOConfig(
...,
use_vllm=True,
vllm_mode="server", # default value, can be omitted
)
```
</hfoption>
<hfoption id="RLOO">
```python
from trl import RLOOConfig
training_args = RLOOConfig(
...,
use_vllm=True,
vllm_mode="server", # default value, can be omitted
)
```
</hfoption>
</hfoptions>
#### Colocate Mode
In **colocate mode**, vLLM runs inside the trainer process and shares GPU memory with the training model.
This avoids launching a separate server and can improve GPU utilization, but may lead to memory contention on the training GPUs.
Example configuration:
<hfoptions id="vllm examples">
<hfoption id="GRPO">
```python
from trl import GRPOConfig
training_args = GRPOConfig(
...,
use_vllm=True,
vllm_mode="colocate",
)
```
</hfoption>
<hfoption id="OnlineDPO">
```python
from trl import OnlineDPOConfig
training_args = OnlineDPOConfig(
...,
use_vllm=True,
vllm_mode="colocate",
)
```
</hfoption>
<hfoption id="NashMD">
```python
from trl import NashMDConfig
training_args = NashMDConfig(
...,
use_vllm=True,
vllm_mode="colocate",
)
```
</hfoption>
<hfoption id="XPO">
```python
from trl import XPOConfig
training_args = XPOConfig(
...,
use_vllm=True,
vllm_mode="colocate",
)
```
</hfoption>
<hfoption id="RLOO">
```python
from trl import RLOOConfig
training_args = RLOOConfig(
...,
use_vllm=True,
vllm_mode="colocate",
)
```
</hfoption>
</hfoptions>
> [!WARNING]
> Check the documentation of the trainer you are using for specific details on vLLM usage and parameters.
> [!WARNING]
> To reduce GPU memory usage when running vLLM, consider [enabling vLLM sleep mode](reducing_memory_usage#vllm-sleep-mode).

View File

@ -1,10 +1,10 @@
# XPO Trainer
[![](https://img.shields.io/badge/All_models-XPO-blue)](https://huggingface.co/models?other=xpo,trl)
[![model badge](https://img.shields.io/badge/All_models-XPO-blue)](https://huggingface.co/models?other=xpo,trl)
## Overview
Exploratory Preference Optimization (XPO) was proposed in the paper [Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF](https://huggingface.co/papers/2405.21046) by Tengyang Xie, Dylan J. Foster, Akshay Krishnamurthy, [Corby Rosset](https://huggingface.co/corbyrosset), [Ahmed Awadallah](https://huggingface.co/AhmedAwadallah), and Alexander Rakhlin. It is a simple online preference tuning method based on the DPO loss together with a reward model (RM). XPO augments the DPO objective with an exploration bonus allowing the method to explore outside the support of the intitial model and human feedback data.
Exploratory Preference Optimization (XPO) was proposed in the paper [Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF](https://huggingface.co/papers/2405.21046) by Tengyang Xie, Dylan J. Foster, Akshay Krishnamurthy, [Corby Rosset](https://huggingface.co/corbyrosset), [Ahmed Awadallah](https://huggingface.co/AhmedAwadallah), and Alexander Rakhlin. It is a simple online preference tuning method based on the DPO loss together with a reward model (RM). XPO augments the DPO objective with an exploration bonus allowing the method to explore outside the support of the initial model and human feedback data.
The abstract from the paper is the following:
@ -35,7 +35,7 @@ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
judge = PairRMJudge()
train_dataset = load_dataset("trl-lib/ultrafeedback-prompt", split="train")
training_args = XPOConfig(output_dir="Qwen2-0.5B-XPO", logging_steps=10)
training_args = XPOConfig(output_dir="Qwen2-0.5B-XPO")
trainer = XPOTrainer(
model=model, judge=judge, args=training_args, processing_class=tokenizer, train_dataset=train_dataset
)
@ -50,9 +50,9 @@ accelerate launch train_xpo.py
Distributed across 8 GPUs, the training takes approximately 1 hour.
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-XPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-XPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-XPO
<pre><code>$ transformers chat trl-lib/Qwen2-0.5B-XPO
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?
@ -80,15 +80,12 @@ Instead of a judge, you can chose to use a reward model -- see [Reward Bench](ht
trainer = XPOTrainer(
...
- judge=judge,
+ reward_model=reward_model,
+ reward_funcs=reward_model,
)
```
<Tip warning={true}>
Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training.
</Tip>
> [!WARNING]
> Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training.
### Encourage EOS token generation
@ -110,7 +107,7 @@ trainer.add_callback(completions_callback)
This callback logs the model's generated completions directly to Weights & Biases.
![Logged Completions](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/wandb_completions.png)
![Logged Completions](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/wandb_completions.png)
## Example script
@ -124,7 +121,6 @@ python examples/scripts/xpo.py \
--judge pair_rm \
--dataset_name trl-lib/ultrafeedback-prompt \
--learning_rate 5.0e-7 \
--logging_steps 25 \
--output_dir Qwen2.5-0.5B-XPO-PairRM \
--warmup_ratio 0.1 \
--push_to_hub
@ -132,7 +128,7 @@ python examples/scripts/xpo.py \
## Logged metrics
The logged metrics are as follows:
While training and evaluating we record the following reward metrics:
* `loss/xpo`: The mean xpo part of the full loss.
* `loss/dpo`: The mean dpo part of the full loss.
@ -152,10 +148,12 @@ The logged metrics are as follows:
* `alpha`: The weight of the XPO loss term. Typically fixed, but can be made dynamic by passing a list to [`XPOConfig`].
* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`XPOConfig`].
## XPOTrainer
[[autodoc]] XPOTrainer
- train
- save_model
- push_to_hub
## XPOConfig

View File

@ -0,0 +1,30 @@
# Context Parallelism with FSDP for 2 GPUs
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: true # Enable activation checkpointing for memory efficiency
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2 # Number of GPUs
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
parallelism_config:
parallelism_config_dp_replicate_size: 1
parallelism_config_dp_shard_size: 1
parallelism_config_tp_size: 1
parallelism_config_cp_size: 2 # Context parallel size

View File

@ -0,0 +1,28 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: true
fsdp_offload_params: false
fsdp_reshard_after_forward: FULL_SHARD
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true
fsdp_version: 1
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Some files were not shown because too many files have changed in this diff Show More