Compare commits

...

253 Commits

Author SHA1 Message Date
9c46200672 adds entropy calculation 2025-05-27 19:03:39 +00:00
ac6dc65fdd save qwip 2025-05-27 14:13:30 +00:00
d1174adc5b 🛠️ Initialize reward_kwargs to prevent UnboundLocalError in GRPOTrainer (#3459)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-26 18:28:27 -07:00
cd838417e4 👇 Update grpo.py to fix bugs for cli grpo --reward_funcs my_lib.my_reward (#3454)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-26 17:59:57 -07:00
c7e3f096a5 [GKD] fix the gkd script (#3497) 2025-05-26 20:22:15 +02:00
5c08897570 [GRPO] disabling top_k sampling default (#3494)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-05-26 11:32:07 +02:00
3ef9faf257 [Docs] sync logging doc to current metrics (#3478) 2025-05-25 17:46:28 +02:00
9ac614fb08 Fix mis-aligned prompts and completions in colocate mode (#3491)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-05-24 16:50:45 -06:00
29401e790e [Doc][SFT] Update sft_trainer.md. link prompt-completion dataset example (#3486)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-05-24 19:13:00 +02:00
31bf3f9244 Fix typo (#3489) 2025-05-24 13:24:15 +02:00
7f32792c07 [CI] fix sampler api to make the CI green (#3488) 2025-05-23 17:32:23 +02:00
3d8727918a [SFT] update minimal liger version (#3483) 2025-05-23 13:44:20 +02:00
65245f6be8 Update .pre-commit-config.yaml (#3479) 2025-05-22 16:08:23 +02:00
a528b9c465 [NashMD] fix the edge case where the model is a peft model (#3473)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-05-20 17:02:04 +02:00
e0dd525021 🙅 PPO value_model can't be None, so it shouldn't be Optional (#3300) 2025-05-19 17:01:08 -07:00
64aa06499b enable activation offloading on XPU (#3444)
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-05-19 11:56:14 +02:00
be93a0c30c enable vllm c-s tests on XPU (#3445)
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-05-19 11:55:57 +02:00
f9fbd91ea9 [CI] fix CI failure of transformer dev (#3457) 2025-05-19 10:08:42 +02:00
54d4f6b13a 🎁 Reward submodule (#3430) 2025-05-15 19:10:22 -07:00
05bc43e960 feat: Implement Two-Sided Clipping for GRPO Trainer (#3434)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-05-13 20:36:39 +02:00
d3dc8ff654 use device agnostic empty_cache in ppo & rloo (#3439)
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-05-13 20:10:14 +02:00
21738c3732 enable trl env on xpu (#3438)
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-05-13 11:36:01 +02:00
eab175d434 🏹 Support kv_cache_dtype to quantize kv-cache in vllm (#3422)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-05-08 17:11:16 -07:00
4da4dc9117 Update README.md 2025-05-07 20:49:35 -07:00
6b3a02385d Update README.md (#3420) 2025-05-07 20:48:22 -07:00
abbbb93d6a 🧪 Testing support for Qwen3 tiny (#3415) 2025-05-07 19:32:42 -07:00
cafa663c84 [Models] Activation checkpointing from TorchTune (#2954)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: DanFosing <danfoss12340@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Robert <robert.veres00@gmail.com>
Co-authored-by: Robert Veres <robert.veres@languagetool.org>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Mathew Shen <datahonor@gmail.com>
Co-authored-by: Ishan Kumar <ishankumar216@gmail.com>
Co-authored-by: Huazhong Ji <hzji210@gmail.com>
Co-authored-by: tpoisonooo <khj.application@aliyun.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-05-07 12:36:11 +02:00
fd04a5461a 🐍 Support Python 3.13 (#2593) 2025-05-06 21:38:23 -07:00
56e5766205 🎁 Reward takes completion ids (#3272)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-05-06 10:34:50 -07:00
89d44caece 📝 vLLM-integration documentation (#3376)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-06 09:37:02 -06:00
adfa7fd59a 🎲 [GRPO] Shuffle mini batches (#3391)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-06 11:09:00 +02:00
cf5183db7f 💔 [GRPO] Decouple gradient accumulation from the number of minibatches generated (#3388)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-05-06 09:59:32 +02:00
1954c02d86 🤝 Compatibility of the TRL CLI with accelerate arguments (#3409)
Co-authored-by: Lewis Tunstall <lewis.c.tunstall@gmail.com>
2025-05-06 00:09:23 -07:00
45f4c58832 ✌️ Add support for FSDP2 (#3317)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-05-06 08:29:11 +02:00
cc044e35b2 🕊️ Un-restrict diffusers (#3407) 2025-05-02 15:06:53 -07:00
999acd53ec 🕺 Migrate setup configuration from setup.py to setup.cfg and make rich an optional dep (#3403)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-05-02 11:03:57 -07:00
8606b1ad09 🪪 Remove license classifier (#3402) 2025-05-02 10:03:39 -07:00
a673da5773 👉 [DPO] Model forward pass padding side fix (#3307)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-05-01 20:37:55 -07:00
00b8e311aa 🦁 Fix liger initialization (#3401)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-05-01 20:36:46 -07:00
c163cf5081 💔 [SFT] Raise error when formatting_func is used with completion_only_loss (#3385)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-01 16:23:27 -07:00
bc9c019c43 [IterativeSFT] Small refresher (#3378) 2025-05-01 16:18:41 -07:00
18596cf232 🧑‍🤝‍🧑 Co-Locating vLLM w/ training to for higher throughput and GPU utilization (#3394)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-01 16:17:26 -07:00
280d35301b 🌊 Add MLflow metrics in profiling context (#3400)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-01 16:15:38 -07:00
13fa8402a3 [GRPO] Reference model initialization bug fix (#3397) 2025-05-01 17:31:21 +02:00
09b669fbf7 [🐯+GRPO] Support FSDP + Fix bug when using LigerGRPO with DDP (#3260)
Co-authored-by: Ubuntu <azureuser@liger-ci-h100-vm.kvghai4yzzmufguwws3040dwlf.dx.internal.cloudapp.net>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-04-30 22:49:45 +02:00
01d0be15cb Deprecate TextEnvironment and tools (#3389) 2025-04-29 20:25:36 +02:00
3a42af1c78 DPO fixes for evaluations (#3377) 2025-04-29 17:16:30 +02:00
aaf39604ba PEFT support for Liger GRPO (#3355)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-04-29 17:05:35 +02:00
2bf48478e8 📋 Allow calling trl cli in sft mode with config file (#3380)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-04-28 14:23:42 -07:00
a8cfca6d01 ⚰️ Remove deprecated (#3364) 2025-04-26 11:11:35 -07:00
1bca49515e Better guards for DeepSpeed imports (#3351) 2025-04-26 10:18:11 +02:00
39e96394a9 🎭 Fix train and eval mode checking in GRPOTrainer and SFTTrainer (#3337)
Co-authored-by: Jiaming Ma <jiaming.ma@connect.polyu.hk>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-04-25 17:42:43 -07:00
8e6ed93dfd 🥸🔢 Adding pad_multiple to SFT trainer (#3365) 2025-04-25 18:12:35 -06:00
29c5e05e3a 🔢 Pad to multiple of (#3362) 2025-04-25 09:53:20 -07:00
a9b27f82d6 ⬆️ Bump dev version (#3357) 2025-04-24 16:22:12 -07:00
cd6b3de356 Release: v0.17 (#3356) 2025-04-24 16:15:45 -07:00
36685c8bba Up to 4x faster: Data Parallel for vLLM server (#3310)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Shirin Yamani <75791599+shirinyamani@users.noreply.github.com>
2025-04-24 15:14:16 -07:00
89556c8cbf 🍡 Fix using reward model and DeepSpeed ZeRO 3 (#3326) 2025-04-23 15:09:33 -07:00
f3e8c23044 Define default chat template for SFT (#3309)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-23 15:49:42 +02:00
9ee6c3aa56 🏁 Fix adding special tokens in SFT (#3328) 2025-04-22 17:51:51 -07:00
ef05331752 [CPO] Check that max_prompt_length < max_length (#3341) 2025-04-22 15:45:15 -07:00
05e2ba6e01 🦄 Add optional uvicorn log level for vLLM serve (#3338)
Co-authored-by: Jiaming Ma <jiaming.ma@connect.polyu.hk>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-22 11:45:13 -07:00
1b4f189e09 💡 Fix type hint in _generate_and_score_completions (#3336) 2025-04-22 08:57:29 -07:00
1faa7f9b36 🧸 Fix unset tokenizer pad_token (#3290)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-04-21 17:20:09 -07:00
66e6eab9bb [doc] Update sft_trainer.md in table x->✓ (#3313)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-21 17:05:20 -07:00
27af0aaf4a Fix typo in text_environments.md (#3305) 2025-04-21 16:39:55 -07:00
b4ffda769e 🙋 Add Optional Eager Execution Mode for vLLM Serving (#3335)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-04-21 15:33:59 -07:00
0dad4eb7ca 🎲 [GRPO] Make training dataset shuffle optional (#3334)
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-04-21 14:34:31 -07:00
c82f626f94 Empty commit to test new protection rules 2025-04-20 23:07:28 +00:00
33add19161 Empty commit to trigger CI 2025-04-20 23:00:31 +00:00
294f35bf3c ☝️ [GRPO] Generate once per effective batch (#3283)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-04-17 16:35:58 -07:00
9874b3aa04 [GRPO] Add metrics for low and high clipped token probabilities (#3289)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-16 14:43:34 +02:00
1e61f6cc5a 🅾️ Fixes typo in SFTTrainer (#3282) 2025-04-15 15:23:40 -07:00
27adc30162 🧗 Add Ascend NPU support for vLLM server (#3286)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-15 15:22:46 -07:00
df737f99c1 🏷️ Fixed naming error in output_dir for Gemma 3 VLM script (#3297) 2025-04-15 14:51:26 -07:00
c04e84c454 Expose EOS token in SFTConfig (#3299)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-15 21:53:28 +02:00
d625c5533a ⏱️ Fix vLLM server to support V1 Engine (#3276) 2025-04-10 18:29:50 -07:00
6cdd24a360 🦾 Test vLLM client-server (#3277) 2025-04-10 18:29:04 -07:00
8b38570258 🕊️ Un-restrict diffusers (#3274) 2025-04-10 07:24:11 -07:00
95b1a9f612 Add Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset) guide to docs (#3235)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-04-10 09:33:41 +02:00
5c1511423b 🔗 Fix Dr. GRPO paper link (#3275) 2025-04-09 19:31:15 -07:00
5e2e9cb442 🩺 Dr. GRPO loss (#3256)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-04-09 11:13:22 -07:00
227df8271e ♾️ [CI] Remove test_raise_error_not_causallm (#3265) 2025-04-09 10:39:36 -07:00
ae1581474e 🚧 Temporarily restrict diffusers to <0.33.0 due to ftfy optional dep issue breaking doc builds (#3273) 2025-04-09 10:20:43 -07:00
47b9515fb1 👎 [GRPO] Adds option to disable dropout (#3234)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-04-09 09:59:06 -07:00
c4891dcfee 🕷 Fix online DPO crash when model is a DataParallel object (#3225)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-09 09:29:13 -07:00
055cee255a Revert "reward takes completion ids"
This reverts commit 73a2fb05545db3c2e92f9311473738278b0d9cd0.
2025-04-09 14:41:55 +00:00
73a2fb0554 reward takes completion ids 2025-04-09 14:40:42 +00:00
982ba08092 🐯 is_liger_kernel_available with min version (#3266) 2025-04-09 06:43:59 -07:00
e03e7acc5c ⛏️ Add cli dict parsing for grpo_config (#3082) 2025-04-08 15:55:33 -07:00
9df19e8a75 📜 Fix license and copyrights (#3264) 2025-04-08 15:22:58 -07:00
1d7b8c4f70 Overlong-filtering for GRPO (#3248)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-08 12:52:52 -06:00
7e170612a4 💠 Fix multi-gpu padding free (#3245) 2025-04-08 11:43:56 -07:00
559724ee2c 📦 [SFT] Deprecate batched formatting_func (#3147)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-04-08 09:42:17 -07:00
a5a46725c8 🗑️ Deprecate ConstantLengthDataset (#3242) 2025-04-08 08:03:57 -07:00
b6bcafb8bb 🏃 Fix and make CI faster (#3160) 2025-04-08 06:12:08 -07:00
4bfb8eb0d1 🔭 Add support for better KL estimator (k3) in PPOTrainer (#3240)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-05 22:33:28 -07:00
4d66bad208 ☑ Update PULL_REQUEST_TEMPLATE.md (#3241) 2025-04-05 16:28:19 -07:00
e90117b3e1 PPOTrainer: fix progress bar for num_mini_batches > 1 (#2531)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-05 15:47:28 -07:00
31b54a6237 🌊 Add error for iterable datasets in GRPOTrainer (#3216) 2025-04-05 15:41:53 -07:00
17e33cdaa0 🎀 Simplify logging text (#3219)
Co-authored-by: Lewis Tunstall <lewis.c.tunstall@gmail.com>
2025-04-05 15:38:32 -07:00
5a0cebc786 📢 Improve GRPO trainer error message for invalid num_generations (#3199)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-04 21:52:00 -07:00
65308cfd84 ⏯️ Fix logging when resuming from checkpoint GRPO (#3185) 2025-04-04 21:51:36 -07:00
1755e03f6f Update ruff to 11.3 and base Python version to 3.9 (#3230)
Signed-off-by: cyy <cyyever@outlook.com>
2025-04-04 13:50:14 +02:00
793735a698 🐯 Integrate Liger GRPO Loss to GRPO Trainer (#3184)
Co-authored-by: Ubuntu <azureuser@liger-ci-h100-vm.kvghai4yzzmufguwws3040dwlf.dx.internal.cloudapp.net>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-03 19:17:00 +02:00
e70a0efeca Group completion metrics by common prefix (#3212)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-04-03 08:11:35 +02:00
7eaca76ed1 📚 Accumulate completions for logging (#3217) 2025-04-02 17:00:43 -07:00
657f9ce6ee 🗝️ Fix type hint in vLLM client (#3205) 2025-04-02 09:40:21 -07:00
485852c942 😷 Fix SFT masking EOS when equal to PAD (#3200) 2025-04-02 08:56:05 -07:00
9f3702f6be [GRPO] Improve completion length logging (#3188) 2025-04-01 10:00:40 +02:00
e751a16df5 🐗 [CI] Fix trufflehog false positives (#3192) 2025-03-31 11:01:55 -07:00
582bc5684b Show unique prompts in GRPO WandB tables (#3191) 2025-03-31 18:50:21 +02:00
c5ba70d4fc Fix breaking typo for flash_attention reducing_memory_usage.md (#3190) 2025-03-31 12:17:10 +02:00
5b586da3cc 📎 Fix is_clipped to compute the effective clip_ratio (#3175)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-03-30 22:24:14 -07:00
488025cd87 ⏯️ Fix: handle None inputs when resuming GRPO Trainer from checkpoint (#3148)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-30 21:25:53 -07:00
2594cb39de ❤️‍🩹 [CI] fix transformers dev CI failure (#3176)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-29 18:39:40 -07:00
2fe2337067 🏃 Migrate CI to self-hosted runners (#3174) 2025-03-29 11:56:44 -07:00
f6b4d6e569 [Liger] Liger KTO support (#2812)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-03-28 20:56:59 +01:00
26d86757a7 💎 Gemma 3 VLM SFT example script for single-image and multi-image (#3131)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-03-26 08:16:02 -07:00
9771f259ed 💰 Richer rich table - log all the rewards (#3156) 2025-03-26 07:45:51 -07:00
7bdedd4075 👨‍🍳 vLLM serve: destroy process group on exit and pass worker_cls as string (#3159) 2025-03-26 07:00:57 -07:00
a069a2f19c 🔫 Disable triggering CI when PR is draft (#3154) 2025-03-25 10:59:01 -07:00
ea45f513f3 ⚰️ Remove deprecated (#3153) 2025-03-25 09:57:50 -07:00
a91023990a 🩹 Fix CI (#3155) 2025-03-25 09:16:23 -07:00
1a9387b922 Enable number of printed completions to be set (#3149)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-03-25 08:47:13 +01:00
1884ff1bb8 🤝 Align GRPO equation doc with the implementation (#3151) 2025-03-24 11:37:06 -07:00
bfe2075608 🐇 [Research] Layer Skip SFT (#3111)
Co-authored-by: Mostafa Elhoushi <m.elhoushi@ieee.org>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-03-24 11:02:00 -07:00
6067e2a669 BCOTrainer version upgrade fixes (#2867)
Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
2025-03-24 10:55:00 +01:00
dee37342a8 📊 Fix clip_ratio logging and better document logged values (#3145) 2025-03-23 16:05:42 -07:00
8037f18cdf Fix: Multi gpu hang for ORPO and CPO Trainer (#3069) 2025-03-23 16:25:15 +01:00
a0a53171cc ⬆️ Bump dev version 2025-03-22 21:14:59 +00:00
23a635ed61 Release: v0.16 (#3137) 2025-03-22 14:03:54 -07:00
9b38b0b5ee ⚖️ Add option not to scale rewards (Dr. GRPO) (#3135) 2025-03-22 13:47:52 -07:00
0f26049ea2 ☎️ Documentation for disable gathering of model weights for generation in DeepSpeed ZeRO-3 (#3136) 2025-03-22 13:29:47 -07:00
7511aa4e36 Pack 300 times faster, truncate 100 times faster (#3009)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-22 12:33:31 -07:00
f713f614e9 🚀 Scaling GRPO to 70B+ Models and Multi-Node Training with vLLM Server & NCCL Communication (#3094)
* 🚀allow GRPO to connect to VLLM in remote/local node with NCCL communication

* Update trl/extras/remote_vllm_helper.py

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

* use argparse for options

* add  imports for remote vllm helper

* formatting

* fix arguments

* use cli options

* vllm serve

* clean server

* better naming

* client

* style

* new params in generate

* this method is the new default

* update config

* do not use asserts

* update config

* separate host and post

* proper deprectation

* deprecated arg in the vllm server

* simplify moving

* document host and port

* style

* update trainer

* new generate args

* update doc

* Fix for zero3

* Better naming

* Remove remote_vllm_helper

* remove grpo_with_remote_vllm

* remove cloudpickle from deps

* Some consistency

* Update docs/source/grpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update setup.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* add revision argument to vllm server

* Update docs/source/grpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/grpo_trainer.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Reset the prefix cache after updating weights

* Update vllm_client.py

* Update vllm_client.py

* Update vllm_serve.py

* Add health check endpoint to vLLM server

* connection timeout

* style

* fix doc langauge hint

* move reset_prefix_cache to its own endpoint

* async

* merge peft adaptor to send to vllm

* Looks simple. Wasn't.

* Peft compatibility

* Update docs/source/speeding_up_training.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update docs/source/speeding_up_training.md

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* Update trl/extras/vllm_client.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* GatheredParameters can be disabled

* gather and ungather peft weights within the same deepseed context

* use is_vllm_available

* minor consistency fixes

* fix error when deepspeed is not installed

* fix deepspeed import when not peft

* simpler

* multinode doc

* minor code and comments changes

* style

* optional deps

* vllm_server_timeout as arg

* small refinement in doc

* update deps

* Fix VLLMClient argument in grpo_trainer; Add zero3+peft vllm transfer solution

* Revert "Fix VLLMClient argument in grpo_trainer; Add zero3+peft vllm transfer solution"

This reverts commit d759c9c4d12ff4531482c465c6257a59987ba748.

* log num_tokens

* disable vllm test (in the future we'll add a mock for vllm server for them)

* style

* fix ds3_gather_for_generation

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-03-21 12:12:08 -07:00
a34987956c 🎬 Clip higher (#3118)
* epsilon range added

* epsilon doc str updated

* test removed

* pre-commit run

* Update trl/trainer/grpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* upper epsilon updated

* precommit updates added

* minor format and dtype fixes

* moving upper bound computation in init

* hf.co for paper link

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-03-19 19:28:19 -06:00
0f88c179e3 Merge pull request #3079 from huggingface/flexible_reward
Flexible_reward
2025-03-18 11:32:16 -06:00
beda4328cc Use main process for dataset.map (#3106) 2025-03-18 17:36:12 +01:00
07cfe1677e add "_prepare_fsdp" for DPOTrainer (#2539)
* enable prepare fsdp

* Update trl/trainer/dpo_trainer.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* remove activation_checkpointing

* move to utils.py

* fix style

* Update utils.py

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-03-17 14:37:15 +01:00
9f7755d8ed 🕊️ Padding-free for SFT (#3076) 2025-03-15 12:52:24 -07:00
4e3f569eb8 Update grpo_trainer.md [ci skip] 2025-03-14 18:48:50 -07:00
979fda1548 title multi-task added for example4 2025-03-15 01:19:31 +00:00
f6fb6a88a9 precommit fixed applied 2025-03-15 01:10:32 +00:00
6cbf8fbc9f Merge branch 'flexible_reward' of github.com:huggingface/trl into flexible_reward 2025-03-15 01:08:08 +00:00
5cb390cd30 Add EOS token to processed input in SFT (#3091)
* Add EOS token to processed input

* Update sft_trainer.py

* fix test
2025-03-14 18:06:15 -07:00
b3c391e628 Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 19:03:31 -06:00
1b85ca6147 grpo doc updated 2025-03-15 01:03:04 +00:00
e7a1290b0a Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:57:13 -06:00
3822edd67b Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:56:54 -06:00
230455cab0 Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:56:33 -06:00
08f014d559 Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:50:56 -06:00
10740333bd Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:49:07 -06:00
058a733c30 Update tests/test_grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:48:59 -06:00
3f193972d8 Update tests/test_grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:48:39 -06:00
b575596b89 Update tests/test_grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:45:55 -06:00
118c43f0e0 Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:44:05 -06:00
40b1c33edf Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:38:08 -06:00
1a2e74cc5a Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:35:38 -06:00
80f7dcb16d Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:35:04 -06:00
4404ccd24a Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:34:50 -06:00
39f77ca2d8 Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:34:36 -06:00
52085dd96b final version 2025-03-15 00:19:34 +00:00
c7a1c95017 Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:07:38 -06:00
3003058418 Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:07:31 -06:00
a759cee2e0 Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:07:24 -06:00
0a3bad44f0 Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:07:13 -06:00
bb5b96a823 Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:07:06 -06:00
8466c7273e Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:06:59 -06:00
a871ec8e91 Update tests/test_grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:06:36 -06:00
f7572221db Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 18:06:29 -06:00
8ec2e42833 Online fixes 2025-03-14 23:58:33 +00:00
218d493d11 Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 17:15:54 -06:00
1a9f78eb3a Update docs/source/grpo_trainer.md
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 16:57:18 -06:00
a10978ebdf reviews reflected 2025-03-14 22:27:46 +00:00
87fbb831d3 Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 14:04:39 -06:00
52f39d6a24 Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-14 13:57:48 -06:00
931f7a14d2 conflict 2 pushes fixed 2025-03-14 19:47:05 +00:00
9951105a90 Merge remote-tracking branch 'origin/flexible_reward' into flexible_reward 2025-03-14 19:36:32 +00:00
5a6e23aac9 review commnts reflected + unittest n doc added 2025-03-14 19:28:59 +00:00
d9104c8b0d Update trl/trainer/grpo_trainer.py
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-13 16:27:55 -06:00
d5a5840307 Remove simple_test.py from version control 2025-03-13 22:23:09 +00:00
f3cbd41e2c interactive reward_func added 2025-03-13 22:09:12 +00:00
d41a32f619 restriction removed from util 2025-03-13 18:58:07 +00:00
fc4dae256d 🫣 [GRPO] add cache_implementation option in GRPO (#3075)
* add cache_implementation option in GRPO

* add cache_implementation to config

* Update trl/trainer/grpo_config.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-13 19:21:36 +01:00
e4e5671e80 💎 Gemma 3 SFT example on Codeforces dataset (#3070)
* Gemma 3 and padding free

* remove padding free changes

* style

* update sft cli

* update script

* revert

* style
2025-03-13 10:50:52 -07:00
7c76f103da irrelavant reward ignorance added 2025-03-13 17:39:49 +00:00
aad18ef52a 🎭 Minor spelling fix in documentation (caracteres -> characters) (#3074)
Signed-off-by: Ed Snible <snible@us.ibm.com>
2025-03-13 08:59:24 -07:00
b55d9f0412 Fixing JSD loss computation as per definition (#3043)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-03-13 11:52:50 +01:00
4871c82b0c 🏊 [SFT] Compatibility with padding free and iterable dataset (#3053)
* Compatibilitywith padding free and iterable dataset

* Fix collator test

* add a test for streaming

* some cleaning

* improve and fix tests

* tiny revert

* bump datasets to 3.0.0
2025-03-12 11:44:25 -07:00
fd9e5a7cab 🦥 Fixed SFTTrainer.compute_loss hang by re-summing before the gather (#3056) 2025-03-12 05:43:33 -07:00
5463e49a55 use argument names with processing_class (#3062) 2025-03-12 13:03:45 +01:00
22759c8208 👯 [GRPO] Relax the assumption that prompts are unique within a batch (#3052)
* Relax the assumption that prompts are unique within a batch

* style
2025-03-11 15:24:06 -07:00
2ee6fd369f 💠 Fixing SFTTrainer.compute_loss crash with accelerate (#3048)
* Fixed crash in SFTTrainer due to accelerator.gather_for_metrics during training

* Moved sum outside of accelerator.gather_for_metrics

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-11 11:08:51 -07:00
844a9c665f 🏁 Passing custom BOS/EOS token to GPROTrainer.generation_config (#3046)
* Passing custom BOS/EOS token to fallback GRPOTrainer.generation_config

* Reordered kwargs per PR comment

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-11 11:08:33 -07:00
04f6597377 🌡️ Fix temperature inconsistency in GRPO trainer (#3029)
* fix temperature inconsistency in GRPO trainer

* adding 1e-7 isn't necessary

* comment

---------

Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-03-11 10:36:42 -07:00
e3244d2d09 🚀 Supporting deepspeed>=0.16.4's rename (#2963)
* Added else clause to avoid NameError on optimizer_offload

* Accounted for deepspeed's renaming in 0.16.4

* Switched to packaging.version.parse over the (broken) tuple split

* Moved from NotImplementedError to RuntimeError in else clause
2025-03-05 15:49:21 +01:00
6a02c69789 🎲 Add support for additional generation kwargs in GRPO Trainer (#2989)
* Add support for additional generation kwargs in GRPO Trainer

- Extend GRPOConfig to support additional generation kwargs
- Update GRPOTrainer to incorporate additional generation parameters
- Add tests for training with additional generation kwargs for both standard and vLLM modes

* Add missing vllm_gpu_memory_utilization=0.5

* 🔧 Refactor GRPO generation parameters and configuration

- Restructure GRPOConfig to separate generation parameters
- Add support for top_p, top_k, min_p, repetition_penalty, and length_penalty
- Remove additional_generation_kwargs in favor of explicit parameters
- Update GRPOTrainer to use new generation parameter configuration

* Update tests

* Remove length_penalty and fix tests

* Update defaults and docs

- Change temperature type from Optional[float] to float
- Set default top_p to 1.0 instead of None
- Simplify parameter descriptions by removing redundant "if set to None" text
- Maintain consistent type hints and default values for generation parameters

* GRPO remove optional type hint for temperature parameter

* Remove length_penalty from sampling_kwargs dict in GRPOTrainer

* some refactoring

* top k None support

* change value of in test to amke them work

---------

Co-authored-by: Robert Veres <robert.veres@languagetool.org>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-03-05 09:58:00 +01:00
a1c58aa42a 🗜️ Loosened tokenizer type hint on apply_chat_template (#3005) 2025-03-04 17:41:42 +01:00
3f0695a4ca 🌍 Use global normalization for KL logging (to match normalization for loss) (#3004)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-04 17:14:22 +01:00
a72b50b772 📚 Update customization and distributing training documentation (#2991) 2025-03-04 16:37:54 +01:00
ea1d9be2a7 ✌️ Remove double compute of sum in SFTTrainer (#3001)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-04 16:35:30 +01:00
402187baab Improve ci (#3007)
* Create codeQL.yml

* Create custom-queries.qls

* Update custom-queries.qls
2025-03-04 15:53:51 +01:00
5858ceab7e 🪙 [SFT] Log num_tokens and some logging fixes (#3006) 2025-03-04 15:45:11 +01:00
7442d42c21 Update pr_style_bot.yml (#3003) 2025-03-03 19:23:16 +01:00
98de0e7c62 🚀 DeepSpeed integration documentation (#2993)
* ds doc

* Update docs/source/deepspeed_integration.md

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-03-03 14:51:45 +01:00
491921c1a4 🛣️ inference_mode to no_grad when computing old_per_token_logps (#2987) 2025-02-28 22:58:05 +01:00
ad6a35bdd5 🫔 [GRPO] Pass wrapped model to unwrap_model_for_generation for DeepSpeed Stage-3 compatibility (#2871)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-28 18:17:04 +01:00
7bc9858a8f 🔍 Update GRPO config documentation for beta parameter stability (#2992) 2025-02-28 17:39:12 +01:00
b882f57d93 ⚰️ Deprecate liger-kernel (#2949)
* Deprecate liger

* remove import

* oops, shouldn't be here

* Fix other deprecations

* remove liger from gkd for now

* remove liger for teacher

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-02-28 14:58:47 +01:00
ac7bde5832 📑 Fix logged metrics for KTO (#2982) 2025-02-28 14:58:31 +01:00
3d94e4e25c 📜 Update README and doc index (#2986)
* Update readme and doc index

* bold

* consistent uppercase
2025-02-28 13:51:58 +01:00
1a303cca8e 🧬 Fix typo in grpo_trainer.py (#2988) 2025-02-28 13:49:47 +01:00
ac327d5e84 🪪 Adds a more fine-grained profiling context (#2975)
* adds a more fine grained profiling context

* precommit

* fix reward func name

* add reward to RM name

* Update trl/extras/profiling.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* some doc and fixes

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-02-27 21:58:39 +01:00
c0854c32c9 🌌 Fix logits computation in trainer prediction step (#2969)
* Fix logits computation in DPO trainer prediction step

* fix compute_metrics for bco and test

* same for cpo

* same from dpo

* for kto

* anf finally orpo

* Apply style fixes

---------

Co-authored-by: kyungdae-jo <kyungdae.jo@navercorp.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-02-27 17:09:11 +01:00
aa18ecfde7 👂 Update learning rate doc in KTOConfig (#2912)
* Update kto_config.py

Fix the mismatch between documentation (and suggested) kto learning rate

* fix doc

---------

Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-27 14:40:54 +01:00
6849c050b9 🕸 Add distributing training guide (#2956) 2025-02-27 14:31:52 +01:00
27a6f2201b 🧗 Add GRPO Trainer support for third-party accelerators (#2836)
* Add GRPO Trainer support for Ascend NPU

* 更新 grpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* code format

* 更新 grpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* patch mem_get_info

* stylre

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-02-27 13:25:24 +01:00
f074dcdc86 👧🏽 Adding DoRA support to model config (#2974) 2025-02-27 12:37:22 +01:00
0caff61600 Update grpo_trainer.py (#2973) 2025-02-27 09:38:32 +01:00
019fc6dbaa 🔢 Fix GRPO doc about num_iterations (#2966) 2025-02-26 12:46:08 +01:00
69ad852e56 Parameterize enable_prefix_caching (#2900)
* parameterize enable_prefix_caching

* apply review suggestion

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-25 00:40:09 +01:00
45ccdefac4 📌 Pin liger-kernel and vLLM (#2952)
* pin liger-kernel

* style
2025-02-25 00:34:16 +01:00
703484a8c2 🗿 Updated DPO default values for alpha and tau (#2918)
* updated DPO default values for alpha and tau

* same for grpo

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-02-25 00:19:48 +01:00
9b76d5f2e9 ↩️ Fix typo in TextEnvironment init param, should be max_tool_response (#2921)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-25 00:08:06 +01:00
cbe0681ba1 📇 GRPO: print completions to console and update docs (#2951)
*  Enhance GRPO logging with configurable completions sampling

- Update `GRPOConfig` to replace `log_completions` with `log_completions_steps`
- Add `print_prompt_completions_sample()` utility function for rich console logging
- Modify `GRPOTrainer` to additionally print 5 random prompt-completion pairs every log_completions_steps steps

* GRPO trainer completions logging, move wandb checks together

* Add rich availability check and use fallback in print_prompt_completions_sample when rich is not available

* Update docstrings on print_prompt_completions_sample

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Revert back to simple log_completions bool

* GRPO log completions fully

* Remove print fallback from print_prompt_completions_sample

* Move accelerator main process check up for grpo log completions

* Explicit variable names in print_prompt_completions_sample

* Make GRPOConfig docstring match field description

* Update log_completions docs again

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update GRPOConfig docs to match field

* improve readibility when prompt or completions are multilines

* log reward

* prevent hanging, don't print without rich, print reward

* style

---------

Co-authored-by: Robert Veres <robert.veres@languagetool.org>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-02-24 23:53:13 +01:00
4e0cf01aef Prevent applying the chat template to tokenized datasets (#2939)
* Update sft_config.py

* Update sft_trainer.py

* Update sft_config.py

* Update sft_trainer.py

* Apply style fixes

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-02-24 23:14:49 +01:00
5c05913196 🐯 Fix LigerKernel for SFTTrainer (#2940)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-24 17:29:48 +01:00
caba04da42 ☠️ Update max_seq_length to max_length in SFTConfig (#2947) 2025-02-24 16:26:20 +01:00
be5a088337 📋 Add vLLM version to environment printout (#2946) 2025-02-24 14:22:43 +01:00
38861475e6 ♻️ Fix caching in SFT (#2945) 2025-02-24 10:54:39 +01:00
f69707dab4 🐈 Bye bye chat (#2934)
* Bye chat

* better warning

* style error

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-02-23 19:18:28 +01:00
76f00fc394 Ensure precommit exits 0 status 2025-02-23 16:34:54 +00:00
8453017622 🧼 Upgrade ruff (#2938) 2025-02-23 17:33:50 +01:00
3608709529 Update pr_style_bot.yml 2025-02-23 14:32:36 +01:00
21f0055893 🤖 Style bot (#2935) 2025-02-23 14:29:22 +01:00
013d360b8f 🔹 Fix: Miscalculated mask shape in comments (#2925) 2025-02-21 17:01:53 +01:00
e5ae703d35 🐦🔥 6x faster GRPO with multi-step optimization (#2899)
* Add num_updates and epsilon parameters to GRPOConfig and GRPOTrainer

* test sampler

* update the loss computation

* fix eval sampler

* should work now

* buffer inputs with grad accum

* optimize when num_iterations == 1

* test

* minor comment removal and fix log metric

* beta position

* clarify comment [ci skip]

* clarify sampler doc [ci skip]

* fix collision with eval logging

* clarify
2025-02-20 19:51:45 +01:00
a92e00e810 🪪 Adds profiling decorators for GRPOTrainer (#2889)
* adds profiling decorator

* naming + precommit

* style

* revert inclusion of slider table

* revert 2

* revert3

* revert4

* revert 5 fml

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-20 09:57:42 +01:00
9b3c5bf64f 📍 [GRPO] add gradient_checkpointing (#2848)
* add gradient_checkpointing

* added a helper

* Update trl/trainer/grpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update trl/trainer/grpo_trainer.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* minor refactor for better readability

* use acceelrate util

* enable_input_require_grads is in base class

---------

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-02-18 18:09:16 +01:00
15fec312d5 🍃 GRPO - Do not load reference model when beta == 0 (#2806)
* 🔧 Optimize GRPO training by conditionally loading reference model based on beta value

*  Add test for GRPOTrainer with beta=0 to ensure no reference model and KL divergence

* 🔧 Refactor GRPOTrainer code for improved readability and maintainability

* 🔧 Simplify per_token_loss calculation in GRPOTrainer for clarity

* fix test, style, and some struct for clarity

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-18 17:57:15 +01:00
be1e34003c 🩳 max_seq_length to max_length (#2895)
* `max_seq_length` to `max_length`

* remove in 0.20
2025-02-18 16:53:37 +01:00
6aaf379a82 ⚰️ Remove deprecated (#2894) 2025-02-18 16:53:21 +01:00
49adf74833 Add vLLM guided decoding support to GRPO Trainer (#2811)
*  Add vLLM guided decoding support to GRPO Trainer

* 🔧 Update vLLM guided decoding in GRPO to use regex parameter

* style and docstring

* test

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-18 16:53:05 +01:00
6c54f023ae 🪂 Don't gather logits in SFT to avoid hanging (#2890)
* Don't gather logits

* Remove unused function and test
2025-02-18 15:31:08 +01:00
963243a7d1 Optimize vllm num_generations (#2855)
* small optimization of vllm batching

* style

* adds comment

* style
2025-02-18 11:44:15 +01:00
aafd8cbea5 🍟 [SFT] Handles the dataset if it has been preprocessed (#2863)
* return dataset if it's preprocessed

* add is_processed flag variable

* add test

* move test_sft_trainer_directly_with_pretokenized_data to Tester2

* Update sft_trainer.py

* no need for padding and truncation

* minor reorganization

* Update trl/trainer/sft_trainer.py

* let the collator pad

* style

* fix tests

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
2025-02-18 09:56:47 +01:00
822653824b 🧶 [GRPO][vLLM + LoRA] Move unmerge of PEFT model after weight loading (#2873) 2025-02-17 20:34:07 +01:00
ba036576d4 💬 Add maybe_convert_to_chatml map for conversational datasets in SFT (#2862)
* add back get_formatting_func_from_dataset

* maybe_convert_to_chatml

* maybe_convert_to_chatml before maybe_apply_chat_template map

* remove comment

* test

* desc

* style

* Update trl/data_utils.py

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-02-17 16:47:06 +01:00
293b620950 [GRPO] Fix loss normalization (#2881)
* fix GRPO loss normalization

* fix sum dim

* fix loss= repeated
2025-02-17 13:26:21 +01:00
ae3bd0d07a 🆙 Bump vLLM min version to 0.7.2 (#2860)
Bumps vllm as there were a number of throughput improvements in vllm==0.7.2

Also may resolve issue such as https://github.com/huggingface/trl/issues/2851
2025-02-17 10:54:07 +01:00
6d9fc11fd6 [SFT] fix check for AutoLigerKernelForCausalLM (#2874)
* fix check for AutoLigerKernelForCausalLM

* fix case where AutoLigerKernelForCausalLM is not defined

* update min liger version

* formatting

* fix win CI
2025-02-17 07:50:55 +01:00
ffcb9f4aee ⬆️ Bump dev version 2025-02-13 14:33:44 +00:00
238 changed files with 10748 additions and 3964 deletions

View File

@ -21,8 +21,7 @@ Fixes # (issue)
Pull Request section?
- [ ] Was this discussed/approved via a GitHub issue? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the
[documentation guidelines](https://github.com/huggingface/trl/tree/main/docs).
- [ ] Did you make sure to update the documentation with your changes?
- [ ] Did you write any new necessary tests?

19
.github/codeql/custom-queries.qls vendored Normal file
View File

@ -0,0 +1,19 @@
import codeql
from WorkflowString interpolation, Workflow workflow
where
interpolation.getStringValue().matches("${{ github.event.issue.title }}") or
interpolation.getStringValue().matches("${{ github.event.issue.body }}") or
interpolation.getStringValue().matches("${{ github.event.pull_request.title }}") or
interpolation.getStringValue().matches("${{ github.event.pull_request.body }}") or
interpolation.getStringValue().matches("${{ github.event.review.body }}") or
interpolation.getStringValue().matches("${{ github.event.comment.body }}") or
interpolation.getStringValue().matches("${{ github.event.inputs.* }}") or
interpolation.getStringValue().matches("${{ github.event.head_commit.message }}")
interpolation.getStringValue().matches("${{ github.event.* }}") and
(
step.getKey() = "run" or // Injection in run
step.getKey() = "env" or // Injection via env
step.getKey() = "with" // Injection via with
)
select workflow, "🚨 Do not use directly as input of action"

View File

@ -9,6 +9,7 @@ concurrency:
jobs:
build:
if: github.event.pull_request.draft == false
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
with:
commit_sha: ${{ github.event.pull_request.head.sha }}

26
.github/workflows/codeQL.yml vendored Normal file
View File

@ -0,0 +1,26 @@
name: "CodeQL Analysis - Workflows"
on:
workflow_dispatch:
jobs:
analyze:
name: "Analyze GitHub Workflows"
runs-on: ubuntu-latest
permissions:
security-events: write
actions: read
contents: read
steps:
- name: "Checkout repository"
uses: actions/checkout@v4
- name: "Initialize CodeQL"
uses: github/codeql-action/init@v2
with:
languages: "yaml"
queries: +security-and-quality, ./.github/codeql/custom-queries.qls
- name: "Perform CodeQL Analysis"
uses: github/codeql-action/analyze@v2

127
.github/workflows/pr_style_bot.yml vendored Normal file
View File

@ -0,0 +1,127 @@
name: PR Style Bot
on:
workflow_dispatch:
permissions:
contents: write
pull-requests: write
jobs:
run-style-bot:
if: >
contains(github.event.comment.body, '@bot /style') &&
github.event.issue.pull_request != null
runs-on: ubuntu-latest
steps:
- name: Extract PR details
id: pr_info
uses: actions/github-script@v6
with:
script: |
const prNumber = context.payload.issue.number;
const { data: pr } = await github.rest.pulls.get({
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: prNumber
});
// We capture both the branch ref and the "full_name" of the head repo
// so that we can check out the correct repository & branch (including forks).
core.setOutput("prNumber", prNumber);
core.setOutput("headRef", pr.head.ref);
core.setOutput("headRepoFullName", pr.head.repo.full_name);
- name: Check out PR branch
uses: actions/checkout@v3
env:
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
HEADREF: ${{ steps.pr_info.outputs.headRef }}
with:
# Instead of checking out the base repo, use the contributor's repo name
repository: ${{ env.HEADREPOFULLNAME }}
ref: ${{ env.HEADREF }}
# You may need fetch-depth: 0 for being able to push
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Debug
env:
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
HEADREF: ${{ steps.pr_info.outputs.headRef }}
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
run: |
echo "PR number: ${{ env.PRNUMBER }}"
echo "Head Ref: ${{ env.HEADREF }}"
echo "Head Repo Full Name: ${{ env.HEADREPOFULLNAME }}"
- name: Set up Python
uses: actions/setup-python@v4
- name: Install dependencies
run: |
pip install ruff pre-commit
- name: Download Makefile from main branch
run: |
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/trl/main/Makefile
- name: Compare Makefiles
run: |
if ! diff -q main_Makefile Makefile; then
echo "Error: The Makefile has changed. Please ensure it matches the main branch."
exit 1
fi
echo "No changes in Makefile. Proceeding..."
rm -rf main_Makefile
- name: Run make style and make quality
run: |
make precommit || true
- name: Commit and push changes
id: commit_and_push
env:
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
HEADREF: ${{ steps.pr_info.outputs.headRef }}
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
echo "HEADREPOFULLNAME: ${{ env.HEADREPOFULLNAME }}, HEADREF: ${{ env.HEADREF }}"
# Configure git with the Actions bot user
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
# Make sure your 'origin' remote is set to the contributor's fork
git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${{ env.HEADREPOFULLNAME }}.git"
# If there are changes after running style/quality, commit them
if [ -n "$(git status --porcelain)" ]; then
git add .
git commit -m "Apply style fixes"
# Push to the original contributor's forked branch
git push origin HEAD:${{ env.HEADREF }}
echo "changes_pushed=true" >> $GITHUB_OUTPUT
else
echo "No changes to commit."
echo "changes_pushed=false" >> $GITHUB_OUTPUT
fi
- name: Comment on PR with workflow run link
if: steps.commit_and_push.outputs.changes_pushed == 'true'
uses: actions/github-script@v6
with:
script: |
const prNumber = parseInt(process.env.prNumber, 10);
const runUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: prNumber,
body: `Style fixes have been applied. [View the workflow run here](${runUrl}).`
});
env:
prNumber: ${{ steps.pr_info.outputs.prNumber }}

View File

@ -21,11 +21,9 @@ jobs:
check_code_quality:
name: Check code quality
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
submodules: recursive
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
@ -38,126 +36,216 @@ jobs:
name: Tests
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
os: ['ubuntu-latest', 'windows-latest']
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
fail-fast: false
runs-on: ${{ matrix.os }}
runs-on:
group: aws-g4dn-2xlarge
container:
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
options: --gpus all
defaults:
run:
shell: bash
if: github.event.pull_request.draft == false
steps:
- uses: actions/checkout@v4
- name: Git checkout
uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "pip"
cache-dependency-path: |
setup.py
requirements.txt
- name: Install Make and Git
run: |
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[dev]"
source .venv/bin/activate
uv pip install ".[dev]"
- name: Test with pytest
run: |
source .venv/bin/activate
make test
- name: Post to Slack
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results with Python ${{ matrix.python-version }} on ${{ matrix.os }} with lastest dependencies
title: Results with Python ${{ matrix.python-version }} and latest dependencies
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
tests_dev:
name: Tests with dev dependencies
runs-on: 'ubuntu-latest'
runs-on:
group: aws-g4dn-2xlarge
container:
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
options: --gpus all
defaults:
run:
shell: bash
if: github.event.pull_request.draft == false
steps:
- uses: actions/checkout@v4
- name: Git checkout
uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: "pip"
cache-dependency-path: |
setup.py
requirements.txt
- name: Install Make and Git
run: |
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -U git+https://github.com/huggingface/accelerate.git
python -m pip install -U git+https://github.com/huggingface/datasets.git
python -m pip install -U git+https://github.com/huggingface/transformers.git
python -m pip install ".[dev]"
source .venv/bin/activate
uv pip install -U git+https://github.com/huggingface/accelerate.git
uv pip install -U git+https://github.com/huggingface/datasets.git
uv pip install -U git+https://github.com/huggingface/transformers.git
uv pip install ".[dev]"
- name: Test with pytest
run: |
source .venv/bin/activate
make test
- name: Post to Slack
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results with Python 3.12 on ubuntu-latest with dev dependencies
title: Results with Python 3.12 and dev dependencies
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
tests_wo_optional_deps:
name: Tests without optional dependencies
runs-on: 'ubuntu-latest'
runs-on:
group: aws-g4dn-2xlarge
container:
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
options: --gpus all
defaults:
run:
shell: bash
if: github.event.pull_request.draft == false
steps:
- uses: actions/checkout@v4
- name: Git checkout
uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: "pip"
cache-dependency-path: |
setup.py
requirements.txt
- name: Install Make and Git
run: |
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install ".[test]"
source .venv/bin/activate
uv pip install ".[test]"
- name: Test with pytest
run: |
source .venv/bin/activate
make test
- name: Post to Slack
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results with Python 3.12 on ubuntu-latest without optional dependencies
title: Results with Python 3.12 without optional dependencies
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
tests_min_versions:
name: Tests with minimum versions
runs-on: 'ubuntu-latest'
runs-on:
group: aws-g4dn-2xlarge
container:
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
options: --gpus all
defaults:
run:
shell: bash
if: github.event.pull_request.draft == false
steps:
- uses: actions/checkout@v4
- name: Git checkout
uses: actions/checkout@v4
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: "pip"
cache-dependency-path: |
setup.py
requirements.txt
- name: Install Make and Git
run: |
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install accelerate==0.34.0
python -m pip install datasets==2.21.0
python -m pip install transformers==4.46.0
python -m pip install ".[dev]"
source .venv/bin/activate
uv pip install accelerate==0.34.0
uv pip install datasets==3.0.0
uv pip install transformers==4.46.0
uv pip install ".[dev]"
- name: Test with pytest
run: |
source .venv/bin/activate
make test
- name: Post to Slack
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results with Python 3.12 on ubuntu-latest with minimum versions
title: Results with Python 3.12 and minimum dependencies versions
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}

View File

@ -13,33 +13,54 @@ env:
jobs:
tests:
name: Tests latest TRL release with dev dependencies
runs-on: 'ubuntu-latest'
runs-on:
group: aws-g4dn-2xlarge
container:
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
options: --gpus all
defaults:
run:
shell: bash
steps:
- name: Git checkout
uses: actions/checkout@v4
with: { ref: v0.15-release }
with: { ref: v0.17-release }
- name: Set up Python 3.12
uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: "pip"
cache-dependency-path: |
setup.py
requirements.txt
- name: Install Make and Git
run: |
apt-get update && apt-get install -y make git curl
- name: Install uv
run: |
curl -LsSf https://astral.sh/uv/install.sh | sh
- name: Create Python virtual environment
run: |
uv venv
uv pip install --upgrade setuptools wheel
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -U git+https://github.com/huggingface/accelerate.git
python -m pip install -U git+https://github.com/huggingface/datasets.git
python -m pip install -U git+https://github.com/huggingface/transformers.git
python -m pip install ".[dev]"
source .venv/bin/activate
uv pip install -U git+https://github.com/huggingface/accelerate.git
uv pip install -U git+https://github.com/huggingface/datasets.git
uv pip install -U git+https://github.com/huggingface/transformers.git
uv pip install ".[dev]"
- name: Test with pytest
run: |
source .venv/bin/activate
make test
- name: Post to Slack
uses: huggingface/hf-workflows/.github/actions/post-slack@main
with:
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
title: Results of latest TRL with Python 3.12 on ubuntu-latest with dev dependencies
title: Results of latest TRL with Python 3.12 and dev dependencies
status: ${{ job.status }}
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}

View File

@ -12,4 +12,7 @@ jobs:
with:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main
uses: trufflesecurity/trufflehog@853e1e8d249fd1e29d0fcc7280d29b03df3d643d
with:
# exclude buggy postgres detector that is causing false positives and not relevant to our codebase
extra_args: --results=verified,unknown --exclude-detectors=postgres

2
.gitignore vendored
View File

@ -142,4 +142,4 @@ checklink/cookies.txt
# wandb files
nbs/wandb/
examples/notebooks/wandb/
wandb/
wandb/

View File

@ -1,8 +1,8 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.3
rev: v0.11.10
hooks:
- id: ruff
- id: ruff-check
types_or: [ python, pyi ]
args: [ --fix ]
- id: ruff-format

View File

@ -31,4 +31,4 @@ keywords:
- pytorch
- transformers
license: Apache-2.0
version: 0.15
version: 0.17

View File

@ -171,8 +171,7 @@ Follow these steps to start contributing:
$ pytest tests/<TEST_TO_RUN>.py
```
> For the following commands leveraging the `make` utility, we recommend using the WSL system when running on
> Windows. More information [here](https://docs.microsoft.com/en-us/windows/wsl/about).
> For the following commands leveraging the `make` utility.
You can also run the full suite with the following command.
@ -457,3 +456,193 @@ Warnings play a critical role in guiding users toward resolving potential issues
```
By following this classification, you ensure that warnings, information, and exceptions are used appropriately, providing clear guidance to the user without cluttering the system with unnecessary messages.
## Making a release
> [!NOTE]
> VERSION needs to be formatted following the `v{major}.{minor}.{patch}` convention. We need to follow this convention to be able to retrieve versioned scripts.
To create the package for PyPI.
#### 0. Prerequisites
- Dependencies:
- twine: `pip install build twine`
- Create an account in (and join the `trl` project):
- PyPI: https://pypi.org/
- Test PyPI: https://test.pypi.org/
#### 1. Ensure your local repository is up to date with the upstream repository
```bash
git checkout main
git pull origin main
```
> [!WARNING]
> Do not merge other pull requests into `main` until the release is done. This is to ensure that the release is stable and does not include any untested changes. Announce internally (#trl-internal) to other maintainers that you are doing a release and that they must not merge PRs until the release is done.
#### 2. Create a release branch from main
```bash
git checkout -b release-v{major}.{minor}
```
#### 3. Change the version in the following files
- `.github/workflows/tests_latest.yml`:
```diff
- with: { ref: v{major}.{minor-1}-release }
+ with: { ref: v{major}.{minor}-release }
```
- `CITATION.cff`
```diff
- version: {major}.{minor-1}
+ version: {major}.{minor}
```
- `__init__.py`
```diff
- __version__ = "{major}.{minor}.0.dev0"
+ __version__ = "{major}.{minor}.0"
```
- `setup.cfg`
```diff
- version = {major}.{minor}.0.dev0
+ version = {major}.{minor}.0
```
#### 4. Commit and push these changes
```shell
git commit -m 'Release: {major}.{minor}'
git push origin release-v{major}.{minor}
```
#### 5. Create a pull request
from `release-v{major}.{minor}` to `main`, named `Release: v{major}.{minor}`, wait for tests to pass, and request a review.
#### 6. Once the pull request is approved, merge it into `main`
#### 7. Add a tag in git to mark the release
```shell
git checkout main
git pull origin main
git tag -a v{major}.{minor}.0 -m 'Adds tag v{major}.{minor}.0 for PyPI'
git push origin v{major}.{minor}.0
```
#### 8. Create a branch `v{major}.{minor}-release` for future patch releases.
```shell
git checkout -b v{major}.{minor}-release
git push origin v{major}.{minor}-release
```
This ensures that future patch releases (`v{major}.{minor}.1`, `v{major}.{minor}.2`, etc.) can be made separately from `main`.
#### 9. Create the wheels for your release
These are the artifacts that will be uploaded to PyPI and installed by users via `pip install trl`.
Clean previous builds:
```shell
rm -rf build dist
```
At the root of your repo, run
```bash
python -m build .
```
This will create a folders named `dist` with the new versions of your package.
#### 10. Upload the package to PyPI Test
> [!IMPORTANT]
> Do not skip this step. It is important to test the package before uploading it to the main PyPI server.
```shell
twine upload dist/* -r testpypi
```
Then in a fresh environment containing all dependencies you need, try to install your new package from the PyPI test server.
```bash
pip install -i https://test.pypi.org/simple/ trl
```
You might get errors for missing dependencies since the PyPI test server does not contain all packages like PyPI does. To make sure you have everything you can do:
```bash
pip install trl
pip uninstall trl
```
(the second line will remove trl but keep all its dependencies).
Also make sure you can actually use the package! Run the following line:
```bash
python -c "from trl import *"
```
along with anything that tests:
- the core feature of your package
- the new features youre adding in the release
#### 11. Publish on PyPI
> [!WARNING]
> This can't be reverted. Make sure you have tested everything before doing this step.
```shell
twine upload dist/*
```
#### 12. Create a GitHub Release
1. Go to the repos [releases section](https://github.com/huggingface/trl/releases) on GitHub.
2. Click **Draft a new release**.
3. Select the `v{major}.{minor}.0` tag you just created in step 7.
4. Add a title (`v{major}.{minor}.0`) and a short description of whats new.
5. Click **Publish Release**.
#### 13. Bump to dev version
1. Create a branch `bump-dev-version-{major}.{minor+1}` from `main` and checkout to it.
```shell
git checkout -b bump-dev-version-{major}.{minor+1}
```
2. Change the version in the following files:
1. `__init__.py`
```diff
- __version__ = "{major}.{minor}.0"
+ __version__ = "{major}.{minor+1}.0.dev0"
```
2. `setup.cfg`
```diff
- version = {major}.{minor}.0
+ version = {major}.{minor+1}.0.dev0
```
3. Commit and push these changes
```shell
git add trl/__init__.py setup.cfg
git commit -m '⬆️ Bump dev version'
git push origin bump-dev-version-{major}.{minor+1}
```
4. Create a pull request from `bump-dev-version-{major}.{minor+1}` to `main`, named `⬆️ Bump dev version`, and request urgent review.
5. Once the pull request is approved, merge it into `main`.
6. The codebase is now ready for the next development cycle, inform the team in the #trl-internal channel.

View File

@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Copyright 2020-2025 The HuggingFace Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@ -1,6 +1,6 @@
include settings.ini
include LICENSE
include CONTRIBUTING.md
include README.md
recursive-exclude * __pycache__
include trl/templates/*.md
include trl/templates/*.md
include trl/accelerate_configs/*.yaml

View File

@ -6,17 +6,14 @@ ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
COMMAND_FILES_PATH = `pwd`/commands
test:
python -m pytest -n auto --dist=loadfile -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' ./tests/
pytest -n auto -m "not slow and not low-priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' tests/
precommit:
pre-commit run --all-files
python scripts/add_copyrights.py
tests_gpu:
python -m pytest tests/test_* $(if $(IS_GITHUB_CI),--report-log "common_tests.log",)
pre-commit run --all-files
slow_tests:
python -m pytest tests/slow/test_* $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
pytest -m "slow" tests/ $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
test_examples:
touch temp_results_sft_tests.txt

145
README.md
View File

@ -12,8 +12,9 @@
<p align="center">
<a href="https://github.com/huggingface/trl/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/huggingface/trl.svg?color=blue"></a>
<a href="https://huggingface.co/docs/trl/index"><img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/trl/index.svg?down_color=red&down_message=offline&up_color=blue&up_message=online"></a>
<a href="https://huggingface.co/docs/trl/index"><img alt="Documentation" src="https://img.shields.io/website?label=documentation&url=https%3A%2F%2Fhuggingface.co%2Fdocs%2Ftrl%2Findex&down_color=red&down_message=offline&up_color=blue&up_message=online"></a>
<a href="https://github.com/huggingface/trl/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/trl.svg"></a>
<a href="https://huggingface.co/trl-lib"><img alt="Hugging Face Hub" src="https://img.shields.io/badge/🤗%20Hub-trl--lib-yellow"></a>
</p>
## Overview
@ -22,16 +23,14 @@ TRL is a cutting-edge library designed for post-training foundation models using
## Highlights
- **Trainers**: Various fine-tuning methods are easily accessible via trainers like [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer) and more.
- **Efficient and scalable**:
- Leverages [🤗 Accelerate](https://github.com/huggingface/accelerate) to scale from single GPU to multi-node clusters using methods like DDP and DeepSpeed.
- Full integration with [`PEFT`](https://github.com/huggingface/peft) enables training on large models with modest hardware via quantization and LoRA/QLoRA.
- Integrates [Unsloth](https://github.com/unslothai/unsloth) for accelerating training using optimized kernels.
- Leverages [🤗 Accelerate](https://github.com/huggingface/accelerate) to scale from single GPU to multi-node clusters using methods like [DDP](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) and [DeepSpeed](https://github.com/deepspeedai/DeepSpeed).
- Full integration with [🤗 PEFT](https://github.com/huggingface/peft) enables training on large models with modest hardware via quantization and LoRA/QLoRA.
- Integrates [🦥 Unsloth](https://github.com/unslothai/unsloth) for accelerating training using optimized kernels.
- **Command Line Interface (CLI)**: A simple interface lets you fine-tune and interact with models without needing to write code.
- **Trainers**: Various fine-tuning methods are easily accessible via trainers like [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`ORPOTrainer`](https://huggingface.co/docs/trl/orpo_trainer) and more.
- **AutoModels**: Use pre-defined model classes like [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) to simplify reinforcement learning (RL) with LLMs.
- **Command Line Interface (CLI)**: A simple interface lets you fine-tune with models without needing to write code.
## Installation
@ -59,60 +58,75 @@ If you want to use the examples you can clone the repository with the following
git clone https://github.com/huggingface/trl.git
```
## Command Line Interface (CLI)
## Quick Start
You can use the TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT) and Direct Preference Optimization (DPO), or vibe check your model with the chat CLI:
**SFT:**
```bash
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/Capybara \
--output_dir Qwen2.5-0.5B-SFT
```
**DPO:**
```bash
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--dataset_name argilla/Capybara-Preferences \
--output_dir Qwen2.5-0.5B-DPO
```
**Chat:**
```bash
trl chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
```
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.
## How to use
For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.
### `SFTTrainer`
Here is a basic example of how to use the `SFTTrainer`:
Here is a basic example of how to use the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer):
```python
from trl import SFTConfig, SFTTrainer
from trl import SFTTrainer
from datasets import load_dataset
dataset = load_dataset("trl-lib/Capybara", split="train")
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
trainer = SFTTrainer(
args=training_args,
model="Qwen/Qwen2.5-0.5B",
train_dataset=dataset,
)
trainer.train()
```
### `GRPOTrainer`
[`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer) implements the [Group Relative Policy Optimization (GRPO) algorithm](https://huggingface.co/papers/2402.03300) that is more memory-efficient than PPO and was used to train [Deepseek AI's R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).
```python
from datasets import load_dataset
from trl import GRPOTrainer
dataset = load_dataset("trl-lib/tldr", split="train")
# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
return [len(set(c)) for c in completions]
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_num_unique_chars,
train_dataset=dataset,
)
trainer.train()
```
### `DPOTrainer`
[`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer) implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train [Llama 3](https://huggingface.co/papers/2407.21783) and many other models. Here is a basic example of how to use the `DPOTrainer`:
```python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(
model=model,
args=training_args,
train_dataset=dataset,
processing_class=tokenizer
)
trainer.train()
```
### `RewardTrainer`
Here is a basic example of how to use the `RewardTrainer`:
Here is a basic example of how to use the [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer):
```python
from trl import RewardConfig, RewardTrainer
@ -137,47 +151,28 @@ trainer = RewardTrainer(
trainer.train()
```
### `GRPOTrainer`
## Command Line Interface (CLI)
`GRPOTrainer` implements the [Group Relative Policy Optimization (GRPO) algorithm](https://huggingface.co/papers/2402.03300) that is more memory-efficient than PPO and was used to train [Deepseek AI's R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).
You can use the TRL Command Line Interface (CLI) to quickly get started with post-training methods like Supervised Fine-Tuning (SFT) or Direct Preference Optimization (DPO):
```python
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
**SFT:**
dataset = load_dataset("trl-lib/tldr", split="train")
# Dummy reward function: rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
return [-abs(20 - len(completion)) for completion in completions]
training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=10)
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_len,
args=training_args,
train_dataset=dataset,
)
trainer.train()
```bash
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/Capybara \
--output_dir Qwen2.5-0.5B-SFT
```
### `DPOTrainer`
**DPO:**
`DPOTrainer` implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train Llama 3 and many other models. Here is a basic example of how to use the `DPOTrainer`:
```python
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOConfig, DPOTrainer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
trainer = DPOTrainer(model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer)
trainer.train()
```bash
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--dataset_name argilla/Capybara-Preferences \
--output_dir Qwen2.5-0.5B-DPO
```
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.
## Development
If you want to contribute to `trl` or customize it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:

View File

@ -42,7 +42,7 @@ accelerate launch $EXTRA_ACCELERATE_ARGS \
--output_dir $OUTPUT_DIR \
--max_steps $MAX_STEPS \
--per_device_train_batch_size $BATCH_SIZE \
--max_seq_length $SEQ_LEN \
--max_length $SEQ_LEN \
$EXTRA_TRAINING_ARGS
"""

View File

@ -23,6 +23,8 @@
title: Reducing Memory Usage
- local: speeding_up_training
title: Speeding Up Training
- local: distributing_training
title: Distributing Training
- local: use_model
title: Using Trained Models
title: How-to guides
@ -35,6 +37,8 @@
title: PEFT
- local: unsloth_integration
title: Unsloth
- local: vllm_integration
title: vLLM
title: Integrations
- sections:
- local: example_overview
@ -47,10 +51,10 @@
title: Training StackLlama
- local: detoxifying_a_lm
title: Detoxifying a Language Model
- local: learning_tools
title: Learning to Use Tools
- local: multi_adapter_rl
title: Multi Adapter RLHF
- local: training_vlm_sft
title: Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset)
title: Examples
- sections:
- sections: # Sorted alphabetically
@ -93,6 +97,8 @@
title: Trainers
- local: models
title: Model Classes
- local: model_utils
title: Model Utilities
- local: best_of_n
title: Best of N Sampling
- local: judges
@ -101,8 +107,10 @@
title: Callbacks
- local: data_utils
title: Data Utilities
- local: text_environments
title: Text Environments
- local: rewards
title: Reward Functions
- local: script_utils
title: Script Utilities
- local: others
title: Others
title: API

View File

@ -1,105 +1,231 @@
# Command Line Interfaces (CLIs)
You can use TRL to fine-tune your Language Model with Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) or even chat with your model using the TRL CLIs.
TRL provides a powerful command-line interface (CLI) to fine-tune large language models (LLMs) using methods like Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and more. The CLI abstracts away much of the boilerplate, letting you launch training jobs quickly and reproducibly.
Currently supported CLIs are:
Currently supported commands are:
#### Training commands
#### Training Commands
- `trl dpo`: fine-tune a LLM with DPO
- `trl grpo`: fine-tune a LLM with GRPO
- `trl kto`: fine-tune a LLM with KTO
- `trl sft`: fine-tune a LLM with SFT
#### Other commands
#### Other Commands
- `trl chat`: quickly spin up an LLM fine-tuned for chatting
- `trl env`: get the system information
- `trl vllm-serve`: serve a model with vLLM
## Fine-tuning with the CLI
## Fine-Tuning with the TRL CLI
Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task.
### Basic Usage
You can launch training directly from the CLI by specifying required arguments like the model and dataset:
<hfoptions id="command_line">
<hfoption id="SFT">
Before using the `sft` or `dpo` commands make sure to run:
```bash
accelerate config
trl sft \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name stanfordnlp/imdb
```
and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command.
We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command.
</hfoption>
<hfoption id="DPO">
```bash
trl dpo \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name anthropic/hh-rlhf
```
</hfoption>
</hfoptions>
### Using Configuration Files
To keep your CLI commands clean and reproducible, you can define all training arguments in a YAML configuration file:
<hfoptions id="config_file">
<hfoption id="SFT">
```yaml
model_name_or_path:
Qwen/Qwen2.5-0.5B
dataset_name:
stanfordnlp/imdb
report_to:
none
learning_rate:
0.0001
lr_scheduler_type:
cosine
# sft_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: stanfordnlp/imdb
```
Save that config in a `.yaml` and get started immediately! An example CLI config is available as `examples/cli_configs/example_config.yaml`. Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g. from the root folder:
Launch with:
```bash
trl sft --config examples/cli_configs/example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts
trl sft --config sft_config.yaml
```
Will force-use `cosine_with_restarts` for `lr_scheduler_type`.
</hfoption>
<hfoption id="DPO">
### Supported Arguments
```yaml
# dpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: anthropic/hh-rlhf
```
We do support all arguments from `transformers.TrainingArguments`, for loading your model, we support all arguments from `~trl.ModelConfig`:
[[autodoc]] ModelConfig
You can pass any of these arguments either to the CLI or the YAML file.
### Supervised Fine-tuning (SFT)
Follow the basic instructions above and run `trl sft --output_dir <output_dir> <*args>`:
Launch with:
```bash
trl sft --model_name_or_path facebook/opt-125m --dataset_name stanfordnlp/imdb --output_dir opt-sft-imdb
trl dpo --config dpo_config.yaml
```
The SFT CLI is based on the `trl/scripts/sft.py` script.
</hfoption>
</hfoptions>
### Direct Policy Optimization (DPO)
### Scaling Up with Accelerate
To use the DPO CLI, you need to have a dataset in the TRL format such as
TRL CLI natively supports [🤗 Accelerate](https://huggingface.co/docs/accelerate), making it easy to scale training across multiple GPUs, machines, or use advanced setups like DeepSpeed — all from the same CLI.
* TRL's Anthropic HH dataset: https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-helpful-base-trl-style
* TRL's OpenAI TL;DR summarization dataset: https://huggingface.co/datasets/trl-internal-testing/tldr-preference-trl-style
You can pass any `accelerate launch` arguments directly to `trl`, such as `--num_processes`. For more information see [Using accelerate launch](https://huggingface.co/docs/accelerate/en/basic_tutorials/launch#using-accelerate-launch).
These datasets always have at least three columns `prompt, chosen, rejected`:
* `prompt` is a list of strings.
* `chosen` is the chosen response in [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
* `rejected` is the rejected response [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
To do a quick start, you can run the following command:
<hfoptions id="launch_args">
<hfoption id="SFT inline">
```bash
trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style
trl sft \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name stanfordnlp/imdb \
--num_processes 4
```
</hfoption>
<hfoption id="SFT w/ config file">
The DPO CLI is based on the `trl/scripts/dpo.py` script.
```yaml
# sft_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: stanfordnlp/imdb
num_processes: 4
```
#### Custom preference dataset
Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`):
Launch with:
```bash
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
trl sft --config sft_config.yaml
```
## Chat interface
</hfoption>
<hfoption id="DPO inline">
```bash
trl dpo \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name anthropic/hh-rlhf \
--num_processes 4
```
</hfoption>
<hfoption id="DPO w/ config file">
```yaml
# dpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: anthropic/hh-rlhf
num_processes: 4
```
Launch with:
```bash
trl dpo --config dpo_config.yaml
```
</hfoption>
</hfoptions>
### Using `--accelerate_config` for Accelerate Configuration
The `--accelerate_config` flag lets you easily configure distributed training with [🤗 Accelerate](https://github.com/huggingface/accelerate). This flag accepts either:
* the name of a predefined config profile (built into TRL), or
* a path to a custom Accelerate YAML config file.
#### Predefined Config Profiles
TRL provides several ready-to-use Accelerate configs to simplify common training setups:
| Name | Description |
| ------------ | ----------------------------------- |
| `fsdp1` | Fully Sharded Data Parallel Stage 1 |
| `fsdp2` | Fully Sharded Data Parallel Stage 2 |
| `zero1` | DeepSpeed ZeRO Stage 1 |
| `zero2` | DeepSpeed ZeRO Stage 2 |
| `zero3` | DeepSpeed ZeRO Stage 3 |
| `multi_gpu` | Multi-GPU training |
| `single_gpu` | Single-GPU training |
To use one of these, just pass the name to `--accelerate_config`. TRL will automatically load the corresponding config file from `trl/accelerate_config/`.
#### Example Usage
<hfoptions id="accelerate_config">
<hfoption id="SFT inline">
```bash
trl sft \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name stanfordnlp/imdb \
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
```
</hfoption>
<hfoption id="SFT w/ config file">
```yaml
# sft_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: stanfordnlp/imdb
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
```
Launch with:
```bash
trl sft --config sft_config.yaml
```
</hfoption>
<hfoption id="DPO inline">
```bash
trl dpo \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name anthropic/hh-rlhf \
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
```
</hfoption>
<hfoption id="DPO w/ config file">
```yaml
# dpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: anthropic/hh-rlhf
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
```
Launch with:
```bash
trl dpo --config dpo_config.yaml
```
</hfoption>
</hfoptions>
## Chat Interface
<Tip warning={true}>
The chat interface is deprecated and will be removed in TRL 0.19. Use `transformers-cli chat` instead. For more information, see the [Transformers documentation, chat with text generation models](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
</Tip>
The chat CLI lets you quickly load the model and talk to it. Simply run the following:
@ -124,7 +250,7 @@ Besides talking to the model there are a few commands you can use:
- `save` or `save {SAVE_NAME}`: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
- `exit`: closes the interface
## Getting the system information
## Getting the System Information
You can get the system information by running the following command:
@ -132,7 +258,7 @@ You can get the system information by running the following command:
trl env
```
This will print out the system information including the GPU information, the CUDA version, the PyTorch version, the transformers version, and the TRL version, and any optional dependencies that are installed.
This will print out the system information, including the GPU information, the CUDA version, the PyTorch version, the transformers version, the TRL version, and any optional dependencies that are installed.
```txt
Copy-paste the following information when reporting an issue:
@ -140,7 +266,7 @@ Copy-paste the following information when reporting an issue:
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
- Python version: 3.11.9
- PyTorch version: 2.4.1
- CUDA device: NVIDIA H100 80GB HBM3
- accelerator(s): NVIDIA H100 80GB HBM3
- Transformers version: 4.45.0.dev0
- Accelerate version: 0.34.2
- Accelerate config:
@ -171,6 +297,7 @@ Copy-paste the following information when reporting an issue:
- LLM-Blender version: 0.0.2
- OpenAI version: 1.46.0
- PEFT version: 0.12.0
- vLLM version: not installed
```
This information are required when reporting an issue.
This information is required when reporting an issue.

View File

@ -2,48 +2,6 @@
TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. Note: Although these examples use the DPOTrainer, the customization applies to most (if not all) trainers.
## Train on multiple GPUs / nodes
The trainers in TRL use 🤗 Accelerate to enable distributed training across multiple GPUs or nodes. To do so, first create an 🤗 Accelerate config file by running
```bash
accelerate config
```
and answering the questions according to your multi-gpu / multi-node setup. You can then launch distributed training by running:
```bash
accelerate launch your_script.py
```
We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.:
```shell
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
```
Refer to the [examples page](https://github.com/huggingface/trl/tree/main/examples) for more details.
### Distributed training with DeepSpeed
All of the trainers in TRL can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run:
```shell
accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_your_script.py --all_arguments_of_the_script
```
Note that for ZeRO-3, a small tweak is needed to initialize your reward model on the correct device via the `zero3_init_context_manager()` context manager. In particular, this is needed to avoid DeepSpeed hanging after a fixed number of training steps. Here is a snippet of what is involved from the [`sentiment_tuning`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) example:
```python
ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin
if ds_plugin is not None and ds_plugin.is_zero3_init_enabled():
with ds_plugin.zero3_init_context_manager(enable=False):
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
else:
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
```
Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin.
## Use different optimizers and schedulers

View File

@ -12,6 +12,10 @@
[[autodoc]] maybe_apply_chat_template
## maybe_convert_to_chatml
[[autodoc]] maybe_convert_to_chatml
## extract_prompt
[[autodoc]] extract_prompt
@ -31,3 +35,11 @@
## pack_examples
[[autodoc]] pack_examples
## pack_dataset
[[autodoc]] pack_dataset
## truncate_dataset
[[autodoc]] truncate_dataset

View File

@ -279,7 +279,7 @@ Choosing the right dataset type depends on the task you are working on and the s
| [`PPOTrainer`] | Tokenized language modeling |
| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
| [`SFTTrainer`] | [Language modeling](#language-modeling) |
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
<Tip>

View File

@ -4,4 +4,36 @@
Section under construction. Feel free to contribute!
</Tip>
</Tip>
TRL supports training with DeepSpeed, a library that implements advanced training optimization techniques. These include optimizer state partitioning, offloading, gradient partitioning, and more.
DeepSpeed integrates the [Zero Redundancy Optimizer (ZeRO)](https://huggingface.co/papers/1910.02054), which allows to scale the model size proportional to the number of devices with sustained high efficiency.
![ZeRO Stages](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/zero_stages.png)
## Installation
To use DeepSpeed with TRL, install it using the following command:
```bash
pip install deepspeed
```
## Running Training Scripts with DeepSpeed
No modifications to your training script are required. Simply run it with the DeepSpeed configuration file:
```bash
accelerate launch --config_file <ACCELERATE_WITH_DEEPSPEED_CONFIG_FILE.yaml> train.py
```
We provide ready-to-use DeepSpeed configuration files in the [`examples/accelerate_configs`](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) directory. For example, to run training with ZeRO Stage 2, use the following command:
```bash
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml train.py
```
## Additional Resources
Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin.

View File

@ -0,0 +1,60 @@
# Distributing Training
<Tip warning={true}>
Section under construction. Feel free to contribute!
</Tip>
## Multi-GPU Training with TRL
The trainers in TRL use [🤗 Accelerate](https://github.com/huggingface/accelerate) to enable distributed training across multiple GPUs or nodes. To do so, first create an [🤗 Accelerate](https://github.com/huggingface/accelerate) config file by running
```bash
accelerate config
```
and answering the questions according to your multi-GPU / multi-node setup. You can then launch distributed training by running:
```bash
accelerate launch train.py
```
We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.:
```shell
accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml train.py <SCRIPT_ARGS>
```
This automatically distributes the workload across all available GPUs.
Under the hood, [🤗 Accelerate](https://github.com/huggingface/accelerate) creates one model per GPU. Each process:
- Processes its own batch of data
- Computes the loss and gradients for that batch
- Shares gradient updates across all GPUs
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/multi_gpu.png)
The effective batch size is calculated as:
$$
\text{Batch Size} = \text{per\_device\_train\_batch\_size} \times \text{num\_devices} \times \text{gradient\_accumulation\_steps}
$$
To maintain a consistent batch size when scaling to multiple GPUs, make sure to update `per_device_train_batch_size` and `gradient_accumulation_steps` accordingly.
Example, these configurations are equivalent, and should yield the same results:
| Number of GPUs | Per device batch size | Gradient accumulation steps | Comments |
| --- | --- | --- | --- |
| 1 | 32 | 1 | Possibly high memory usage, but faster training |
| 1 | 4 | 8 | Lower memory usage, slower training |
| 8 | 4 | 1 | Multi-GPU to get the best of both worlds |
<Tip>
Having one model per GPU can lead to high memory usage, which may not be feasible for large models or low-memory GPUs. In such cases, you can leverage [DeepSpeed](https://github.com/deepspeedai/DeepSpeed), which provides optimizations like model sharding, Zero Redundancy Optimizer, mixed precision training, and offloading to CPU or NVMe. Check out our [DeepSpeed Integration](deepspeed_integration.md) guide for more details.
</Tip>
## Multi-Nodes Training
We're working on a guide for multi-node training. Stay tuned! 🚀

View File

@ -61,9 +61,9 @@ Distributed across 8 GPUs, the training takes approximately 3 minutes. You can v
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/dpo-qwen2-reward-margin.png)
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-DPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-DPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-DPO
<pre><code>$ transformers-cli chat --model_name_or_path trl-lib/Qwen2-0.5B-DPO
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?

View File

@ -33,26 +33,37 @@ Then, it is encouraged to launch jobs with `accelerate launch`!
Scripts can be used as examples of how to use TRL trainers. They are located in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) directory. Additionally, we provide examples in the [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directory. These examples are maintained and tested regularly.
| File | Description |
| ----------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| [`examples/scripts/alignprop.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/alignprop.py) | This script shows how to use the [`AlignPropTrainer`] to fine-tune a diffusion model. |
| [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. |
| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
| [`examples/scripts/ddpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ddpo.py) | This script shows how to use the [`DDPOTrainer`] to fine-tune a stable diffusion model using reinforcement learning. |
| [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. |
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language |
| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a reward model on your own dataset. |
| [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models so users may see unexpected behaviour in other model architectures. |
| File | Description |
| --- | --- |
| [`examples/scripts/alignprop.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/alignprop.py) | This script shows how to use the [`AlignPropTrainer`] to fine-tune a diffusion model. |
| [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. |
| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
| [`examples/scripts/ddpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ddpo.py) | This script shows how to use the [`DDPOTrainer`] to fine-tune a stable diffusion model using reinforcement learning. |
| [`examples/scripts/dpo_online.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_online.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a model. |
| [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. |
| [`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py) | This script shows how to use the [`GKDTrainer`] to fine-tune a model. |
| [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`NashMDTrainer`] to fine-tune a model. |
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language |
| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
| [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) | This script shows how to use the [`PRMTrainer`] to fine-tune a Process-supervised Reward Model (PRM). |
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a Outcome Reward Model (ORM) on your own dataset. |
| [`examples/scripts/rloo/rloo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo/rloo.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language |
| [`examples/scripts/rloo/rloo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo/rloo_tldr.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
| [`examples/scripts/sft_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model. |
| [`examples/scripts/sft_video_llm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_video_llm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Video Language Model. |
| [`examples/scripts/sft_vlm_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model on vision to text tasks. |
| [`examples/scripts/sft_vlm_smol_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_smol_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a SmolVLM model. |
| [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models so users may see unexpected behaviour in other model architectures. |
| [`examples/scripts/xpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/xpo.py) | This script shows how to use the [`XPOTrainer`] to fine-tune a model. |
Here are also some easier-to-run colab notebooks that you can use to get started with TRL:
| File | Description |
| --------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------- |
| [`examples/notebooks/best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb) | This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. |
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
| File | Description |
| --- | --- |
| [`examples/notebooks/best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb) | This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. |
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
We also have some other examples that are less maintained but can be used as a reference:

View File

@ -68,11 +68,17 @@ At each training step, we sample a batch of prompts and generate a set of \\( G
### Computing the advantage
For each of the \\( G \\) sequences, we compute the reward using a reward model. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows:
For each of the \\( G \\) sequences, we compute the reward using a reward model. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows:
$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$
$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$
This approach gives the method its name: **Group Relative Policy Optimization (GRPO)**.
This approach gives the method its name: **Group Relative Policy Optimization (GRPO)**.
<Tip>
It was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that scaling by \\( \text{std}(\mathbf{r}) \\) may cause a question-level difficulty bias. You can disable this scaling by setting `scale_rewards=False` in [`GRPOConfig`].
</Tip>
### Estimating the KL divergence
@ -83,46 +89,215 @@ $$
### Computing the loss
The objective is to maximize the advantage while ensuring that the model remains close to the reference policy. Consequently, the loss is defined as follows:
The objective is to maximize the advantage while ensuring that the model remains close to the reference policy. Consequently, the loss is defined as follows:
$$
\mathcal{L}_{\text{GRPO}}(\theta) = -\frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
\mathcal{L}_{\text{GRPO}}(\theta) = -\frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
$$
where the first term represents the scaled advantage and the second term penalizes deviations from the reference policy through KL divergence.
where the first term represents the scaled advantage and the second term penalizes deviations from the reference policy through KL divergence.
In the original paper, this formulation is generalized to account for multiple updates after each generation by leveraging the **clipped surrogate objective**:
<Tip>
Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we don't scale by \\( \frac{1}{|o_i|} \\) because it was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that this introduces a response-level length bias. More details in [loss types](#loss-types).
</Tip>
In the original paper, this formulation is generalized to account for multiple updates after each generation (denoted \\( \mu \\), can be set with `num_iterations` in [`GRPOConfig`]) by leveraging the **clipped surrogate objective**:
$$
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
$$
where \\(\text{clip}(\cdot, 1 - \epsilon, 1 + \epsilon) \\) ensures that updates do not deviate excessively from the reference policy by bounding the policy ratio between \\( 1 - \epsilon \\) and \\( 1 + \epsilon \\).
In TRL though, as in the original paper, we only do one update per generation, so we can simplify the loss to the first form.
When \\( \mu = 1 \\) (default in TRL), the clipped surrogate objective simplifies to the original objective.
#### Loss Types
Several formulations of the objective have been proposed in the literature. Initially, the objective of GRPO was defined as follows:
$$
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} l_{i,t},
$$
where
$$
l_{i,t} = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right].
$$
The DAPO paper highlights the limitations of the GRPO algorithms sample-level loss in long-CoT scenarios, where longer responses are under-penalized, leading to poorer quality outputs. The proposed solution is a token-level normalization, which better handles longer sequences by assigning more balanced rewards to individual tokens, regardless of response length:
$$
\mathcal{L}_{\text{DAPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},
$$
Furthermore, it was demonstrated in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that the initial GRPO formulation introduces a response length bias. They show that while the DAPO formulation reduces this bias, it does not eliminate it completely. To fully remove this bias, they propose dividing by a constant instead of the sequence length, resulting in the following formulation:
$$
\mathcal{L}_{\text{Dr. GRPO}}(\theta) = - \frac{1}{LG} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},
$$
This constant is recommended to be the maximum completion length. To use this formulation, set `loss_type="dr_grpo"` in the [`GRPOConfig`].
## Logged metrics
The GRPO Trainer logs the following metrics:
- `completion_length`: The average completion length.
- `reward/{reward_func_name}`: The reward computed by each reward function.
- `reward`: The average reward.
- `reward_std` : The average standard deviation within reward groups.
- `kl` : The average KL divergence between the model and the reference model calculated on completions.
- `num_tokens`: The total number of tokens processed so far, including both prompts and completions.
- `completions/mean_length`: The average length of generated completions.
- `completions/min_length`: The minimun length of generated completions.
- `completions/max_length`: The maximum length of generated completions.
- `completions/mean_terminated_length`: The average length of generated completions that terminate with EOS.
- `completions/min_terminated_length`: The minimun length of generated completions that terminate with EOS.
- `completions/max_terminated_length`: The maximum length of generated completions that terminate with EOS.
- `completions/clipped_ratio` : The ratio of truncated (clipped) completions.
- `reward/{reward_func_name}/mean`: The average reward from a specific reward function.
- `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function.
- `reward`: The overall average reward after applying reward weights.
- `reward_std`: The standard deviation of the overall reward within each batch after applying reward weights.
- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero.
- `clip_ratio/region_mean`: The ratio of token probabilities where the GRPO objective is clipped to stay within the trust region:
$$
\text{clip}\left( r_{i,t}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \qquad r_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}\,.
$$
A higher value means more tokens are clipped, which constrains how much the policy $\pi_\theta$ can change.
- `clip_ratio/low_mean`: The average ratio of token probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
- `clip_ratio/low_min`: The minimum ratio of token probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
- `clip_ratio/high_mean`: The average ratio of token probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\)
- `clip_ratio/high_max`: The maximum ratio of token probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\).
## Customization
## Speed up training with vLLM-powered generation
### Speed up training with vLLM-powered generation
Generation is often the main bottleneck that makes training slow with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation. To enable it, pass `use_vllm=True` in the training arguments.
Generation is often the main bottleneck when training with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a high-throughput, low-latency inference engine for LLMs. To enable it, first install the package with
```shell
pip install trl[vllm]
```
We support two ways of using vLLM during training: **server mode** and **colocate mode**.
#### 🔌 Option 1: Server mode
In this mode, vLLM runs in a separate process (and using separate GPUs) and communicates with the trainer via HTTP. This is ideal if you have dedicated GPUs for inference.
1. **Start the vLLM server**:
```bash
trl vllm-serve --model <model_name>
```
2. **Enable server mode in your training script**:
```python
from trl import GRPOConfig
training_args = GRPOConfig(
...,
use_vllm=True,
vllm_mode="server", # default value, can be omitted
)
```
<Tip warning={true}>
Make sure that the server is using different GPUs than the trainer, otherwise you may run into NCCL errors. You can specify the GPUs to use with the `CUDA_VISIBLE_DEVICES` environment variable.
</Tip>
#### 🧩 Option 2: Colocate mode
In this mode, vLLM runs inside the trainer process and shares GPU memory with the training model. This avoids launching a separate server and can improve GPU utilization, but may lead to memory contention on the training GPUs.
```python
from trl import GRPOConfig
training_args = GRPOConfig(..., use_vllm=True)
```
training_args = GRPOConfig(
...,
use_vllm=True,
vllm_mode="colocate",
)
```
For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).
<Tip>
Depending on the model size and the overall GPU memory requirements for training, you may need to adjust the `vllm_gpu_memory_utilization` parameter in [`GRPOConfig`] to avoid underutilization or out-of-memory errors.
</Tip>
For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).
### GRPO at scale: train a 70B+ Model on multiple nodes
When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include:
- **DeepSpeed ZeRO Stage 3**: ZeRO leverages data parallelism to distribute model states (weights, gradients, optimizer states) across multiple GPUs and CPUs, reducing memory and compute requirements on each device. Since large models cannot fit on a single GPU, using ZeRO Stage 3 is required for training such model. For more details, see [DeepSpeed Integration](deepspeed_integration).
- **Accelerate**: Accelerate is a library that simplifies distributed training across multiple GPUs and nodes. It provides a simple API to launch distributed training and handles the complexities of distributed training, such as data parallelism, gradient accumulation, and distributed data loading. For more details, see [Distributing Training](distributing_training).
- **vLLM**: See the previous section on how to use vLLM to speed up generation.
Below is an example SLURM script to train a 70B model with GRPO on multiple nodes. This script trains a model on 4 nodes and uses the 5th node for vLLM-powered generation.
```sh
#!/bin/bash
#SBATCH --nodes=5
#SBATCH --gres=gpu:8
# Get the list of allocated nodes
NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
# Assign the first 4 nodes for training and the 5th node for vLLM
TRAIN_NODES="${NODELIST[@]:0:4}" # Nodes 0, 1, 2, 3 for training
VLLM_NODE="${NODELIST[4]}" # Node 4 for vLLM
# Run training on the first 4 nodes (Group 1)
srun --nodes=4 --ntasks=4 --nodelist="${NODELIST[@]:0:4}" accelerate launch \
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
--num_processes 32 \
--num_machines 4 \
--main_process_ip ${NODELIST[0]} \
--machine_rank $SLURM_PROCID \
--rdzv_backend c10d \
train_grpo.py \
--server_ip $VLLM_NODE &
# Run vLLM server on the 5th node (Group 2)
srun --nodes=1 --ntasks=1 --nodelist="${NODELIST[4]}" trl vllm-serve --model Qwen/Qwen2.5-72B --tensor_parallel_size 8 &
wait
```
```python
import argparse
from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--vllm_server_host", type=str, default="", help="The server IP")
args = parser.parse_args()
# Example dataset from TLDR
dataset = load_dataset("trl-lib/tldr", split="train")
# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
return [len(set(c)) for c in completions]
training_args = GRPOConfig(
output_dir="Qwen2.5-72B-GRPO",
per_device_train_batch_size=4,
bf16=True,
gradient_checkpointing=True,
logging_steps=10,
use_vllm=True,
vllm_server_host=args.vllm_server_host.replace("ip-", "").replace("-", "."), # from ip-X-X-X-X to X.X.X.X
)
trainer = GRPOTrainer(model="Qwen/Qwen2.5-72B", args=training_args, reward_funcs=reward_num_unique_chars, train_dataset=dataset)
trainer.train()
if __name__=="__main__":
main()
```
### Using a custom reward function
@ -132,6 +307,7 @@ The [`GRPOTrainer`] supports using custom reward functions instead of dense rewa
- The function must accept the following as keyword arguments:
- `prompts` (contains the prompts),
- `completions` (contains the generated completions),
- `completions_ids` (contains the tokenized completions),
- All columns names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument.
The easiest way to comply with this requirement is to use `**kwargs` in the function signature.
@ -145,9 +321,29 @@ The [`GRPOTrainer`] supports using custom reward functions instead of dense rewa
Below is an example of a reward function for a standard format that rewards longer completions:
```python
def reward_func(completions_ids, **kwargs):
"""Reward function that assigns higher scores to longer completions (in terms of token count)."""
return [float(len(ids)) for ids in completions_ids]
```
You can test it as follows:
```python
>>> prompts = ["The sky is", "The sun is"] # not used in the reward function, but the trainer will pass it
>>> completions = [" blue.", " in the sky."] # not used in the reward function, but the trainer will pass it
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]]
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
[2.0, 4.0]
```
#### Example 1.1: Reward longer completions (based in the number of characters)
Same as the previous example, but this time the reward function is based on the number of characters instead of tokens.
```python
def reward_func(completions, **kwargs):
"""Reward function that gives higher scores to longer completions."""
"""Reward function that assigns higher scores to longer completions (in terms of character count)."""
return [float(len(completion)) for completion in completions]
```
@ -156,7 +352,8 @@ You can test it as follows:
```python
>>> prompts = ["The sky is", "The sun is"]
>>> completions = [" blue.", " in the sky."]
>>> print(reward_func(prompts=prompts, completions=completions))
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]] # not used in the reward function, but the trainer will pass it
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
[6.0, 12.0]
```
@ -216,6 +413,67 @@ You can test this function as follows:
>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth)
[1.0, 0.0]
```
#### Example 4: Multi-task reward functions
Below is an example of using multiple reward functions in the [`GRPOTrainer`]. In this example, we define two task-specific reward functions: `math_reward_func` and `coding_reward_func`. The `math_reward_func` rewards math problems based on their correctness, while the `coding_reward_func` rewards coding problems based on whether the solution works.
```python
from datasets import Dataset
from trl import GRPOTrainer
# Define a dataset that contains both math and coding problems
dataset = Dataset.from_list(
[
{"prompt": "What is 2+2?", "task": "math"},
{"prompt": "Write a function that returns the sum of two numbers.", "task": "code"},
{"prompt": "What is 3*4?", "task": "math"},
{"prompt": "Write a function that returns the product of two numbers.", "task": "code"},
]
)
# Math-specific reward function
def math_reward_func(prompts, completions, task, **kwargs):
rewards = []
for prompt, completion, t in zip(prompts, completions, task):
if t == "math":
# Calculate math-specific reward
correct = check_math_solution(prompt, completion)
reward = 1.0 if correct else -1.0
rewards.append(reward)
else:
# Return None for non-math tasks
rewards.append(None)
return rewards
# Coding-specific reward function
def coding_reward_func(prompts, completions, task, **kwargs):
rewards = []
for prompt, completion, t in zip(prompts, completions, task):
if t == "coding":
# Calculate coding-specific reward
works = test_code_solution(prompt, completion)
reward = 1.0 if works else -1.0
rewards.append(reward)
else:
# Return None for non-coding tasks
rewards.append(None)
return rewards
# Use both task-specific reward functions
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=[math_reward_func, coding_reward_func],
train_dataset=dataset,
)
trainer.train()
```
In this example, the `math_reward_func` and `coding_reward_func` are designed to work with a mixed dataset that contains both math and coding problems. The `task` column in the dataset is used to determine which reward function to apply to each problem. If there is no relevant reward function for a sample in the dataset, the reward function will return `None` and the [`GRPOTrainer`] will continue with the valid functions and tasks. This allows the [`GRPOTrainer`] to handle multiple reward functions with different applicability.
Note that the [`GRPOTrainer`] will ignore the `None` rewards returned by the reward functions and only consider the rewards returned by the relevant functions. This ensures that the model is trained on the relevant tasks and ignores the tasks for which there is no relevant reward function.
#### Passing the reward function to the trainer

View File

@ -4,37 +4,35 @@
# TRL - Transformer Reinforcement Learning
TRL is a full stack library where we provide a set of tools to train transformer language models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step.
TRL is a full stack library where we provide a set of tools to train transformer language models with methods like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO), Reward Modeling, and more.
The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers).
You can also explore TRL-related models, datasets, and demos in the [TRL Hugging Face organization](https://huggingface.co/trl-lib).
## Learn
Learn post-training with TRL and other libraries in 🤗 [smol course](https://github.com/huggingface/smol-course).
## API documentation
## Contents
- [Model Classes](models): *A brief overview of what each public model class does.*
- [`SFTTrainer`](sft_trainer): *Supervise Fine-tune your model easily with `SFTTrainer`*
- [`RewardTrainer`](reward_trainer): *Train easily your reward model using `RewardTrainer`.*
- [`PPOTrainer`](ppo_trainer): *Further fine-tune the supervised fine-tuned model using PPO algorithm*
- [Best-of-N Sampling](best-of-n): *Use best of n sampling as an alternative way to sample predictions from your active model*
- [`DPOTrainer`](dpo_trainer): *Direct Preference Optimization training using `DPOTrainer`.*
- [`TextEnvironment`](text_environments): *Text environment to train your model using tools with RL.*
## Examples
- [Sentiment Tuning](sentiment_tuning): *Fine tune your model to generate positive movie contents*
- [Training with PEFT](lora_tuning_peft): *Memory efficient RLHF training using adapters with PEFT*
- [Detoxifying LLMs](detoxifying_a_lm): *Detoxify your language model through RLHF*
- [StackLlama](using_llama_models): *End-to-end RLHF training of a Llama model on Stack exchange dataset*
- [Learning with Tools](learning_tools): *Walkthrough of using `TextEnvironments`*
- [Multi-Adapter Training](multi_adapter_rl): *Use a single base model and multiple adapters for memory efficient end-to-end training*
The documentation is organized into the following sections:
- **Getting Started**: installation and quickstart guide.
- **Conceptual Guides**: dataset formats, training FAQ, and understanding logs.
- **How-to Guides**: reducing memory usage, speeding up training, distributing training, etc.
- **Integrations**: DeepSpeed, Liger Kernel, PEFT, etc.
- **Examples**: example overview, community tutorials, etc.
- **API**: trainers, utils, etc.
## Blog posts
<div class="mt-10">
<div class="w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5">
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/open-r1">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/open-r1/thumbnails.png" alt="thumbnail" class="mt-0">
<p class="text-gray-500 text-sm">Published on January 28, 2025</p>
<p class="text-gray-700">Open-R1: a fully open reproduction of DeepSeek-R1</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo_vlm">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/dpo_vlm/thumbnail.png" alt="thumbnail" class="mt-0">
<p class="text-gray-500 text-sm">Published on July 10, 2024</p>

View File

@ -2,56 +2,138 @@
[![](https://img.shields.io/badge/All_models-Iterative_SFT-blue)](https://huggingface.co/models?other=iterative-sft,trl)
Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.
## Usage
## Quickstart
To get started quickly, instantiate an instance a model, and a tokenizer.
To get started quickly, you can either pass a model identifier or a pre-instantiated model to the trainer:
```python
from trl import IterativeSFTConfig, IterativeSFTTrainer
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Using a model identifier
trainer = IterativeSFTTrainer(
"facebook/opt-350m",
args=IterativeSFTConfig(
max_length=512,
output_dir="./output",
),
)
# Or using a pre-instantiated model
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
trainer = IterativeSFTTrainer(
model,
tokenizer
args=IterativeSFTConfig(
max_length=512,
output_dir="./output",
),
processing_class=tokenizer,
)
```
You have the choice to either provide a list of strings or a list of tensors to the step function.
## Usage
#### Using a list of tensors as input:
The [`IterativeSFTTrainer`] supports two ways of providing input data to the `step` function:
### Using a list of tensors as input:
```python
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask
"attention_mask": attention_mask,
}
trainer.step(**inputs)
```
#### Using a list of strings as input:
### Using a list of strings as input:
```python
inputs = {
"texts": texts
"texts": texts,
"texts_labels": texts_labels, # Optional, defaults to texts
}
trainer.step(**inputs)
```
For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels.
For causal language models, labels will automatically be created from `input_ids` or from `texts`. When using sequence to sequence models you will have to provide your own labels or `text_labels`.
## IterativeTrainer
## Configuration
The [`IterativeSFTConfig`] class provides several parameters to customize the training:
```python
from trl import IterativeSFTConfig
config = IterativeSFTConfig(
# Model initialization parameters
model_init_kwargs={"torch_dtype": "bfloat16"},
# Data preprocessing parameters
max_length=512,
truncation_mode="keep_end",
# Training parameters
output_dir="./output",
learning_rate=2e-5,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
max_steps=1000,
logging_steps=10,
save_steps=100,
optim="adamw_torch",
report_to="wandb",
)
```
### Model Initialization
You can control how the model is initialized by passing keyword arguments to `model_init_kwargs`:
```python
config = IterativeSFTConfig(
model_init_kwargs={
"torch_dtype": "bfloat16",
"device_map": "auto",
"trust_remote_code": True,
}
)
```
### Data Preprocessing
The trainer supports two truncation modes:
- `keep_end`: Truncates from the start of the sequence
- `keep_start`: Truncates from the end of the sequence
```python
config = IterativeSFTConfig(
max_length=512,
truncation_mode="keep_end", # or "keep_start"
)
```
### Training Optimization
You can optimize CUDA cache usage for more memory-efficient training:
```python
config = IterativeSFTConfig(
optimize_device_cache=True,
)
```
## IterativeSFTTrainer
[[autodoc]] IterativeSFTTrainer
## IterativeSFTConfig
[[autodoc]] IterativeSFTConfig

View File

@ -53,9 +53,9 @@ Distributed across 8 x H100 GPUs, the training takes approximately 30 minutes. Y
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/kto-qwen2-reward-margin.png)
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-KTO
<pre><code>$ transformers-cli chat --model_name_or_path trl-lib/Qwen2-0.5B-KTO
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?
@ -121,14 +121,14 @@ By default, they are both 1. However, if you have more of one or the other, then
While training and evaluating we record the following reward metrics:
- `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
- `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
- `logps/chosen`: the mean log probabilities of the chosen completions
- `logps/rejected`: the mean log probabilities of the rejected completions
- `logits/chosen`: the mean logits of the chosen completions
- `logits/rejected`: the mean logits of the rejected completions
- `kl`: the KL divergence between the policy model and the reference model
- `rewards/chosen_sum`: the sum of log probabilities of the policy model for the chosen responses scaled by beta
- `rewards/rejected_sum`: the sum of log probabilities of the policy model for the rejected responses scaled by beta
- `logps/chosen_sum`: the sum of log probabilities of the chosen completions
- `logps/rejected_sum`: the sum of log probabilities of the rejected completions
- `logits/chosen_sum`: the sum of logits of the chosen completions
- `logits/rejected_sum`: the sum of logits of the rejected completions
- `count/chosen`: the count of chosen samples in a batch
- `count/rejected`: the count of rejected samples in a batch
## KTOTrainer

View File

@ -1,233 +0,0 @@
# Learning Tools (Experimental 🧪)
Using Large Language Models (LLMs) with tools has been a popular topic recently with awesome works such as [ToolFormer](https://huggingface.co/papers/2302.04761) and [ToolBench](https://huggingface.co/papers/2305.16504). In TRL, we provide a simple example of how to teach LLM to use tools with reinforcement learning.
Here's an overview of the scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples/research_projects/tools):
| File | Description |
|---|---|
| [`calculator.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/calculator.py) | Script to train LLM to use a calculator with reinforcement learning. |
| [`triviaqa.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/triviaqa.py) | Script to train LLM to use a wiki tool to answer questions. |
| [`python_interpreter.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/python_interpreter.py) | Script to train LLM to use python interpreter to solve math puzzles. |
<Tip warning={true}>
Note that the scripts above rely heavily on the `TextEnvironment` API which is still under active development. The API may change in the future. Please see [`TextEnvironment`](text_environment) for the related docs.
</Tip>
## Learning to Use a Calculator
The rough idea is as follows:
1. Load a tool such as [ybelkada/simple-calculator](https://huggingface.co/spaces/ybelkada/simple-calculator) that parse a text calculation like `"14 + 34"` and return the calculated number:
```python
from transformers import AutoTokenizer, load_tool
tool = load_tool("ybelkada/simple-calculator")
tool_fn = lambda text: str(round(float(tool(text)), 2)) # rounding to 2 decimal places
```
1. Define a reward function that returns a positive reward if the tool returns the correct answer. In the script we create a dummy reward function like `reward_fn = lambda x: 1`, but we override the rewards directly later.
1. Create a prompt on how to use the tools
```python
# system prompt
prompt = """\
What is 13.1-3?
<request><SimpleCalculatorTool>13.1-3<call>10.1<response>
Result=10.1<submit>
What is 4*3?
<request><SimpleCalculatorTool>4*3<call>12<response>
Result=12<submit>
What is 12.1+1?
<request><SimpleCalculatorTool>12.1+1<call>13.1<response>
Result=13.1<submit>
What is 12.1-20?
<request><SimpleCalculatorTool>12.1-20<call>-7.9<response>
Result=-7.9<submit>"""
```
3. Create a `trl.TextEnvironment` with the model
```python
env = TextEnvironment(
model,
tokenizer,
{"SimpleCalculatorTool": tool_fn},
reward_fn,
prompt,
generation_kwargs=generation_kwargs,
)
```
4. Then generate some data such as `tasks = ["\n\nWhat is 13.1-3?", "\n\nWhat is 4*3?"]` and run the environment with `queries, responses, masks, rewards, histories = env.run(tasks)`. The environment will look for the `<call>` token in the prompt and append the tool output to the response; it will also return the mask associated with the response. You can further use the `histories` to visualize the interaction between the model and the tool; `histories[0].show_text()` will show the text with color-coded tool output and `histories[0].show_tokens(tokenizer)` will show visualize the tokens.
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/learning_tools.png)
1. Finally, we can train the model with `train_stats = ppo_trainer.step(queries, responses, rewards, masks)`. The trainer will use the mask to ignore the tool output when computing the loss, make sure to pass that argument to `step`.
## Experiment results
We trained a model with the above script for 10 random seeds. You can reproduce the run with the following command. Feel free to remove the `--slurm-*` arguments if you don't have access to a slurm cluster.
```
WANDB_TAGS="calculator_final" python benchmark/benchmark.py \
--command "python examples/research_projects/tools/calculator.py" \
--num-seeds 10 \
--start-seed 1 \
--workers 10 \
--slurm-gpus-per-task 1 \
--slurm-ntasks 1 \
--slurm-total-cpus 8 \
--slurm-template-path benchmark/trl.slurm_template
```
We can then use [`openrlbenchmark`](https://github.com/openrlbenchmark/openrlbenchmark) which generates the following plot.
```
# pip install openrlbenchmark==0.2.1a5
python -m openrlbenchmark.rlops_multi_metrics \
--filters '?we=openrlbenchmark&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.tracker_project_name&cen=trl_ppo_trainer_config.value.log_with&metrics=env/reward_mean&metrics=objective/kl' \
'wandb?tag=calculator_final&cl=calculator_mask' \
--env-ids trl \
--check-empty-runs \
--pc.ncols 2 \
--pc.ncols-legend 1 \
--output-filename static/0compare \
--scan-history
```
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/learning_tools_chart.png)
As we can see, while 1-2 experiments crashed for some reason, most of the runs obtained near perfect proficiency in the calculator task.
## (Early Experiments 🧪): learning to use a wiki tool for question answering
In the [ToolFormer](https://huggingface.co/papers/2302.04761) paper, it shows an interesting use case that utilizes a Wikipedia Search tool to help answer questions. In this section, we attempt to perform similar experiments but uses RL instead to teach the model to use a wiki tool on the [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) dataset.
<Tip warning={true}>
**Note that many settings are different so the results are not directly comparable.**
</Tip>
### Building a search index
Since [ToolFormer](https://huggingface.co/papers/2302.04761) did not open source, we needed to first replicate the search index. It is mentioned in their paper that the authors built the search index using a BM25 retriever that indexes the Wikipedia dump from [KILT](https://github.com/facebookresearch/KILT)
Fortunately, [`pyserini`](https://github.com/castorini/pyserini) already implements the BM25 retriever and provides a prebuilt index for the KILT Wikipedia dump. We can use the following code to search the index.
```python
from pyserini.search.lucene import LuceneSearcher
import json
searcher = LuceneSearcher.from_prebuilt_index('wikipedia-kilt-doc')
def search(query):
hits = searcher.search(query, k=1)
hit = hits[0]
contents = json.loads(hit.raw)['contents']
return contents
print(search("tennis racket"))
```
```
Racket (sports equipment)
A racket or racquet is a sports implement consisting of a handled frame with an open hoop across which a network of strings or catgut is stretched tightly. It is used for striking a ball or shuttlecock in games such as squash, tennis, racquetball, and badminton. Collectively, these games are known as racket sports. Racket design and manufacturing has changed considerably over the centuries.
The frame of rackets for all sports was traditionally made of solid wood (later laminated wood) and the strings of animal intestine known as catgut. The traditional racket size was limited by the strength and weight of the wooden frame which had to be strong enough to hold the strings and stiff enough to hit the ball or shuttle. Manufacturers started adding non-wood laminates to wood rackets to improve stiffness. Non-wood rackets were made first of steel, then of aluminum, and then carbon fiber composites. Wood is still used for real tennis, rackets, and xare. Most rackets are now made of composite materials including carbon fiber or fiberglass, metals such as titanium alloys, or ceramics.
...
```
We then basically deployed this snippet as a Hugging Face space [here](https://huggingface.co/spaces/vwxyzjn/pyserini-wikipedia-kilt-doc), so that we can use the space as a `transformers.Tool` later.
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/pyserini.png)
### Experiment settings
We use the following settings:
* use the `bigcode/starcoderbase` model as the base model
* use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragraphs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool.
* test if the response contain the answer string, if so, give a reward of 1, otherwise, give a reward of 0.
* notice this is a simplified evaluation criteria. In [ToolFormer](https://huggingface.co/papers/2302.04761), the authors checks if the first 20 words of the response contain the correct answer.
* used the following prompt that demonstrates the usage of the wiki tool.
```python
prompt = """\
Answer the following question:
Q: In which branch of the arts is Patricia Neary famous?
A: Ballets
A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
Result=Ballets<submit>
Q: Who won Super Bowl XX?
A: Chicago Bears
A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 4610, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
Result=Chicago Bears<submit>
Q: """
```
### Result and Discussion
Our experiments show that the agent can learn to use the wiki tool to answer questions. The learning curves would go up mostly, but one of the experiment did crash.
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/triviaqa_learning_curves.png)
Wandb report is [here](https://wandb.ai/costa-huang/cleanRL/reports/TriviaQA-Final-Experiments--Vmlldzo1MjY0ODk5) for further inspection.
Note that the correct rate of the trained model is on the low end, which could be due to the following reasons:
* **incorrect searches:** When given the question `"What is Bruce Willis' real first name?"` if the model searches for `Bruce Willis`, our wiki tool returns "Patrick Poivey (born 18 February 1948) is a French actor. He is especially known for his voice: he is the French dub voice of Bruce Willis since 1988.` But a correct search should be `Walter Bruce Willis (born March 19, 1955) is an American former actor. He achieved fame with a leading role on the comedy-drama series Moonlighting (19851989) and appeared in over a hundred films, gaining recognition as an action hero after his portrayal of John McClane in the Die Hard franchise (19882013) and other roles.[1][2]"
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/real_first_name.png)
* **unnecessarily long response**: The wiki tool by default sometimes output very long sequences. E.g., when the wiki tool searches for "Brown Act"
* Our wiki tool returns "The Ralph M. Brown Act, located at California Government Code 54950 "et seq.", is an act of the California State Legislature, authored by Assemblymember Ralph M. Brown and passed in 1953, that guarantees the public's right to attend and participate in meetings of local legislative bodies."
* [ToolFormer](https://huggingface.co/papers/2302.04761)'s wiki tool returns "The Ralph M. Brown Act is an act of the California State Legislature that guarantees the public's right to attend and participate in meetings of local legislative bodies." which is more succinct.
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/brown_act.png)
## (Early Experiments 🧪): solving math puzzles with python interpreter
In this section, we attempt to teach the model to use a python interpreter to solve math puzzles. The rough idea is to give the agent a prompt like the following:
```python
prompt = """\
Example of using a Python API to solve math questions.
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
<request><PythonInterpreter>
def solution():
money_initial = 23
bagels = 5
bagel_cost = 3
money_spent = bagels * bagel_cost
money_left = money_initial - money_spent
result = money_left
return result
print(solution())
<call>8<response>
Result = 8 <submit>
Q: """
```
Training experiment can be found at https://wandb.ai/lvwerra/trl-gsm8k/runs/a5odv01y
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/gms8k_learning_curve.png)

View File

@ -1,74 +1,99 @@
# Logging
As reinforcement learning algorithms are historically challenging to debug, it's important to pay careful attention to logging.
By default, the TRL [`PPOTrainer`] saves a lot of relevant information to wandb or tensorboard.
By default, TRL trainers like [`PPOTrainer`] and [`GRPOTrainer`] save a lot of relevant information to supported experiment trackers like Weights & Biases (wandb) or TensorBoard.
Upon initialization, pass one of these two options to the [`PPOConfig`]:
Upon initialization, pass the `report_to` argument to the respective configuration object (e.g., [`PPOConfig`] for `PPOTrainer`, or [`GRPOConfig`] for `GRPOTrainer`):
```
training_args = PPOConfig(..., report_to="wandb") # or "tensorboard"
```python
# For PPOTrainer
ppo_config = PPOConfig(
# ...,
report_to="wandb" # or "tensorboard"
)
# For GRPOTrainer
grpc_config = GRPOConfig(
# ...,
report_to="wandb" # or "tensorboard"
)
```
If you want to log with tensorboard, add the kwarg `project_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig.
If you want to log with TensorBoard, you might also need to specify logging directories, for example, by adding `logging_dir=PATH_TO_LOGS` to the configuration object (e.g., `PPOConfig` or `GRPOConfig`).
## PPO Logging
Here's a brief explanation for the logged metrics provided in the data:
Key metrics to monitor. We want to maximize the reward, maintain a low KL divergence, and maximize entropy:
1. `env/reward_mean`: The average reward obtained from the environment. Alias `ppo/mean_scores`, which is used to specifically monitor the reward model.
1. `env/reward_std`: The standard deviation of the reward obtained from the environment. Alias ``ppo/std_scores`, which is used to specifically monitor the reward model.
1. `env/reward_dist`: The histogram distribution of the reward obtained from the environment.
1. `objective/kl`: The mean Kullback-Leibler (KL) divergence between the old and new policies. It measures how much the new policy deviates from the old policy. The KL divergence is used to compute the KL penalty in the objective function.
1. `objective/kl_dist`: The histogram distribution of the `objective/kl`.
1. `objective/kl_coef`: The coefficient for Kullback-Leibler (KL) divergence in the objective function.
1. `ppo/mean_non_score_reward`: The **KL penalty** calculated by `objective/kl * objective/kl_coef` as the total reward for optimization to prevent the new policy from deviating too far from the old policy.
1. `objective/entropy`: The entropy of the model's policy, calculated by `-logprobs.sum(-1).mean()`. High entropy means the model's actions are more random, which can be beneficial for exploration.
Training stats:
1. `ppo/learning_rate`: The learning rate for the PPO algorithm.
1. `ppo/policy/entropy`: The entropy of the model's policy, calculated by `pd = torch.nn.functional.softmax(logits, dim=-1); entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)`. It measures the randomness of the policy.
1. `ppo/policy/clipfrac`: The fraction of probability ratios (old policy / new policy) that fell outside the clipping range in the PPO objective. This can be used to monitor the optimization process.
1. `ppo/policy/approxkl`: The approximate KL divergence between the old and new policies, measured by `0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)`, corresponding to the `k2` estimator in http://joschu.net/blog/kl-approx.html
1. `ppo/policy/policykl`: Similar to `ppo/policy/approxkl`, but measured by `masked_mean(old_logprobs - logprobs, mask)`, corresponding to the `k1` estimator in http://joschu.net/blog/kl-approx.html
1. `ppo/policy/ratio`: The histogram distribution of the ratio between the new and old policies, used to compute the PPO objective.
1. `ppo/policy/advantages_mean`: The average of the GAE (Generalized Advantage Estimation) advantage estimates. The advantage function measures how much better an action is compared to the average action at a state.
1. `ppo/policy/advantages`: The histogram distribution of `ppo/policy/advantages_mean`.
1. `ppo/returns/mean`: The mean of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance. See https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ for more details.
1. `ppo/returns/var`: The variance of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance.
1. `ppo/val/mean`: The mean of the values, used to monitor the value function's performance.
1. `ppo/val/var` : The variance of the values, used to monitor the value function's performance.
1. `ppo/val/var_explained`: The explained variance for the value function, used to monitor the value function's performance.
1. `ppo/val/clipfrac`: The fraction of the value function's predicted values that are clipped.
1. `ppo/val/vpred`: The predicted values from the value function.
1. `ppo/val/error`: The mean squared error between the `ppo/val/vpred` and returns, used to monitor the value function's performance.
1. `ppo/loss/policy`: The policy loss for the Proximal Policy Optimization (PPO) algorithm.
1. `ppo/loss/value`: The loss for the value function in the PPO algorithm. This value quantifies how well the function estimates the expected future rewards.
1. `ppo/loss/total`: The total loss for the PPO algorithm. It is the sum of the policy loss and the value function loss.
Stats on queries, responses, and logprobs:
1. `tokens/queries_len_mean`: The average length of the queries tokens.
1. `tokens/queries_len_std`: The standard deviation of the length of the queries tokens.
1. `tokens/queries_dist`: The histogram distribution of the length of the queries tokens.
1. `tokens/responses_len_mean`: The average length of the responses tokens.
1. `tokens/responses_len_std`: The standard deviation of the length of the responses tokens.
1. `tokens/responses_dist`: The histogram distribution of the length of the responses tokens. (Costa: inconsistent naming, should be `tokens/responses_len_dist`)
1. `objective/logprobs`: The histogram distribution of the log probabilities of the actions taken by the model.
1. `objective/ref_logprobs`: The histogram distribution of the log probabilities of the actions taken by the reference model.
* `eps`: Tracks the number of episodes per second.
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
* `objective/scores`: The mean scores returned by the reward model / environment.
* `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
* `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
* `loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
* `loss/value_avg`: The average value loss, indicating the difference between the predicted value and the actual reward.
* `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to `policy/clipfrac_avg` but for the value function.
* `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
* `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
* `lr`: The current learning rate used by the optimizer.
* `episode`: The current episode count in the training process.
### Crucial values
During training, many values are logged, here are the most important ones:
1. `env/reward_mean`,`env/reward_std`, `env/reward_dist`: the properties of the reward distribution from the "environment" / reward model
1. `ppo/mean_non_score_reward`: The mean negated KL penalty during training (shows the delta between the reference model and the new policy over the batch in the step)
1. `objective/scores`: The mean scores returned by the reward model / environment.
1. `objective/rlhf_reward`: The mean RLHF reward. This is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
1. `objective/non_score_reward`: The mean reward from non-score-related sources (e.g., KL penalty).
Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables):
1. `ppo/loss/value`: it will spike / NaN when not going well.
1. `ppo/policy/ratio`: `ratio` being 1 is a baseline value, meaning that the probability of sampling a token is the same under the new and old policy. If the ratio is too high like 200, it means the probability of sampling a token is 200 times higher under the new policy than the old policy. This is a sign that the new policy is too different from the old policy, which will likely cause overoptimization and collapse training later on.
1. `ppo/policy/clipfrac` and `ppo/policy/approxkl`: if `ratio` is too high, the `ratio` is going to get clipped, resulting in high `clipfrac` and high `approxkl` as well.
1. `objective/kl`: it should stay positive so that the policy is not too far away from the reference policy.
1. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities.
1. `loss/value_avg`: The average value loss. It will spike / NaN when not going well.
1. `val/ratio`: The mean ratio of the current policy probability to the old policy probability. This number should float around 1.0. If this `ratio` is too high (e.g., 2.0 or 1000.0) or too small (e.g., 0.1), it means the updates between consecutive policies are too drastic.
1. `policy/clipfrac_avg` and `policy/approxkl_avg`: If `val/ratio` is too high, the `ratio` is going to get clipped, resulting in high `policy/clipfrac_avg` and high `policy/approxkl_avg` as well.
1. `objective/kl`: The mean KL divergence. It should stay positive and ideally not too large, so that the policy is not too far away from the reference policy.
## GRPO Logging
Here's a brief explanation for the logged metrics provided in the data for the GRPO trainer:
* `num_tokens`: Total number of input tokens processed during training so far.
**Completions:**
* `completions/mean_length`: Mean length of all generated completions (including those not ending with an EOS token).
* `completions/min_length`: Minimum length among all generated completions.
* `completions/max_length`: Maximum length among all generated completions.
* `completions/clipped_ratio`: The ratio of completions that did not end with an EOS token before reaching the maximum generation length (i.e., they were truncated).
* `completions/mean_terminated_length`: Mean length of only those completions that successfully ended with an EOS token.
* `completions/min_terminated_length`: Minimum length among completions that ended with an EOS token.
* `completions/max_terminated_length`: Maximum length among completions that ended with an EOS token.
**Rewards:**
* `rewards/{reward_func_name}/mean`: The mean reward obtained from a specific, named reward function (e.g., `rewards/my_custom_reward/mean`). This is logged for each reward function used.
* `rewards/{reward_func_name}/std`: The standard deviation of rewards from a specific, named reward function.
* `reward`: The overall mean of the (potentially weighted and, if `args.scale_rewards` is true, normalized) rewards, after group-wise normalization (advantages).
* `reward_std`: The standard deviation of the (potentially weighted) rewards *before* group-wise normalization for advantages.
**Policy and Loss Metrics:**
* `kl`: The mean Kullback-Leibler (KL) divergence between the current policy and the reference policy. This is logged only if `beta` (the KL coefficient in `GRPOConfig`) is non-zero.
* If Liger GRPOLoss is used (`use_liger_loss: True` in `GRPOConfig`):
* `clip_ratio`: The fraction of policy updates where the probability ratio was clipped according to the GRPO loss's epsilon bounds.
* If standard GRPOLoss is used (`use_liger_loss: False`):
* `clip_ratio/low_mean`: The mean fraction of instances where the probability ratio `r_t(θ)` was clipped at the lower bound `1 - epsilon_low` (occurs when advantage is negative and ratio is below the bound).
* `clip_ratio/low_min`: The minimum observed fraction for `clip_ratio/low_mean` across batches/processes.
* `clip_ratio/high_mean`: The mean fraction of instances where the probability ratio `r_t(θ)` was clipped at the upper bound `1 + epsilon_high` (occurs when advantage is positive and ratio is above the bound).
* `clip_ratio/high_max`: The maximum observed fraction for `clip_ratio/high_mean` across batches/processes.
* `clip_ratio/region_mean`: The mean fraction of instances where the probability ratio was clipped at either the lower or upper bound.
### Crucial GRPO values
During GRPO training, monitor these values for insights into performance and stability:
1. `reward`: This is the primary objective. It reflects the (group-wise normalized) rewards the policy is achieving. It should generally increase during successful training.
1. `kl`: If `beta > 0`, this tracks the divergence from the reference model. Keep an eye on it to ensure the policy doesn't stray too far, which can lead to instability.
1. `clip_ratio/*` (either `clip_ratio` for Liger loss or the more detailed `clip_ratio/...` metrics for standard loss): These indicate how often the policy updates are being constrained by the GRPO clipping mechanism. Very high values might suggest that the policy is trying to change too drastically (potentially due to large advantages or a learning rate that's too high) or that the epsilon clipping range is too restrictive.
1. `completions/clipped_ratio`: A high ratio here indicates that the model is frequently generating completions that are cut off by `max_completion_length` rather than naturally ending with an EOS token. This might suggest issues with learning sequence termination or that `max_completion_length` is too short.
1. `rewards/{reward_func_name}/mean`: Monitoring the mean of individual reward functions can help diagnose which aspects of the desired behavior the model is learning or struggling with, especially when using multiple reward sources.

View File

@ -0,0 +1,5 @@
# Model Utilities
## get_act_offloading_ctx_manager
[[autodoc]] models.get_act_offloading_ctx_manager

View File

@ -51,9 +51,9 @@ accelerate launch train_nash_md.py
Distributed across 8 GPUs, the training takes approximately 3 hours.
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-NashMD) performs, you can use the [TRL Chat CLI](clis#chat-interface).
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-NashMD) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-NashMD
<pre><code>$ transformers-cli chat --model_name_or_path trl-lib/Qwen2-0.5B-NashMD
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?

View File

@ -53,9 +53,9 @@ Distributed across 8 GPUs, the training takes approximately 1 hour. You can veri
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/online-dpo-qwen2.png)
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-OnlineDPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-OnlineDPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-OnlineDPO
<pre><code>$ transformers-cli chat --model_name_or_path trl-lib/Qwen2-0.5B-OnlineDPO
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?

View File

@ -56,9 +56,9 @@ Distributed across 8 GPUs, the training takes approximately 30 minutes. You can
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/orpo-qwen2-reward-margin.png)
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-ORPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-ORPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-ORPO
<pre><code>$ transformers-cli chat --model_name_or_path trl-lib/Qwen2-0.5B-ORPO
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?

9
docs/source/others.md Normal file
View File

@ -0,0 +1,9 @@
# Other
## profiling_decorator
[[autodoc]] extras.profiling.profiling_decorator
## profiling_context
[[autodoc]] extras.profiling.profiling_context

View File

@ -52,7 +52,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
* `lr`: lr: The current learning rate used by the optimizer.
* `episode`: episode: The current global step or episode count in the training process.
* `episode`: episode: The current episode count in the training process.
## Cookbook

View File

@ -14,7 +14,7 @@ Sequence lengths in the dataset can vary widely. When data is batched, sequences
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/why_you_should_truncate.png" alt="Truncation prompt completion" width="600"/>
</div>
To reduce memory usage, its important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case.
To reduce memory usage, it's important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case.
<hfoptions id="dpo">
<hfoption id="DPO">
@ -44,7 +44,7 @@ training_args = DPOConfig(..., max_completion_length=...)
</hfoption>
<hfoption id="SFT">
SFT truncation is applied to the input sequence via the `max_seq_length` parameter.
SFT truncation is applied to the input sequence via the `max_length` parameter.
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/truncation_input_ids.png" alt="Truncation input ids" width="600"/>
@ -55,7 +55,7 @@ To set the truncation parameter, use the following code snippet:
```python
from trl import SFTConfig
training_args = SFTConfig(..., max_seq_length=...)
training_args = SFTConfig(..., max_length=...)
```
</hfoption>
@ -85,7 +85,7 @@ Packing eliminates padding, preserves all sequence information, and allows for f
```python
from trl import SFTConfig
training_args = SFTConfig(..., packing=True, max_seq_length=512)
training_args = SFTConfig(..., packing=True, max_length=512)
```
<Tip warning={true}>
@ -94,6 +94,77 @@ Packing may cause batch contamination, where adjacent sequences influence one an
</Tip>
## Padding-free
Padding-free batching is an alternative approach for reducing memory usage. In this method, a batch is first sampled and then flattened into a single sequence, avoiding padding. Unlike packing, which can result in incomplete sequences by combining parts of different samples, padding-free batching ensures that all sequences remain complete and intact.
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/padding-free.png" alt="Padding-free batching" width="600"/>
</div>
<Tip warning={true}>
It's highly recommended to use padding-free batching with **Flash Attention 2**. Otherwise, you may encounter batch contamination issues.
</Tip>
<hfoptions id="padding-free">
<hfoption id="DPO">
```python
from trl import DPOConfig
training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"})
```
</hfoption>
<hfoption id="SFT">
```python
from trl import SFTConfig
training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"})
```
</hfoption>
</hfoptions>
## Activation offloading
Activation offloading is a memory efficiency technique that reduces GPU VRAM usage by temporarily moving activation tensors to CPU RAM during the forward pass and bringing them back only when needed for the backward pass. This significantly reduces peak memory usage at the cost of slightly increased training time.
To enable activation offloading in your SFT training configuration:
</hfoption>
<hfoption id="SFT">
```python
from trl import SFTConfig
training_args = SFTConfig(..., activation_offloading=True)
```
</hfoption>
</hfoptions>
<Tip warning={true}>
When using activation offloading with models that use Liger kernels, you must disable Liger cross entropy due to compatibility issues. The issue occurs specifically with `use_liger_kernel=True` because Liger cross entropy performs in-place operations which conflict with activation offloading. The default setting (`use_liger_kernel=False`) works:
```python
# When using activation offloading with a model that uses Liger kernels:
from trl import SFTConfig
training_args = SFTConfig(
activation_offloading=True,
use_liger_kernel=False, # Disable Liger cross entropy
# Other parameters...
)
```
</Tip>
Under the hood, activation offloading implements PyTorch's [`saved_tensors_hooks`](https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html#hooks-for-autograd-saved-tensors) to intercept activations during the forward pass. It intelligently manages which tensors to offload based on size and context, avoiding offloading output tensors which would be inefficient. For performance optimization, it can optionally use CUDA streams to overlap computation with CPU-GPU transfers.
## Disabling model gathering for generation in online methods
When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Online methods involve generating completions from the model as part of the training process. During this step, the model weights are temporarily gathered on a single GPU for generation. For very large models, this gathering can lead to out-of-memory (OOM) errors, as described in this issue: [#2250](https://github.com/huggingface/trl/issues/2250#issue-2598304204).
@ -101,6 +172,15 @@ When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Onl
If you encounter this issue, you can disable the gathering of model weights for generation by setting the following parameter:
<hfoptions id="ds3_gather_for_generation">
<hfoption id="GRPO">
```python
from trl import GRPOConfig
training_args = GRPOConfig(..., ds3_gather_for_generation=False)
```
</hfoption>
<hfoption id="Online DPO">
```python

9
docs/source/rewards.md Normal file
View File

@ -0,0 +1,9 @@
# Reward Functions
This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`].
## Format rewards
### think_format_reward
[[autodoc]] rewards.think_format_reward

View File

@ -2,10 +2,7 @@
[![](https://img.shields.io/badge/All_models-SFT-blue)](https://huggingface.co/models?other=sft,trl) [![](https://img.shields.io/badge/smol_course-Chapter_1-yellow)](https://github.com/huggingface/smol-course/tree/main/1_instruction_tuning)
Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset.
Check out a complete flexible example at [`trl/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/trl/scripts/sft.py).
Experimental support for Vision Language Models is also included in the example [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft_vlm.py).
Supervised fine-tuning (SFT) is the most common step in post-training foundation models, and also one of the most effective. In TRL, we provide a simple API to train models with SFT in a few lines of code; for a complete training script, check out [`trl/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/trl/scripts/sft.py). Experimental support for Vision Language Models is also included in [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft_vlm.py).
## Quickstart
@ -19,7 +16,7 @@ from trl import SFTConfig, SFTTrainer
dataset = load_dataset("stanfordnlp/imdb", split="train")
training_args = SFTConfig(
max_seq_length=512,
max_length=512,
output_dir="/tmp",
)
trainer = SFTTrainer(
@ -29,7 +26,7 @@ trainer = SFTTrainer(
)
trainer.train()
```
Make sure to pass the correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
Make sure to pass the correct value for `max_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
You can also construct a model outside of the trainer and pass it as follows:
@ -53,123 +50,24 @@ trainer = SFTTrainer(
trainer.train()
```
The above snippets will use the default training arguments from the [`SFTConfig`] class. If you want to modify the defaults pass in your modification to the `SFTConfig` constructor and pass them to the trainer via the `args` argument.
The above snippets will use the default training arguments from the [`SFTConfig`] class. If you want to modify the defaults, pass in your modification to the `SFTConfig` constructor and pass it to the trainer via the `args` argument.
## Advanced usage
### Train on completions only
You can use the `DataCollatorForCompletionOnlyLM` to train your model on the generated prompts only. Note that this works only in the case when `packing=False`.
To instantiate that collator for instruction data, pass a response template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on completions only on the CodeAlpaca dataset:
To train on completions only, simply use a [prompt-completion](dataset_formats#prompt-completion) dataset. In this mode, loss is computed solely on the completion part.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
def formatting_prompts_func(example):
output_texts = []
for i in range(len(example['instruction'])):
text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
output_texts.append(text)
return output_texts
response_template = " ### Answer:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
trainer = SFTTrainer(
model,
train_dataset=dataset,
args=SFTConfig(output_dir="/tmp"),
formatting_func=formatting_prompts_func,
data_collator=collator,
)
trainer.train()
```
To instantiate that collator for assistant style conversation data, pass a response template, an instruction template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on assistant completions only on the Open Assistant Guanaco dataset:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
instruction_template = "### Human:"
response_template = "### Assistant:"
collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template, response_template=response_template, tokenizer=tokenizer, mlm=False)
trainer = SFTTrainer(
model,
args=SFTConfig(output_dir="/tmp"),
train_dataset=dataset,
data_collator=collator,
)
trainer.train()
```
Make sure to have a `pad_token_id` which is different from `eos_token_id` which can result in the model not properly predicting EOS (End of Sentence) tokens during generation.
#### Using token_ids directly for `response_template`
Some tokenizers like Llama 2 (`meta-llama/Llama-2-XXb-hf`) tokenize sequences differently depending on whether they have context or not. For example:
```python
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
def print_tokens_with_ids(txt):
tokens = tokenizer.tokenize(txt, add_special_tokens=False)
token_ids = tokenizer.encode(txt, add_special_tokens=False)
print(list(zip(tokens, token_ids)))
prompt = """### User: Hello\n\n### Assistant: Hi, how can I help you?"""
print_tokens_with_ids(prompt) # [..., ('▁Hello', 15043), ('<0x0A>', 13), ('<0x0A>', 13), ('##', 2277), ('#', 29937), ('▁Ass', 4007), ('istant', 22137), (':', 29901), ...]
response_template = "### Assistant:"
print_tokens_with_ids(response_template) # [('▁###', 835), ('▁Ass', 4007), ('istant', 22137), (':', 29901)]
```
In this case, and due to lack of context in `response_template`, the same string ("### Assistant:") is tokenized differently:
- Text (with context): `[2277, 29937, 4007, 22137, 29901]`
- `response_template` (without context): `[835, 4007, 22137, 29901]`
This will lead to an error when the `DataCollatorForCompletionOnlyLM` does not find the `response_template` in the dataset example text:
```
RuntimeError: Could not find response key [835, 4007, 22137, 29901] in token IDs tensor([ 1, 835, ...])
```
To solve this, you can tokenize the `response_template` with the same context as in the dataset, truncate it as needed and pass the `token_ids` directly to the `response_template` argument of the `DataCollatorForCompletionOnlyLM` class. For example:
```python
response_template_with_context = "\n### Assistant:" # We added context here: "\n". This is enough for this tokenizer
response_template_ids = tokenizer.encode(response_template_with_context, add_special_tokens=False)[2:] # Now we have it like in the dataset texts: `[2277, 29937, 4007, 22137, 29901]`
data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)
```
If youd like to compute loss on both the prompt **and** the completion while still using a prompt-completion dataset, set `completion_only_loss=False` in the [`SFTConfig`]. This is equivalent to [converting the dataset to a language modeling](dataset_formats#from-prompt-completion-to-language-modeling-dataset) format.
### Add Special Tokens for Chat Format
Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment.
Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system, and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment.
The [`setup_chat_format`] function in `trl` easily sets up a model and tokenizer for conversational AI tasks. This function:
- Adds special tokens to the tokenizer, e.g. `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
- Adds special tokens to the tokenizer, e.g., `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
- Resizes the models embedding layer to accommodate the new tokens.
- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI.
- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g. 64. If you want to see more formats being supported in the future, please open a GitHub issue on [trl](https://github.com/huggingface/trl)
- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g., `64`. If you want to see more formats being supported in the future, please open a GitHub issue on [trl](https://github.com/huggingface/trl)
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
@ -179,11 +77,13 @@ from trl import setup_chat_format
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
# Set up the chat format with default 'chatml' format
# Set up the chat format with the default 'chatml' format
model, tokenizer = setup_chat_format(model, tokenizer)
```
> [!WARNING]
> Some base models, like those from Qwen, have a predefined chat template in the model's tokenizer. In these cases, it is not necessary to apply `setup_chat_format()`, as the tokenizer already handles the formatting. However, it is necessary to align the EOS token with the chat template to ensure the model's responses terminate correctly. In these cases, specify `eos_token` in `SFTConfig`; for example, for `Qwen/Qwen2.5-1.5B`, one should set `eos_token="<|im_end|>"`.
With our model and tokenizer set up, we can now fine-tune our model on a conversational dataset. Below is an example of how a dataset can be formatted for fine-tuning.
### Dataset format support
@ -226,7 +126,7 @@ trainer = SFTTrainer(
)
```
If the dataset is not in one of those format you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look.
If the dataset is not in one of those formats, you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look.
### Format your input prompts
@ -246,26 +146,23 @@ Let us assume your dataset has two fields, `question` and `answer`. Therefore yo
```python
...
def formatting_prompts_func(example):
output_texts = []
for i in range(len(example['question'])):
text = f"### Question: {example['question'][i]}\n ### Answer: {example['answer'][i]}"
output_texts.append(text)
return output_texts
return f"### Question: {example['question']}\n ### Answer: {example['answer']}"
trainer = SFTTrainer(
model,
args=training_args,
train_dataset=dataset,
formatting_func=formatting_prompts_func,
formatting_func=formatting_prompt_func,
)
trainer.train()
```
To properly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example of how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
To properly format your input, make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example of how to use SFTTrainer on the alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
### Packing dataset ([`ConstantLengthDataset`])
### Packing dataset
[`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. This is done with the [`ConstantLengthDataset`] utility class that returns constant length chunks of tokens from a stream of examples. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTConfig`] constructor.
[`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTConfig`] constructor.
```python
...
@ -280,12 +177,12 @@ trainer = SFTTrainer(
trainer.train()
```
Note that if you use a packed dataset and if you pass `max_steps` in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing.
Note that if you use a packed dataset and if you pass `max_steps` in the training arguments, you will probably train your models for more than a few epochs, depending on the way you have configured the packed dataset and the training protocol. Double-check that you know and understand what you are doing.
If you don't want to pack your `eval_dataset`, you can pass `eval_packing=False` to the `SFTConfig` init method.
#### Customize your prompts using packed dataset
If your dataset has several fields that you want to combine, for example if the dataset has `question` and `answer` fields and you want to combine them, you can pass a formatting function to the trainer that will take care of that. For example:
If your dataset has several fields that you want to combine, for example, if the dataset has `question` and `answer` fields and you want to combine them, you can pass a formatting function to the trainer that will take care of that. For example:
```python
def formatting_func(example):
@ -302,7 +199,6 @@ trainer = SFTTrainer(
trainer.train()
```
You can also customize the [`ConstantLengthDataset`] much more by directly passing the arguments to the [`SFTConfig`] constructor. Please refer to that class' signature for more information.
### Control over the pretrained model
@ -360,7 +256,7 @@ trainer.train()
```
> [!WARNING]
> If the chat template contains special tokens like `<|im_start|>` (ChatML) or `<|eot_id|>` (Llama), the embedding layer and LM head must be included in the trainable parameters via the `modules_to_save` argument. Without this, the fine-tuned model will produce unbounded or nonsense generations. If the chat template doesn't contain special tokens (e.g. Alpaca), then the `modules_to_save` argument can be ignored or set to `None`.
> If the chat template contains special tokens like `<|im_start|>` (ChatML) or `<|eot_id|>` (Llama), the embedding layer and LM head must be included in the trainable parameters via the `modules_to_save` argument. Without this, the fine-tuned model will produce unbounded or nonsensical generations. If the chat template doesn't contain special tokens (e.g., Alpaca), then the `modules_to_save` argument can be ignored or set to `None`.
You can also continue training your `PeftModel`. For that, first load a `PeftModel` outside `SFTTrainer` and pass it directly to the trainer without the `peft_config` argument being passed.
@ -425,15 +321,15 @@ Once you have loaded your model, wrap the `trainer.train()` call under the `with
trainer.train()
```
Note that you cannot train your model using Flash Attention 1 on an arbitrary dataset as `torch.scaled_dot_product_attention` does not support training with padding tokens if you use Flash Attention kernels. Therefore you can only use that feature with `packing=True`. If your dataset contains padding tokens, consider switching to Flash Attention 2 integration.
Note that you cannot train your model using Flash Attention 1 on an arbitrary dataset as `torch.scaled_dot_product_attention` does not support training with padding tokens if you use Flash Attention kernels. Therefore, you can only use that feature with `packing=True`. If your dataset contains padding tokens, consider switching to Flash Attention 2 integration.
Below are some numbers you can get in terms of speedup and memory efficiency, using Flash Attention 1, on a single NVIDIA-T4 16GB.
| use_flash_attn_1 | model_name | max_seq_len | batch_size | time per training step |
| ---------------- | ----------------- | ----------- | ---------- | ---------------------- |
| x | facebook/opt-350m | 2048 | 8 | ~59.1s |
| | facebook/opt-350m | 2048 | 8 | ~59.1s |
| | facebook/opt-350m | 2048 | 8 | **OOM** |
| x | facebook/opt-350m | 2048 | 4 | ~30.3s |
| | facebook/opt-350m | 2048 | 4 | ~30.3s |
| | facebook/opt-350m | 2048 | 4 | ~148.9s |
### Using Flash Attention-2
@ -455,12 +351,12 @@ model = AutoModelForCausalLM.from_pretrained(
```
If you don't use quantization, make sure your model is loaded in half-precision and dispatch your model on a supported GPU device.
After loading your model, you can either train it as it is, or attach adapters and train adapters on it in case your model is quantized.
After loading your model, you can either train it as it is or attach adapters and train adapters on it in case your model is quantized.
In contrast to Flash Attention 1, the integration makes it possible to train your model on an arbitrary dataset that also includes padding tokens.
### Using model creation utility
### Using the model creation utility
We included a utility function to create your model.
@ -495,17 +391,17 @@ trainer = SFTTrainer(
)
```
### Enhance the model's performances using NEFTune
### Enhance the model's performance using NEFTune
NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://huggingface.co/papers/2310.05914) from Jain et al. it consists of adding noise to the embedding vectors during training. According to the abstract of the paper:
NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://huggingface.co/papers/2310.05914) from Jain et al. It consists of adding noise to the embedding vectors during training. According to the abstract of the paper:
> Standard finetuning of LLaMA-2-7B using Alpaca achieves 29.79% on AlpacaEval, which rises to 64.69% using noisy embeddings. NEFTune also improves over strong baselines on modern instruction datasets. Models trained with Evol-Instruct see a 10% improvement, with ShareGPT an 8% improvement, and with OpenPlatypus an 8% improvement. Even powerful models further refined with RLHF such as LLaMA-2-Chat benefit from additional training with NEFTune.
> Standard finetuning of LLaMA-2-7B using Alpaca achieves 29.79% on AlpacaEval, which rises to 64.69% using noisy embeddings. NEFTune also improves over strong baselines on modern instruction datasets. Models trained with Evol-Instruct see a 10% improvement, with ShareGPT an 8% improvement, and with OpenPlatypus an 8% improvement. Even powerful models further refined with RLHF, such as LLaMA-2-Chat, benefit from additional training with NEFTune.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/neft-screenshot.png">
</div>
To use it in `SFTTrainer` simply pass `neftune_noise_alpha` when creating your `SFTConfig` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to retrieve back the original behaviour of the embedding layer.
To use it in `SFTTrainer`, simply pass `neftune_noise_alpha` when creating your `SFTConfig` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to revert to the original behaviour of the embedding layer.
```python
from datasets import load_dataset
@ -534,7 +430,7 @@ Note however, that the amount of performance gain is _dataset dependent_ and in
### Accelerate fine-tuning 2x using `unsloth`
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks on 1x A100 listed below:
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently, `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek, etc) and Mistral architectures. Some benchmarks on 1x A100 listed below:
| 1 A100 40GB | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
| --------------- | --------- | --- | --------------------- | --------- | ------------ |
@ -543,19 +439,19 @@ You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [
| Mistral 7b | Slim Orca | 1x | 1.17x | **1.88x** | -65.9% |
| Tiny Llama 1.1b | Alpaca | 1x | 1.55x | **2.74x** | -57.8% |
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
First, install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
```python
import torch
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel
max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number
max_length = 2048 # Supports automatic RoPE Scaling, so choose any number
# Load model
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/mistral-7b",
max_seq_length=max_seq_length,
max_seq_length=max_length,
dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
@ -581,7 +477,7 @@ model = FastLanguageModel.get_peft_model(
random_state=3407,
)
training_args = SFTConfig(output_dir="./output", max_seq_length=max_seq_length)
training_args = SFTConfig(output_dir="./output", max_length=max_length)
trainer = SFTTrainer(
model=model,
@ -593,28 +489,29 @@ trainer.train()
The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).
## Liger-Kernel: Increase 20% throughput and reduces 60% memory for multi-GPU training
## Liger-Kernel: Increase 20% throughput and reduce 60% memory for multi-GPU training
[Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. That way, we can **4x** our context length, as described in the benchmark below. They have implemented Hugging Face Compatible `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed).
[Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduce memory usage by 60%. That way, we can **4x** our context length, as described in the benchmark below. They have implemented Hugging Face Compatible `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed).
With great memory reduction, you can potentially turn off cpu_offloading or gradient checkpointing to further boost the performance.
With this memory reduction, you can potentially turn off `cpu_offloading` or gradient checkpointing to further boost the performance.
| Speed Up | Memory Reduction |
|--------------------------|-------------------------|
| ![Speed up](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-tps.png) | ![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-memory.png) |
1. To use Liger-Kernel in `SFTTrainer`, first install by
1. To use Liger-Kernel in [`SFTTrainer`], first install it by:
```bash
pip install liger-kernel
```
2. Once installed, set `use_liger` in [`SFTConfig`]. No other changes are needed!
2. Once installed, set `use_liger_kernel` in [`SFTConfig`]. No other changes are needed!
```python
training_args = SFTConfig(
use_liger=True
use_liger_kernel=True,
...
)
```
@ -624,14 +521,14 @@ To learn more about Liger-Kernel, visit their [official repository](https://gith
Pay attention to the following best practices when training a model with that trainer:
- [`SFTTrainer`] always truncates by default the sequences to the `max_seq_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training.
- [`SFTTrainer`] always truncates by default the sequences to the `max_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training.
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
- If you create a model outside the trainer, make sure not to pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
## Multi-GPU Training
Trainer (and thus SFTTrainer) supports multi-GPU training. If you run your script with `python script.py` it will default to using DP as the strategy, which may be [slower than expected](https://github.com/huggingface/trl/issues/1303). To use DDP (which is generally recommended, see [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many?select-gpu=Accelerate#data-parallelism) for more info) you must launch the script with `python -m torch.distributed.launch script.py` or `accelerate launch script.py`. For DDP to work you must also check the following:
Trainer (and thus SFTTrainer) supports multi-GPU training. If you run your script with `python script.py` it will default to using DP as the strategy, which may be [slower than expected](https://github.com/huggingface/trl/issues/1303). To use DDP (which is generally recommended, see [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many?select-gpu=Accelerate#data-parallelism) for more info) you must launch the script with `python -m torch.distributed.launch script.py` or `accelerate launch script.py`. For DDP to work, you must also check the following:
- If you're using gradient_checkpointing, add the following to the TrainingArguments: `gradient_checkpointing_kwargs={'use_reentrant':False}` (more info [here](https://github.com/huggingface/transformers/issues/26969)
- Ensure that the model is placed on the correct device:
```python
@ -649,7 +546,7 @@ You may experience some issues with GPTQ Quantization after completing training.
## Extending `SFTTrainer` for Vision Language Models
`SFTTrainer` does not inherently support vision-language data. However, we provide a guide on how to tweak the trainer to support vision-language data. Specifically, you need to use a custom data collator that is compatible with vision-language data. This guide outlines the steps to make these adjustments. For a concrete example, refer to the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) which demonstrates how to fine-tune the LLaVA 1.5 model on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset.
`SFTTrainer` does not inherently support vision-language data. However, we provide a guide on how to tweak the trainer to support vision-language data. Specifically, you need to use a custom data collator that is compatible with vision-language data. This guide outlines the steps to make these adjustments. For a concrete example, refer to the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py), which demonstrates how to fine-tune the LLaVA 1.5 model on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset.
### Preparing the Data
@ -768,10 +665,6 @@ A full example of training LLaVa 1.5 on the [HuggingFaceH4/llava-instruct-mix-vs
## Datasets
In the SFTTrainer we smartly support `datasets.IterableDataset` in addition to other style datasets. This is useful if you are using large corpora that you do not want to save all to disk. The data will be tokenized and processed on the fly, even when packing is enabled.
In the SFTTrainer, we smartly support `datasets.IterableDataset` in addition to other style datasets. This is useful if you are using large corpora that you do not want to save all to disk. The data will be tokenized and processed on the fly, even when packing is enabled.
Additionally, in the SFTTrainer, we support pre-tokenized datasets if they are `datasets.Dataset` or `datasets.IterableDataset`. In other words, if such a dataset has a column of `input_ids`, no further processing (tokenization or packing) will be done, and the dataset will be used as-is. This can be useful if you have pretokenized your dataset outside of this script and want to re-use it directly.
### ConstantLengthDataset
[[autodoc]] trainer.ConstantLengthDataset
Additionally, in the SFTTrainer, we support pre-tokenized datasets if they are `datasets.Dataset` or `datasets.IterableDataset`. In other words, if such a dataset has a column of `input_ids`, no further processing (tokenization or packing) will be done, and the dataset will be used as-is. This can be useful if you have pretokenized your dataset outside of this script and want to reuse it directly.

View File

@ -37,7 +37,13 @@ training_args = OnlineDPOConfig(..., use_vllm=True)
</hfoption>
<hfoption id="GRPO">
Then, enable it by passing `use_vllm=True` in the training arguments.
First, start a vLLM server by running:
```bash
trl vllm-serve --model <model_name>
```
Then, run the training script and pass `use_vllm=True` in the training arguments.
```python
from trl import GRPOConfig
@ -45,31 +51,23 @@ from trl import GRPOConfig
training_args = GRPOConfig(..., use_vllm=True)
```
The strategy here is to use a dedicated GPU for generation powered by vLLM, while using the remainder for training.
You can customize the server configuration by passing additional arguments. For more information, see [vLLM integration](vllm_integration).
<Tip warning={true}>
When using vLLM, an additional GPU is required exclusively for generation. This means you need at least two available GPUs and must ensure that one remains unused by the trainer. To achieve this, run the training with `--num_processes <NUMBER_OF_GPUs - 1>`.
When using vLLM, ensure that the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation using `CUDA_VISIBLE_DEVICES`.
For example, if you have 4 GPUs, set `--num_processes 3` to allocate three GPUs for training while reserving one for generation.
```bash
accelerate launch --multi_gpu --num_processes 3 train_grpo.py
```
Set GPUs **0-3** for vLLM generation:
```sh
CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model <model_name>
```
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/1_gpu_for_generation.png)
And GPUs **4-7** for training:
```sh
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
```
</Tip>
You can further tune the vLLM configuration by setting a specific `vllm_device` and `vllm_gpu_memory_utilization` in the [`GRPOConfig`].
```python
training_args = GRPOConfig(
...,
use_vllm=True,
vllm_device="cuda:4",
vllm_gpu_memory_utilization=0.7,
)
```
</hfoption>
</hfoptions>

View File

@ -1,197 +0,0 @@
# Text Environments
Text environments provide a learning ground for language agents. It allows a language model to use tools to accomplish a task such as using a Python interpreter to answer math questions or using a search index for trivia questions. Having access to tools allows language models to solve tasks that would be very hard for the models itself but can be trivial for the appropriate tools. A good example is arithmetics of large numbers that become a simple copy-paste task once you have access to a calculator.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/textenv.png">
</div>
Let's dive into how text environments work and start with tools!
## Tools
One of the core building blocks of text environments are tools that the model can use to solve tasks. In general tools can be any Python function that takes a string as input and returns string. The `TextEnvironment` offers two options for tools: either go with predefined tools from `transformers.Tool` or define your own function or class with `__call__` method. Let's have a look at both!
### `transformers.Tool`
Text environments fully support tools of the class `transformers.Tool`. The advantage of building tools in that framework is that they can easily be shared
```Python
from transformers import load_tool
# simple calculator tool that runs +-/* operations
calc_tool = load_tool("ybelkada/simple-calculator")
# python interpreter that executes program and returns outputs
py_tool = load_tool("lvwerra/python-interpreter")
# wikipedia search index that returns best search match
wiki_tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
```
These tools are either loaded from the hub or from a local folder. Using the tool is as simple as calling them with a text query:
```Python
calc_tool("1/2")
>>> "0.5"
```
Note that both input and return values are strings to enable easy usage with a language model.
### Custom Tools
The following is an example of a tool that adds two integers:
```Python
def add(text):
int_1, int_2 = text.split("+")
result = int(int_1) + int(int_2)
return str(result)
print(add("1+1"))
>>> "2"
```
We looked at basic examples such as a calculator but the principle holds for more complex tools as well such as a web search tool where you input the query and get the search results in return. Now let's look at how the model can use the tools with the call syntax.
### Call syntax
In order to have a unified way for the model to call a tool we created a simple syntax that looks as follows:
```python
"<request><TOOL_NAME>QUERY<call>TOOL_RESPONSE<response>"
```
There are a few special tokens involved so let's decompose it: First the model can signal that it wants to use a tool by emitting the `<request>` token. After that we want to know the name of the tool to call which is done by enclosing the tool name with `<>` brackets. Once we know which tool to call the tool query follows which is in free text form. The `<call>` tokens signifies the end of the query and stops the model generation. At this point the model output is parsed and the query sent to the tool. The environment appends the tool response to the string followed by the `<response>` token to show the end the tool output.
Let's look at the concrete example of the calculator and assume its name is `Calculator` (more on how the name of a tool is inferred later):
```python
"<request><Calculator>1/2<call>0.5<response>"
```
Finally, the episode is ended and generation stops when the model generates `<submit>` which marks the interaction as completed.
Now let's have a look how we can create a new text environment!
## Create a `TextEnvironment`
```python
prompt = """\
What is 13-3?
<request><SimpleCalculatorTool>13-3<call>10.0<response>
Result=10<submit>
"""
def reward_fn(result, answer):
"""Simplified reward function returning 1 if result matches answer and 0 otherwise."""
result_parsed = result.split("=")[1].split("<")[0]
return int(result_parsed==answer)
text_env = TextEnvironemnt(
model=model,
tokenizer=tokenizer,
tools= {"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")},
reward_fn=exact_match_reward,
prompt=prompt,
max_turns=1
max_tool_response=100
generation_kwargs={"do_sample": "true"}
)
```
Let's decompose the settings:
| Argument | Description |
|:-------------------|:----------------|
| `model` | Language model to interact with the environment and generate requests. |
| `tokenizer` | Tokenizer of language model handling tokenization of strings. |
| `tools` | `list` of `dict` of tools. If former the name of the tool is inferred from class name and otherwise it's the keys of the dictionary.|
| `reward_fn` | A function that takes a string as input and returns. Can have extra arguments that are passed to `.run()` such as ground truth.|
| `prompt` | Prompt to prepend to every task. Usually a few examples to demonstrate to the model how to use the tools in a few-shot fashion. |
| `max_turns` | Maximum number of interactions between model and tools before episode ends.|
| `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.|
| `max_length` | The maximum number of tokens to allow in an episode. |
| `generation_kwargs`| Generation settings used by the language model. |
You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools!
## Run an Episode
To run a set of queries through the text environment one can simply use the `run` method.
```python
queries = ["What is 1/2?"]
answers = ["0.5"]
queries, responses, masks, rewards, histories = text_env.run(queries, answers=answers)
```
This will execute the model/tool feedback loop for each query until either no tool is called anymore, the maximum number of turns is reached or to maximum number of tokens in an episode is exceeded. The extra `kwargs` (e.g. `answers=answers` above) passed to `run` will be passed on to the reward function.
There are five objects that are returned by `run`:
- `queries`: a list of the tokenized queries
- `responses`: all tokens that have been generated withing the environment including model and tool tokens
- `masks`: mask that indicates which tokens have been generated by the model and which tokens are generated by the tool
- `rewards`: a list of reward for each query/response
- `histories`: list of `TextHistory` objects, which are useful objects containing all the above and also the text equivalents
The masks are crucial for training as we don't want to optimize tokens that the model has not generated which are tokens produced by the tools.
Next, we'll train a PPO step with the generated responses!
### Train
Training on episodes from the `TextEnvironment` is straight forward and simply requires forwarding all the returned variables except the `TextHistory` objects to the `step` method:
```python
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
```
## `TextHistory`
The `TextHistory` object stores the interactions between the model and the text environment. It stores tokens and text generated in each turn and their source in each turn (model or system) as well as rewards. Let's go through the class attributes and methods.
### Attributes
The following table summarises the available attributes of the `TextEnvironment` class:
| Attribute | Description |
|:-------------------|:----------------|
| `text` | The full string of the text generated in the text environment with both model and system generated text. |
| `text_spans` | A list of tuples with the spans for each model or system generated text segment. |
| `system_spans` | A list of boolean values indicating if the segment is model or system generated. |
| `tokens` | All tokens generated in text environment with both model and system generated tokens. |
| `token_spans` | Similar to `text_spans` the `token_spans` indicate the boundaries of model andsystem generated tokens. |
| `token_masks` | The token masks can be used to ignore system generated tokens by masking them. |
| `completed` | Indicates if the interaction with the environment has completed. |
| `truncated` | Indicates if the interaction with the environment has completed because max length was reached. |
With these attributes you can reconstruct every interaction of the model with the `TextEnvironment`. The `TextHistory` also lets you visualize the text history. Let's have a look!
### Visualization
When the model interacts inside the `TextEnvironment` it can be useful to visualize and separate which parts of the text outputs were generated by the model and which parts come from the system and tools. For that purpose there are the two methods [`TextHistory.show_text`] and [`TextHistory.show_tokens`]. They print the text and tokens respectively and highlight the various segments using the [`rich` library](https://github.com/Textualize/rich) (make sure to install it before using these methods).
You can see that the prompt is highlighted in gray, whereas system segments such as query and tool responses are highlighted in green. All segments generated by the model are highlighted in blue and in addition to the pure text output the reward is displayed as additional text in plum. Here an example of `show_text`:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/textenv_show_text.png" width=600>
</div>
Sometimes there can be tricky tokenization related issues that are hidden when showing the decoded text. Thus `TextHistory` also offers an option to display the same highlighting on the tokens directly with `show_tokens`:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/textenv_show_tokens.png" width=800>
</div>
Note that you can turn on the colour legend by passing `show_legend=True`.
## API Documentation
[[autodoc]] TextEnvironment
[[autodoc]] TextHistory

View File

@ -0,0 +1,381 @@
# Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset)
![VLM SFT training procedure](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/training_vlm_sft_training_procedure.png)
## Overview
This guide walks you through the process of fine-tuning a multimodal language model (e.g., **Gemma 3**) using **Supervised Fine-Tuning (SFT)**. We cover two cases:
- **Single Image + Text**
- **Multi-Image + Text**
This guide serves as a **detailed walkthrough** and complements the existing [VLM SFT script](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py). If you're already familiar with the concepts, you can use the script directly.
We demonstrate the fine-tuning process using two datasets, but these principles extend to other **Vision-Language Models (VLMs)** and datasets.
## Understanding the Datasets
To address both **Single Image + Text** and **Multi-Image + Text** scenarios, we use two datasets that are well-suited for this task.
### HuggingFaceH4/llava-instruct-mix-vsft Dataset (Image + Text)
This dataset is a reformatted version of [LLaVA Instruct Mix](https://huggingface.co/datasets/theblackcat102/llava-instruct-mix). It consists of conversations where a user provides both **text** and a **single image** as input.
The model (referred to as the **"assistant"**) responds based on both the **visual and textual information** shared by the user. This dataset is particularly useful for training multimodal models to **understand and generate responses based on images and text**.
<iframe
src="https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft/embed/viewer/default/train"
frameborder="0"
width="100%"
height="560px"
></iframe>
### FanqingM/MMIU-Benchmark Dataset (Multi-Image + Text)
The **FanqingM/MMIU-Benchmark** dataset consists of:
- **Context:** Included in the system prompt.
- **Question:** Provided as part of the user's input.
- **Series of Images:** Multiple images related to the question.
- **Answer:** The model's expected response.
This dataset is designed for tasks where the model must reason over multiple images to generate an informed response based on both visual and textual inputs.
<iframe
src="https://huggingface.co/datasets/FanqingM/MMIU-Benchmark/embed/viewer/default/test"
frameborder="0"
width="100%"
height="560px"
></iframe>
## Developing a Fine-Tuning Script for Multimodal SFT
In this section, we build the script needed to fine-tune a multimodal model for both **Single Image + Text** and **Multi-Image + Text** use cases.
### Setting Up the Environment
Before fine-tuning, we need to install the required dependencies. Let's start by setting up the environment:
```bash
# Install the required libraries. Futher details: https://huggingface.co/docs/trl/installation
pip install -U -q trl bitsandbytes peft hf_xet tensorboard
```
Once all dependencies are installed, we need to log in to the **Hugging Face Hub**. Since **Gemma 3** is a gated model, access permissions are required.
If you havent requested access yet, visit the [Model Card](https://huggingface.co/google/gemma-3-4b-it) and request it.
To log in, youll need to generate an [access token](https://huggingface.co/settings/tokens) from your Hugging Face account.
```bash
huggingface-cli login
```
### **Loading the Data**
As mentioned earlier, we will cover two possible use cases. While the specific procedure may vary based on the dataset, the core principles remain consistent.
This guide supports both use cases, so refer to the **Single Image + Text** or **Multi-Image + Text** sections depending on your specific scenario.
#### **Single Image + Text**
![Single Image + Text](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/training_vlm_sft_training_procedure_single_image.png)
In this case, each sample in a batch consists of a **single image paired with text**. Since the dataset is already formatted for supervised fine-tuning (SFT), we can directly load it using `load_dataset`.
```python
from datasets import load_dataset
dataset_name = "HuggingFaceH4/llava-instruct-mix-vsft"
# Load Dataset
dataset = load_dataset(dataset_name)
```
#### **Multi-Image + Text (or Interleaving)**
![Multi-Image + Text](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/training_vlm_sft_training_procedure_multi_image.png)
Gemma 3 also supports **Multi-Image + Text** scenarios, where:
- The model receives a **list of images** alongside a user message.
- The model processes **interleaved images and text** within a conversation.
For this dataset, some preprocessing is required before training.
```python
from datasets import load_dataset
dataset_name = "FanqingM/MMIU-Benchmark"
# Load Dataset
dataset = load_dataset(dataset_name)
```
After loading the dataset, we need to preprocess and format it into a conversational structure. Heres an example of how the data might look:
```python
{"role": "system", "content": [{"type": "text", "text": "You are a judge in a photography competition, and now you are given the four images. Please examine the details and tell which one of them is most likely to be a real photograph.\nSelect from the following choices.\nA: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]},
{"role": "user", "content": images_list + [{"type": "text", "text": "Which image is most likely to be a real photograph?"}]},
{"role": "assistant", "content": [{"type": "text", "text": "A: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]},
```
Here, `images_list` is a list of images:
```python
images_list = [
{"type": "image", "image": <class 'PIL.Image.Image'>},
{"type": "image", "image": <class 'PIL.Image.Image'>},
{"type": "image", "image": <class 'PIL.Image.Image'>},
{"type": "image", "image": <class 'PIL.Image.Image'>},
{"type": "image", "image": <class 'PIL.Image.Image'>},
]
```
This structure can be translated into code like this:
```python
import os
import zipfile
import io
from datasets import DatasetDict
from huggingface_hub import hf_hub_download, list_repo_files
from PIL import Image
dataset_train_split = "test"
def format_data(samples: dict[str, any]) -> dict[str, list]:
formatted_samples = {"messages": []}
for cont in range(len(samples["question"])):
images = []
for img_path in samples["input_image_path"][cont]:
try:
with open(img_path, "rb") as f:
img_bytes = f.read()
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
images.append({"type": "image", "image": image})
except Exception as e:
print(f"Error processing image {img_path}: {e}")
continue
formatted_samples["messages"].append(
[
{"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]},
{"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]},
{"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]},
]
)
return formatted_samples
# For multi-image example
def prepare_dataset(dataset: DatasetDict, dataset_name: str, dataset_train_split: str) -> DatasetDict:
all_files = list_repo_files(dataset_name, repo_type="dataset")
zip_files = [f for f in all_files if f.endswith(".zip")]
for zip_filename in zip_files:
zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset")
extract_folder = zip_filename.replace(".zip", "")
os.makedirs(extract_folder, exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_folder)
dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16)
return dataset
dataset = prepare_dataset(dataset, dataset_name, dataset_train_split)
```
With this, your **Multi-Image + Text** dataset is now prepared for training.
### **Preparing for Training**
We start by loading the model and processor. In this example, we use `google/gemma-3-4b-it`, but the same process applies to its other variants and similar models.
To optimize memory usage, we configure `BitsAndBytes` to load the quantized version of the model.
```python
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig
model_id = "google/gemma-3-4b-it"
# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_storage=torch.bfloat16,
)
# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="eager", # Important (Ref: https://github.com/huggingface/transformers/blob/c15a7adb283fa984a40558c7fe7bed30ae975cdd/src/transformers/models/gemma3/modeling_gemma3.py#L934)
quantization_config=bnb_config
)
processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer.padding_side = "right"
```
Next, we set up [Quantized Low-Rank Adaptation (QLoRA)](https://huggingface.co/papers/2305.14314), an efficient fine-tuning technique for Large Language Models (LLMs) and Vision-Language Models (VLMs).
```python
from peft import LoraConfig, get_peft_model
# Configure QLoRA
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.05,
r=16,
bias="none",
target_modules="all-linear",
task_type="CAUSAL_LM",
modules_to_save=[
"lm_head",
"embed_tokens",
],
)
```
With QLoRA now set up, we need to define the training arguments for SFT. The [`SFTConfig`] class simplifies this process, providing an easy way to adjust parameters based on our specific needs.
```python
from trl import SFTConfig
training_args = SFTConfig(
output_dir="gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft", # Directory to save the model and push to the Hub. Use a specific repository id (e.g., gemma-3-4b-it-trl-sft-MMIU-Benchmark for multi-image datasets).
num_train_epochs=1, # Set the number of epochs to train the model.
per_device_train_batch_size=8, # Batch size for each device (e.g., GPU) during training. multi-image -> per_device_train_batch_size=1
gradient_accumulation_steps=4, # Number of steps before performing a backward/update pass to accumulate gradients. multi-image -> gradient_accumulation_steps=1
gradient_checkpointing=True, # Enable gradient checkpointing to reduce memory usage during training.
optim="adamw_torch_fused", # Use the fused AdamW optimizer for better performance.
logging_steps=10, # Frequency of logging training progress (log every 10 steps).
save_strategy="epoch", # Save checkpoints at the end of each epoch.
learning_rate=2e-05, # Learning rate for training.
bf16=True, # Enable bfloat16 precision for training to save memory and speed up computations.
push_to_hub=True, # Automatically push the fine-tuned model to Hugging Face Hub after training.
report_to="tensorboard", # Automatically report metrics to tensorboard.
gradient_checkpointing_kwargs={"use_reentrant": False}, # Set gradient checkpointing to non-reentrant to avoid issues.
dataset_kwargs={"skip_prepare_dataset": True}, # Skip dataset preparation to handle preprocessing manually.
remove_unused_columns=False, # Ensure unused columns are not removed in the collator (important for batch processing).
)
```
The `collate_fn` is responsible for processing and preparing individual examples to form a batch.
Each example in the batch undergoes the following steps:
1. The **chat template** is applied to the text.
2. The **processor tokenizes** both `texts` and `images`, encoding them into tensors.
3. The **labels** for training are set as the `input_ids` of the example.
4. Certain **special tokens** are **masked (ignored)** during loss computation:
- `pad_token_id`
- `<image_token_id>`
- `<image_soft_token>` (corresponding to ID `262144`)
This process is similar across different dataset types, with a minor variation in how images are handled:
- **Single Image + Text** → A **list of images** is directly processed.
- **Multi-Image + Text** → A **list of lists of images** is used, where each batch element contains multiple images.
```python
from PIL import Image
# For multi-image cases
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
image_inputs = []
for msg in messages:
content = msg.get("content", [])
if not isinstance(content, list):
content = [content]
for element in content:
if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
if "image" in element:
image = element["image"]
else:
image = element
if image is not None:
image = Image.open(io.BytesIO(image["bytes"]))
image_inputs.append(image.convert("RGB"))
return image_inputs
def collate_fn(examples):
texts = [processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False).strip() for example in examples]
if "images" in examples[0]: # single-image
images = [
[img.convert("RGB") for img in example["images"]]
for example in examples
]
else: # multi-image
images = [process_vision_info(example["messages"]) for example in examples]
# Tokenize the texts and process the images
batch = processor(
text=texts, images=images, return_tensors="pt", padding=True
) # Encode texts and images into tensors
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone() # Clone input IDs for labels
# Mask image tokens
image_token_id = [
processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"])
]
# Mask tokens for not being used in the loss computation
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == image_token_id] = -100
labels[labels == 262144] = -100
batch["labels"] = labels
return batch # Return the prepared batch
```
### **Training the Model**
With all the components set up, we now configure the `SFTTrainer` using the previously defined settings and start the training process.
``` python
# Training
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=dataset["train"], # multi-image -> train_dataset=dataset["test"],
processing_class=processor,
peft_config=peft_config,
)
trainer.train()
# Save the final model
trainer.save_model()
```
We save the fine-tuned model to the Hub, making it easily accessible for future use. Additionally, TRL automatically logs the training results to **Weights & Biases (Wandb)** or **TensorBoard**, depending on the chosen configuration.
<!-- Add Wandb training results -->
### Results
During and after trainig, we can inspect the results using **Weights & Biases (Wandb)** or **TensorBoard**. For example:
* [**gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft (Single Image+Text)**](https://huggingface.co/sergiopaniego/gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft)
* [**gemma-3-4b-it-trl-sft-MMIU-Benchmark (Multi-Images+Text or Interleaving)**](https://huggingface.co/sergiopaniego/gemma-3-4b-it-trl-sft-MMIU-Benchmark)
## Limitations
Currently, fine-tuning Gemma has some [known limitations](https://github.com/huggingface/trl/issues/3121). We recommend following the procedure outlined in this guide to ensure the best results.
## References
For further reading and complementary resources, check out the following:
- [Fine-Tuning Vision-Language Models with QLoRA](https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora)
- [Fine-Tuning a Vision Language Model (Qwen2-VL-7B) with the Hugging Face Ecosystem (TRL)](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl)

View File

@ -43,7 +43,6 @@ To use the data efficiently, we use a technique called packing: instead of havin
With this approach the training is much more efficient as each token that is passed through the model is also trained in contrast to padding tokens which are usually masked from the loss.
If you don't have much data and are more concerned about occasionally cutting off some tokens that are overflowing the context you can also use a classical data loader.
The packing is handled by the `ConstantLengthDataset` and we can then use the `Trainer` after loading the model with `peft`. First, we load the model in int8, prepare it for training, and then add the LoRA adapters.
```python
# load model in 8bit

View File

@ -0,0 +1,185 @@
# vLLM Integration
This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood. Let's go! 🔥
## 🚀 How can I use vLLM with TRL to speed up training?
💡 **Note**: Resources required for this specific example: a single node with 8 GPUs.
First, install vLLM using the following command:
```bash
pip install "trl[vllm]"
```
Then run the server:
```sh
trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 2 --data-parallel-size 2
```
Once the server is running, you can use it to generate completions for training. In the example below, we are using the `GRPOTrainer` to train a model using the vLLM server for generation. The `--tensor-parallel-size` and `--data-parallel-size` arguments control how the model and data are sharded across GPUs.
In this example, we are sharding two copies of the model across 4 GPUs. Increasing data parallelism increases throughput, while increasing tensor parallelism allows for serving larger models. Then, run the training script by passing `use_vllm=True` in the training arguments as follows:
Sample of a simple `train.py` script:
```python
from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
dataset = load_dataset("trl-lib/tldr", split="train")
# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
return [len(set(c)) for c in completions]
training_args = GRPOConfig(
output_dir="my_test",
use_vllm=True,
bf16=True,
gradient_checkpointing=True,
logging_steps=10,
)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-7B",
args=training_args,
reward_funcs=reward_num_unique_chars,
train_dataset=dataset,
)
trainer.train()
```
And the train command:
```sh
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
```
## 🎬 Flashback: Why do we need to use vLLM in online methods?
Online methods like GRPO or Online DPO require the model to generate completions during training, which are then used to compute reward signals. However, generation can be extremely time-consuming, especially with large or reasoning models. In the default setup (without vLLM), completions are generated using the [(unwrapped) model's `generate` method](https://github.com/huggingface/trl/blob/f3e8c2304428ef16e9ae5de9e5741ed84d533b7b/trl/trainer/grpo_trainer.py#L965C39-L965C66). This approach quickly becomes a major bottleneck — generation is slow and inefficient, particularly for large batches or models. As a result, training times increase significantly, and overall efficiency drops. To address this, we turn to vLLM, which enables much faster and more scalable generation, helping eliminate this bottleneck in online methods.
## 🤔 How does vLLM solve the slow generation issue?
If you've ever done autoregressive decoder training, you know all the input tokens to the LLM produce their attention key and value tensors, and these tensors are kept in GPU memory to later generate subsequent tokens based on them. These cached key and value tensors are often referred to as the KV cache. However, storing the KV cache occupies a lot of memory, so vLLM uses a technique called **PagedAttention** to solve this problem. PagedAttention, which is inspired by the OSs virtual memory concept, stores continuous keys and values in **non-contiguous memory space**, which is much more efficient. The details of this are beyond the scope of this document, but in short, it allows the model to store the keys and values in a more efficient way, reducing the memory footprint and speeding up the generation process. If you are interested, make sure to check out the [vLLM PagedAttention](https://blog.vllm.ai/2023/06/20/vllm.html) for more details.
## 🤔 What exactly happens when you run `trl vllm-serve --model <model_name>`?
When you run for example
```sh
trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 1 --data-parallel-size 4
```
the following happens:
![vllm](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/vllm-doc.png)
1. vLLM first spawns multiple workers to handle incoming requests in parallel. The number of workers is determined by multiplying the `--tensor-parallel-size` and `--data-parallel-size` values. In this example, it spawns 4 workers (1 × 4).
Each worker operates independently and processes a chunk of the incoming requests — which are basically the prompts sent to the server for generation. A key point to understand is that these 4 workers are running in parallel, and each one is responsible for handling a subset of the total incoming load.
2. Once the incoming requests (prompts) are distributed across the workers, the model starts generating completions. Internally, the models weights are split across multiple GPUs based on the `--tensor-parallel-size` argument — this is how tensor parallelism is handled. Meanwhile, data parallelism (controlled by `--data-parallel-size`) ensures that different sets of requests are processed independently across the workers. In short: tensor parallelism splits the model across GPUs, and data parallelism splits the batch of requests across different model replicas.
3. Although the GPUs process requests independently and in parallel, they still need to communicate with each other. Remember that each GPU handles only a slice of the incoming prompts (for example, with 4 GPUs and 8 prompts using `--data-parallel-size=4`, each GPU processes 2 prompts).
This GPU-to-GPU communication is managed efficiently by NVIDIAs NCCL library. The communication mainly ensures that each GPU gets its correct portion of the incoming requests — its lightweight and doesnt interfere with generation itself.
Separately, the number of completions to generate per prompt is controlled by the `num_generations` setting in the GRPO config. For instance, if you set `num_generations=2` (like in the picture above), each prompt will have 2 completions. So, with 8 prompts and `num_generations=2`, you would end up with 16 completions total — regardless of the number of GPUs or parallelism settings.
## 🥸 More detail on what happens under the hood when running the server
* The vLLM server starts by running the command: `trl vllm-serve --model Qwen/Qwen2.5-7B`.
* Once the server is running, it generates completions based on requests from the client (trainer) using `vllm_client.generate` [here](https://github.com/huggingface/trl/blob/cc044e35b285be7dc062764b3364e1e684db4c7c/trl/trainer/grpo_trainer.py#L1025-L1035).
* The client (trainer) then requests these completions from the server.
* These completions are used to compute the reward signal.
* Based on the reward signal and the models output, the loss is computed, and the backward pass is performed to update the models weights.
* **Note**: The server only handles completion generation — it doesnt train the model. Therefore, the models weights arent updated on the server. Once the backward pass is complete, the client sends the updated weights to the server using `vllm_client.update_named_param(name, param.data)`.
When using vLLM, ensure the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation for training using `CUDA_VISIBLE_DEVICES`. See the example below:
* **Set GPUs *03* for vLLM generation:** Assume `CUDA_VISIBLE_DEVICES=0,1,2,3` are allocated for vLLM generation.
```sh
trl vllm-serve --model <model_name> --tensor-parallel-size 1 --data-parallel-size 4
```
* **And GPUs *47* for training:** If you do not set the `CUDA_VISIBLE_DEVICES` environment variable, the training script will use all available GPUs by default, which may lead to resource conflicts. To avoid this, you can specify which GPUs to use for training. For example, if you want to use GPUs 47 for training, set the environment variable as follows:
```sh
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
```
## 🍷 More customization options with vLLM?
You can customize the server configuration by passing additional arguments.
```
$ trl vllm-serve --help
usage: trl vllm-serve [-h] --model MODEL [--revision REVISION] [--tensor_parallel_size TENSOR_PARALLEL_SIZE]
[--data_parallel_size DATA_PARALLEL_SIZE] [--host HOST] [--port PORT]
[--gpu_memory_utilization GPU_MEMORY_UTILIZATION] [--dtype DTYPE] [--max_model_len MAX_MODEL_LEN]
[--enable_prefix_caching ENABLE_PREFIX_CACHING] [--enforce_eager ENFORCE_EAGER] [--log_level LOG_LEVEL]
options:
-h, --help Show this help message and exit
--model MODEL Model name or path to load the model from. (default: None)
--revision REVISION Revision to use for the model. If not specified, the default branch will be used. (default: None)
--tensor_parallel_size TENSOR_PARALLEL_SIZE, --tensor-parallel-size TENSOR_PARALLEL_SIZE
Number of tensor parallel workers to use. (default: 1)
--data_parallel_size DATA_PARALLEL_SIZE, --data-parallel-size DATA_PARALLEL_SIZE
Number of data parallel workers to use. (default: 1)
--host HOST Host address to run the server on. (default: 0.0.0.0)
--port PORT Port to run the server on. (default: 8000)
--gpu_memory_utilization GPU_MEMORY_UTILIZATION, --gpu-memory-utilization GPU_MEMORY_UTILIZATION
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the device
dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus improve the
model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors during
initialization. (default: 0.9)
--dtype DTYPE Data type to use for vLLM generation. If set to 'auto', the data type will be automatically determined based on
the model configuration. Find the supported values in the vLLM documentation. (default: auto)
--max_model_len MAX_MODEL_LEN, --max-model-len MAX_MODEL_LEN
If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model context
size, which might be much larger than the KV cache, leading to inefficiencies. (default: None)
--enable_prefix_caching ENABLE_PREFIX_CACHING, --enable-prefix-caching ENABLE_PREFIX_CACHING
Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support this
feature. (default: None)
--enforce_eager ENFORCE_EAGER, --enforce-eager ENFORCE_EAGER
Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the model
in eager mode. If `False` (default behavior), we will use CUDA graph and eager execution in hybrid. (default:
None)
--log_level LOG_LEVEL, --log-level LOG_LEVEL
Log level for uvicorn. Possible choices: 'critical', 'error', 'warning', 'info', 'debug', 'trace'. (default:
info)
```
## 🥳 Okay, now that we have the server running, how can we use it to generate completions?
Run the training script and pass `use_vllm=True` in the training arguments:
```python
from trl import GRPOConfig
training_args = GRPOConfig(..., use_vllm=True)
```
## 💆🏻‍♀️ What's the best distributed setup?
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_8_gpus.png)
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_4_gpus.png)
First and foremost, always remember that the optimal setup depends on:
* The model size
* The number of GPUs you have
* The GPU memory size
* The batch size you are using
* The number of requests you are sending to the server (prompts)
* The `max_model_len` you are using (this is the max length of the input sequence that the model can process, a.k.a. the context window size)
* The number of completions you are generating for each request (`num_generations`)
Given these factors, our experiments on the Qwen model family (3B, 7B, 14B, 32B) using 8 H100 GPUs show that:
* For reasonable-sized models (3B14B) and a moderate context window (`max_len < 8k`), using full capacity for data parallelism gives better throughput. The setup `(tp=1, dp=8)` yields the best results.
* For larger models (32B) and longer context windows (`max_len > 8k`), a smaller DP size combined with some model-side parallelism performs better. For example, `(tp=2, dp=4)` is a good setup for 32B models with a larger context window.

View File

@ -50,9 +50,9 @@ accelerate launch train_xpo.py
Distributed across 8 GPUs, the training takes approximately 1 hour.
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-XPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-XPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-XPO
<pre><code>$ transformers-cli chat --model_name_or_path trl-lib/Qwen2-0.5B-XPO
<strong><span style="color: red;">&lt;quentin_gallouedec&gt;:</span></strong>
What is the best programming language?

View File

@ -0,0 +1,28 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: true
fsdp_offload_params: false
fsdp_reshard_after_forward: FULL_SHARD
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true
fsdp_version: 1
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,25 @@
# Requires accelerate 1.7.0 or higher
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -1,25 +0,0 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: true
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -4,4 +4,4 @@ This directory contains a collection of Jupyter notebooks that demonstrate how t
- [`best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb): This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO.
- [`gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb): This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook.
- [`gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook.
- [`gpt2-sentiment-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook.

View File

@ -0,0 +1,15 @@
# LayerSkip Training Recipe
Implements the training recipe as described in the [LayerSkip paper](https://huggingface.co/papers/2404.16710).
## Run training
```
cd scripts
python layer_skip_sft.py
```
## Run benchmark
```
cd scripts
python benchmark_layer_skip.py
```

View File

@ -0,0 +1,77 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import config
import torch
from torch.utils import benchmark
from transformers import AutoModelForCausalLM, AutoTokenizer
def generate_tokens(model, inputs):
outputs = model.generate(
**inputs,
do_sample=False,
max_new_tokens=64,
)
return outputs
def generate_tokens_with_assistance(model, inputs, assistant_early_exit):
outputs = model.generate(
**inputs,
assistant_early_exit=assistant_early_exit,
do_sample=False,
max_new_tokens=64,
)
return outputs
if __name__ == "__main__":
ckpt = config.hub_model_id
model = AutoModelForCausalLM.from_pretrained(ckpt, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(ckpt)
prompt = "### Instruction: What are my alarms for the rest of the day?\n ### Response: "
results = []
label = "Generation Times"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
results.append(
benchmark.Timer(
stmt="generate_tokens(model, inputs)",
setup="from __main__ import generate_tokens",
globals={"model": model, "inputs": inputs},
num_threads=torch.get_num_threads(),
label=label,
sub_label="no layer skip",
description="generation",
).blocked_autorange()
)
for i in range(1, model.config.num_hidden_layers):
results.append(
benchmark.Timer(
stmt="generate_tokens_with_assistance(model, inputs, assistant_early_exit)",
setup="from __main__ import generate_assistant_tokens",
globals={"model": model, "assistant_early_exit": i, "inputs": inputs},
num_threads=torch.get_num_threads(),
label=label,
sub_label=f"layer skip {i}",
description="generation",
).blocked_autorange()
)
benchmark.Compare(results).print()

View File

@ -0,0 +1,28 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from huggingface_hub import whoami
model_name = "unsloth/Llama-3.2-3B"
tokenizer_name = "unsloth/Llama-3.2-3B"
dataset_name = "WillHeld/top_v2"
output_root_dir = "./checkpoints/"
hub_model_id = f"{whoami()['name']}/layerskip-{model_name.split('/')[1]}-{dataset_name.split('/')[1]}"
output_dir = f"{output_root_dir}/{hub_model_id}"
per_device_train_batch_size = 8
gradient_accumulation_steps = 1
learning_rate = 2e-5

View File

@ -0,0 +1,48 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from trl import SFTTrainer
class LayerSkipSFTTrainer(SFTTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.early_exit_layer = 0 # initialize with 0
self.always_last_layer = True
self.early_exit_loss_scale = 1.0
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
self.early_exit_layer = (
self.early_exit_layer % (model.config.num_hidden_layers - 1)
) + 1 # rotates between [1, num_hidden_layers-1]
bs, seqlen = inputs.input_ids.shape
labels = inputs.pop("labels")
outputs = model(**inputs, output_hidden_states=True)
hidden_state = outputs["hidden_states"][self.early_exit_layer].to(model.dtype)
if self.early_exit_layer != model.config.num_hidden_layers:
hidden_state = model.model.norm(hidden_state)
logits = model.lm_head(hidden_state)
loss_early = model.loss_function(logits=logits, labels=labels, vocab_size=model.vocab_size)
if self.always_last_layer:
loss_last = model.loss_function(logits=outputs["logits"], labels=labels, vocab_size=model.vocab_size)
loss = self.early_exit_loss_scale * loss_early.to(loss_last.device) + 1.0 * loss_last
# normalize loss scales
loss = loss / (1.0 + self.early_exit_loss_scale)
else:
loss = loss_early
return loss

View File

@ -0,0 +1,91 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import config
import torch
from custom_trainer import LayerSkipSFTTrainer
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DataCollatorForCompletionOnlyLM, SFTConfig
def formatting_prompts_func(example):
text = f"### Instruction: {example['utterance']}\n ### Response: {example['semantic_parse']}"
# Inject eos_token as a string before tokenization, because they are not always added
# See: https://github.com/huggingface/transformers/issues/22794 and
# https://github.com/huggingface/trl/issues/1623
if tokenizer.eos_token: # usually something like "</s>" for GPT2 or "<|endoftext|>"
text += f"{tokenizer.eos_token}"
return text
if __name__ == "__main__":
# load the dataset
print("[INFO] loading the dataset...")
train_dataset = load_dataset(config.dataset_name, split="train")
print(f"output_root_dir: {config.output_root_dir}")
print(f"hub_model_id: {config.hub_model_id}")
# load the model and tokenizer
print("[INFO] loading the model and tokenizer...")
model = AutoModelForCausalLM.from_pretrained(config.model_name, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, add_eos_token=True)
# adding pad and eos tokens if not provided in the tokenizer
if tokenizer.pad_token is None:
# Add '[PAD]' token if it doesn't exist
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
model.resize_token_embeddings(len(tokenizer))
model.config.pad_token_id = tokenizer.pad_token_id
if tokenizer.eos_token is None or tokenizer.eos_token == tokenizer.bos_token:
# Add '[EOS]' token if it doesn't exist
tokenizer.add_special_tokens({"eos_token": "[EOS]"})
model.resize_token_embeddings(len(tokenizer))
model.config.eos_token_id = tokenizer.eos_token_id
response_template = " ### Response:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
args = SFTConfig(
do_train=True,
bf16=True,
max_seq_length=None,
per_device_train_batch_size=config.per_device_train_batch_size,
gradient_accumulation_steps=config.gradient_accumulation_steps,
learning_rate=config.learning_rate,
packing=False,
num_train_epochs=1.0,
report_to="none",
push_to_hub=True,
hub_model_id=config.hub_model_id,
output_dir=config.output_dir,
logging_steps=500,
save_steps=1000,
save_total_limit=2,
)
trainer = LayerSkipSFTTrainer(
model,
train_dataset=train_dataset,
args=args,
formatting_func=formatting_prompts_func,
data_collator=collator,
)
trainer.train()

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -185,7 +185,7 @@ trainer = SFTTrainer(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
max_seq_length=None,
max_length=None,
formatting_func=prepare_sample_text,
processing_class=tokenizer,
args=training_args,

View File

@ -1,118 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import numpy as np
import torch
from transformers import AutoTokenizer, load_tool
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment
def generate_data(n):
"""Generate random arithmetic tasks and answers."""
tasks, answers = [], []
for _ in range(n):
a = np.random.randint(0, 50)
b = np.random.randint(0, 50)
op = np.random.choice(["-", "+", "*"])
tasks.append(f"\n\nWhat is {a} {op} {b}?")
if op == "-":
answers.append(a - b)
elif op == "+":
answers.append(a + b)
else:
answers.append(a * b)
return tasks, answers
def exact_match_reward(responses, answers=None):
"""Reward if generated response contains correct answer."""
rewards = []
pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*<submit>" # generated by chatGPT
for response, answer in zip(responses, answers):
reward = 0.0
predicted_number = None
match_pattern = re.findall(pattern, response)
if match_pattern:
predicted_number = float(match_pattern[0])
if predicted_number is not None:
if np.abs(predicted_number - answer) < 0.01:
reward += 1.0
rewards.append(torch.tensor(reward))
return rewards
# set up models
model_id = "gpt2"
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# system prompt
prompt = """\
What is 13-3?
<request><SimpleCalculatorTool>13-3<call>10.0<response>
Result=10<submit>
What is 4*3?
<request><SimpleCalculatorTool>4*3<call>12.0<response>
Result=12<submit>"""
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"eos_token_id": -1,
"max_new_tokens": 32,
}
# trainer
ppo_config = PPOConfig(
batch_size=256,
learning_rate=1.41e-5,
mini_batch_size=64,
log_with="wandb",
)
ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer)
# text env
text_env = TextEnvironment(
model,
tokenizer,
{"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")},
exact_match_reward,
prompt,
generation_kwargs=generation_kwargs,
)
# main training loop
for _step in range(100):
tasks, answers = generate_data(ppo_config.batch_size)
queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers)
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
response_texts = [tokenizer.decode(response) for response in responses]
query_texts = [tokenizer.decode(query) for query in queries]
texts = {"query": [qt.split("<submit>")[-1].strip() for qt in query_texts], "response": response_texts}
ppo_trainer.log_stats(train_stats, texts, rewards, columns_to_log=["query", "response", "answer"])
ppo_trainer.save_pretrained(model_id + "-calculator")

View File

@ -1,193 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, HfArgumentParser, load_tool
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@dataclass
class ScriptArguments:
model_name: Optional[str] = field(default="bigcode/starcoderbase", metadata={"help": "the model name"})
learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"})
gradient_accumulation_steps: Optional[int] = field(
default=16, metadata={"help": "the number of gradient accumulation steps"}
)
max_new_tokens: Optional[int] = field(default=256, metadata={"help": "max number of generated tokens per turn"})
ppo_epochs: Optional[int] = field(default=1, metadata={"help": "max number of ppo epochs"})
n_epochs: Optional[int] = field(default=32, metadata={"help": "max number of ppo epochs"})
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
def exact_match_reward(responses, answers=None):
"""Reward if generated response contains correct answer."""
rewards = []
pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*<submit>" # generated by chatGPT
for response, answer in zip(responses, answers):
reward = 0.0
try:
predicted_number = None
match_pattern = re.findall(pattern, response)
if match_pattern:
predicted_number = float(match_pattern[0])
if predicted_number is not None:
if np.abs(predicted_number - float(answer)) < 0.1:
reward += 1.0
except Exception:
pass
rewards.append(torch.tensor(reward))
return rewards
def evaluate(test_dataloader, text_env, ppo_trainer):
test_rewards = []
for test_batch in test_dataloader:
_, _, _, rewards, _ = text_env.run(test_batch["query"], answers=test_batch["answer"])
test_rewards.extend(rewards)
test_rewards = ppo_trainer.accelerator.gather_for_metrics(
torch.stack(test_rewards).to(ppo_trainer.accelerator.device)
)
return test_rewards.mean()
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["c_proj", "c_attn", "q_attn"],
)
# set up models
model = AutoModelForCausalLMWithValueHead.from_pretrained(
script_args.model_name,
use_auth_token=True,
load_in_4bit=True,
peft_config=lora_config,
)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, use_auth_token=True)
tokenizer.pad_token = tokenizer.eos_token
ds = load_dataset("openai/gsm8k", "main", split="train")
ds = ds.rename_columns({"question": "query"})
ds = ds.map(lambda x: {"answer": x["answer"].split("#### ")[1]})
ds = ds.select(range(1, len(ds))) # skip the first sample which is used in prompt
ds_test = load_dataset("openai/gsm8k", "main", split="test")
ds_test = ds_test.rename_columns({"question": "query"})
ds_test = ds_test.map(lambda x: {"answer": x["answer"].split("#### ")[1]})
test_dataloader = torch.utils.data.DataLoader(ds_test, batch_size=script_args.batch_size)
# prompt
prompt = """\
Example of using a Python API to solve math questions.
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
<request><PythonInterpreter>
def solution():
money_initial = 23
bagels = 5
bagel_cost = 3
money_spent = bagels * bagel_cost
money_left = money_initial - money_spent
result = money_left
return result
print(solution())
<call>72<response>
Result = 72 <submit>
Q: """
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"eos_token_id": -1,
"max_new_tokens": script_args.max_new_tokens,
}
# trainer
ppo_config = PPOConfig(
batch_size=script_args.batch_size,
learning_rate=script_args.learning_rate,
mini_batch_size=script_args.mini_batch_size,
ppo_epochs=script_args.ppo_epochs,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
log_with="wandb",
tracker_project_name="trl-gsm8k",
remove_unused_columns=False,
optimize_cuda_cache=True,
)
ppo_trainer = PPOTrainer(args=ppo_config, model=model, tokenizer=tokenizer, dataset=ds)
test_dataloader = ppo_trainer.accelerator.prepare(test_dataloader)
# text env
text_env = TextEnvironment(
model,
tokenizer,
[load_tool("lvwerra/python-interpreter")],
exact_match_reward,
prompt,
max_turns=2,
generation_kwargs=generation_kwargs,
)
# main training loop
for epoch in range(script_args.n_epochs):
for step, batch in enumerate(ppo_trainer.dataloader):
if (step == 0) and (epoch % 4 == 0): # evaluate every 4 epochs
reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer)
else:
reward_mean_test = None
queries, responses, masks, rewards, histories = text_env.run(batch["query"], answers=batch["answer"])
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
# logging
if reward_mean_test is not None:
train_stats["env/reward_mean_test"] = reward_mean_test
texts = {
"query": batch["query"],
"response": [tokenizer.decode(response) for response in responses],
"answer": batch["answer"],
}
ppo_trainer.log_stats(train_stats, texts, rewards, columns_to_log=["query", "response", "answer"])
reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer)
ppo_trainer.save_pretrained(f"model/{script_args.model_name}-gsm8k")

View File

@ -1,192 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass, field
from typing import Optional
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, HfArgumentParser, load_tool
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@dataclass
class ScriptArguments:
model_name: Optional[str] = field(default="bigcode/starcoderbase", metadata={"help": "the model name"})
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"})
gradient_accumulation_steps: Optional[int] = field(
default=16, metadata={"help": "the number of gradient accumulation steps"}
)
max_new_tokens: Optional[int] = field(default=256, metadata={"help": "max number of generated tokens per turn"})
ppo_epochs: Optional[int] = field(default=1, metadata={"help": "max number of ppo epochs"})
iterations: Optional[int] = field(default=1000, metadata={"help": "the number of iterations"})
seed: Optional[int] = field(default=0, metadata={"help": "the random seed"})
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["c_proj", "c_attn", "q_attn"],
)
# set up models
model = AutoModelForCausalLMWithValueHead.from_pretrained(
script_args.model_name,
use_auth_token=True,
trust_remote_code=True,
load_in_4bit=True,
peft_config=lora_config,
)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, use_auth_token=True)
tokenizer.pad_token = tokenizer.eos_token
# system prompt
prompt = """\
Answer the following question:
Q: In which branch of the arts is Patricia Neary famous?
A: Ballets
A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
Result=Ballets<submit>
Q: Who won Super Bowl XX?
A: Chicago Bears
A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 4610, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
Result=Chicago Bears<submit>
Q: """
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"eos_token_id": -1,
"max_new_tokens": script_args.max_new_tokens,
}
# trainer
config = PPOConfig(
batch_size=script_args.batch_size,
model_name=script_args.model_name,
learning_rate=script_args.learning_rate,
log_with=script_args.log_with,
mini_batch_size=script_args.mini_batch_size,
ppo_epochs=script_args.ppo_epochs,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
seed=script_args.seed,
optimize_cuda_cache=True,
)
ppo_trainer = PPOTrainer(args=config, model=model, tokenizer=tokenizer)
dataset = load_dataset("mandarjoshi/trivia_qa", "rc", split="train")
local_seed = script_args.seed + ppo_trainer.accelerator.process_index * 100003 # Prime
dataset = dataset.shuffle(local_seed)
def data_generator():
for i in range(len(dataset)):
yield dataset[i]["question"], list(dataset[i]["answer"]["normalized_aliases"])
gen = data_generator()
gen = iter(gen)
def generate_data(n):
tasks, answers = [], []
for _i in range(n):
q, a = next(gen)
tasks.append(q)
answers.append(a)
return tasks, answers
def exact_match_reward(responses, answers=None):
"""Reward if generated response contains correct answer."""
rewards = []
for response, answer in zip(responses, answers):
reward = 0.0
for a in answer:
if a.lower() in response.lower():
reward += 1.0
break
rewards.append(torch.tensor(reward))
return rewards
def tool_fn(x):
# limit the amount of tokens
return tool(x).split("\n")[1][:600]
# text env
tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
text_env = TextEnvironment(
model,
tokenizer,
{"Wiki": tool_fn},
exact_match_reward,
prompt,
generation_kwargs=generation_kwargs,
max_tool_reponse=400,
)
def print_trainable_parameters(model):
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
print_trainable_parameters(model)
# main training loop
for i in range(script_args.iterations):
tasks, answers = generate_data(config.batch_size)
queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers)
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
response_texts = [tokenizer.decode(response) for response in responses]
query_texts = [tokenizer.decode(query) for query in queries]
texts = {
"query": [qt.split("<submit>")[-1].strip() for qt in query_texts],
"response": response_texts,
"answer": [", ".join(item) for item in answers],
}
all_rewards = ppo_trainer.accelerator.gather(torch.tensor(rewards, device=ppo_trainer.accelerator.device))
ppo_trainer.log_stats(train_stats, texts, list(all_rewards), columns_to_log=["query", "response", "answer"])
if i % 100 == 0:
ppo_trainer.save_pretrained(f"models/{script_args.model_name}_{script_args.seed}_{i}_triviaqa")

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -99,4 +99,4 @@ else:
completions = [[c0, c1] for c0, c1 in zip(reference_completions, model_completions)]
best_idxs = judge.judge(prompts, completions)
model_win_rate = best_idxs.count(1) / len(best_idxs)
print(f"Model win rate: {model_win_rate*100:.2f}%")
print(f"Model win rate: {model_win_rate * 100:.2f}%")

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -45,7 +45,6 @@ python examples/scripts/gkd.py \
--lora_alpha 16
"""
from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoTokenizer, GenerationConfig
@ -106,14 +105,6 @@ if __name__ == "__main__":
################
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
with PartialState().local_main_process_first():
dataset = dataset.map(
lambda x: {
"prompt": tokenizer.apply_chat_template(x["prompt"], tokenize=False, add_generation_prompt=True)
},
num_proc=training_args.dataset_num_proc,
)
################
# Training
################

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -0,0 +1,62 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Train Gemma-3 on the Codeforces COTS dataset.
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml examples/scripts/sft_gemma3.py
"""
from datasets import load_dataset
from transformers import AutoModelForImageTextToText
from trl import SFTConfig, SFTTrainer
def main():
# Load dataset
train_dataset = load_dataset("open-r1/codeforces-cots", split="train")
train_dataset = train_dataset.remove_columns("prompt")
# Load model
model_id = "google/gemma-3-12b-it"
model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager")
# Train model
training_args = SFTConfig(
output_dir=f"{model_id}-codeforces-SFT",
logging_steps=10,
bf16=True,
use_liger_kernel=True,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
max_length=8192,
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
dataset_num_proc=32,
num_train_epochs=1,
)
trainer = SFTTrainer(
args=training_args,
model=model,
train_dataset=train_dataset,
)
trainer.train()
# Push to hub
trainer.push_to_hub(dataset_name="open-r1/codeforces-cots")
if __name__ == "__main__":
main()

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -0,0 +1,223 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Train Gemma-3 on the HuggingFaceH4/llava-instruct-mix-vsft dataset (single-image).
accelerate launch \
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/sft_vlm_gemma3.py \
--dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
--model_name_or_path google/gemma-3-4b-it \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--output_dir gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft \
--bf16 \
--torch_dtype bfloat16 \
--use_peft \
--lora_target_modules all-linear \
--attn_implementation eager
Train Gemma-3 on the FanqingM/MMIU-Benchmark dataset (multi-image).
accelerate launch \
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
examples/scripts/sft_vlm_gemma3.py \
--dataset_name FanqingM/MMIU-Benchmark \
--dataset_train_split test \
--model_name_or_path google/gemma-3-4b-it \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--output_dir gemma-3-4b-it-trl-sft-MMIU-Benchmark \
--bf16 \
--torch_dtype bfloat16 \
--use_peft \
--lora_target_modules all-linear
--attn_implementation eager
"""
import io
import os
import zipfile
import torch
from datasets import DatasetDict, load_dataset
from huggingface_hub import hf_hub_download, list_repo_files
from PIL import Image
from transformers import AutoModelForImageTextToText, AutoProcessor
from trl import (
ModelConfig,
ScriptArguments,
SFTConfig,
SFTTrainer,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
# For multi-image example
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
image_inputs = []
for msg in messages:
content = msg.get("content", [])
if not isinstance(content, list):
content = [content]
for element in content:
if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
if "image" in element:
image = element["image"]
else:
image = element
if image is not None:
image = Image.open(io.BytesIO(image["bytes"]))
image_inputs.append(image.convert("RGB"))
return image_inputs
def format_data(samples: dict[str, any]) -> dict[str, list]:
formatted_samples = {"messages": []}
for cont in range(len(samples["question"])):
images = []
for img_path in samples["input_image_path"][cont]:
try:
with open(img_path, "rb") as f:
img_bytes = f.read()
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
images.append({"type": "image", "image": image})
except Exception as e:
print(f"Error processing image {img_path}: {e}")
continue
formatted_samples["messages"].append(
[
{"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]},
{"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]},
{"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]},
]
)
return formatted_samples
# For multi-image example
def prepare_dataset(dataset: DatasetDict, dataset_name: str, dataset_train_split: str) -> DatasetDict:
all_files = list_repo_files(dataset_name, repo_type="dataset")
zip_files = [f for f in all_files if f.endswith(".zip")]
for zip_filename in zip_files:
zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset")
extract_folder = zip_filename.replace(".zip", "")
os.makedirs(extract_folder, exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_folder)
dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16)
return dataset
def main():
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
################
# Model, Tokenizer & Processor
################
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
processor = AutoProcessor.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
processor.tokenizer.padding_side = "right"
model = AutoModelForImageTextToText.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
)
def collate_fn(examples):
texts = [
processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False).strip()
for example in examples
]
if "images" in examples[0]: # single-image
images = [[img.convert("RGB") for img in example["images"]] for example in examples]
else: # multi-image
images = [process_vision_info(example["messages"]) for example in examples]
# Tokenize the texts and process the images
batch = processor(
text=texts, images=images, return_tensors="pt", padding=True
) # Encode texts and images into tensors
# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone() # Clone input IDs for labels
# Mask image tokens
image_token_id = [
processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"])
]
# Mask tokens for not being used in the loss computation
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == image_token_id] = -100
labels[labels == 262144] = -100
batch["labels"] = labels
return batch # Return the prepared batch
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
if script_args.dataset_name == "FanqingM/MMIU-Benchmark":
dataset = prepare_dataset(dataset, script_args.dataset_name, script_args.dataset_train_split)
################
# Training
################
trainer = SFTTrainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=processor.tokenizer,
peft_config=get_peft_config(model_args),
)
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
if trainer.accelerator.is_main_process:
processor.push_to_hub(training_args.hub_model_id)
if __name__ == "__main__":
main()

Some files were not shown because too many files have changed in this diff Show More