Compare commits

...

64 Commits

Author SHA1 Message Date
2c49300910 Release: 0.18.1 2025-05-29 19:01:07 +00:00
e530486c26 📚 Fix doc building by removing vLLM from dev dependencies in setup.cfg (#3511) 2025-05-29 18:49:51 +00:00
1bae58c292 📎 Fix clip ratio logging (#3506) 2025-05-29 18:49:41 +00:00
ef4b0b225c Release: v0.18 (#3504) 2025-05-27 18:43:58 -07:00
8e8e62b380 ✂️ [DPO] Fix truncation keep_end leading to zero'd out samples (#3398)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-05-27 16:36:01 -07:00
824100ce25 🏰 [vllm] Support base_url parameter for vLLM client initialization (#3324)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-27 16:05:40 -07:00
4e7f0a5eb9 🤧 LD-DPO support (#3458)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-27 16:05:30 -07:00
17a9069710 📏 Completion length logging fix + remainder logging fix (#3482)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-27 14:31:03 -07:00
cb07c44920 Forgotten commit from #3502 2025-05-27 20:02:22 +00:00
0b6a1874f1 🔭 [GRPO] Log advantages and fraction of samples with an std of zero (#3502)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-27 12:58:41 -07:00
ac18c9d532 🐌 Clean two-sided clipping (#3499) 2025-05-27 09:39:37 -07:00
d1174adc5b 🛠️ Initialize reward_kwargs to prevent UnboundLocalError in GRPOTrainer (#3459)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-26 18:28:27 -07:00
cd838417e4 👇 Update grpo.py to fix bugs for cli grpo --reward_funcs my_lib.my_reward (#3454)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-26 17:59:57 -07:00
c7e3f096a5 [GKD] fix the gkd script (#3497) 2025-05-26 20:22:15 +02:00
5c08897570 [GRPO] disabling top_k sampling default (#3494)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-05-26 11:32:07 +02:00
3ef9faf257 [Docs] sync logging doc to current metrics (#3478) 2025-05-25 17:46:28 +02:00
9ac614fb08 Fix mis-aligned prompts and completions in colocate mode (#3491)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-05-24 16:50:45 -06:00
29401e790e [Doc][SFT] Update sft_trainer.md. link prompt-completion dataset example (#3486)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-05-24 19:13:00 +02:00
31bf3f9244 Fix typo (#3489) 2025-05-24 13:24:15 +02:00
7f32792c07 [CI] fix sampler api to make the CI green (#3488) 2025-05-23 17:32:23 +02:00
3d8727918a [SFT] update minimal liger version (#3483) 2025-05-23 13:44:20 +02:00
65245f6be8 Update .pre-commit-config.yaml (#3479) 2025-05-22 16:08:23 +02:00
a528b9c465 [NashMD] fix the edge case where the model is a peft model (#3473)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-05-20 17:02:04 +02:00
e0dd525021 🙅 PPO value_model can't be None, so it shouldn't be Optional (#3300) 2025-05-19 17:01:08 -07:00
64aa06499b enable activation offloading on XPU (#3444)
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-05-19 11:56:14 +02:00
be93a0c30c enable vllm c-s tests on XPU (#3445)
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-05-19 11:55:57 +02:00
f9fbd91ea9 [CI] fix CI failure of transformer dev (#3457) 2025-05-19 10:08:42 +02:00
54d4f6b13a 🎁 Reward submodule (#3430) 2025-05-15 19:10:22 -07:00
05bc43e960 feat: Implement Two-Sided Clipping for GRPO Trainer (#3434)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-05-13 20:36:39 +02:00
d3dc8ff654 use device agnostic empty_cache in ppo & rloo (#3439)
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-05-13 20:10:14 +02:00
21738c3732 enable trl env on xpu (#3438)
Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-05-13 11:36:01 +02:00
eab175d434 🏹 Support kv_cache_dtype to quantize kv-cache in vllm (#3422)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-05-08 17:11:16 -07:00
4da4dc9117 Update README.md 2025-05-07 20:49:35 -07:00
6b3a02385d Update README.md (#3420) 2025-05-07 20:48:22 -07:00
abbbb93d6a 🧪 Testing support for Qwen3 tiny (#3415) 2025-05-07 19:32:42 -07:00
cafa663c84 [Models] Activation checkpointing from TorchTune (#2954)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: DanFosing <danfoss12340@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Robert <robert.veres00@gmail.com>
Co-authored-by: Robert Veres <robert.veres@languagetool.org>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Mathew Shen <datahonor@gmail.com>
Co-authored-by: Ishan Kumar <ishankumar216@gmail.com>
Co-authored-by: Huazhong Ji <hzji210@gmail.com>
Co-authored-by: tpoisonooo <khj.application@aliyun.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-05-07 12:36:11 +02:00
fd04a5461a 🐍 Support Python 3.13 (#2593) 2025-05-06 21:38:23 -07:00
56e5766205 🎁 Reward takes completion ids (#3272)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-05-06 10:34:50 -07:00
89d44caece 📝 vLLM-integration documentation (#3376)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-06 09:37:02 -06:00
adfa7fd59a 🎲 [GRPO] Shuffle mini batches (#3391)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-06 11:09:00 +02:00
cf5183db7f 💔 [GRPO] Decouple gradient accumulation from the number of minibatches generated (#3388)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-05-06 09:59:32 +02:00
1954c02d86 🤝 Compatibility of the TRL CLI with accelerate arguments (#3409)
Co-authored-by: Lewis Tunstall <lewis.c.tunstall@gmail.com>
2025-05-06 00:09:23 -07:00
45f4c58832 ✌️ Add support for FSDP2 (#3317)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-05-06 08:29:11 +02:00
cc044e35b2 🕊️ Un-restrict diffusers (#3407) 2025-05-02 15:06:53 -07:00
999acd53ec 🕺 Migrate setup configuration from setup.py to setup.cfg and make rich an optional dep (#3403)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-05-02 11:03:57 -07:00
8606b1ad09 🪪 Remove license classifier (#3402) 2025-05-02 10:03:39 -07:00
a673da5773 👉 [DPO] Model forward pass padding side fix (#3307)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-05-01 20:37:55 -07:00
00b8e311aa 🦁 Fix liger initialization (#3401)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-05-01 20:36:46 -07:00
c163cf5081 💔 [SFT] Raise error when formatting_func is used with completion_only_loss (#3385)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-01 16:23:27 -07:00
bc9c019c43 [IterativeSFT] Small refresher (#3378) 2025-05-01 16:18:41 -07:00
18596cf232 🧑‍🤝‍🧑 Co-Locating vLLM w/ training to for higher throughput and GPU utilization (#3394)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-01 16:17:26 -07:00
280d35301b 🌊 Add MLflow metrics in profiling context (#3400)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-05-01 16:15:38 -07:00
13fa8402a3 [GRPO] Reference model initialization bug fix (#3397) 2025-05-01 17:31:21 +02:00
09b669fbf7 [🐯+GRPO] Support FSDP + Fix bug when using LigerGRPO with DDP (#3260)
Co-authored-by: Ubuntu <azureuser@liger-ci-h100-vm.kvghai4yzzmufguwws3040dwlf.dx.internal.cloudapp.net>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
2025-04-30 22:49:45 +02:00
01d0be15cb Deprecate TextEnvironment and tools (#3389) 2025-04-29 20:25:36 +02:00
3a42af1c78 DPO fixes for evaluations (#3377) 2025-04-29 17:16:30 +02:00
aaf39604ba PEFT support for Liger GRPO (#3355)
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
2025-04-29 17:05:35 +02:00
2bf48478e8 📋 Allow calling trl cli in sft mode with config file (#3380)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-04-28 14:23:42 -07:00
a8cfca6d01 ⚰️ Remove deprecated (#3364) 2025-04-26 11:11:35 -07:00
1bca49515e Better guards for DeepSpeed imports (#3351) 2025-04-26 10:18:11 +02:00
39e96394a9 🎭 Fix train and eval mode checking in GRPOTrainer and SFTTrainer (#3337)
Co-authored-by: Jiaming Ma <jiaming.ma@connect.polyu.hk>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-04-25 17:42:43 -07:00
8e6ed93dfd 🥸🔢 Adding pad_multiple to SFT trainer (#3365) 2025-04-25 18:12:35 -06:00
29c5e05e3a 🔢 Pad to multiple of (#3362) 2025-04-25 09:53:20 -07:00
a9b27f82d6 ⬆️ Bump dev version (#3357) 2025-04-24 16:22:12 -07:00
91 changed files with 3970 additions and 2168 deletions

View File

@ -36,7 +36,7 @@ jobs:
name: Tests
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
fail-fast: false
runs-on:
group: aws-g4dn-2xlarge

View File

@ -24,7 +24,7 @@ jobs:
steps:
- name: Git checkout
uses: actions/checkout@v4
with: { ref: v0.17-release }
with: { ref: v0.18-release }
- name: Set up Python 3.12
uses: actions/setup-python@v5

View File

@ -1,8 +1,8 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.3
rev: v0.11.10
hooks:
- id: ruff
- id: ruff-check
types_or: [ python, pyi ]
args: [ --fix ]
- id: ruff-format

View File

@ -31,4 +31,4 @@ keywords:
- pytorch
- transformers
license: Apache-2.0
version: 0.17
version: 0.18

View File

@ -456,3 +456,193 @@ Warnings play a critical role in guiding users toward resolving potential issues
```
By following this classification, you ensure that warnings, information, and exceptions are used appropriately, providing clear guidance to the user without cluttering the system with unnecessary messages.
## Making a release
> [!NOTE]
> VERSION needs to be formatted following the `v{major}.{minor}.{patch}` convention. We need to follow this convention to be able to retrieve versioned scripts.
To create the package for PyPI.
#### 0. Prerequisites
- Dependencies:
- twine: `pip install build twine`
- Create an account in (and join the `trl` project):
- PyPI: https://pypi.org/
- Test PyPI: https://test.pypi.org/
#### 1. Ensure your local repository is up to date with the upstream repository
```bash
git checkout main
git pull origin main
```
> [!WARNING]
> Do not merge other pull requests into `main` until the release is done. This is to ensure that the release is stable and does not include any untested changes. Announce internally (#trl-internal) to other maintainers that you are doing a release and that they must not merge PRs until the release is done.
#### 2. Create a release branch from main
```bash
git checkout -b release-v{major}.{minor}
```
#### 3. Change the version in the following files
- `.github/workflows/tests_latest.yml`:
```diff
- with: { ref: v{major}.{minor-1}-release }
+ with: { ref: v{major}.{minor}-release }
```
- `CITATION.cff`
```diff
- version: {major}.{minor-1}
+ version: {major}.{minor}
```
- `__init__.py`
```diff
- __version__ = "{major}.{minor}.0.dev0"
+ __version__ = "{major}.{minor}.0"
```
- `setup.cfg`
```diff
- version = {major}.{minor}.0.dev0
+ version = {major}.{minor}.0
```
#### 4. Commit and push these changes
```shell
git commit -m 'Release: {major}.{minor}'
git push origin release-v{major}.{minor}
```
#### 5. Create a pull request
from `release-v{major}.{minor}` to `main`, named `Release: v{major}.{minor}`, wait for tests to pass, and request a review.
#### 6. Once the pull request is approved, merge it into `main`
#### 7. Add a tag in git to mark the release
```shell
git checkout main
git pull origin main
git tag -a v{major}.{minor}.0 -m 'Adds tag v{major}.{minor}.0 for PyPI'
git push origin v{major}.{minor}.0
```
#### 8. Create a branch `v{major}.{minor}-release` for future patch releases.
```shell
git checkout -b v{major}.{minor}-release
git push origin v{major}.{minor}-release
```
This ensures that future patch releases (`v{major}.{minor}.1`, `v{major}.{minor}.2`, etc.) can be made separately from `main`.
#### 9. Create the wheels for your release
These are the artifacts that will be uploaded to PyPI and installed by users via `pip install trl`.
Clean previous builds:
```shell
rm -rf build dist
```
At the root of your repo, run
```bash
python -m build .
```
This will create a folders named `dist` with the new versions of your package.
#### 10. Upload the package to PyPI Test
> [!IMPORTANT]
> Do not skip this step. It is important to test the package before uploading it to the main PyPI server.
```shell
twine upload dist/* -r testpypi
```
Then in a fresh environment containing all dependencies you need, try to install your new package from the PyPI test server.
```bash
pip install -i https://test.pypi.org/simple/ trl
```
You might get errors for missing dependencies since the PyPI test server does not contain all packages like PyPI does. To make sure you have everything you can do:
```bash
pip install trl
pip uninstall trl
```
(the second line will remove trl but keep all its dependencies).
Also make sure you can actually use the package! Run the following line:
```bash
python -c "from trl import *"
```
along with anything that tests:
- the core feature of your package
- the new features youre adding in the release
#### 11. Publish on PyPI
> [!WARNING]
> This can't be reverted. Make sure you have tested everything before doing this step.
```shell
twine upload dist/*
```
#### 12. Create a GitHub Release
1. Go to the repos [releases section](https://github.com/huggingface/trl/releases) on GitHub.
2. Click **Draft a new release**.
3. Select the `v{major}.{minor}.0` tag you just created in step 7.
4. Add a title (`v{major}.{minor}.0`) and a short description of whats new.
5. Click **Publish Release**.
#### 13. Bump to dev version
1. Create a branch `bump-dev-version-{major}.{minor+1}` from `main` and checkout to it.
```shell
git checkout -b bump-dev-version-{major}.{minor+1}
```
2. Change the version in the following files:
1. `__init__.py`
```diff
- __version__ = "{major}.{minor}.0"
+ __version__ = "{major}.{minor+1}.0.dev0"
```
2. `setup.cfg`
```diff
- version = {major}.{minor}.0
+ version = {major}.{minor+1}.0.dev0
```
3. Commit and push these changes
```shell
git add trl/__init__.py setup.cfg
git commit -m '⬆️ Bump dev version'
git push origin bump-dev-version-{major}.{minor+1}
```
4. Create a pull request from `bump-dev-version-{major}.{minor+1}` to `main`, named `⬆️ Bump dev version`, and request urgent review.
5. Once the pull request is approved, merge it into `main`.
6. The codebase is now ready for the next development cycle, inform the team in the #trl-internal channel.

View File

@ -1,6 +1,6 @@
include settings.ini
include LICENSE
include CONTRIBUTING.md
include README.md
recursive-exclude * __pycache__
include trl/templates/*.md
include trl/templates/*.md
include trl/accelerate_configs/*.yaml

View File

@ -12,8 +12,9 @@
<p align="center">
<a href="https://github.com/huggingface/trl/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/huggingface/trl.svg?color=blue"></a>
<a href="https://huggingface.co/docs/trl/index"><img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/trl/index.svg?down_color=red&down_message=offline&up_color=blue&up_message=online"></a>
<a href="https://huggingface.co/docs/trl/index"><img alt="Documentation" src="https://img.shields.io/website?label=documentation&url=https%3A%2F%2Fhuggingface.co%2Fdocs%2Ftrl%2Findex&down_color=red&down_message=offline&up_color=blue&up_message=online"></a>
<a href="https://github.com/huggingface/trl/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/trl.svg"></a>
<a href="https://huggingface.co/trl-lib"><img alt="Hugging Face Hub" src="https://img.shields.io/badge/🤗%20Hub-trl--lib-yellow"></a>
</p>
## Overview

View File

@ -51,8 +51,6 @@
title: Training StackLlama
- local: detoxifying_a_lm
title: Detoxifying a Language Model
- local: learning_tools
title: Learning to Use Tools
- local: multi_adapter_rl
title: Multi Adapter RLHF
- local: training_vlm_sft
@ -99,6 +97,8 @@
title: Trainers
- local: models
title: Model Classes
- local: model_utils
title: Model Utilities
- local: best_of_n
title: Best of N Sampling
- local: judges
@ -107,8 +107,8 @@
title: Callbacks
- local: data_utils
title: Data Utilities
- local: text_environments
title: Text Environments
- local: rewards
title: Reward Functions
- local: script_utils
title: Script Utilities
- local: others

View File

@ -1,105 +1,225 @@
# Command Line Interfaces (CLIs)
You can use TRL to fine-tune your language model with methods like Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) using the command line interface (CLI).
TRL provides a powerful command-line interface (CLI) to fine-tune large language models (LLMs) using methods like Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and more. The CLI abstracts away much of the boilerplate, letting you launch training jobs quickly and reproducibly.
Currently supported CLIs are:
Currently supported commands are:
#### Training commands
#### Training Commands
- `trl dpo`: fine-tune a LLM with DPO
- `trl grpo`: fine-tune a LLM with GRPO
- `trl kto`: fine-tune a LLM with KTO
- `trl sft`: fine-tune a LLM with SFT
#### Other commands
#### Other Commands
- `trl env`: get the system information
- `trl vllm-serve`: serve a model with vLLM
## Fine-tuning with the CLI
## Fine-Tuning with the TRL CLI
Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task.
### Basic Usage
You can launch training directly from the CLI by specifying required arguments like the model and dataset:
<hfoptions id="command_line">
<hfoption id="SFT">
Before using the `sft` or `dpo` commands make sure to run:
```bash
accelerate config
trl sft \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name stanfordnlp/imdb
```
and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command.
We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command.
</hfoption>
<hfoption id="DPO">
```bash
trl dpo \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name anthropic/hh-rlhf
```
</hfoption>
</hfoptions>
### Using Configuration Files
To keep your CLI commands clean and reproducible, you can define all training arguments in a YAML configuration file:
<hfoptions id="config_file">
<hfoption id="SFT">
```yaml
model_name_or_path:
Qwen/Qwen2.5-0.5B
dataset_name:
stanfordnlp/imdb
report_to:
none
learning_rate:
0.0001
lr_scheduler_type:
cosine
# sft_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: stanfordnlp/imdb
```
Save that config in a `.yaml` and get started immediately! An example CLI config is available as `examples/cli_configs/example_config.yaml`. Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g. from the root folder:
Launch with:
```bash
trl sft --config examples/cli_configs/example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts
trl sft --config sft_config.yaml
```
Will force-use `cosine_with_restarts` for `lr_scheduler_type`.
</hfoption>
<hfoption id="DPO">
### Supported Arguments
```yaml
# dpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: anthropic/hh-rlhf
```
We do support all arguments from `transformers.TrainingArguments`, for loading your model, we support all arguments from `~trl.ModelConfig`:
[[autodoc]] ModelConfig
You can pass any of these arguments either to the CLI or the YAML file.
### Supervised Fine-tuning (SFT)
Follow the basic instructions above and run `trl sft --output_dir <output_dir> <*args>`:
Launch with:
```bash
trl sft --model_name_or_path facebook/opt-125m --dataset_name stanfordnlp/imdb --output_dir opt-sft-imdb
trl dpo --config dpo_config.yaml
```
The SFT CLI is based on the `trl/scripts/sft.py` script.
</hfoption>
</hfoptions>
### Direct Policy Optimization (DPO)
### Scaling Up with Accelerate
To use the DPO CLI, you need to have a dataset in the TRL format such as
TRL CLI natively supports [🤗 Accelerate](https://huggingface.co/docs/accelerate), making it easy to scale training across multiple GPUs, machines, or use advanced setups like DeepSpeed — all from the same CLI.
* TRL's Anthropic HH dataset: https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-helpful-base-trl-style
* TRL's OpenAI TL;DR summarization dataset: https://huggingface.co/datasets/trl-internal-testing/tldr-preference-trl-style
You can pass any `accelerate launch` arguments directly to `trl`, such as `--num_processes`. For more information see [Using accelerate launch](https://huggingface.co/docs/accelerate/en/basic_tutorials/launch#using-accelerate-launch).
These datasets always have at least three columns `prompt, chosen, rejected`:
* `prompt` is a list of strings.
* `chosen` is the chosen response in [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
* `rejected` is the rejected response [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
To do a quick start, you can run the following command:
<hfoptions id="launch_args">
<hfoption id="SFT inline">
```bash
trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style
trl sft \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name stanfordnlp/imdb \
--num_processes 4
```
</hfoption>
<hfoption id="SFT w/ config file">
The DPO CLI is based on the `trl/scripts/dpo.py` script.
```yaml
# sft_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: stanfordnlp/imdb
num_processes: 4
```
#### Custom preference dataset
Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`):
Launch with:
```bash
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
trl sft --config sft_config.yaml
```
## Chat interface
</hfoption>
<hfoption id="DPO inline">
```bash
trl dpo \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name anthropic/hh-rlhf \
--num_processes 4
```
</hfoption>
<hfoption id="DPO w/ config file">
```yaml
# dpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: anthropic/hh-rlhf
num_processes: 4
```
Launch with:
```bash
trl dpo --config dpo_config.yaml
```
</hfoption>
</hfoptions>
### Using `--accelerate_config` for Accelerate Configuration
The `--accelerate_config` flag lets you easily configure distributed training with [🤗 Accelerate](https://github.com/huggingface/accelerate). This flag accepts either:
* the name of a predefined config profile (built into TRL), or
* a path to a custom Accelerate YAML config file.
#### Predefined Config Profiles
TRL provides several ready-to-use Accelerate configs to simplify common training setups:
| Name | Description |
| ------------ | ----------------------------------- |
| `fsdp1` | Fully Sharded Data Parallel Stage 1 |
| `fsdp2` | Fully Sharded Data Parallel Stage 2 |
| `zero1` | DeepSpeed ZeRO Stage 1 |
| `zero2` | DeepSpeed ZeRO Stage 2 |
| `zero3` | DeepSpeed ZeRO Stage 3 |
| `multi_gpu` | Multi-GPU training |
| `single_gpu` | Single-GPU training |
To use one of these, just pass the name to `--accelerate_config`. TRL will automatically load the corresponding config file from `trl/accelerate_config/`.
#### Example Usage
<hfoptions id="accelerate_config">
<hfoption id="SFT inline">
```bash
trl sft \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name stanfordnlp/imdb \
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
```
</hfoption>
<hfoption id="SFT w/ config file">
```yaml
# sft_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: stanfordnlp/imdb
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
```
Launch with:
```bash
trl sft --config sft_config.yaml
```
</hfoption>
<hfoption id="DPO inline">
```bash
trl dpo \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name anthropic/hh-rlhf \
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
```
</hfoption>
<hfoption id="DPO w/ config file">
```yaml
# dpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: anthropic/hh-rlhf
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
```
Launch with:
```bash
trl dpo --config dpo_config.yaml
```
</hfoption>
</hfoptions>
## Chat Interface
<Tip warning={true}>
@ -130,7 +250,7 @@ Besides talking to the model there are a few commands you can use:
- `save` or `save {SAVE_NAME}`: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
- `exit`: closes the interface
## Getting the system information
## Getting the System Information
You can get the system information by running the following command:
@ -138,7 +258,7 @@ You can get the system information by running the following command:
trl env
```
This will print out the system information including the GPU information, the CUDA version, the PyTorch version, the transformers version, and the TRL version, and any optional dependencies that are installed.
This will print out the system information, including the GPU information, the CUDA version, the PyTorch version, the transformers version, the TRL version, and any optional dependencies that are installed.
```txt
Copy-paste the following information when reporting an issue:
@ -146,7 +266,7 @@ Copy-paste the following information when reporting an issue:
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
- Python version: 3.11.9
- PyTorch version: 2.4.1
- CUDA device: NVIDIA H100 80GB HBM3
- accelerator(s): NVIDIA H100 80GB HBM3
- Transformers version: 4.45.0.dev0
- Accelerate version: 0.34.2
- Accelerate config:
@ -177,6 +297,7 @@ Copy-paste the following information when reporting an issue:
- LLM-Blender version: 0.0.2
- OpenAI version: 1.46.0
- PEFT version: 0.12.0
- vLLM version: not installed
```
This information are required when reporting an issue.
This information is required when reporting an issue.

View File

@ -168,6 +168,10 @@ The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterativ
The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`].
### LD-DPO loss
The [LD-DPO](https://huggingface.co/papers/2409.06411) paper decomposes the portion of the response that exceeds the desired length into two components — human-like preferences and verbosity preference — based on a mixing coefficient \\( \alpha \\). To use this method, set the `ld_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this value between `0.0` and `1.0`.
### For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.

View File

@ -155,6 +155,7 @@ This constant is recommended to be the maximum completion length. To use this fo
- `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function.
- `reward`: The overall average reward after applying reward weights.
- `reward_std`: The standard deviation of the overall reward within each batch after applying reward weights.
- `frac_reward_zero_std`: The fraction of samples in the generation batch with a reward std of zero, implying there is little diversity for that prompt (all answers are correct or incorrect).
- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero.
- `clip_ratio/region_mean`: The ratio of token probabilities where the GRPO objective is clipped to stay within the trust region:
$$
@ -170,26 +171,59 @@ A higher value means more tokens are clipped, which constrains how much the poli
### Speed up training with vLLM-powered generation
Generation is often the main bottleneck that makes training slow with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation. To enable it, first install the package with
Generation is often the main bottleneck when training with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a high-throughput, low-latency inference engine for LLMs. To enable it, first install the package with
```shell
pip install trl[vllm]
```
Then, start the vLLM server with the desired model:
We support two ways of using vLLM during training: **server mode** and **colocate mode**.
```bash
trl vllm-serve --model <model_name>
```
#### 🔌 Option 1: Server mode
Then, pass `use_vllm=True` in the training arguments and run the training script:
In this mode, vLLM runs in a separate process (and using separate GPUs) and communicates with the trainer via HTTP. This is ideal if you have dedicated GPUs for inference.
1. **Start the vLLM server**:
```bash
trl vllm-serve --model <model_name>
```
2. **Enable server mode in your training script**:
```python
from trl import GRPOConfig
training_args = GRPOConfig(
...,
use_vllm=True,
vllm_mode="server", # default value, can be omitted
)
```
<Tip warning={true}>
Make sure that the server is using different GPUs than the trainer, otherwise you may run into NCCL errors. You can specify the GPUs to use with the `CUDA_VISIBLE_DEVICES` environment variable.
</Tip>
#### 🧩 Option 2: Colocate mode
In this mode, vLLM runs inside the trainer process and shares GPU memory with the training model. This avoids launching a separate server and can improve GPU utilization, but may lead to memory contention on the training GPUs.
```python
from trl import GRPOConfig
training_args = GRPOConfig(..., use_vllm=True)
training_args = GRPOConfig(
...,
use_vllm=True,
vllm_mode="colocate",
)
```
<Tip>
Depending on the model size and the overall GPU memory requirements for training, you may need to adjust the `vllm_gpu_memory_utilization` parameter in [`GRPOConfig`] to avoid underutilization or out-of-memory errors.
</Tip>
For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).
### GRPO at scale: train a 70B+ Model on multiple nodes
@ -274,6 +308,7 @@ The [`GRPOTrainer`] supports using custom reward functions instead of dense rewa
- The function must accept the following as keyword arguments:
- `prompts` (contains the prompts),
- `completions` (contains the generated completions),
- `completions_ids` (contains the tokenized completions),
- All columns names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument.
The easiest way to comply with this requirement is to use `**kwargs` in the function signature.
@ -287,9 +322,29 @@ The [`GRPOTrainer`] supports using custom reward functions instead of dense rewa
Below is an example of a reward function for a standard format that rewards longer completions:
```python
def reward_func(completions_ids, **kwargs):
"""Reward function that assigns higher scores to longer completions (in terms of token count)."""
return [float(len(ids)) for ids in completions_ids]
```
You can test it as follows:
```python
>>> prompts = ["The sky is", "The sun is"] # not used in the reward function, but the trainer will pass it
>>> completions = [" blue.", " in the sky."] # not used in the reward function, but the trainer will pass it
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]]
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
[2.0, 4.0]
```
#### Example 1.1: Reward longer completions (based in the number of characters)
Same as the previous example, but this time the reward function is based on the number of characters instead of tokens.
```python
def reward_func(completions, **kwargs):
"""Reward function that gives higher scores to longer completions."""
"""Reward function that assigns higher scores to longer completions (in terms of character count)."""
return [float(len(completion)) for completion in completions]
```
@ -298,7 +353,8 @@ You can test it as follows:
```python
>>> prompts = ["The sky is", "The sun is"]
>>> completions = [" blue.", " in the sky."]
>>> print(reward_func(prompts=prompts, completions=completions))
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]] # not used in the reward function, but the trainer will pass it
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
[6.0, 12.0]
```

View File

@ -2,56 +2,138 @@
[![](https://img.shields.io/badge/All_models-Iterative_SFT-blue)](https://huggingface.co/models?other=iterative-sft,trl)
Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.
## Usage
## Quickstart
To get started quickly, instantiate an instance a model, and a tokenizer.
To get started quickly, you can either pass a model identifier or a pre-instantiated model to the trainer:
```python
from trl import IterativeSFTConfig, IterativeSFTTrainer
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Using a model identifier
trainer = IterativeSFTTrainer(
"facebook/opt-350m",
args=IterativeSFTConfig(
max_length=512,
output_dir="./output",
),
)
# Or using a pre-instantiated model
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
trainer = IterativeSFTTrainer(
model,
tokenizer
args=IterativeSFTConfig(
max_length=512,
output_dir="./output",
),
processing_class=tokenizer,
)
```
You have the choice to either provide a list of strings or a list of tensors to the step function.
## Usage
#### Using a list of tensors as input:
The [`IterativeSFTTrainer`] supports two ways of providing input data to the `step` function:
### Using a list of tensors as input:
```python
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask
"attention_mask": attention_mask,
}
trainer.step(**inputs)
```
#### Using a list of strings as input:
### Using a list of strings as input:
```python
inputs = {
"texts": texts
"texts": texts,
"texts_labels": texts_labels, # Optional, defaults to texts
}
trainer.step(**inputs)
```
For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels.
For causal language models, labels will automatically be created from `input_ids` or from `texts`. When using sequence to sequence models you will have to provide your own labels or `text_labels`.
## IterativeTrainer
## Configuration
The [`IterativeSFTConfig`] class provides several parameters to customize the training:
```python
from trl import IterativeSFTConfig
config = IterativeSFTConfig(
# Model initialization parameters
model_init_kwargs={"torch_dtype": "bfloat16"},
# Data preprocessing parameters
max_length=512,
truncation_mode="keep_end",
# Training parameters
output_dir="./output",
learning_rate=2e-5,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
max_steps=1000,
logging_steps=10,
save_steps=100,
optim="adamw_torch",
report_to="wandb",
)
```
### Model Initialization
You can control how the model is initialized by passing keyword arguments to `model_init_kwargs`:
```python
config = IterativeSFTConfig(
model_init_kwargs={
"torch_dtype": "bfloat16",
"device_map": "auto",
"trust_remote_code": True,
}
)
```
### Data Preprocessing
The trainer supports two truncation modes:
- `keep_end`: Truncates from the start of the sequence
- `keep_start`: Truncates from the end of the sequence
```python
config = IterativeSFTConfig(
max_length=512,
truncation_mode="keep_end", # or "keep_start"
)
```
### Training Optimization
You can optimize CUDA cache usage for more memory-efficient training:
```python
config = IterativeSFTConfig(
optimize_device_cache=True,
)
```
## IterativeSFTTrainer
[[autodoc]] IterativeSFTTrainer
## IterativeSFTConfig
[[autodoc]] IterativeSFTConfig

View File

@ -1,233 +0,0 @@
# Learning Tools (Experimental 🧪)
Using Large Language Models (LLMs) with tools has been a popular topic recently with awesome works such as [ToolFormer](https://huggingface.co/papers/2302.04761) and [ToolBench](https://huggingface.co/papers/2305.16504). In TRL, we provide a simple example of how to teach LLM to use tools with reinforcement learning.
Here's an overview of the scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples/research_projects/tools):
| File | Description |
|---|---|
| [`calculator.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/calculator.py) | Script to train LLM to use a calculator with reinforcement learning. |
| [`triviaqa.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/triviaqa.py) | Script to train LLM to use a wiki tool to answer questions. |
| [`python_interpreter.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/python_interpreter.py) | Script to train LLM to use python interpreter to solve math puzzles. |
<Tip warning={true}>
Note that the scripts above rely heavily on the `TextEnvironment` API which is still under active development. The API may change in the future. Please see [`TextEnvironment`](text_environment) for the related docs.
</Tip>
## Learning to Use a Calculator
The rough idea is as follows:
1. Load a tool such as [ybelkada/simple-calculator](https://huggingface.co/spaces/ybelkada/simple-calculator) that parse a text calculation like `"14 + 34"` and return the calculated number:
```python
from transformers import AutoTokenizer, load_tool
tool = load_tool("ybelkada/simple-calculator")
tool_fn = lambda text: str(round(float(tool(text)), 2)) # rounding to 2 decimal places
```
1. Define a reward function that returns a positive reward if the tool returns the correct answer. In the script we create a dummy reward function like `reward_fn = lambda x: 1`, but we override the rewards directly later.
1. Create a prompt on how to use the tools
```python
# system prompt
prompt = """\
What is 13.1-3?
<request><SimpleCalculatorTool>13.1-3<call>10.1<response>
Result=10.1<submit>
What is 4*3?
<request><SimpleCalculatorTool>4*3<call>12<response>
Result=12<submit>
What is 12.1+1?
<request><SimpleCalculatorTool>12.1+1<call>13.1<response>
Result=13.1<submit>
What is 12.1-20?
<request><SimpleCalculatorTool>12.1-20<call>-7.9<response>
Result=-7.9<submit>"""
```
3. Create a `trl.TextEnvironment` with the model
```python
env = TextEnvironment(
model,
tokenizer,
{"SimpleCalculatorTool": tool_fn},
reward_fn,
prompt,
generation_kwargs=generation_kwargs,
)
```
4. Then generate some data such as `tasks = ["\n\nWhat is 13.1-3?", "\n\nWhat is 4*3?"]` and run the environment with `queries, responses, masks, rewards, histories = env.run(tasks)`. The environment will look for the `<call>` token in the prompt and append the tool output to the response; it will also return the mask associated with the response. You can further use the `histories` to visualize the interaction between the model and the tool; `histories[0].show_text()` will show the text with color-coded tool output and `histories[0].show_tokens(tokenizer)` will show visualize the tokens.
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/learning_tools.png)
1. Finally, we can train the model with `train_stats = ppo_trainer.step(queries, responses, rewards, masks)`. The trainer will use the mask to ignore the tool output when computing the loss, make sure to pass that argument to `step`.
## Experiment results
We trained a model with the above script for 10 random seeds. You can reproduce the run with the following command. Feel free to remove the `--slurm-*` arguments if you don't have access to a slurm cluster.
```
WANDB_TAGS="calculator_final" python benchmark/benchmark.py \
--command "python examples/research_projects/tools/calculator.py" \
--num-seeds 10 \
--start-seed 1 \
--workers 10 \
--slurm-gpus-per-task 1 \
--slurm-ntasks 1 \
--slurm-total-cpus 8 \
--slurm-template-path benchmark/trl.slurm_template
```
We can then use [`openrlbenchmark`](https://github.com/openrlbenchmark/openrlbenchmark) which generates the following plot.
```
# pip install openrlbenchmark==0.2.1a5
python -m openrlbenchmark.rlops_multi_metrics \
--filters '?we=openrlbenchmark&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.tracker_project_name&cen=trl_ppo_trainer_config.value.log_with&metrics=env/reward_mean&metrics=objective/kl' \
'wandb?tag=calculator_final&cl=calculator_mask' \
--env-ids trl \
--check-empty-runs \
--pc.ncols 2 \
--pc.ncols-legend 1 \
--output-filename static/0compare \
--scan-history
```
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/learning_tools_chart.png)
As we can see, while 1-2 experiments crashed for some reason, most of the runs obtained near perfect proficiency in the calculator task.
## (Early Experiments 🧪): learning to use a wiki tool for question answering
In the [ToolFormer](https://huggingface.co/papers/2302.04761) paper, it shows an interesting use case that utilizes a Wikipedia Search tool to help answer questions. In this section, we attempt to perform similar experiments but uses RL instead to teach the model to use a wiki tool on the [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) dataset.
<Tip warning={true}>
**Note that many settings are different so the results are not directly comparable.**
</Tip>
### Building a search index
Since [ToolFormer](https://huggingface.co/papers/2302.04761) did not open source, we needed to first replicate the search index. It is mentioned in their paper that the authors built the search index using a BM25 retriever that indexes the Wikipedia dump from [KILT](https://github.com/facebookresearch/KILT)
Fortunately, [`pyserini`](https://github.com/castorini/pyserini) already implements the BM25 retriever and provides a prebuilt index for the KILT Wikipedia dump. We can use the following code to search the index.
```python
from pyserini.search.lucene import LuceneSearcher
import json
searcher = LuceneSearcher.from_prebuilt_index('wikipedia-kilt-doc')
def search(query):
hits = searcher.search(query, k=1)
hit = hits[0]
contents = json.loads(hit.raw)['contents']
return contents
print(search("tennis racket"))
```
```
Racket (sports equipment)
A racket or racquet is a sports implement consisting of a handled frame with an open hoop across which a network of strings or catgut is stretched tightly. It is used for striking a ball or shuttlecock in games such as squash, tennis, racquetball, and badminton. Collectively, these games are known as racket sports. Racket design and manufacturing has changed considerably over the centuries.
The frame of rackets for all sports was traditionally made of solid wood (later laminated wood) and the strings of animal intestine known as catgut. The traditional racket size was limited by the strength and weight of the wooden frame which had to be strong enough to hold the strings and stiff enough to hit the ball or shuttle. Manufacturers started adding non-wood laminates to wood rackets to improve stiffness. Non-wood rackets were made first of steel, then of aluminum, and then carbon fiber composites. Wood is still used for real tennis, rackets, and xare. Most rackets are now made of composite materials including carbon fiber or fiberglass, metals such as titanium alloys, or ceramics.
...
```
We then basically deployed this snippet as a Hugging Face space [here](https://huggingface.co/spaces/vwxyzjn/pyserini-wikipedia-kilt-doc), so that we can use the space as a `transformers.Tool` later.
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/pyserini.png)
### Experiment settings
We use the following settings:
* use the `bigcode/starcoderbase` model as the base model
* use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragraphs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool.
* test if the response contain the answer string, if so, give a reward of 1, otherwise, give a reward of 0.
* notice this is a simplified evaluation criteria. In [ToolFormer](https://huggingface.co/papers/2302.04761), the authors checks if the first 20 words of the response contain the correct answer.
* used the following prompt that demonstrates the usage of the wiki tool.
```python
prompt = """\
Answer the following question:
Q: In which branch of the arts is Patricia Neary famous?
A: Ballets
A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
Result=Ballets<submit>
Q: Who won Super Bowl XX?
A: Chicago Bears
A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 4610, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
Result=Chicago Bears<submit>
Q: """
```
### Result and Discussion
Our experiments show that the agent can learn to use the wiki tool to answer questions. The learning curves would go up mostly, but one of the experiment did crash.
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/triviaqa_learning_curves.png)
Wandb report is [here](https://wandb.ai/costa-huang/cleanRL/reports/TriviaQA-Final-Experiments--Vmlldzo1MjY0ODk5) for further inspection.
Note that the correct rate of the trained model is on the low end, which could be due to the following reasons:
* **incorrect searches:** When given the question `"What is Bruce Willis' real first name?"` if the model searches for `Bruce Willis`, our wiki tool returns "Patrick Poivey (born 18 February 1948) is a French actor. He is especially known for his voice: he is the French dub voice of Bruce Willis since 1988.` But a correct search should be `Walter Bruce Willis (born March 19, 1955) is an American former actor. He achieved fame with a leading role on the comedy-drama series Moonlighting (19851989) and appeared in over a hundred films, gaining recognition as an action hero after his portrayal of John McClane in the Die Hard franchise (19882013) and other roles.[1][2]"
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/real_first_name.png)
* **unnecessarily long response**: The wiki tool by default sometimes output very long sequences. E.g., when the wiki tool searches for "Brown Act"
* Our wiki tool returns "The Ralph M. Brown Act, located at California Government Code 54950 "et seq.", is an act of the California State Legislature, authored by Assemblymember Ralph M. Brown and passed in 1953, that guarantees the public's right to attend and participate in meetings of local legislative bodies."
* [ToolFormer](https://huggingface.co/papers/2302.04761)'s wiki tool returns "The Ralph M. Brown Act is an act of the California State Legislature that guarantees the public's right to attend and participate in meetings of local legislative bodies." which is more succinct.
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/brown_act.png)
## (Early Experiments 🧪): solving math puzzles with python interpreter
In this section, we attempt to teach the model to use a python interpreter to solve math puzzles. The rough idea is to give the agent a prompt like the following:
```python
prompt = """\
Example of using a Python API to solve math questions.
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
<request><PythonInterpreter>
def solution():
money_initial = 23
bagels = 5
bagel_cost = 3
money_spent = bagels * bagel_cost
money_left = money_initial - money_spent
result = money_left
return result
print(solution())
<call>8<response>
Result = 8 <submit>
Q: """
```
Training experiment can be found at https://wandb.ai/lvwerra/trl-gsm8k/runs/a5odv01y
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/gms8k_learning_curve.png)

View File

@ -1,74 +1,99 @@
# Logging
As reinforcement learning algorithms are historically challenging to debug, it's important to pay careful attention to logging.
By default, the TRL [`PPOTrainer`] saves a lot of relevant information to wandb or tensorboard.
By default, TRL trainers like [`PPOTrainer`] and [`GRPOTrainer`] save a lot of relevant information to supported experiment trackers like Weights & Biases (wandb) or TensorBoard.
Upon initialization, pass one of these two options to the [`PPOConfig`]:
Upon initialization, pass the `report_to` argument to the respective configuration object (e.g., [`PPOConfig`] for `PPOTrainer`, or [`GRPOConfig`] for `GRPOTrainer`):
```
training_args = PPOConfig(..., report_to="wandb") # or "tensorboard"
```python
# For PPOTrainer
ppo_config = PPOConfig(
# ...,
report_to="wandb" # or "tensorboard"
)
# For GRPOTrainer
grpc_config = GRPOConfig(
# ...,
report_to="wandb" # or "tensorboard"
)
```
If you want to log with tensorboard, add the kwarg `project_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig.
If you want to log with TensorBoard, you might also need to specify logging directories, for example, by adding `logging_dir=PATH_TO_LOGS` to the configuration object (e.g., `PPOConfig` or `GRPOConfig`).
## PPO Logging
Here's a brief explanation for the logged metrics provided in the data:
Key metrics to monitor. We want to maximize the reward, maintain a low KL divergence, and maximize entropy:
1. `env/reward_mean`: The average reward obtained from the environment. Alias `ppo/mean_scores`, which is used to specifically monitor the reward model.
1. `env/reward_std`: The standard deviation of the reward obtained from the environment. Alias ``ppo/std_scores`, which is used to specifically monitor the reward model.
1. `env/reward_dist`: The histogram distribution of the reward obtained from the environment.
1. `objective/kl`: The mean Kullback-Leibler (KL) divergence between the old and new policies. It measures how much the new policy deviates from the old policy. The KL divergence is used to compute the KL penalty in the objective function.
1. `objective/kl_dist`: The histogram distribution of the `objective/kl`.
1. `objective/kl_coef`: The coefficient for Kullback-Leibler (KL) divergence in the objective function.
1. `ppo/mean_non_score_reward`: The **KL penalty** calculated by `objective/kl * objective/kl_coef` as the total reward for optimization to prevent the new policy from deviating too far from the old policy.
1. `objective/entropy`: The entropy of the model's policy, calculated by `-logprobs.sum(-1).mean()`. High entropy means the model's actions are more random, which can be beneficial for exploration.
Training stats:
1. `ppo/learning_rate`: The learning rate for the PPO algorithm.
1. `ppo/policy/entropy`: The entropy of the model's policy, calculated by `pd = torch.nn.functional.softmax(logits, dim=-1); entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)`. It measures the randomness of the policy.
1. `ppo/policy/clipfrac`: The fraction of probability ratios (old policy / new policy) that fell outside the clipping range in the PPO objective. This can be used to monitor the optimization process.
1. `ppo/policy/approxkl`: The approximate KL divergence between the old and new policies, measured by `0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)`, corresponding to the `k2` estimator in http://joschu.net/blog/kl-approx.html
1. `ppo/policy/policykl`: Similar to `ppo/policy/approxkl`, but measured by `masked_mean(old_logprobs - logprobs, mask)`, corresponding to the `k1` estimator in http://joschu.net/blog/kl-approx.html
1. `ppo/policy/ratio`: The histogram distribution of the ratio between the new and old policies, used to compute the PPO objective.
1. `ppo/policy/advantages_mean`: The average of the GAE (Generalized Advantage Estimation) advantage estimates. The advantage function measures how much better an action is compared to the average action at a state.
1. `ppo/policy/advantages`: The histogram distribution of `ppo/policy/advantages_mean`.
1. `ppo/returns/mean`: The mean of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance. See https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ for more details.
1. `ppo/returns/var`: The variance of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance.
1. `ppo/val/mean`: The mean of the values, used to monitor the value function's performance.
1. `ppo/val/var` : The variance of the values, used to monitor the value function's performance.
1. `ppo/val/var_explained`: The explained variance for the value function, used to monitor the value function's performance.
1. `ppo/val/clipfrac`: The fraction of the value function's predicted values that are clipped.
1. `ppo/val/vpred`: The predicted values from the value function.
1. `ppo/val/error`: The mean squared error between the `ppo/val/vpred` and returns, used to monitor the value function's performance.
1. `ppo/loss/policy`: The policy loss for the Proximal Policy Optimization (PPO) algorithm.
1. `ppo/loss/value`: The loss for the value function in the PPO algorithm. This value quantifies how well the function estimates the expected future rewards.
1. `ppo/loss/total`: The total loss for the PPO algorithm. It is the sum of the policy loss and the value function loss.
Stats on queries, responses, and logprobs:
1. `tokens/queries_len_mean`: The average length of the queries tokens.
1. `tokens/queries_len_std`: The standard deviation of the length of the queries tokens.
1. `tokens/queries_dist`: The histogram distribution of the length of the queries tokens.
1. `tokens/responses_len_mean`: The average length of the responses tokens.
1. `tokens/responses_len_std`: The standard deviation of the length of the responses tokens.
1. `tokens/responses_dist`: The histogram distribution of the length of the responses tokens. (Costa: inconsistent naming, should be `tokens/responses_len_dist`)
1. `objective/logprobs`: The histogram distribution of the log probabilities of the actions taken by the model.
1. `objective/ref_logprobs`: The histogram distribution of the log probabilities of the actions taken by the reference model.
* `eps`: Tracks the number of episodes per second.
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
* `objective/scores`: The mean scores returned by the reward model / environment.
* `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
* `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
* `loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
* `loss/value_avg`: The average value loss, indicating the difference between the predicted value and the actual reward.
* `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to `policy/clipfrac_avg` but for the value function.
* `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
* `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
* `lr`: The current learning rate used by the optimizer.
* `episode`: The current episode count in the training process.
### Crucial values
During training, many values are logged, here are the most important ones:
1. `env/reward_mean`,`env/reward_std`, `env/reward_dist`: the properties of the reward distribution from the "environment" / reward model
1. `ppo/mean_non_score_reward`: The mean negated KL penalty during training (shows the delta between the reference model and the new policy over the batch in the step)
1. `objective/scores`: The mean scores returned by the reward model / environment.
1. `objective/rlhf_reward`: The mean RLHF reward. This is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
1. `objective/non_score_reward`: The mean reward from non-score-related sources (e.g., KL penalty).
Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables):
1. `ppo/loss/value`: it will spike / NaN when not going well.
1. `ppo/policy/ratio`: `ratio` being 1 is a baseline value, meaning that the probability of sampling a token is the same under the new and old policy. If the ratio is too high like 200, it means the probability of sampling a token is 200 times higher under the new policy than the old policy. This is a sign that the new policy is too different from the old policy, which will likely cause overoptimization and collapse training later on.
1. `ppo/policy/clipfrac` and `ppo/policy/approxkl`: if `ratio` is too high, the `ratio` is going to get clipped, resulting in high `clipfrac` and high `approxkl` as well.
1. `objective/kl`: it should stay positive so that the policy is not too far away from the reference policy.
1. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities.
1. `loss/value_avg`: The average value loss. It will spike / NaN when not going well.
1. `val/ratio`: The mean ratio of the current policy probability to the old policy probability. This number should float around 1.0. If this `ratio` is too high (e.g., 2.0 or 1000.0) or too small (e.g., 0.1), it means the updates between consecutive policies are too drastic.
1. `policy/clipfrac_avg` and `policy/approxkl_avg`: If `val/ratio` is too high, the `ratio` is going to get clipped, resulting in high `policy/clipfrac_avg` and high `policy/approxkl_avg` as well.
1. `objective/kl`: The mean KL divergence. It should stay positive and ideally not too large, so that the policy is not too far away from the reference policy.
## GRPO Logging
Here's a brief explanation for the logged metrics provided in the data for the GRPO trainer:
* `num_tokens`: Total number of input tokens processed during training so far.
**Completions:**
* `completions/mean_length`: Mean length of all generated completions (including those not ending with an EOS token).
* `completions/min_length`: Minimum length among all generated completions.
* `completions/max_length`: Maximum length among all generated completions.
* `completions/clipped_ratio`: The ratio of completions that did not end with an EOS token before reaching the maximum generation length (i.e., they were truncated).
* `completions/mean_terminated_length`: Mean length of only those completions that successfully ended with an EOS token.
* `completions/min_terminated_length`: Minimum length among completions that ended with an EOS token.
* `completions/max_terminated_length`: Maximum length among completions that ended with an EOS token.
**Rewards:**
* `rewards/{reward_func_name}/mean`: The mean reward obtained from a specific, named reward function (e.g., `rewards/my_custom_reward/mean`). This is logged for each reward function used.
* `rewards/{reward_func_name}/std`: The standard deviation of rewards from a specific, named reward function.
* `reward`: The overall mean of the (potentially weighted and, if `args.scale_rewards` is true, normalized) rewards, after group-wise normalization (advantages).
* `reward_std`: The standard deviation of the (potentially weighted) rewards *before* group-wise normalization for advantages.
**Policy and Loss Metrics:**
* `kl`: The mean Kullback-Leibler (KL) divergence between the current policy and the reference policy. This is logged only if `beta` (the KL coefficient in `GRPOConfig`) is non-zero.
* If Liger GRPOLoss is used (`use_liger_loss: True` in `GRPOConfig`):
* `clip_ratio`: The fraction of policy updates where the probability ratio was clipped according to the GRPO loss's epsilon bounds.
* If standard GRPOLoss is used (`use_liger_loss: False`):
* `clip_ratio/low_mean`: The mean fraction of instances where the probability ratio `r_t(θ)` was clipped at the lower bound `1 - epsilon_low` (occurs when advantage is negative and ratio is below the bound).
* `clip_ratio/low_min`: The minimum observed fraction for `clip_ratio/low_mean` across batches/processes.
* `clip_ratio/high_mean`: The mean fraction of instances where the probability ratio `r_t(θ)` was clipped at the upper bound `1 + epsilon_high` (occurs when advantage is positive and ratio is above the bound).
* `clip_ratio/high_max`: The maximum observed fraction for `clip_ratio/high_mean` across batches/processes.
* `clip_ratio/region_mean`: The mean fraction of instances where the probability ratio was clipped at either the lower or upper bound.
### Crucial GRPO values
During GRPO training, monitor these values for insights into performance and stability:
1. `reward`: This is the primary objective. It reflects the (group-wise normalized) rewards the policy is achieving. It should generally increase during successful training.
1. `kl`: If `beta > 0`, this tracks the divergence from the reference model. Keep an eye on it to ensure the policy doesn't stray too far, which can lead to instability.
1. `clip_ratio/*` (either `clip_ratio` for Liger loss or the more detailed `clip_ratio/...` metrics for standard loss): These indicate how often the policy updates are being constrained by the GRPO clipping mechanism. Very high values might suggest that the policy is trying to change too drastically (potentially due to large advantages or a learning rate that's too high) or that the epsilon clipping range is too restrictive.
1. `completions/clipped_ratio`: A high ratio here indicates that the model is frequently generating completions that are cut off by `max_completion_length` rather than naturally ending with an EOS token. This might suggest issues with learning sequence termination or that `max_completion_length` is too short.
1. `rewards/{reward_func_name}/mean`: Monitoring the mean of individual reward functions can help diagnose which aspects of the desired behavior the model is learning or struggling with, especially when using multiple reward sources.

View File

@ -0,0 +1,5 @@
# Model Utilities
## get_act_offloading_ctx_manager
[[autodoc]] models.get_act_offloading_ctx_manager

View File

@ -14,7 +14,7 @@ Sequence lengths in the dataset can vary widely. When data is batched, sequences
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/why_you_should_truncate.png" alt="Truncation prompt completion" width="600"/>
</div>
To reduce memory usage, its important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case.
To reduce memory usage, it's important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case.
<hfoptions id="dpo">
<hfoption id="DPO">
@ -129,6 +129,42 @@ training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_imple
</hfoption>
</hfoptions>
## Activation offloading
Activation offloading is a memory efficiency technique that reduces GPU VRAM usage by temporarily moving activation tensors to CPU RAM during the forward pass and bringing them back only when needed for the backward pass. This significantly reduces peak memory usage at the cost of slightly increased training time.
To enable activation offloading in your SFT training configuration:
</hfoption>
<hfoption id="SFT">
```python
from trl import SFTConfig
training_args = SFTConfig(..., activation_offloading=True)
```
</hfoption>
</hfoptions>
<Tip warning={true}>
When using activation offloading with models that use Liger kernels, you must disable Liger cross entropy due to compatibility issues. The issue occurs specifically with `use_liger_kernel=True` because Liger cross entropy performs in-place operations which conflict with activation offloading. The default setting (`use_liger_kernel=False`) works:
```python
# When using activation offloading with a model that uses Liger kernels:
from trl import SFTConfig
training_args = SFTConfig(
activation_offloading=True,
use_liger_kernel=False, # Disable Liger cross entropy
# Other parameters...
)
```
</Tip>
Under the hood, activation offloading implements PyTorch's [`saved_tensors_hooks`](https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html#hooks-for-autograd-saved-tensors) to intercept activations during the forward pass. It intelligently manages which tensors to offload based on size and context, avoiding offloading output tensors which would be inefficient. For performance optimization, it can optionally use CUDA streams to overlap computation with CPU-GPU transfers.
## Disabling model gathering for generation in online methods
When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Online methods involve generating completions from the model as part of the training process. During this step, the model weights are temporarily gathered on a single GPU for generation. For very large models, this gathering can lead to out-of-memory (OOM) errors, as described in this issue: [#2250](https://github.com/huggingface/trl/issues/2250#issue-2598304204).

9
docs/source/rewards.md Normal file
View File

@ -0,0 +1,9 @@
# Reward Functions
This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`].
## Format rewards
### think_format_reward
[[autodoc]] rewards.think_format_reward

View File

@ -50,24 +50,24 @@ trainer = SFTTrainer(
trainer.train()
```
The above snippets will use the default training arguments from the [`SFTConfig`] class. If you want to modify the defaults pass in your modification to the `SFTConfig` constructor and pass them to the trainer via the `args` argument.
The above snippets will use the default training arguments from the [`SFTConfig`] class. If you want to modify the defaults, pass in your modification to the `SFTConfig` constructor and pass it to the trainer via the `args` argument.
## Advanced usage
### Train on completions only
To train on completions only, simply use a [prompt-completion](#prompt-completion) dataset. In this mode, loss is computed solely on the completion part.
To train on completions only, simply use a [prompt-completion](dataset_formats#prompt-completion) dataset. In this mode, loss is computed solely on the completion part.
If youd like to compute loss on both the prompt **and** the completion while still using a prompt-completion dataset, set `completion_only_loss=False` in the [`SFTConfig`]. This is equivalent to [converting the dataset to a language modeling](#from-prompt-completion-to-language-modeling-dataset) format.
If youd like to compute loss on both the prompt **and** the completion while still using a prompt-completion dataset, set `completion_only_loss=False` in the [`SFTConfig`]. This is equivalent to [converting the dataset to a language modeling](dataset_formats#from-prompt-completion-to-language-modeling-dataset) format.
### Add Special Tokens for Chat Format
Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment.
Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system, and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment.
The [`setup_chat_format`] function in `trl` easily sets up a model and tokenizer for conversational AI tasks. This function:
- Adds special tokens to the tokenizer, e.g. `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
- Adds special tokens to the tokenizer, e.g., `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
- Resizes the models embedding layer to accommodate the new tokens.
- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI.
- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g. 64. If you want to see more formats being supported in the future, please open a GitHub issue on [trl](https://github.com/huggingface/trl)
- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g., `64`. If you want to see more formats being supported in the future, please open a GitHub issue on [trl](https://github.com/huggingface/trl)
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
@ -77,12 +77,12 @@ from trl import setup_chat_format
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
# Set up the chat format with default 'chatml' format
# Set up the chat format with the default 'chatml' format
model, tokenizer = setup_chat_format(model, tokenizer)
```
> [!WARNING]
> Some base models, like those from Qwen, have a predefined chat template in the model's tokenizer. In these cases it is not necessary to apply `setup_chat_format()`, as the tokenizer already handles the formatting. However, it is necessary to align the EOS token with the chat template to ensure the model's responses terminate correctly. In these cases, specify `eos_token` in `SFTConfig`; for example, for `Qwen/Qwen2.5-1.5B` one should set `eos_token="<|im_end|>"`.
> Some base models, like those from Qwen, have a predefined chat template in the model's tokenizer. In these cases, it is not necessary to apply `setup_chat_format()`, as the tokenizer already handles the formatting. However, it is necessary to align the EOS token with the chat template to ensure the model's responses terminate correctly. In these cases, specify `eos_token` in `SFTConfig`; for example, for `Qwen/Qwen2.5-1.5B`, one should set `eos_token="<|im_end|>"`.
With our model and tokenizer set up, we can now fine-tune our model on a conversational dataset. Below is an example of how a dataset can be formatted for fine-tuning.
@ -126,7 +126,7 @@ trainer = SFTTrainer(
)
```
If the dataset is not in one of those format you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look.
If the dataset is not in one of those formats, you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look.
### Format your input prompts
@ -158,7 +158,7 @@ trainer = SFTTrainer(
trainer.train()
```
To properly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example of how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
To properly format your input, make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example of how to use SFTTrainer on the alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
### Packing dataset
@ -177,12 +177,12 @@ trainer = SFTTrainer(
trainer.train()
```
Note that if you use a packed dataset and if you pass `max_steps` in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing.
Note that if you use a packed dataset and if you pass `max_steps` in the training arguments, you will probably train your models for more than a few epochs, depending on the way you have configured the packed dataset and the training protocol. Double-check that you know and understand what you are doing.
If you don't want to pack your `eval_dataset`, you can pass `eval_packing=False` to the `SFTConfig` init method.
#### Customize your prompts using packed dataset
If your dataset has several fields that you want to combine, for example if the dataset has `question` and `answer` fields and you want to combine them, you can pass a formatting function to the trainer that will take care of that. For example:
If your dataset has several fields that you want to combine, for example, if the dataset has `question` and `answer` fields and you want to combine them, you can pass a formatting function to the trainer that will take care of that. For example:
```python
def formatting_func(example):
@ -256,7 +256,7 @@ trainer.train()
```
> [!WARNING]
> If the chat template contains special tokens like `<|im_start|>` (ChatML) or `<|eot_id|>` (Llama), the embedding layer and LM head must be included in the trainable parameters via the `modules_to_save` argument. Without this, the fine-tuned model will produce unbounded or nonsense generations. If the chat template doesn't contain special tokens (e.g. Alpaca), then the `modules_to_save` argument can be ignored or set to `None`.
> If the chat template contains special tokens like `<|im_start|>` (ChatML) or `<|eot_id|>` (Llama), the embedding layer and LM head must be included in the trainable parameters via the `modules_to_save` argument. Without this, the fine-tuned model will produce unbounded or nonsensical generations. If the chat template doesn't contain special tokens (e.g., Alpaca), then the `modules_to_save` argument can be ignored or set to `None`.
You can also continue training your `PeftModel`. For that, first load a `PeftModel` outside `SFTTrainer` and pass it directly to the trainer without the `peft_config` argument being passed.
@ -321,7 +321,7 @@ Once you have loaded your model, wrap the `trainer.train()` call under the `with
trainer.train()
```
Note that you cannot train your model using Flash Attention 1 on an arbitrary dataset as `torch.scaled_dot_product_attention` does not support training with padding tokens if you use Flash Attention kernels. Therefore you can only use that feature with `packing=True`. If your dataset contains padding tokens, consider switching to Flash Attention 2 integration.
Note that you cannot train your model using Flash Attention 1 on an arbitrary dataset as `torch.scaled_dot_product_attention` does not support training with padding tokens if you use Flash Attention kernels. Therefore, you can only use that feature with `packing=True`. If your dataset contains padding tokens, consider switching to Flash Attention 2 integration.
Below are some numbers you can get in terms of speedup and memory efficiency, using Flash Attention 1, on a single NVIDIA-T4 16GB.
@ -351,12 +351,12 @@ model = AutoModelForCausalLM.from_pretrained(
```
If you don't use quantization, make sure your model is loaded in half-precision and dispatch your model on a supported GPU device.
After loading your model, you can either train it as it is, or attach adapters and train adapters on it in case your model is quantized.
After loading your model, you can either train it as it is or attach adapters and train adapters on it in case your model is quantized.
In contrast to Flash Attention 1, the integration makes it possible to train your model on an arbitrary dataset that also includes padding tokens.
### Using model creation utility
### Using the model creation utility
We included a utility function to create your model.
@ -391,17 +391,17 @@ trainer = SFTTrainer(
)
```
### Enhance the model's performances using NEFTune
### Enhance the model's performance using NEFTune
NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://huggingface.co/papers/2310.05914) from Jain et al. it consists of adding noise to the embedding vectors during training. According to the abstract of the paper:
NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://huggingface.co/papers/2310.05914) from Jain et al. It consists of adding noise to the embedding vectors during training. According to the abstract of the paper:
> Standard finetuning of LLaMA-2-7B using Alpaca achieves 29.79% on AlpacaEval, which rises to 64.69% using noisy embeddings. NEFTune also improves over strong baselines on modern instruction datasets. Models trained with Evol-Instruct see a 10% improvement, with ShareGPT an 8% improvement, and with OpenPlatypus an 8% improvement. Even powerful models further refined with RLHF such as LLaMA-2-Chat benefit from additional training with NEFTune.
> Standard finetuning of LLaMA-2-7B using Alpaca achieves 29.79% on AlpacaEval, which rises to 64.69% using noisy embeddings. NEFTune also improves over strong baselines on modern instruction datasets. Models trained with Evol-Instruct see a 10% improvement, with ShareGPT an 8% improvement, and with OpenPlatypus an 8% improvement. Even powerful models further refined with RLHF, such as LLaMA-2-Chat, benefit from additional training with NEFTune.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/neft-screenshot.png">
</div>
To use it in `SFTTrainer` simply pass `neftune_noise_alpha` when creating your `SFTConfig` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to retrieve back the original behaviour of the embedding layer.
To use it in `SFTTrainer`, simply pass `neftune_noise_alpha` when creating your `SFTConfig` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to revert to the original behaviour of the embedding layer.
```python
from datasets import load_dataset
@ -430,7 +430,7 @@ Note however, that the amount of performance gain is _dataset dependent_ and in
### Accelerate fine-tuning 2x using `unsloth`
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks on 1x A100 listed below:
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently, `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek, etc) and Mistral architectures. Some benchmarks on 1x A100 listed below:
| 1 A100 40GB | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
| --------------- | --------- | --- | --------------------- | --------- | ------------ |
@ -439,7 +439,7 @@ You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [
| Mistral 7b | Slim Orca | 1x | 1.17x | **1.88x** | -65.9% |
| Tiny Llama 1.1b | Alpaca | 1x | 1.55x | **2.74x** | -57.8% |
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
First, install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
```python
import torch
@ -489,18 +489,18 @@ trainer.train()
The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).
## Liger-Kernel: Increase 20% throughput and reduces 60% memory for multi-GPU training
## Liger-Kernel: Increase 20% throughput and reduce 60% memory for multi-GPU training
[Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. That way, we can **4x** our context length, as described in the benchmark below. They have implemented Hugging Face Compatible `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed).
[Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduce memory usage by 60%. That way, we can **4x** our context length, as described in the benchmark below. They have implemented Hugging Face Compatible `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed).
With great memory reduction, you can potentially turn off cpu_offloading or gradient checkpointing to further boost the performance.
With this memory reduction, you can potentially turn off `cpu_offloading` or gradient checkpointing to further boost the performance.
| Speed Up | Memory Reduction |
|--------------------------|-------------------------|
| ![Speed up](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-tps.png) | ![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-memory.png) |
1. To use Liger-Kernel in [`SFTTrainer`], first install by
1. To use Liger-Kernel in [`SFTTrainer`], first install it by:
```bash
pip install liger-kernel
@ -510,7 +510,8 @@ pip install liger-kernel
```python
training_args = SFTConfig(
use_liger_kernel=True
use_liger_kernel=True,
...
)
```
@ -523,11 +524,11 @@ Pay attention to the following best practices when training a model with that tr
- [`SFTTrainer`] always truncates by default the sequences to the `max_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training.
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
- If you create a model outside the trainer, make sure not to pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
## Multi-GPU Training
Trainer (and thus SFTTrainer) supports multi-GPU training. If you run your script with `python script.py` it will default to using DP as the strategy, which may be [slower than expected](https://github.com/huggingface/trl/issues/1303). To use DDP (which is generally recommended, see [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many?select-gpu=Accelerate#data-parallelism) for more info) you must launch the script with `python -m torch.distributed.launch script.py` or `accelerate launch script.py`. For DDP to work you must also check the following:
Trainer (and thus SFTTrainer) supports multi-GPU training. If you run your script with `python script.py` it will default to using DP as the strategy, which may be [slower than expected](https://github.com/huggingface/trl/issues/1303). To use DDP (which is generally recommended, see [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many?select-gpu=Accelerate#data-parallelism) for more info) you must launch the script with `python -m torch.distributed.launch script.py` or `accelerate launch script.py`. For DDP to work, you must also check the following:
- If you're using gradient_checkpointing, add the following to the TrainingArguments: `gradient_checkpointing_kwargs={'use_reentrant':False}` (more info [here](https://github.com/huggingface/transformers/issues/26969)
- Ensure that the model is placed on the correct device:
```python
@ -545,7 +546,7 @@ You may experience some issues with GPTQ Quantization after completing training.
## Extending `SFTTrainer` for Vision Language Models
`SFTTrainer` does not inherently support vision-language data. However, we provide a guide on how to tweak the trainer to support vision-language data. Specifically, you need to use a custom data collator that is compatible with vision-language data. This guide outlines the steps to make these adjustments. For a concrete example, refer to the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) which demonstrates how to fine-tune the LLaVA 1.5 model on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset.
`SFTTrainer` does not inherently support vision-language data. However, we provide a guide on how to tweak the trainer to support vision-language data. Specifically, you need to use a custom data collator that is compatible with vision-language data. This guide outlines the steps to make these adjustments. For a concrete example, refer to the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py), which demonstrates how to fine-tune the LLaVA 1.5 model on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset.
### Preparing the Data
@ -664,6 +665,6 @@ A full example of training LLaVa 1.5 on the [HuggingFaceH4/llava-instruct-mix-vs
## Datasets
In the SFTTrainer we smartly support `datasets.IterableDataset` in addition to other style datasets. This is useful if you are using large corpora that you do not want to save all to disk. The data will be tokenized and processed on the fly, even when packing is enabled.
In the SFTTrainer, we smartly support `datasets.IterableDataset` in addition to other style datasets. This is useful if you are using large corpora that you do not want to save all to disk. The data will be tokenized and processed on the fly, even when packing is enabled.
Additionally, in the SFTTrainer, we support pre-tokenized datasets if they are `datasets.Dataset` or `datasets.IterableDataset`. In other words, if such a dataset has a column of `input_ids`, no further processing (tokenization or packing) will be done, and the dataset will be used as-is. This can be useful if you have pretokenized your dataset outside of this script and want to re-use it directly.
Additionally, in the SFTTrainer, we support pre-tokenized datasets if they are `datasets.Dataset` or `datasets.IterableDataset`. In other words, if such a dataset has a column of `input_ids`, no further processing (tokenization or packing) will be done, and the dataset will be used as-is. This can be useful if you have pretokenized your dataset outside of this script and want to reuse it directly.

View File

@ -1,197 +0,0 @@
# Text Environments
Text environments provide a learning ground for language agents. It allows a language model to use tools to accomplish a task such as using a Python interpreter to answer math questions or using a search index for trivia questions. Having access to tools allows language models to solve tasks that would be very hard for the models itself but can be trivial for the appropriate tools. A good example is arithmetics of large numbers that become a simple copy-paste task once you have access to a calculator.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/textenv.png">
</div>
Let's dive into how text environments work and start with tools!
## Tools
One of the core building blocks of text environments are tools that the model can use to solve tasks. In general tools can be any Python function that takes a string as input and returns string. The `TextEnvironment` offers two options for tools: either go with predefined tools from `transformers.Tool` or define your own function or class with `__call__` method. Let's have a look at both!
### `transformers.Tool`
Text environments fully support tools of the class `transformers.Tool`. The advantage of building tools in that framework is that they can easily be shared
```Python
from transformers import load_tool
# simple calculator tool that runs +-/* operations
calc_tool = load_tool("ybelkada/simple-calculator")
# python interpreter that executes program and returns outputs
py_tool = load_tool("lvwerra/python-interpreter")
# wikipedia search index that returns best search match
wiki_tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
```
These tools are either loaded from the hub or from a local folder. Using the tool is as simple as calling them with a text query:
```Python
calc_tool("1/2")
>>> "0.5"
```
Note that both input and return values are strings to enable easy usage with a language model.
### Custom Tools
The following is an example of a tool that adds two integers:
```Python
def add(text):
int_1, int_2 = text.split("+")
result = int(int_1) + int(int_2)
return str(result)
print(add("1+1"))
>>> "2"
```
We looked at basic examples such as a calculator but the principle holds for more complex tools as well such as a web search tool where you input the query and get the search results in return. Now let's look at how the model can use the tools with the call syntax.
### Call syntax
In order to have a unified way for the model to call a tool we created a simple syntax that looks as follows:
```python
"<request><TOOL_NAME>QUERY<call>TOOL_RESPONSE<response>"
```
There are a few special tokens involved so let's decompose it: First the model can signal that it wants to use a tool by emitting the `<request>` token. After that we want to know the name of the tool to call which is done by enclosing the tool name with `<>` brackets. Once we know which tool to call the tool query follows which is in free text form. The `<call>` tokens signifies the end of the query and stops the model generation. At this point the model output is parsed and the query sent to the tool. The environment appends the tool response to the string followed by the `<response>` token to show the end the tool output.
Let's look at the concrete example of the calculator and assume its name is `Calculator` (more on how the name of a tool is inferred later):
```python
"<request><Calculator>1/2<call>0.5<response>"
```
Finally, the episode is ended and generation stops when the model generates `<submit>` which marks the interaction as completed.
Now let's have a look how we can create a new text environment!
## Create a `TextEnvironment`
```python
prompt = """\
What is 13-3?
<request><SimpleCalculatorTool>13-3<call>10.0<response>
Result=10<submit>
"""
def reward_fn(result, answer):
"""Simplified reward function returning 1 if result matches answer and 0 otherwise."""
result_parsed = result.split("=")[1].split("<")[0]
return int(result_parsed==answer)
text_env = TextEnvironemnt(
model=model,
tokenizer=tokenizer,
tools= {"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")},
reward_fn=exact_match_reward,
prompt=prompt,
max_turns=1
max_tool_response=100
generation_kwargs={"do_sample": "true"}
)
```
Let's decompose the settings:
| Argument | Description |
|:-------------------|:----------------|
| `model` | Language model to interact with the environment and generate requests. |
| `tokenizer` | Tokenizer of language model handling tokenization of strings. |
| `tools` | `list` of `dict` of tools. If former the name of the tool is inferred from class name and otherwise it's the keys of the dictionary.|
| `reward_fn` | A function that takes a string as input and returns. Can have extra arguments that are passed to `.run()` such as ground truth.|
| `prompt` | Prompt to prepend to every task. Usually a few examples to demonstrate to the model how to use the tools in a few-shot fashion. |
| `max_turns` | Maximum number of interactions between model and tools before episode ends.|
| `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.|
| `max_length` | The maximum number of tokens to allow in an episode. |
| `generation_kwargs`| Generation settings used by the language model. |
You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools!
## Run an Episode
To run a set of queries through the text environment one can simply use the `run` method.
```python
queries = ["What is 1/2?"]
answers = ["0.5"]
queries, responses, masks, rewards, histories = text_env.run(queries, answers=answers)
```
This will execute the model/tool feedback loop for each query until either no tool is called anymore, the maximum number of turns is reached or to maximum number of tokens in an episode is exceeded. The extra `kwargs` (e.g. `answers=answers` above) passed to `run` will be passed on to the reward function.
There are five objects that are returned by `run`:
- `queries`: a list of the tokenized queries
- `responses`: all tokens that have been generated withing the environment including model and tool tokens
- `masks`: mask that indicates which tokens have been generated by the model and which tokens are generated by the tool
- `rewards`: a list of reward for each query/response
- `histories`: list of `TextHistory` objects, which are useful objects containing all the above and also the text equivalents
The masks are crucial for training as we don't want to optimize tokens that the model has not generated which are tokens produced by the tools.
Next, we'll train a PPO step with the generated responses!
### Train
Training on episodes from the `TextEnvironment` is straight forward and simply requires forwarding all the returned variables except the `TextHistory` objects to the `step` method:
```python
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
```
## `TextHistory`
The `TextHistory` object stores the interactions between the model and the text environment. It stores tokens and text generated in each turn and their source in each turn (model or system) as well as rewards. Let's go through the class attributes and methods.
### Attributes
The following table summarises the available attributes of the `TextHistory` class:
| Attribute | Description |
|:-------------------|:----------------|
| `text` | The full string of the text generated in the text environment with both model and system generated text. |
| `text_spans` | A list of tuples with the spans for each model or system generated text segment. |
| `system_spans` | A list of boolean values indicating if the segment is model or system generated. |
| `tokens` | All tokens generated in text environment with both model and system generated tokens. |
| `token_spans` | Similar to `text_spans` the `token_spans` indicate the boundaries of model andsystem generated tokens. |
| `token_masks` | The token masks can be used to ignore system generated tokens by masking them. |
| `completed` | Indicates if the interaction with the environment has completed. |
| `truncated` | Indicates if the interaction with the environment has completed because max length was reached. |
With these attributes you can reconstruct every interaction of the model with the `TextEnvironment`. The `TextHistory` also lets you visualize the text history. Let's have a look!
### Visualization
When the model interacts inside the `TextEnvironment` it can be useful to visualize and separate which parts of the text outputs were generated by the model and which parts come from the system and tools. For that purpose there are the two methods [`TextHistory.show_text`] and [`TextHistory.show_tokens`]. They print the text and tokens respectively and highlight the various segments using the [`rich` library](https://github.com/Textualize/rich) (make sure to install it before using these methods).
You can see that the prompt is highlighted in gray, whereas system segments such as query and tool responses are highlighted in green. All segments generated by the model are highlighted in blue and in addition to the pure text output the reward is displayed as additional text in plum. Here an example of `show_text`:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/textenv_show_text.png" width=600>
</div>
Sometimes there can be tricky tokenization related issues that are hidden when showing the decoded text. Thus `TextHistory` also offers an option to display the same highlighting on the tokens directly with `show_tokens`:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/textenv_show_tokens.png" width=800>
</div>
Note that you can turn on the colour legend by passing `show_legend=True`.
## API Documentation
[[autodoc]] TextEnvironment
[[autodoc]] TextHistory

View File

@ -1,20 +1,125 @@
# vLLM Integration
<Tip warning={true}>
This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood. Let's go! 🔥
Section under construction. Feel free to contribute!
## 🚀 How can I use vLLM with TRL to speed up training?
</Tip>
💡 **Note**: Resources required for this specific example: a single node with 8 GPUs.
## TRL vLLM server
First, install vLLM using the following command:
TRL provides a way to speedup generation using a dedicated vLLM server.
```bash
pip install "trl[vllm]"
```
Then run the server:
```sh
trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 2 --data-parallel-size 2
```
Once the server is running, you can use it to generate completions for training. In the example below, we are using the `GRPOTrainer` to train a model using the vLLM server for generation. The `--tensor-parallel-size` and `--data-parallel-size` arguments control how the model and data are sharded across GPUs.
In this example, we are sharding two copies of the model across 4 GPUs. Increasing data parallelism increases throughput, while increasing tensor parallelism allows for serving larger models. Then, run the training script by passing `use_vllm=True` in the training arguments as follows:
Sample of a simple `train.py` script:
```python
from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
dataset = load_dataset("trl-lib/tldr", split="train")
# Dummy reward function: count the number of unique characters in the completions
def reward_num_unique_chars(completions, **kwargs):
return [len(set(c)) for c in completions]
training_args = GRPOConfig(
output_dir="my_test",
use_vllm=True,
bf16=True,
gradient_checkpointing=True,
logging_steps=10,
)
trainer = GRPOTrainer(
model="Qwen/Qwen2.5-7B",
args=training_args,
reward_funcs=reward_num_unique_chars,
train_dataset=dataset,
)
trainer.train()
```
And the train command:
```sh
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
```
## 🎬 Flashback: Why do we need to use vLLM in online methods?
Online methods like GRPO or Online DPO require the model to generate completions during training, which are then used to compute reward signals. However, generation can be extremely time-consuming, especially with large or reasoning models. In the default setup (without vLLM), completions are generated using the [(unwrapped) model's `generate` method](https://github.com/huggingface/trl/blob/f3e8c2304428ef16e9ae5de9e5741ed84d533b7b/trl/trainer/grpo_trainer.py#L965C39-L965C66). This approach quickly becomes a major bottleneck — generation is slow and inefficient, particularly for large batches or models. As a result, training times increase significantly, and overall efficiency drops. To address this, we turn to vLLM, which enables much faster and more scalable generation, helping eliminate this bottleneck in online methods.
## 🤔 How does vLLM solve the slow generation issue?
If you've ever done autoregressive decoder training, you know all the input tokens to the LLM produce their attention key and value tensors, and these tensors are kept in GPU memory to later generate subsequent tokens based on them. These cached key and value tensors are often referred to as the KV cache. However, storing the KV cache occupies a lot of memory, so vLLM uses a technique called **PagedAttention** to solve this problem. PagedAttention, which is inspired by the OSs virtual memory concept, stores continuous keys and values in **non-contiguous memory space**, which is much more efficient. The details of this are beyond the scope of this document, but in short, it allows the model to store the keys and values in a more efficient way, reducing the memory footprint and speeding up the generation process. If you are interested, make sure to check out the [vLLM PagedAttention](https://blog.vllm.ai/2023/06/20/vllm.html) for more details.
## 🤔 What exactly happens when you run `trl vllm-serve --model <model_name>`?
When you run for example
```sh
trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 1 --data-parallel-size 4
```
the following happens:
![vllm](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/vllm-doc.png)
1. vLLM first spawns multiple workers to handle incoming requests in parallel. The number of workers is determined by multiplying the `--tensor-parallel-size` and `--data-parallel-size` values. In this example, it spawns 4 workers (1 × 4).
Each worker operates independently and processes a chunk of the incoming requests — which are basically the prompts sent to the server for generation. A key point to understand is that these 4 workers are running in parallel, and each one is responsible for handling a subset of the total incoming load.
2. Once the incoming requests (prompts) are distributed across the workers, the model starts generating completions. Internally, the models weights are split across multiple GPUs based on the `--tensor-parallel-size` argument — this is how tensor parallelism is handled. Meanwhile, data parallelism (controlled by `--data-parallel-size`) ensures that different sets of requests are processed independently across the workers. In short: tensor parallelism splits the model across GPUs, and data parallelism splits the batch of requests across different model replicas.
3. Although the GPUs process requests independently and in parallel, they still need to communicate with each other. Remember that each GPU handles only a slice of the incoming prompts (for example, with 4 GPUs and 8 prompts using `--data-parallel-size=4`, each GPU processes 2 prompts).
This GPU-to-GPU communication is managed efficiently by NVIDIAs NCCL library. The communication mainly ensures that each GPU gets its correct portion of the incoming requests — its lightweight and doesnt interfere with generation itself.
Separately, the number of completions to generate per prompt is controlled by the `num_generations` setting in the GRPO config. For instance, if you set `num_generations=2` (like in the picture above), each prompt will have 2 completions. So, with 8 prompts and `num_generations=2`, you would end up with 16 completions total — regardless of the number of GPUs or parallelism settings.
## 🥸 More detail on what happens under the hood when running the server
* The vLLM server starts by running the command: `trl vllm-serve --model Qwen/Qwen2.5-7B`.
* Once the server is running, it generates completions based on requests from the client (trainer) using `vllm_client.generate` [here](https://github.com/huggingface/trl/blob/cc044e35b285be7dc062764b3364e1e684db4c7c/trl/trainer/grpo_trainer.py#L1025-L1035).
* The client (trainer) then requests these completions from the server.
* These completions are used to compute the reward signal.
* Based on the reward signal and the models output, the loss is computed, and the backward pass is performed to update the models weights.
* **Note**: The server only handles completion generation — it doesnt train the model. Therefore, the models weights arent updated on the server. Once the backward pass is complete, the client sends the updated weights to the server using `vllm_client.update_named_param(name, param.data)`.
When using vLLM, ensure the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation for training using `CUDA_VISIBLE_DEVICES`. See the example below:
* **Set GPUs *03* for vLLM generation:** Assume `CUDA_VISIBLE_DEVICES=0,1,2,3` are allocated for vLLM generation.
```sh
trl vllm-serve --model <model_name> --tensor-parallel-size 1 --data-parallel-size 4
```
* **And GPUs *47* for training:** If you do not set the `CUDA_VISIBLE_DEVICES` environment variable, the training script will use all available GPUs by default, which may lead to resource conflicts. To avoid this, you can specify which GPUs to use for training. For example, if you want to use GPUs 47 for training, set the environment variable as follows:
```sh
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
```
## 🍷 More customization options with vLLM?
You can customize the server configuration by passing additional arguments.
```
$ trl vllm-serve --help
usage: trl vllm-serve [-h] --model MODEL [--revision REVISION] [--tensor_parallel_size TENSOR_PARALLEL_SIZE] [--host HOST]
[--port PORT] [--gpu_memory_utilization GPU_MEMORY_UTILIZATION] [--dtype DTYPE]
[--max_model_len MAX_MODEL_LEN] [--enable_prefix_caching ENABLE_PREFIX_CACHING]
usage: trl vllm-serve [-h] --model MODEL [--revision REVISION] [--tensor_parallel_size TENSOR_PARALLEL_SIZE]
[--data_parallel_size DATA_PARALLEL_SIZE] [--host HOST] [--port PORT]
[--gpu_memory_utilization GPU_MEMORY_UTILIZATION] [--dtype DTYPE] [--max_model_len MAX_MODEL_LEN]
[--enable_prefix_caching ENABLE_PREFIX_CACHING] [--enforce_eager ENFORCE_EAGER] [--log_level LOG_LEVEL]
options:
-h, --help Show this help message and exit
@ -22,6 +127,8 @@ options:
--revision REVISION Revision to use for the model. If not specified, the default branch will be used. (default: None)
--tensor_parallel_size TENSOR_PARALLEL_SIZE, --tensor-parallel-size TENSOR_PARALLEL_SIZE
Number of tensor parallel workers to use. (default: 1)
--data_parallel_size DATA_PARALLEL_SIZE, --data-parallel-size DATA_PARALLEL_SIZE
Number of data parallel workers to use. (default: 1)
--host HOST Host address to run the server on. (default: 0.0.0.0)
--port PORT Port to run the server on. (default: 8000)
--gpu_memory_utilization GPU_MEMORY_UTILIZATION, --gpu-memory-utilization GPU_MEMORY_UTILIZATION
@ -38,10 +145,41 @@ options:
--enable_prefix_caching ENABLE_PREFIX_CACHING, --enable-prefix-caching ENABLE_PREFIX_CACHING
Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support this
feature. (default: None)
--enforce_eager ENFORCE_EAGER, --enforce-eager ENFORCE_EAGER
Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the model
in eager mode. If `False` (default behavior), we will use CUDA graph and eager execution in hybrid. (default:
None)
--log_level LOG_LEVEL, --log-level LOG_LEVEL
Log level for uvicorn. Possible choices: 'critical', 'error', 'warning', 'info', 'debug', 'trace'. (default:
info)
```
### Find the best distributed setup
## 🥳 Okay, now that we have the server running, how can we use it to generate completions?
Run the training script and pass `use_vllm=True` in the training arguments:
```python
from trl import GRPOConfig
training_args = GRPOConfig(..., use_vllm=True)
```
## 💆🏻‍♀️ What's the best distributed setup?
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_8_gpus.png)
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_4_gpus.png)
![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_4_gpus.png)
First and foremost, always remember that the optimal setup depends on:
* The model size
* The number of GPUs you have
* The GPU memory size
* The batch size you are using
* The number of requests you are sending to the server (prompts)
* The `max_model_len` you are using (this is the max length of the input sequence that the model can process, a.k.a. the context window size)
* The number of completions you are generating for each request (`num_generations`)
Given these factors, our experiments on the Qwen model family (3B, 7B, 14B, 32B) using 8 H100 GPUs show that:
* For reasonable-sized models (3B14B) and a moderate context window (`max_len < 8k`), using full capacity for data parallelism gives better throughput. The setup `(tp=1, dp=8)` yields the best results.
* For larger models (32B) and longer context windows (`max_len > 8k`), a smaller DP size combined with some model-side parallelism performs better. For example, `(tp=2, dp=4)` is a good setup for 32B models with a larger context window.

View File

@ -0,0 +1,28 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: true
fsdp_offload_params: false
fsdp_reshard_after_forward: FULL_SHARD
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true
fsdp_version: 1
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,25 @@
# Requires accelerate 1.7.0 or higher
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -1,25 +0,0 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: true
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -4,4 +4,4 @@ This directory contains a collection of Jupyter notebooks that demonstrate how t
- [`best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb): This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO.
- [`gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb): This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook.
- [`gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook.
- [`gpt2-sentiment-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook.

View File

@ -1,118 +0,0 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import numpy as np
import torch
from transformers import AutoTokenizer, load_tool
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment
def generate_data(n):
"""Generate random arithmetic tasks and answers."""
tasks, answers = [], []
for _ in range(n):
a = np.random.randint(0, 50)
b = np.random.randint(0, 50)
op = np.random.choice(["-", "+", "*"])
tasks.append(f"\n\nWhat is {a} {op} {b}?")
if op == "-":
answers.append(a - b)
elif op == "+":
answers.append(a + b)
else:
answers.append(a * b)
return tasks, answers
def exact_match_reward(responses, answers=None):
"""Reward if generated response contains correct answer."""
rewards = []
pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*<submit>" # generated by chatGPT
for response, answer in zip(responses, answers):
reward = 0.0
predicted_number = None
match_pattern = re.findall(pattern, response)
if match_pattern:
predicted_number = float(match_pattern[0])
if predicted_number is not None:
if np.abs(predicted_number - answer) < 0.01:
reward += 1.0
rewards.append(torch.tensor(reward))
return rewards
# set up models
model_id = "gpt2"
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# system prompt
prompt = """\
What is 13-3?
<request><SimpleCalculatorTool>13-3<call>10.0<response>
Result=10<submit>
What is 4*3?
<request><SimpleCalculatorTool>4*3<call>12.0<response>
Result=12<submit>"""
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"eos_token_id": -1,
"max_new_tokens": 32,
}
# trainer
ppo_config = PPOConfig(
batch_size=256,
learning_rate=1.41e-5,
mini_batch_size=64,
log_with="wandb",
)
ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer)
# text env
text_env = TextEnvironment(
model,
tokenizer,
{"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")},
exact_match_reward,
prompt,
generation_kwargs=generation_kwargs,
)
# main training loop
for _step in range(100):
tasks, answers = generate_data(ppo_config.batch_size)
queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers)
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
response_texts = [tokenizer.decode(response) for response in responses]
query_texts = [tokenizer.decode(query) for query in queries]
texts = {"query": [qt.split("<submit>")[-1].strip() for qt in query_texts], "response": response_texts}
ppo_trainer.log_stats(train_stats, texts, rewards, columns_to_log=["query", "response", "answer"])
ppo_trainer.save_pretrained(model_id + "-calculator")

View File

@ -1,193 +0,0 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, HfArgumentParser, load_tool
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@dataclass
class ScriptArguments:
model_name: Optional[str] = field(default="bigcode/starcoderbase", metadata={"help": "the model name"})
learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"})
gradient_accumulation_steps: Optional[int] = field(
default=16, metadata={"help": "the number of gradient accumulation steps"}
)
max_new_tokens: Optional[int] = field(default=256, metadata={"help": "max number of generated tokens per turn"})
ppo_epochs: Optional[int] = field(default=1, metadata={"help": "max number of ppo epochs"})
n_epochs: Optional[int] = field(default=32, metadata={"help": "max number of ppo epochs"})
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
def exact_match_reward(responses, answers=None):
"""Reward if generated response contains correct answer."""
rewards = []
pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*<submit>" # generated by chatGPT
for response, answer in zip(responses, answers):
reward = 0.0
try:
predicted_number = None
match_pattern = re.findall(pattern, response)
if match_pattern:
predicted_number = float(match_pattern[0])
if predicted_number is not None:
if np.abs(predicted_number - float(answer)) < 0.1:
reward += 1.0
except Exception:
pass
rewards.append(torch.tensor(reward))
return rewards
def evaluate(test_dataloader, text_env, ppo_trainer):
test_rewards = []
for test_batch in test_dataloader:
_, _, _, rewards, _ = text_env.run(test_batch["query"], answers=test_batch["answer"])
test_rewards.extend(rewards)
test_rewards = ppo_trainer.accelerator.gather_for_metrics(
torch.stack(test_rewards).to(ppo_trainer.accelerator.device)
)
return test_rewards.mean()
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["c_proj", "c_attn", "q_attn"],
)
# set up models
model = AutoModelForCausalLMWithValueHead.from_pretrained(
script_args.model_name,
use_auth_token=True,
load_in_4bit=True,
peft_config=lora_config,
)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, use_auth_token=True)
tokenizer.pad_token = tokenizer.eos_token
ds = load_dataset("openai/gsm8k", "main", split="train")
ds = ds.rename_columns({"question": "query"})
ds = ds.map(lambda x: {"answer": x["answer"].split("#### ")[1]})
ds = ds.select(range(1, len(ds))) # skip the first sample which is used in prompt
ds_test = load_dataset("openai/gsm8k", "main", split="test")
ds_test = ds_test.rename_columns({"question": "query"})
ds_test = ds_test.map(lambda x: {"answer": x["answer"].split("#### ")[1]})
test_dataloader = torch.utils.data.DataLoader(ds_test, batch_size=script_args.batch_size)
# prompt
prompt = """\
Example of using a Python API to solve math questions.
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
<request><PythonInterpreter>
def solution():
money_initial = 23
bagels = 5
bagel_cost = 3
money_spent = bagels * bagel_cost
money_left = money_initial - money_spent
result = money_left
return result
print(solution())
<call>72<response>
Result = 72 <submit>
Q: """
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"eos_token_id": -1,
"max_new_tokens": script_args.max_new_tokens,
}
# trainer
ppo_config = PPOConfig(
batch_size=script_args.batch_size,
learning_rate=script_args.learning_rate,
mini_batch_size=script_args.mini_batch_size,
ppo_epochs=script_args.ppo_epochs,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
log_with="wandb",
tracker_project_name="trl-gsm8k",
remove_unused_columns=False,
optimize_cuda_cache=True,
)
ppo_trainer = PPOTrainer(args=ppo_config, model=model, tokenizer=tokenizer, dataset=ds)
test_dataloader = ppo_trainer.accelerator.prepare(test_dataloader)
# text env
text_env = TextEnvironment(
model,
tokenizer,
[load_tool("lvwerra/python-interpreter")],
exact_match_reward,
prompt,
max_turns=2,
generation_kwargs=generation_kwargs,
)
# main training loop
for epoch in range(script_args.n_epochs):
for step, batch in enumerate(ppo_trainer.dataloader):
if (step == 0) and (epoch % 4 == 0): # evaluate every 4 epochs
reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer)
else:
reward_mean_test = None
queries, responses, masks, rewards, histories = text_env.run(batch["query"], answers=batch["answer"])
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
# logging
if reward_mean_test is not None:
train_stats["env/reward_mean_test"] = reward_mean_test
texts = {
"query": batch["query"],
"response": [tokenizer.decode(response) for response in responses],
"answer": batch["answer"],
}
ppo_trainer.log_stats(train_stats, texts, rewards, columns_to_log=["query", "response", "answer"])
reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer)
ppo_trainer.save_pretrained(f"model/{script_args.model_name}-gsm8k")

View File

@ -1,192 +0,0 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass, field
from typing import Optional
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, HfArgumentParser, load_tool
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@dataclass
class ScriptArguments:
model_name: Optional[str] = field(default="bigcode/starcoderbase", metadata={"help": "the model name"})
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"})
gradient_accumulation_steps: Optional[int] = field(
default=16, metadata={"help": "the number of gradient accumulation steps"}
)
max_new_tokens: Optional[int] = field(default=256, metadata={"help": "max number of generated tokens per turn"})
ppo_epochs: Optional[int] = field(default=1, metadata={"help": "max number of ppo epochs"})
iterations: Optional[int] = field(default=1000, metadata={"help": "the number of iterations"})
seed: Optional[int] = field(default=0, metadata={"help": "the random seed"})
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["c_proj", "c_attn", "q_attn"],
)
# set up models
model = AutoModelForCausalLMWithValueHead.from_pretrained(
script_args.model_name,
use_auth_token=True,
trust_remote_code=True,
load_in_4bit=True,
peft_config=lora_config,
)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, use_auth_token=True)
tokenizer.pad_token = tokenizer.eos_token
# system prompt
prompt = """\
Answer the following question:
Q: In which branch of the arts is Patricia Neary famous?
A: Ballets
A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
Result=Ballets<submit>
Q: Who won Super Bowl XX?
A: Chicago Bears
A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 4610, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
Result=Chicago Bears<submit>
Q: """
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"eos_token_id": -1,
"max_new_tokens": script_args.max_new_tokens,
}
# trainer
config = PPOConfig(
batch_size=script_args.batch_size,
model_name=script_args.model_name,
learning_rate=script_args.learning_rate,
log_with=script_args.log_with,
mini_batch_size=script_args.mini_batch_size,
ppo_epochs=script_args.ppo_epochs,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
seed=script_args.seed,
optimize_cuda_cache=True,
)
ppo_trainer = PPOTrainer(args=config, model=model, tokenizer=tokenizer)
dataset = load_dataset("mandarjoshi/trivia_qa", "rc", split="train")
local_seed = script_args.seed + ppo_trainer.accelerator.process_index * 100003 # Prime
dataset = dataset.shuffle(local_seed)
def data_generator():
for i in range(len(dataset)):
yield dataset[i]["question"], list(dataset[i]["answer"]["normalized_aliases"])
gen = data_generator()
gen = iter(gen)
def generate_data(n):
tasks, answers = [], []
for _i in range(n):
q, a = next(gen)
tasks.append(q)
answers.append(a)
return tasks, answers
def exact_match_reward(responses, answers=None):
"""Reward if generated response contains correct answer."""
rewards = []
for response, answer in zip(responses, answers):
reward = 0.0
for a in answer:
if a.lower() in response.lower():
reward += 1.0
break
rewards.append(torch.tensor(reward))
return rewards
def tool_fn(x):
# limit the amount of tokens
return tool(x).split("\n")[1][:600]
# text env
tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
text_env = TextEnvironment(
model,
tokenizer,
{"Wiki": tool_fn},
exact_match_reward,
prompt,
generation_kwargs=generation_kwargs,
max_tool_response=400,
)
def print_trainable_parameters(model):
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
print_trainable_parameters(model)
# main training loop
for i in range(script_args.iterations):
tasks, answers = generate_data(config.batch_size)
queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers)
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
response_texts = [tokenizer.decode(response) for response in responses]
query_texts = [tokenizer.decode(query) for query in queries]
texts = {
"query": [qt.split("<submit>")[-1].strip() for qt in query_texts],
"response": response_texts,
"answer": [", ".join(item) for item in answers],
}
all_rewards = ppo_trainer.accelerator.gather(torch.tensor(rewards, device=ppo_trainer.accelerator.device))
ppo_trainer.log_stats(train_stats, texts, list(all_rewards), columns_to_log=["query", "response", "answer"])
if i % 100 == 0:
ppo_trainer.save_pretrained(f"models/{script_args.model_name}_{script_args.seed}_{i}_triviaqa")

View File

@ -45,7 +45,6 @@ python examples/scripts/gkd.py \
--lora_alpha 16
"""
from accelerate import PartialState
from datasets import load_dataset
from transformers import AutoTokenizer, GenerationConfig
@ -106,14 +105,6 @@ if __name__ == "__main__":
################
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
with PartialState().local_main_process_first():
dataset = dataset.map(
lambda x: {
"prompt": tokenizer.apply_chat_template(x["prompt"], tokenize=False, add_generation_prompt=True)
},
num_proc=training_args.dataset_num_proc,
)
################
# Training
################

View File

@ -59,6 +59,9 @@ from transformers import (
Qwen2Config,
Qwen2ForCausalLM,
Qwen2ForSequenceClassification,
Qwen3Config,
Qwen3ForCausalLM,
Qwen3ForSequenceClassification,
SiglipVisionConfig,
T5Config,
T5ForConditionalGeneration,
@ -120,6 +123,7 @@ for model_id, config_class, model_class, suffix in [
("microsoft/Phi-3.5-mini-instruct", Phi3Config, Phi3ForCausalLM, None),
("Qwen/Qwen2.5-32B-Instruct", Qwen2Config, Qwen2ForCausalLM, "2.5"),
("Qwen/Qwen2.5-Coder-0.5B", Qwen2Config, Qwen2ForCausalLM, "2.5-Coder"),
("Qwen/Qwen3-4B", Qwen3Config, Qwen3ForCausalLM, None),
]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
config = config_class(
@ -134,7 +138,7 @@ for model_id, config_class, model_class, suffix in [
push_to_hub(model, tokenizer, "tiny", suffix)
# A slightly bigger model, required for vLLM testing
# Two slightly bigger models, required for vLLM testing
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-32B-Instruct")
config = Qwen2Config(
vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()),
@ -147,11 +151,23 @@ config = Qwen2Config(
model = Qwen2ForCausalLM(config)
push_to_hub(model, tokenizer, "small", "2.5")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")
config = Qwen3Config(
vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()),
hidden_size=128, # increase hidden size so that hidden_size // num_attention_heads = 32, required for vLLM
num_attention_heads=4,
num_key_value_heads=2,
num_hidden_layers=2,
intermediate_size=32,
)
model = Qwen3ForCausalLM(config)
push_to_hub(model, tokenizer, "small")
# Reward models
for model_id, config_class, model_class, suffix in [
("meta-llama/Llama-3.2-1B-Instruct", LlamaConfig, LlamaForSequenceClassification, "3.2"),
("Qwen/Qwen2.5-32B-Instruct", Qwen2Config, Qwen2ForSequenceClassification, "2.5"),
("Qwen/Qwen3-4B", Qwen3Config, Qwen3ForSequenceClassification, None),
]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
config = config_class(

View File

@ -1,2 +1,99 @@
[metadata]
license_file = LICENSE
name = trl
version = 0.18.1
description = Train transformer language models with reinforcement learning.
long_description = file: README.md
long_description_content_type = text/markdown
author = Leandro von Werra
author_email = leandro.vonwerra@gmail.com
url = https://github.com/huggingface/trl
keywords = transformers, huggingface, language modeling, post-training, rlhf, sft, dpo, grpo
license_file = LICENSE
classifiers =
Development Status :: 2 - Pre-Alpha
Intended Audience :: Developers
Intended Audience :: Science/Research
Natural Language :: English
Operating System :: OS Independent
Programming Language :: Python :: 3
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
Programming Language :: Python :: 3.11
Programming Language :: Python :: 3.12
Programming Language :: Python :: 3.13
[options]
packages = find:
python_requires = >=3.9
include_package_data = True
install_requires =
accelerate>=0.34.0
datasets>=3.0.0
transformers>=4.50.0
[options.packages.find]
exclude =
tests*
[options.package_data]
trl =
templates/*.md
accelerate_configs/*.yaml
[options.extras_require]
bco =
scikit-learn
joblib
deepspeed =
deepspeed>=0.14.4
diffusers =
diffusers>=0.18.0
judges =
openai>=1.23.2
llm-blender>=0.0.2
liger =
liger-kernel>=0.5.9
mergekit =
mergekit>=0.0.5.1
peft =
peft>=0.8.0
quantization =
bitsandbytes
scikit =
scikit-learn
test =
parameterized
pytest-cov
pytest-rerunfailures
pytest-xdist
pytest
vllm =
# vLLM package does not yet support Python 3.13. These constraints can be lifted once support is added:
# see https://github.com/vllm-project/vllm/pull/13164
vllm>=0.8.3; python_version < "3.13"
fastapi; python_version < "3.13"
pydantic; python_version < "3.13"
requests; python_version < "3.13"
uvicorn; python_version < "3.13"
vlm =
Pillow
dev =
%(bco)s
%(deepspeed)s
%(diffusers)s
%(judges)s
%(liger)s
%(mergekit)s
%(peft)s
%(quantization)s
%(scikit)s
%(test)s
%(vlm)s
[options.entry_points]
console_scripts =
trl = trl.cli:main
[coverage:run]
branch = True

121
setup.py
View File

@ -12,124 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""trl is an open library for RL with transformer models.
Note:
VERSION needs to be formatted following the MAJOR.MINOR.PATCH convention
(we need to follow this convention to be able to retrieve versioned scripts)
Simple check list for release from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py
To create the package for PyPI.
0. Prerequisites:
- Dependencies:
- twine: "pip install twine"
- Create an account in (and join the 'trl' project):
- PyPI: https://pypi.org/
- Test PyPI: https://test.pypi.org/
1. Change the version in:
- __init__.py
- setup.py
2. Commit these changes: "git commit -m 'Release: VERSION'"
3. Add a tag in git to mark the release: "git tag VERSION -m 'Add tag VERSION for PyPI'"
Push the tag to remote: git push --tags origin main
4. Build both the sources and the wheel. Do not change anything in setup.py between
creating the wheel and the source distribution (obviously).
First, delete any "build" directory that may exist from previous builds.
For the wheel, run: "python setup.py bdist_wheel" in the top level directory.
(this will build a wheel for the python version you use to build it).
For the sources, run: "python setup.py sdist"
You should now have a /dist directory with both .whl and .tar.gz source versions.
5. Check that everything looks correct by uploading the package to the PyPI test server:
twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
Check that you can install it in a virtualenv/notebook by running:
pip install -i https://testpypi.python.org/pypi trl
6. Upload the final version to actual PyPI:
twine upload dist/* -r pypi
7. Fill release notes in the tag in github once everything is looking hunky-dory.
8. Change the version in __init__.py and setup.py to X.X.X+1.dev0 (e.g. VERSION=1.18.3 -> 1.18.4.dev0).
Then push the change with a message 'set dev version'
"""
from setuptools import find_packages, setup
from setuptools import setup
__version__ = "0.17.0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
REQUIRED_PKGS = [
"accelerate>=0.34.0",
"datasets>=3.0.0",
"rich", # rich shouldn't be a required package for trl, we should remove it from here
"transformers>=4.46.0",
]
EXTRAS = {
"deepspeed": ["deepspeed>=0.14.4"],
"diffusers": ["diffusers>=0.18.0,<0.33.0"], # Temp set <0.33.0 due to ftfy optional dep issue breaking doc builds
"judges": ["openai>=1.23.2", "llm-blender>=0.0.2"],
"liger": ["liger-kernel>=0.5.6"],
"mergekit": ["mergekit>=0.0.5.1"],
"peft": ["peft>=0.8.0"],
"quantization": ["bitsandbytes"],
"scikit": ["scikit-learn"],
"bco": ["scikit-learn", "joblib"],
"test": ["parameterized", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "pytest"],
"vllm": ["vllm>=0.8.3", "fastapi", "pydantic", "requests", "uvicorn"],
"vlm": ["Pillow"],
}
EXTRAS["dev"] = []
for reqs in EXTRAS.values():
EXTRAS["dev"].extend(reqs)
setup(
name="trl",
license="Apache 2.0",
classifiers=[
"Development Status :: 2 - Pre-Alpha",
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: Apache Software License",
"Natural Language :: English",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
url="https://github.com/huggingface/trl",
entry_points={
"console_scripts": ["trl=trl.cli:main"],
},
include_package_data=True,
package_data={
"trl": ["templates/*.md"],
},
packages=find_packages(exclude={"tests", "tests.slow", "trl.templates"}),
install_requires=REQUIRED_PKGS,
extras_require=EXTRAS,
python_requires=">=3.9",
long_description=open("README.md", encoding="utf-8").read(),
long_description_content_type="text/markdown",
zip_safe=False,
version=__version__,
description="Train transformer language models with reinforcement learning.",
keywords="transformers, huggingface, language modeling, post-training, rlhf, sft, dpo, grpo",
author="Leandro von Werra",
author_email="leandro.vonwerra@gmail.com",
)
setup()

View File

@ -22,7 +22,7 @@ from accelerate.utils.memory import release_memory
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.testing_utils import require_liger_kernel, require_torch_accelerator
from transformers.testing_utils import require_liger_kernel, require_peft, require_torch_accelerator
from trl import GRPOConfig, GRPOTrainer
@ -81,3 +81,67 @@ class GRPOTrainerSlowTester(unittest.TestCase):
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
release_memory(model, trainer)
@parameterized.expand(MODELS_TO_TEST)
@require_liger_kernel
@require_peft
def test_training_with_liger_grpo_loss_and_peft(self, model_name):
from peft import LoraConfig, TaskType
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=3,
num_generations=3,
use_liger_loss=True,
max_completion_length=self.max_length,
report_to="none",
logging_strategy="no",
)
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
# Configure PEFT with LoRA
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj"],
)
trainer = GRPOTrainer(
model=model,
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
processing_class=tokenizer,
peft_config=peft_config,
)
from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss
assert isinstance(trainer.liger_grpo_loss, LigerFusedLinearGRPOLoss)
# Verify PEFT adapter is properly initialized
from peft import PeftModel
self.assertTrue(isinstance(trainer.model, PeftModel), "Model should be wrapped with PEFT")
# Store adapter weights before training
previous_trainable_params = {
n: param.clone() for n, param in trainer.model.named_parameters() if param.requires_grad
}
self.assertTrue(len(previous_trainable_params) > 0, "No trainable parameters found in PEFT model")
trainer.train()
# Verify adapter weights have changed after training
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
release_memory(model, trainer)

View File

@ -420,3 +420,39 @@ class SFTTrainerSlowTester(unittest.TestCase):
trainer.train()
release_memory(trainer.model, trainer)
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS)))
@require_torch_accelerator
def test_train_offloading(self, model_name, packing):
"""Test that activation offloading works with SFTTrainer."""
with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(
output_dir=tmp_dir,
activation_offloading=True,
report_to="none",
per_device_train_batch_size=2,
max_steps=2,
packing=packing,
max_length=self.max_length,
)
trainer = SFTTrainer(
model=model_name, args=training_args, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset
)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
release_memory(trainer.model, trainer)

View File

@ -0,0 +1,156 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from torch import nn
from transformers import AutoModelForCausalLM
from transformers.testing_utils import require_peft, require_torch_accelerator, torch_device
from transformers.utils import is_peft_available
from trl.models.activation_offloading import NoOpManager, OffloadActivations
if is_peft_available():
from peft import LoraConfig, get_peft_model
class TestActivationOffloading(unittest.TestCase):
@require_torch_accelerator
@require_peft
def test_offloading_with_peft_models(self) -> None:
"""Test that activation offloading works with PEFT models."""
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=8,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config)
inp = torch.randint(0, 100, (2, 10), device=torch_device)
# First forward-backward pass without offloading
torch.manual_seed(42)
loss = model(inp, labels=inp).loss
loss.backward()
# Store gradients - only from trainable parameters
grads_original = []
for name, param in model.named_parameters():
if param.requires_grad and param.grad is not None:
grads_original.append((name, param.grad.clone()))
# Reset gradients
for p in model.parameters():
if p.grad is not None:
p.grad = None
# Second forward-backward pass with offloading
torch.manual_seed(42)
with OffloadActivations():
loss_c = model(inp, labels=inp).loss
loss_c.backward()
# Compare gradients - only trainable parameters
for name_orig, grad_orig in grads_original:
for name_param, param in model.named_parameters():
if name_param == name_orig and param.requires_grad and param.grad is not None:
self.assertTrue(
torch.allclose(grad_orig, param.grad, rtol=1e-4, atol=1e-5),
f"Gradient mismatch for {name_orig}",
)
@require_torch_accelerator
def test_noop_manager_with_offloading(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)
inp = torch.randint(0, 100, (2, 10), device=torch_device)
# Run with offloading but disable for specific section
with OffloadActivations():
# First forward-backward with normal offloading
torch.manual_seed(42)
out1 = model(inp, labels=inp)
out1.loss.backward()
grads1 = [p.grad.clone() for p in model.parameters()]
# Reset grads
for p in model.parameters():
p.grad = None
# Second forward-backward with NoOpManager
with NoOpManager():
torch.manual_seed(42)
out2 = model(inp, labels=inp)
out2.loss.backward()
grads2 = [p.grad.clone() for p in model.parameters()]
# Gradients should match as NoOpManager should have prevented offloading
for g1, g2 in zip(grads1, grads2):
self.assertTrue(torch.allclose(g1, g2, rtol=1e-4, atol=1e-5))
@require_torch_accelerator
def test_min_offload_size(self):
"""Test that tensors smaller than min_offload_size aren't offloaded"""
model = nn.Sequential(
nn.Linear(5, 5), # Small layer that shouldn't be offloaded
nn.Linear(5, 1000), # Large layer that should be offloaded
).to(torch_device)
inp = torch.randn(2, 5, device=torch_device)
with OffloadActivations(min_offload_size=1000):
out = model(inp)
out.sum().backward()
# The test passes if no errors occur, as we're mainly testing
# that the logic handles both offloaded and non-offloaded tensors
@require_torch_accelerator
def test_real_hf_model(self):
"""Test with an actual HuggingFace model"""
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)
# Create small input
inp = torch.randint(0, 100, (2, 10), device=torch_device)
# Baseline without offloading
torch.manual_seed(42)
out1 = model(inp, labels=inp).loss
out1.backward()
grads1 = [p.grad.clone() for p in model.parameters()]
# Reset grads
for p in model.parameters():
p.grad = None
# With offloading
with OffloadActivations():
torch.manual_seed(42)
out2 = model(inp, labels=inp).loss
out2.backward()
grads2 = [p.grad.clone() for p in model.parameters()]
# Check outputs and gradients match
self.assertTrue(torch.allclose(out1, out2, rtol=1e-5))
for g1, g2 in zip(grads1, grads2):
self.assertTrue(torch.allclose(g1, g2, rtol=1e-5))

View File

@ -13,12 +13,15 @@
# limitations under the License.
import os
import sys
import tempfile
import unittest
from io import StringIO
from unittest.mock import patch
import yaml
@unittest.skipIf(
sys.version_info < (3, 10),
@ -67,6 +70,33 @@ class TestCLI(unittest.TestCase):
with patch("sys.argv", command.split(" ")):
main()
def test_sft_config_file(self):
from trl.cli import main
with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory
output_dir = os.path.join(tmp_dir, "output")
# Create a temporary config file
config_path = os.path.join(tmp_dir, "config.yaml")
config_content = {
"model_name_or_path": "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
"dataset_name": "trl-internal-testing/zen",
"dataset_config": "standard_language_modeling",
"report_to": "none",
"output_dir": output_dir,
"lr_scheduler_type": "cosine_with_restarts",
}
with open(config_path, "w") as config_file:
yaml.dump(config_content, config_file)
# Test the CLI with config file
command = f"trl sft --config {config_path}"
with patch("sys.argv", command.split(" ")):
main()
# Verify that output directory was created
self.assertTrue(os.path.exists(output_dir))
if __name__ == "__main__":
unittest.main()

View File

@ -163,3 +163,80 @@ class TestTrlParser(unittest.TestCase):
self.assertIsInstance(result_args[0], MyDataclass)
self.assertEqual(result_args[0].arg1, 2)
self.assertEqual(result_args[1], ["--remaining_string_in_config", "abc", "--remaining_string_in_args", "def"])
@patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value"))
@patch("yaml.safe_load")
def test_subparsers_with_config_defaults(self, mock_yaml_load):
"""Test that config defaults are applied to all subparsers."""
mock_yaml_load.return_value = {"arg1": 2, "arg2": "config_value"}
# Create the main parser
parser = TrlParser()
# Add subparsers
subparsers = parser.add_subparsers(dest="command", parser_class=TrlParser)
# Create a subparser for a specific command
subparsers.add_parser("subcommand", dataclass_types=[MyDataclass])
# Parse with config file
args = ["subcommand", "--config", "config.yaml"]
result_args = parser.parse_args_and_config(args)
# Check main parser arguments
self.assertEqual(len(result_args), 1)
# Check that config values were applied to the subparser
self.assertEqual(result_args[0].arg1, 2) # Default from config
self.assertEqual(result_args[0].arg2, "config_value") # Default from config
@patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value"))
@patch("yaml.safe_load")
def test_subparsers_with_config_defaults_and_arg_override(self, mock_yaml_load):
"""Test that config defaults are applied to all subparsers."""
mock_yaml_load.return_value = {"arg1": 2, "arg2": "config_value"}
# Create the main parser
parser = TrlParser()
# Add subparsers
subparsers = parser.add_subparsers(dest="command", parser_class=TrlParser)
# Create a subparser for a specific command
subparsers.add_parser("subcommand", dataclass_types=[MyDataclass])
# Test with command line arguments overriding config
args = ["subcommand", "--arg1", "3", "--config", "config.yaml"]
result_args = parser.parse_args_and_config(args)
# Command line arguments should override config
self.assertEqual(result_args[0].arg1, 3)
self.assertEqual(result_args[0].arg2, "config_value") # Still from config
@patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value"))
@patch("yaml.safe_load")
def test_subparsers_multiple_with_config_defaults(self, mock_yaml_load):
"""Test that config defaults are applied to all subparsers."""
mock_yaml_load.return_value = {"arg1": 2, "arg2": "config_value"}
# Create the main parser
parser = TrlParser()
# Add subparsers
subparsers = parser.add_subparsers(dest="command", parser_class=TrlParser)
# Create a subparser for a specific command
subparsers.add_parser("subcommand0", dataclass_types=[MyDataclass])
subparsers.add_parser("subcommand1", dataclass_types=[MyDataclass])
for idx in range(2):
# Parse with config file
args = [f"subcommand{idx}", "--config", "config.yaml"]
result_args = parser.parse_args_and_config(args)
# Check main parser arguments
self.assertEqual(len(result_args), 1)
# Check that config values were applied to the subparser
self.assertEqual(result_args[0].arg1, 2) # Default from config
self.assertEqual(result_args[0].arg2, "config_value") # Default from config

View File

@ -102,6 +102,7 @@ class ApplyChatTemplateTester(unittest.TestCase):
"trl-internal-testing/tiny-MistralForCausalLM-0.2",
"trl-internal-testing/tiny-Phi3ForCausalLM",
"trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
"trl-internal-testing/tiny-Qwen3ForCausalLM",
]
conversational_examples = [

View File

@ -1258,6 +1258,37 @@ class DPOTrainerTester(unittest.TestCase):
self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0)
def test_train_with_length_desensitization(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
tokenizer = AutoTokenizer.from_pretrained(model_id)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
learning_rate=9e-1,
ld_alpha=0.5,
report_to="none",
)
trainer = DPOTrainer(
model=model_id,
args=training_args,
processing_class=tokenizer,
train_dataset=dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check that the parameters have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
@require_vision
class DPOVisionTrainerTester(unittest.TestCase):
@ -1344,13 +1375,15 @@ class DPOVisionTrainerTester(unittest.TestCase):
"trl-internal-testing/tiny-LlavaForConditionalGeneration",
"trl-internal-testing/tiny-LlavaNextForConditionalGeneration",
] and (
n.startswith("vision_tower.vision_model.encoder.layers.1")
or n == "vision_tower.vision_model.post_layernorm.weight"
"vision_tower.vision_model.encoder.layers.1" in n
or "vision_tower.vision_model.post_layernorm.weight" in n
):
# For some reason, these params are not updated. This is probably not related to TRL, but to
# the model itself. We should investigate this further, but for now we just skip these params.
continue
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
self.assertFalse(
torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
)
if __name__ == "__main__":

View File

@ -24,7 +24,7 @@ from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available
from trl import GRPOConfig, GRPOTrainer
from trl.trainer.grpo_trainer import RepeatSampler
from trl.trainer.grpo_trainer import RepeatSampler, shuffle_tensor_dict, split_tensor_dict
from .testing_utils import require_vllm
@ -33,6 +33,77 @@ if is_peft_available():
from peft import LoraConfig, PeftModel
class SplitTensorDictTester(unittest.TestCase):
def test_split_equal_chunks(self):
x = torch.arange(12).reshape(6, 2)
y = torch.arange(6).reshape(6, 1)
tensor_dict = {"x": x, "y": y}
result = split_tensor_dict(tensor_dict, 3)
expected_x_chunks = torch.chunk(x, 3, dim=0)
expected_y_chunks = torch.chunk(y, 3, dim=0)
self.assertEqual(len(result), 3)
for i in range(3):
self.assertTrue(torch.equal(result[i]["x"], expected_x_chunks[i]))
self.assertTrue(torch.equal(result[i]["y"], expected_y_chunks[i]))
def test_with_none_tensor(self):
x = torch.arange(12).reshape(6, 2)
tensor_dict = {"x": x, "y": None}
result = split_tensor_dict(tensor_dict, 2)
expected_x_chunks = torch.chunk(x, 2, dim=0)
self.assertEqual(len(result), 2)
for i in range(2):
self.assertTrue(torch.equal(result[i]["x"], expected_x_chunks[i]))
self.assertIsNone(result[i]["y"])
class ShuffleTensorDictTester(unittest.TestCase):
def test_shuffle_preserves_shape(self):
x = torch.arange(6).reshape(3, 2)
y = torch.arange(3).reshape(3, 1)
tensor_dict = {"x": x.clone(), "y": y.clone()}
shuffled = shuffle_tensor_dict(tensor_dict)
self.assertEqual(shuffled["x"].shape, x.shape)
self.assertEqual(shuffled["y"].shape, y.shape)
def test_shuffle_consistent_across_tensors(self):
# Use known patterns to check alignment
x = torch.tensor([[10, 11], [20, 21], [30, 31]])
y = torch.tensor([[1], [2], [3]])
tensor_dict = {"x": x.clone(), "y": y.clone()}
shuffled = shuffle_tensor_dict(tensor_dict)
# Build a reverse map from shuffled x rows to y values
for i in range(3):
x_row = shuffled["x"][i]
y_val = shuffled["y"][i].item()
if torch.equal(x_row, torch.tensor([10, 11])):
self.assertEqual(y_val, 1)
elif torch.equal(x_row, torch.tensor([20, 21])):
self.assertEqual(y_val, 2)
elif torch.equal(x_row, torch.tensor([30, 31])):
self.assertEqual(y_val, 3)
else:
self.fail("Unexpected x row in shuffled output.")
def test_none_tensor_remains_none(self):
x = torch.arange(6).reshape(3, 2)
tensor_dict = {"x": x.clone(), "y": None}
shuffled = shuffle_tensor_dict(tensor_dict)
self.assertIsNone(shuffled["y"])
self.assertEqual(shuffled["x"].shape, x.shape)
class RepeatRandomSamplerTester(unittest.TestCase):
def test_sampler(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
@ -1076,3 +1147,34 @@ class GRPOTrainerTester(unittest.TestCase):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
def test_training_delta_clipping(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
delta=2.0, # set delta to a non-None value
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

View File

@ -160,6 +160,38 @@ class TestNashMDTrainer(unittest.TestCase):
# Check if training loss is available
self.assertIn("train_loss", trainer.state.log_history[-1])
@require_peft
def test_training_pre_pefted_model_implicit_ref_with_reward_model(self):
lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM")
# self.model from setUp is a base AutoModelForCausalLM
peft_model_instance = get_peft_model(self.model, lora_config)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = NashMDConfig(
output_dir=tmp_dir,
per_device_train_batch_size=1, # Keep small for quick test
max_steps=2, # Few steps
learning_rate=5.0e-7,
eval_strategy="no",
report_to="none",
remove_unused_columns=False, # Important for the dummy dataset
)
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")["train"]
trainer = NashMDTrainer(
model=peft_model_instance, # Pass the already PEFT model
ref_model=None, # Implicit reference from peft_model_instance's base
reward_model=self.reward_model, # To trigger GeometricMixtureWrapper path
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset,
# peft_config is not passed, as model is already PEFT
)
trainer.train()
self.assertIn("train_loss", trainer.state.log_history[-1])
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)])
@require_llm_blender
def test_nash_md_trainer_judge_training(self, config_name):

65
tests/test_rewards.py Normal file
View File

@ -0,0 +1,65 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from trl.rewards import think_format_reward
class ThinkFormatRewardTester(unittest.TestCase):
def test_valid_format(self):
completions = [
"<think>This is my reasoning.</think>This is my answer.", # Simple, one-line reasoning
"<think>\nThis is my reasoning.\n</think>\nThis is my answer.", # Multiline reasoning
"<think>\nThis is\nmy reasoning.\n</think>\nThis is my answer.", # Multiline reasoning
"<think>\nThis is <some tag> my reasoning.</think>\nThis is my answer.", # Reasoning including other tags
"<think></think>\nThis is my answer.", # Empty reasoning
]
completions = [[{"content": completion}] for completion in completions]
expected_rewards = [1.0, 1.0, 1.0, 1.0, 1.0] # All should be valid
rewards = think_format_reward(completions)
self.assertEqual(rewards, expected_rewards)
def test_invalid_format(self):
completions = [
"<think>\nThis is my reasoning.\nThis is my answer.", # No closing </think>
"<think>This is my reasoning.\nThis is my answer.", # No closing </think>
"This is my reasoning. This is my answer.", # No <think> tags
"This is my reasoning.\nThis is my answer.", # No <think> tags
"This is my reasoning.</think>\nThis is my answer.", # No opening <think>
"This is my reasoning.</think>This is my answer.", # No opening <think>
"This<think>is my reasoning.</think>\nThis is my answer.", # <think> tag in the middle
"<think>This is<think>my reasoning.</think></think>This is my answer.", # Nested <think> tags
"<think>This is</think>\nmy\n<think>reasoning.</think>\nThis is my answer.", # Multiline <think>
]
completions = [[{"content": completion}] for completion in completions]
expected_rewards = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] # All should be invalid
rewards = think_format_reward(completions)
self.assertEqual(rewards, expected_rewards)
def test_mixed_format(self):
completions = [
"<think>This is my reasoning.</think>This is my answer.", # Valid
"<think>\nThis is my reasoning.\n</think>\nThis is my answer.", # Valid
"<think>This is my reasoning.\nThis is my answer.", # Invalid
"This is my reasoning. This is my answer.", # Invalid
]
completions = [[{"content": completion}] for completion in completions]
expected_rewards = [1.0, 1.0, 0.0, 0.0]
rewards = think_format_reward(completions)
self.assertEqual(rewards, expected_rewards)
if __name__ == "__main__":
unittest.main()

View File

@ -22,6 +22,8 @@ from transformers import Trainer, TrainingArguments
from trl.trainer.callbacks import RichProgressCallback
from .testing_utils import require_rich
class DummyModel(nn.Module):
def __init__(self):
@ -32,6 +34,7 @@ class DummyModel(nn.Module):
return self.a * x
@require_rich
class TestRichProgressCallback(unittest.TestCase):
def setUp(self):
self.dummy_model = DummyModel()

View File

@ -370,7 +370,6 @@ class TrainerArgTester(unittest.TestCase):
packing=True,
max_length=256,
dataset_num_proc=4,
dataset_batch_size=512,
neftune_noise_alpha=0.1,
model_init_kwargs={"trust_remote_code": True},
dataset_kwargs={"append_concat_token": True, "skip_prepare_dataset": True},
@ -381,7 +380,6 @@ class TrainerArgTester(unittest.TestCase):
self.assertEqual(trainer.args.packing, True)
self.assertEqual(trainer.args.max_length, 256)
self.assertEqual(trainer.args.dataset_num_proc, 4)
self.assertEqual(trainer.args.dataset_batch_size, 512)
self.assertEqual(trainer.args.neftune_noise_alpha, 0.1)
self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True})
self.assertIn("append_concat_token", trainer.args.dataset_kwargs)

View File

@ -32,6 +32,7 @@ from trl.trainer.utils import (
batch_generation,
decode_and_strip_padding,
flush_left,
flush_right,
generate_model_card,
get_peft_config,
pad,
@ -39,6 +40,8 @@ from trl.trainer.utils import (
selective_log_softmax,
)
from .testing_utils import require_rich
if is_peft_available():
from peft import LoraConfig
@ -95,6 +98,38 @@ class TestPad(unittest.TestCase):
)
self.assertTrue(torch.equal(output, expected))
def test_pad_to_multiple_of_1(self):
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5])
# Max length is 3, pad to multiple of 4
output = pad((x, y), padding_value=0, padding_side="right", pad_to_multiple_of=4)
expected = torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]])
self.assertTrue(torch.equal(output, expected))
def test_pad_to_multiple_of_2(self):
x = torch.tensor([1, 2, 3, 4, 5])
y = torch.tensor([6, 7, 8])
# Max length is 3, pad to multiple of 4
output = pad((x, y), padding_value=0, padding_side="right", pad_to_multiple_of=4)
expected = torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0], [6, 7, 8, 0, 0, 0, 0, 0]])
self.assertTrue(torch.equal(output, expected))
def test_pad_to_multiple_of_side_left(self):
x = torch.tensor([1, 2, 3, 4, 5])
y = torch.tensor([6, 7, 8])
# Max length is 3, pad to multiple of 4
output = pad((x, y), padding_value=0, padding_side="left", pad_to_multiple_of=4)
expected = torch.tensor([[0, 0, 0, 1, 2, 3, 4, 5], [0, 0, 0, 0, 0, 6, 7, 8]])
self.assertTrue(torch.equal(output, expected))
def test_pad_to_multiple_of_no_extra_padding(self):
x = torch.tensor([1, 2, 3, 4])
y = torch.tensor([5, 6, 7, 8])
# Already multiple of 4
output = pad((x, y), padding_value=0, padding_side="left", pad_to_multiple_of=4)
expected = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
self.assertTrue(torch.equal(output, expected))
@require_peft
class TestGetPEFTConfig(unittest.TestCase):
@ -440,12 +475,12 @@ class TestFlushLeft(unittest.TestCase):
self.assertTrue(torch.equal(new_tensor1, expected_tensor1))
def test_no_shift_needed(self):
mask = torch.tensor([[1, 1, 0, 0], [1, 1, 0, 0]])
tensor1 = torch.tensor([[5, 6, 0, 0], [7, 8, 0, 0]])
mask = torch.tensor([[1, 1, 0, 0], [1, 0, 0, 0]])
tensor1 = torch.tensor([[5, 6, 0, 0], [7, 0, 0, 0]])
new_mask, new_tensor1 = flush_left(mask, tensor1)
expected_mask = torch.tensor([[1, 1], [1, 1]])
expected_tensor1 = torch.tensor([[5, 6], [7, 8]])
expected_mask = torch.tensor([[1, 1], [1, 0]])
expected_tensor1 = torch.tensor([[5, 6], [7, 0]])
self.assertTrue(torch.equal(new_mask, expected_mask))
self.assertTrue(torch.equal(new_tensor1, expected_tensor1))
@ -453,9 +488,51 @@ class TestFlushLeft(unittest.TestCase):
def test_no_tensors(self):
mask = torch.tensor([[0, 0, 1, 1, 1], [0, 1, 1, 0, 0]])
new_mask = flush_left(mask)
expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]])
self.assertTrue(torch.equal(new_mask, expected_mask))
class TestFlushRight(unittest.TestCase):
def test_basic_case(self):
mask = torch.tensor([[1, 1, 1, 0, 0], [0, 0, 1, 1, 0]])
tensor1 = torch.tensor([[2, 3, 4, 0, 0], [0, 0, 5, 6, 0]])
tensor2 = torch.tensor([[7, 8, 9, 0, 0], [0, 0, 10, 11, 0]])
new_mask, new_tensor1, new_tensor2 = flush_right(mask, tensor1, tensor2)
expected_mask = torch.tensor([[1, 1, 1], [0, 1, 1]])
expected_tensor1 = torch.tensor([[2, 3, 4], [0, 5, 6]])
expected_tensor2 = torch.tensor([[7, 8, 9], [0, 10, 11]])
self.assertTrue(torch.equal(new_mask, expected_mask))
self.assertTrue(torch.equal(new_tensor1, expected_tensor1))
self.assertTrue(torch.equal(new_tensor2, expected_tensor2))
def test_single_row(self):
mask = torch.tensor([[1, 1, 0, 0]])
tensor1 = torch.tensor([[2, 3, 0, 0]])
new_mask, new_tensor1 = flush_right(mask, tensor1)
expected_mask = torch.tensor([[1, 1]])
expected_tensor1 = torch.tensor([[2, 3]])
self.assertTrue(torch.equal(new_mask, expected_mask))
self.assertTrue(torch.equal(new_tensor1, expected_tensor1))
def test_no_shift_needed(self):
mask = torch.tensor([[0, 0, 1, 1], [0, 0, 0, 1]])
tensor1 = torch.tensor([[0, 0, 5, 6], [0, 0, 0, 7]])
new_mask, new_tensor1 = flush_right(mask, tensor1)
expected_mask = torch.tensor([[1, 1], [0, 1]])
expected_tensor1 = torch.tensor([[5, 6], [0, 7]])
self.assertTrue(torch.equal(new_mask, expected_mask))
self.assertTrue(torch.equal(new_tensor1, expected_tensor1))
def test_no_tensors(self):
mask = torch.tensor([[1, 1, 1, 0, 0], [0, 0, 1, 1, 0]])
new_mask = flush_right(mask)
expected_mask = torch.tensor([[1, 1, 1], [0, 1, 1]])
self.assertTrue(torch.equal(new_mask, expected_mask))
@ -480,28 +557,30 @@ class TestSelectiveLogSoftmax(unittest.TestCase):
torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5)
@require_rich
class TestPrintPromptCompletionsSample(unittest.TestCase):
@patch("sys.stdout", new_callable=StringIO)
def test_print_output(self, mock_stdout):
prompts = ["The sky is", "The sun is"]
completions = [" blue.", " in the sky."]
rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]}
advantages = [0.987, 0.654]
step = 42
print_prompt_completions_sample(prompts, completions, rewards, step)
print_prompt_completions_sample(prompts, completions, rewards, advantages, step)
output = mock_stdout.getvalue()
expected_output = textwrap.dedent("""\
╭────────────────────── Step 42 ───────────────────────╮
│ ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┓ │
│ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ │
│ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━┩ │
│ │ The sky is │ blue. │ 0.12 │ 0.79 │ │
│ ├────────────┼──────────────┼─────────────┼────────┤ │
│ │ The sun is │ in the sky. │ 0.46 │ 0.10 │ │
│ └────────────┴──────────────┴─────────────┴────────┘ │
╰──────────────────────────────────────────────────────╯
╭──────────────────────────── Step 42 ─────────────────────────────╮
│ ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┓ │
│ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ Advantage ┃
│ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━┩ │
│ │ The sky is │ blue. │ 0.12 │ 0.79 │ 0.99 │
│ ├────────────┼──────────────┼─────────────┼────────┼───────────┤ │
│ │ The sun is │ in the sky. │ 0.46 │ 0.10 │ 0.65 │
│ └────────────┴──────────────┴─────────────┴────────┴───────────┘ │
╰──────────────────────────────────────────────────────────────────
""")
self.assertEqual(output, expected_output)
@ -510,29 +589,30 @@ class TestPrintPromptCompletionsSample(unittest.TestCase):
prompts = ["A", "B"]
completions = ["1", "2"]
rewards = {"Score": [0.1, 0.2]}
advantages = [0.3, 0.4]
step = 10
print_prompt_completions_sample(prompts, completions, rewards, step, num_samples=1)
print_prompt_completions_sample(prompts, completions, rewards, advantages, step, num_samples=1)
output = mock_stdout.getvalue()
possible_outputs = [
textwrap.dedent("""\
╭──────────── Step 10 ────────────╮
│ ┏━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┓ │
│ ┃ Prompt ┃ Completion ┃ Score ┃ │
│ ┡━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━┩ │
│ │ A │ 1 │ 0.10 │ │
│ └────────┴────────────┴───────┘ │
╰─────────────────────────────────╯
────────────────── Step 10 ──────────────────╮
│ ┏━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┓ │
│ ┃ Prompt ┃ Completion ┃ Score ┃ Advantage ┃
│ ┡━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━┩ │
│ │ A │ 1 │ 0.10 │ 0.30 │
│ └────────┴────────────┴───────┴───────────┘ │
─────────────────────────────────────────────╯
"""),
textwrap.dedent("""\
╭──────────── Step 10 ────────────╮
│ ┏━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┓ │
│ ┃ Prompt ┃ Completion ┃ Score ┃ │
│ ┡━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━┩ │
│ │ B │ 2 │ 0.20 │ │
│ └────────┴────────────┴───────┘ │
╰─────────────────────────────────╯
────────────────── Step 10 ──────────────────╮
│ ┏━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┓ │
│ ┃ Prompt ┃ Completion ┃ Score ┃ Advantage ┃
│ ┡━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━┩ │
│ │ B │ 2 │ 0.20 │ 0.40 │
│ └────────┴────────────┴───────┴───────────┘ │
─────────────────────────────────────────────╯
"""),
]
self.assertIn(output, possible_outputs)

View File

@ -20,12 +20,12 @@ import unittest
import psutil
import pytest
from transformers import AutoModelForCausalLM
from transformers.testing_utils import require_torch_multi_gpu
from transformers.testing_utils import require_torch_multi_accelerator, torch_device
from trl.extras.vllm_client import VLLMClient
from trl.scripts.vllm_serve import chunk_list
from .testing_utils import require_3_gpus
from .testing_utils import require_3_accelerators
class TestChunkList(unittest.TestCase):
@ -55,15 +55,16 @@ class TestChunkList(unittest.TestCase):
@pytest.mark.slow
@require_torch_multi_gpu
@require_torch_multi_accelerator
class TestVLLMClientServer(unittest.TestCase):
model_id = "Qwen/Qwen2.5-1.5B"
@classmethod
def setUpClass(cls):
# We want the server to run on GPU 1, so we set CUDA_VISIBLE_DEVICES to "1"
# We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1"
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = "1" # Restrict to GPU 1
VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
env[VISIBLE_DEVICES] = "1" # Restrict to accelerator 1
# Start the server process
cls.server_process = subprocess.Popen(
@ -107,7 +108,86 @@ class TestVLLMClientServer(unittest.TestCase):
self.assertLessEqual(len(seq), 32)
def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda")
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
self.client.update_model_params(model)
def test_reset_prefix_cache(self):
# Test resetting the prefix cache
self.client.reset_prefix_cache()
@classmethod
def tearDownClass(cls):
super().tearDownClass()
# Close the client
cls.client.close_communicator()
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
# kill the server process and its children explicitly.
parent = psutil.Process(cls.server_process.pid)
children = parent.children(recursive=True)
for child in children:
child.send_signal(signal.SIGTERM)
cls.server_process.terminate()
cls.server_process.wait()
# Same as above but using base_url to instantiate the client.
@pytest.mark.slow
@require_torch_multi_accelerator
class TestVLLMClientServerBaseURL(unittest.TestCase):
model_id = "Qwen/Qwen2.5-1.5B"
@classmethod
def setUpClass(cls):
# We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1"
env = os.environ.copy()
VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
env[VISIBLE_DEVICES] = "1" # Restrict to accelerator 1
# Start the server process
cls.server_process = subprocess.Popen(
["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
)
# Initialize the client
cls.client = VLLMClient(base_url="http://localhost:8000", connection_timeout=240)
cls.client.init_communicator()
def test_generate(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts)
# Check that the output is a list
self.assertIsInstance(outputs, list)
# Check that the number of generated sequences is equal to the number of prompts
self.assertEqual(len(outputs), len(prompts))
# Check that the generated sequences are lists of integers
for seq in outputs:
self.assertTrue(all(isinstance(tok, int) for tok in seq))
def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)
# Check that the output is a list
self.assertIsInstance(outputs, list)
# Check that the number of generated sequences is 2 times the number of prompts
self.assertEqual(len(outputs), 2 * len(prompts))
# Check that the generated sequences are lists of integers
for seq in outputs:
self.assertTrue(all(isinstance(tok, int) for tok in seq))
# Check that the length of the generated sequences is less than or equal to 32
for seq in outputs:
self.assertLessEqual(len(seq), 32)
def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
self.client.update_model_params(model)
def test_reset_prefix_cache(self):
@ -132,15 +212,16 @@ class TestVLLMClientServer(unittest.TestCase):
@pytest.mark.slow
@require_3_gpus
@require_3_accelerators
class TestVLLMClientServerTP(unittest.TestCase):
model_id = "Qwen/Qwen2.5-1.5B"
@classmethod
def setUpClass(cls):
# We want the server to run on GPU 1 and 2, so we set CUDA_VISIBLE_DEVICES to "1,2"
# We want the server to run on accelerator 1 and 2, so we set VISIBLE_DEVICES to "1,2"
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = "1,2" # Restrict to GPU 1 and 2
VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
env[VISIBLE_DEVICES] = "1,2" # Restrict to accelerator 1 and 2
# Start the server process
cls.server_process = subprocess.Popen(
@ -169,7 +250,7 @@ class TestVLLMClientServerTP(unittest.TestCase):
self.assertTrue(all(isinstance(tok, int) for tok in seq))
def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda")
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
self.client.update_model_params(model)
def test_reset_prefix_cache(self):
@ -194,15 +275,16 @@ class TestVLLMClientServerTP(unittest.TestCase):
@pytest.mark.slow
@require_3_gpus
@require_3_accelerators
class TestVLLMClientServerDP(unittest.TestCase):
model_id = "Qwen/Qwen2.5-1.5B"
@classmethod
def setUpClass(cls):
# We want the server to run on GPU 1 and 2, so we set CUDA_VISIBLE_DEVICES to "1,2"
# We want the server to run on accelerator 1 and 2, so we set VISIBLE_DEVICES to "1,2"
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = "1,2" # Restrict to GPU 1 and 2
VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
env[VISIBLE_DEVICES] = "1,2" # Restrict to accelerator 1 and 2
# Start the server process
cls.server_process = subprocess.Popen(
@ -230,7 +312,7 @@ class TestVLLMClientServerDP(unittest.TestCase):
self.assertTrue(all(isinstance(tok, int) for tok in seq))
def test_update_model_params(self):
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda")
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
self.client.update_model_params(model)
def test_reset_prefix_cache(self):

View File

@ -160,6 +160,36 @@ class TestXPOTrainer(unittest.TestCase):
# Check if training loss is available
self.assertIn("train_loss", trainer.state.log_history[-1])
@require_peft
def test_training_pre_pefted_model_implicit_ref(self):
lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM")
peft_model_instance = get_peft_model(self.model, lora_config)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = XPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=1,
max_steps=2,
learning_rate=5.0e-7,
eval_strategy="no",
report_to="none",
remove_unused_columns=False,
)
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")["train"]
trainer = XPOTrainer(
model=peft_model_instance,
ref_model=None,
reward_model=self.reward_model, # Using reward_model to ensure _generate_completions is used as expected
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset,
)
trainer.train()
self.assertIn("train_loss", trainer.state.log_history[-1])
@require_llm_blender
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)])
def test_xpo_trainer_judge_training(self, config_name):

View File

@ -17,6 +17,8 @@ import unittest
import torch
from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available
from transformers.testing_utils import torch_device
from transformers.utils import is_rich_available
from trl import BaseBinaryJudge, BasePairwiseJudge
from trl.import_utils import (
@ -65,6 +67,13 @@ def require_mergekit(test_case):
return unittest.skipUnless(is_mergekit_available(), "test requires mergekit")(test_case)
def require_rich(test_case):
"""
Decorator marking a test that requires rich. Skips the test if rich is not available.
"""
return unittest.skipUnless(is_rich_available(), "test requires rich")(test_case)
def require_sklearn(test_case):
"""
Decorator marking a test that requires sklearn. Skips the test if sklearn is not available.
@ -86,11 +95,14 @@ def require_no_wandb(test_case):
return unittest.skipUnless(not is_wandb_available(), "test requires no wandb")(test_case)
def require_3_gpus(test_case):
def require_3_accelerators(test_case):
"""
Decorator marking a test that requires at least num_gpus GPUs. Skips the test if num_gpus is not available.
Decorator marking a test that requires at least 3 accelerators. Skips the test if 3 accelerators are not available.
"""
return unittest.skipUnless(torch.cuda.device_count() > 3, "test requires at least 3 GPUs")(test_case)
torch_accelerator_module = getattr(torch, torch_device, torch.cuda)
return unittest.skipUnless(
torch_accelerator_module.device_count() > 3, f"test requires at least 3 {torch_device}s"
)(test_case)
class RandomBinaryJudge(BaseBinaryJudge):

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__version__ = "0.17.0"
__version__ = "0.18.1"
from typing import TYPE_CHECKING
@ -66,6 +66,7 @@ _import_structure = {
"GRPOConfig",
"GRPOTrainer",
"HfPairwiseJudge",
"IterativeSFTConfig",
"IterativeSFTTrainer",
"KTOConfig",
"KTOTrainer",
@ -161,6 +162,7 @@ if TYPE_CHECKING:
GRPOConfig,
GRPOTrainer,
HfPairwiseJudge,
IterativeSFTConfig,
IterativeSFTTrainer,
KTOConfig,
KTOTrainer,

View File

@ -0,0 +1,28 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: true
fsdp_offload_params: false
fsdp_reshard_after_forward: FULL_SHARD
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: true
fsdp_version: 1
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,25 @@
# Requires accelerate 1.7.0 or higher
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_cpu_ram_efficient_loading: true
fsdp_offload_params: false
fsdp_reshard_after_forward: true
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_version: 2
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,16 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,16 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: "NO"
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,20 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
gradient_accumulation_steps: 1
zero3_init_flag: false
zero_stage: 1
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,21 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -0,0 +1,22 @@
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.resources as resources
import os
import sys
import warnings
@ -45,8 +46,37 @@ def main():
make_sft_parser(subparsers)
make_vllm_serve_parser(subparsers)
# Parse the arguments
args = parser.parse_args()
# Parse the arguments; the remaining ones (`launch_args`) are passed to the 'accelerate launch' subparser.
# Duplicates may occur if the same argument is provided in both the config file and CLI.
# For example: launch_args = `["--num_processes", "4", "--num_processes", "8"]`.
# Deduplication and precedence (CLI over config) are handled later by launch_command_parser.
args, launch_args = parser.parse_args_and_config(return_remaining_strings=True)
# Replace `--accelerate_config foo` with `--config_file trl/accelerate_configs/foo.yaml` if it is present in the
# launch_args. It allows the user to use predefined accelerate configs from the `trl` package.
if "--accelerate_config" in launch_args:
# Get the index of the '--accelerate_config' argument and the corresponding config name
config_index = launch_args.index("--accelerate_config")
config_name = launch_args[config_index + 1]
# If the config_name correspond to a path in the filesystem, we don't want to override it
if os.path.isfile(config_name):
accelerate_config_path = config_name
elif resources.files("trl.accelerate_configs").joinpath(f"{config_name}.yaml").exists():
# Get the predefined accelerate config path from the package resources
accelerate_config_path = resources.files("trl.accelerate_configs").joinpath(f"{config_name}.yaml")
else:
raise ValueError(
f"Accelerate config {config_name} is neither a file nor a valid config in the `trl` package. "
"Please provide a valid config name or a path to a config file."
)
# Remove '--accelerate_config' and its corresponding config name
launch_args.pop(config_index)
launch_args.pop(config_index)
# Insert '--config_file' and the absolute path to the front of the list
launch_args = ["--config_file", str(accelerate_config_path)] + launch_args
if args.command == "chat":
(chat_args,) = parser.parse_args_and_config()
@ -54,8 +84,8 @@ def main():
if args.command == "dpo":
# Get the default args for the launch command
dpo_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "dpo.py")
args = launch_command_parser().parse_args([dpo_training_script])
dpo_training_script = resources.files("trl.scripts").joinpath("dpo.py")
args = launch_command_parser().parse_args([str(dpo_training_script)])
# Feed the args to the launch command
args.training_script_args = sys.argv[2:] # remove "trl" and "dpo"
@ -66,8 +96,8 @@ def main():
elif args.command == "grpo":
# Get the default args for the launch command
grpo_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "grpo.py")
args = launch_command_parser().parse_args([grpo_training_script])
grpo_training_script = resources.files("trl.scripts").joinpath("grpo.py")
args = launch_command_parser().parse_args([str(grpo_training_script)])
# Feed the args to the launch command
args.training_script_args = sys.argv[2:] # remove "trl" and "grpo"
@ -75,20 +105,22 @@ def main():
elif args.command == "kto":
# Get the default args for the launch command
kto_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "kto.py")
args = launch_command_parser().parse_args([kto_training_script])
kto_training_script = resources.files("trl.scripts").joinpath("kto.py")
args = launch_command_parser().parse_args([str(kto_training_script)])
# Feed the args to the launch command
args.training_script_args = sys.argv[2:] # remove "trl" and "kto"
launch_command(args) # launch training
elif args.command == "sft":
# Get the default args for the launch command
sft_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "sft.py")
args = launch_command_parser().parse_args([sft_training_script])
# Get the path to the training script
sft_training_script = resources.files("trl.scripts").joinpath("sft.py")
# Feed the args to the launch command
args.training_script_args = sys.argv[2:] # remove "trl" and "sft"
# This simulates running: `accelerate launch <launch args> sft.py <training script args>`.
# Note that the training script args may include launch-related arguments (e.g., `--num_processes`),
# but we rely on the script to ignore any that don't apply to it.
training_script_args = sys.argv[2:] # Remove "trl" and "sft"
args = launch_command_parser().parse_args(launch_args + [str(sft_training_script)] + training_script_args)
launch_command(args) # launch training
elif args.command == "vllm-serve":

View File

@ -13,13 +13,13 @@
# limitations under the License.
import re
import warnings
from typing import Optional
import torch
from accelerate.utils import extract_model_from_parallel
from transformers import StoppingCriteria, StoppingCriteriaList
from ..import_utils import is_rich_available
from transformers.utils import is_rich_available
if is_rich_available():
@ -241,6 +241,10 @@ class TextEnvironment:
max_length (Optional[int]): The maximum number of tokens to allow in an episode.
generation_kwargs (Optional[dict]): A dictionary of keyword arguments to pass to the model's generate method.
"""
warnings.warn(
"This class is deprecated and will be removed in version 0.21.0. To enable tool use with LLMs, check out smolagents (https://huggingface.co/docs/smolagents/index)",
DeprecationWarning,
)
self.model = model
self.tokenizer = tokenizer
self.prompt = prompt

View File

@ -17,17 +17,22 @@ import functools
import time
from collections.abc import Generator
from transformers import Trainer, is_wandb_available
from transformers import Trainer
from transformers.integrations import is_mlflow_available, is_wandb_available
if is_wandb_available():
import wandb
if is_mlflow_available():
import mlflow
@contextlib.contextmanager
def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None]:
"""
A context manager function for profiling a block of code. Results are logged to Weights & Biases if enabled.
A context manager function for profiling a block of code. Results are logged to Weights & Biases or MLflow
depending on the trainer's configuration.
Args:
trainer (`~transformers.Trainer`):
@ -54,8 +59,12 @@ def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None
end_time = time.perf_counter()
duration = end_time - start_time
profiling_metrics = {f"profiling/Time taken: {trainer.__class__.__name__}.{name}": duration}
if "wandb" in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process:
wandb.log({f"profiling/Time taken: {trainer.__class__.__name__}.{name}": duration})
wandb.log(profiling_metrics)
if "mlflow" in trainer.args.report_to and mlflow.run is not None and trainer.accelerator.is_main_process:
mlflow.log_metrics(profiling_metrics, step=trainer.state.global_step)
def profiling_decorator(func: callable) -> callable:

View File

@ -14,8 +14,10 @@
import atexit
import logging
import socket
import time
from typing import Optional
from urllib.parse import urlparse
import torch
from torch import nn
@ -47,10 +49,13 @@ class VLLMClient:
weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`.
Args:
base_url (`str` or `None`, *optional*, defaults to `None`):
Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `host` and `server_port` are
ignored.
host (`str`, *optional*, defaults to `"0.0.0.0"`):
IP address of the vLLM server.
IP address of the vLLM server. Ignored if `base_url` is provided.
server_port (`int`, *optional*, defaults to `8000`):
Port number of the vLLM server.
Port number of the vLLM server. Ignored if `base_url` is provided.
group_port (`int`, *optional*, defaults to `51216`):
Port number for the weight update group.
connection_timeout (`float`, *optional*, defaults to `0.0`):
@ -81,10 +86,24 @@ class VLLMClient:
>>> client.init_communicator()
>>> client.update_model_params(model)
```
There are several ways to initialize the client:
```python
VLLMClient(base_url="http://localhost:8000")
VLLMClient(base_url="http://192.168.1.100:8000")
VLLMClient(host="localhost", server_port=8000)
VLLMClient(host="192.168.1.100", server_port=8000)
```
"""
def __init__(
self, host: str = "0.0.0.0", server_port: int = 8000, group_port: int = 51216, connection_timeout: float = 0.0
self,
base_url: Optional[str] = None,
host: str = "0.0.0.0",
server_port: int = 8000,
group_port: int = 51216,
connection_timeout: float = 0.0,
):
if not is_requests_available():
raise ImportError("requests is not installed. Please install it with `pip install requests`.")
@ -92,8 +111,17 @@ class VLLMClient:
raise ImportError("vLLM is not installed. Please install it with `pip install vllm`.")
self.session = requests.Session()
self.host = host
self.server_port = server_port
if base_url is not None:
# Parse the base_url to extract host and port
parsed_url = urlparse(base_url)
self.host = socket.gethostbyname(parsed_url.hostname)
scheme = parsed_url.scheme or "http"
self.base_url = f"{scheme}://{parsed_url.netloc}{parsed_url.path}"
else:
self.host = host
self.server_port = server_port
self.base_url = f"http://{self.host}:{self.server_port}"
self.group_port = group_port
self.check_server(connection_timeout) # check server and fail after timeout
@ -108,7 +136,7 @@ class VLLMClient:
total_timeout (`float`, *optional*, defaults to `0.0`):
Total timeout duration in seconds.
"""
url = f"http://{self.host}:{self.server_port}/health/"
url = f"{self.base_url}/health/"
start_time = time.time() # Record the start time
while True:
@ -119,11 +147,13 @@ class VLLMClient:
elapsed_time = time.time() - start_time
if elapsed_time >= total_timeout:
raise ConnectionError(
f"The vLLM server can't be reached at {self.host}:{self.server_port} after {total_timeout} "
"seconds. Make sure the server is running by running `trl vllm-serve`."
f"The vLLM server can't be reached at {self.base_url} after {total_timeout} seconds. Make "
"sure the server is running by running `trl vllm-serve`."
) from exc
else:
if response.status_code == 200:
if "X-Forwarded-For" in response.headers:
self.host = response.headers["X-Forwarded-For"]
logger.info("Server is up!")
return None
@ -170,7 +200,7 @@ class VLLMClient:
`list[list[int]]`:
List of lists of token IDs representing the model-generated completions for each prompt.
"""
url = f"http://{self.host}:{self.server_port}/generate/"
url = f"{self.base_url}/generate/"
response = self.session.post(
url,
json={
@ -195,7 +225,7 @@ class VLLMClient:
Initializes the weight update group in a distributed setup for model synchronization.
"""
# Get the world size from the server
url = f"http://{self.host}:{self.server_port}/get_world_size/"
url = f"{self.base_url}/get_world_size/"
response = requests.get(url)
if response.status_code == 200:
vllm_world_size = response.json()["world_size"]
@ -206,7 +236,7 @@ class VLLMClient:
self.rank = vllm_world_size # the client's rank is the last process
# Initialize weight update group
url = f"http://{self.host}:{self.server_port}/init_communicator/"
url = f"{self.base_url}/init_communicator/"
# In the server side, the host is set to 0.0.0.0
response = self.session.post(url, json={"host": "0.0.0.0", "port": self.group_port, "world_size": world_size})
if response.status_code != 200:
@ -235,7 +265,7 @@ class VLLMClient:
Tensor containing the updated weights.
"""
dtype, shape = str(weights.dtype), tuple(weights.shape)
url = f"http://{self.host}:{self.server_port}/update_named_param/"
url = f"{self.base_url}/update_named_param/"
response = self.session.post(url, json={"name": name, "dtype": dtype, "shape": shape})
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")
@ -260,7 +290,7 @@ class VLLMClient:
"""
Resets the prefix cache for the model.
"""
url = f"http://{self.host}:{self.server_port}/reset_prefix_cache/"
url = f"{self.base_url}/reset_prefix_cache/"
response = self.session.post(url)
if response.status_code != 200:
raise Exception(f"Request failed: {response.status_code}, {response.text}")
@ -269,7 +299,7 @@ class VLLMClient:
"""
Closes the weight update group and cleans up the communication group.
"""
url = f"http://{self.host}:{self.server_port}/close_communicator/"
url = f"{self.base_url}/close_communicator/"
try:
response = self.session.post(url)

View File

@ -22,7 +22,7 @@ from packaging import version
from transformers.utils.import_utils import _is_package_available
LIGER_KERNEL_MIN_VERSION = "0.5.6"
LIGER_KERNEL_MIN_VERSION = "0.5.8"
# Use same as transformers.utils.import_utils
_deepspeed_available = _is_package_available("deepspeed")
@ -33,7 +33,6 @@ _llm_blender_available = _is_package_available("llm_blender")
_mergekit_available = _is_package_available("mergekit")
_pydantic_available = _is_package_available("pydantic")
_requests_available = _is_package_available("requests")
_rich_available = _is_package_available("rich")
_unsloth_available = _is_package_available("unsloth")
_uvicorn_available = _is_package_available("uvicorn")
_vllm_available = _is_package_available("vllm")
@ -73,10 +72,6 @@ def is_requests_available() -> bool:
return _requests_available
def is_rich_available() -> bool:
return _rich_available
def is_unsloth_available() -> bool:
return _unsloth_available

View File

@ -18,6 +18,7 @@ from ..import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffu
_import_structure = {
"activation_offloading": ["get_act_offloading_ctx_manager"],
"modeling_base": ["GeometricMixtureWrapper", "PreTrainedModelWrapper", "create_reference_model"],
"modeling_value_head": ["AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead"],
"utils": [
@ -43,6 +44,7 @@ else:
]
if TYPE_CHECKING:
from .activation_offloading import get_act_offloading_ctx_manager
from .modeling_base import GeometricMixtureWrapper, PreTrainedModelWrapper, create_reference_model
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
from .utils import (

View File

@ -0,0 +1,462 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of https://github.com/pytorch/torchtune.
import warnings
import psutil
import torch
from torch import nn
from torch.autograd.graph import saved_tensors_hooks
class OffloadActivations(saved_tensors_hooks):
"""
Context manager under which activation tensors created in the forward pass will be offloaded.
Enable the memory efficiency technique of activation offloading, where activations bigger than `min_offload_size`
bytes will be offloaded to CPU in the forward and brought back in the backward. This is in contrast to maintaining
the activation on GPU VRAM throughout the program.
This manager contains the option of using one additional CUDA stream to handle the communication between CUDA and
CPU, which is intended to overlap with the default computation stream to improve runtime. We designed
synchronization with a few heuristics for optimizing the tradeoff between runtime vs memory usage.
Args:
use_pin_memory (`bool`, *optional*, defaults to `True`):
Whether to offloaded Tensor will be placed in pinned memory on the CPU. Pinned memory allows the Tensor to
be moved back onto GPU more quickly but is a limited resource.
use_streams (`bool`, *optional*, defaults to `True`):
Whether to use streams for performance optimization where the communications get overlapped with the
computation. Requires a torch build after torch-2.5.0.
min_offload_size (`int`, *optional*, defaults to `1024`):
Minimum number of bytes a Tensor must be in order to qualify for offloading. If the tensor is too small, we
do not want to waste bandwidth and resources moving it to CPU and back.
max_fwd_stash_size (`int`, *optional*, defaults to `5`):
Maximum size of the forward stash, or the maximum number of consecutive activations to keep alive during
the forward pass. This number must be at least 1. Keeping alive more activations will potentially allow
more overlap between the communication and compute streams at the cost of increasing memory usage. Keeping
alive fewer activations will conserve memory, but may cause poor overlap between the streams, increasing
runtime.
Raises:
ValueError: if `max_fwd_stash_size` is not at least `1`.
Example:
>>> with OffloadActivations():
>>> outputs = model(inputs, labels=labels)
>>> loss = outputs.loss
>>> loss.backward()
"""
def __init__(
self,
use_pin_memory: bool = True,
use_streams: bool = True,
min_offload_size: int = 1024,
max_fwd_stash_size: int = 5,
) -> None:
self.use_streams = use_streams
self.min_tensor_size_bytes = min_offload_size # we don't want to bother with small tensors
self.tracker = {} # tensor_id => (new_tensor, if_modified) ---> track what saved/offloaded tensors are where
self.tensor_id = 0
self.is_first_forward_call = True
self.is_first_backward_call = True
self.is_first_forward_pass = True
# Managing cpu memory
self.use_pin_memory = use_pin_memory
self.virtual_memory_safe_pct = 60 # we should not exceed this percentage of memory
self.accelerator_type = (
torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
)
# NOTE: xpu doesn't have `default_stream` API, use `current_stream` instead
self.s0 = (
torch.xpu.current_stream() if self.accelerator_type == "xpu" else torch.cuda.default_stream()
) # comp stream
# For streaming
if self.use_streams:
self.s1 = torch.Stream() if self.accelerator_type == "xpu" else torch.cuda.Stream() # comms stream
self.fwd_stash = {} # tensor_id => (activation, ev1)
if max_fwd_stash_size < 1:
raise ValueError(f"max_fwd_stash_size should be at least 1 but is {max_fwd_stash_size}")
self.max_fwd_stash_size = max_fwd_stash_size
self.bwd_tensor_stash = {} # tensor_id => activation
self.bwd_ev_stash = {} # tensor_id => ev0
self.curr_graph_id = None
self.curr_autograd_node = None
# -------- platform util functions -------- #
def verify_sufficient_virtual_memory():
curr_pct = get_cpu_ram_pct()
if curr_pct > self.virtual_memory_safe_pct:
warnings.warn(f"{curr_pct=}% > {self.virtual_memory_safe_pct=}% of virtual memory used")
def get_cpu_ram_pct() -> float:
# get the percentage of memory used by the system
return psutil.virtual_memory().percent
def get_tensor_id() -> int:
# create a unique id for each tensor we are managing
self.tensor_id += 1
return self.tensor_id
def get_num_bytes_tensor(x: torch.Tensor) -> int:
# get the number of bytes in a tensor, for memory management purposes
return x.element_size() * x.nelement() # x.element_size() * x._base_storage().nbytes()
# -------- core pack / unpack work -------- #
def pack_tensor(activation: torch.Tensor) -> int:
# activations are passed in during forward pass - from here we take over and return a unique id
if self.is_first_forward_call:
if len(self.tracker) != 0:
raise ValueError("Backward pass should have cleared tracker of all tensors")
# set training phase trackers
self.is_first_forward_call = False
self.is_first_backward_call = True
# query for basic tensor info
num_bytes = get_num_bytes_tensor(activation)
tensor_id = get_tensor_id()
# only offload hefty bois if they're activations on CUDA (our heuristic
# for that is to check if they're not params or buffers)!
if (
activation.device.type in ["cuda", "xpu"]
and num_bytes >= self.min_tensor_size_bytes
and (
not isinstance(activation, torch.nn.Parameter)
and not (hasattr(torch.nn, "Buffer") and isinstance(activation, torch.nn.Buffer))
)
):
if self.use_streams:
# First, sync back and dereference previously offloaded tensors
# as the offloading should be done sufficiently long ago.
for id in list(self.fwd_stash.keys()):
if id <= tensor_id - self.max_fwd_stash_size:
_, ev = self.fwd_stash[id]
self.s0.wait_event(ev)
del self.fwd_stash[id]
else:
break
# Sync in, offload, and add an event to sync back later
self.s1.wait_stream(self.s0)
stream = self.s1 if self.use_streams else self.s0
with stream if self.accelerator_type == "xpu" else torch.cuda.stream(stream):
cpu_tensor = torch.empty_like(activation, pin_memory=self.use_pin_memory, device="cpu")
cpu_tensor.copy_(activation, non_blocking=True)
self.tracker[tensor_id] = (
cpu_tensor,
True, # True = (in future) modified
)
if self.use_streams:
event = self.s1.record_event()
# Stash to keep activation alive til s1 is done
self.fwd_stash[tensor_id] = (activation, event)
else:
self.tracker[tensor_id] = (
activation,
False,
) # False = not modified, tensor is as is
return tensor_id
def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor:
# backward pass - we are called with the tensor_id, which
# we will use to retrieve the saved/offloaded tensor
if self.is_first_backward_call:
if self.is_first_forward_pass:
self.is_first_forward_pass = False
if self.use_pin_memory:
verify_sufficient_virtual_memory()
self.is_first_backward_call = False
self.is_first_forward_call = True
if unpack_tensor_id not in self.tracker:
raise ValueError(f"Untracked tensor with id {unpack_tensor_id}")
maybe_accelerator_tensor, modified = self.tracker[unpack_tensor_id]
if modified:
accelerator_tensor = maybe_accelerator_tensor.to(self.accelerator_type, non_blocking=True)
maybe_accelerator_tensor = accelerator_tensor
# clear tensor from tracking
del self.tracker[unpack_tensor_id]
return maybe_accelerator_tensor
def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor:
# backward pass - we are called with the tensor_id, which
# we will use to retrieve the saved/offloaded tensor
if self.is_first_backward_call:
self.curr_graph_id = torch._C._current_graph_task_id()
def wait_and_del_remaining_references() -> None:
for id in list(self.bwd_tensor_stash.keys()):
event = self.bwd_ev_stash[id]
self.s1.wait_event(event)
del self.bwd_tensor_stash[id]
# Register a callback to the end of autograd to clean everything up
torch.autograd.variable.Variable._execution_engine.queue_callback(wait_and_del_remaining_references)
if self.is_first_forward_pass:
self.is_first_forward_pass = False
if self.use_pin_memory:
verify_sufficient_virtual_memory()
self.is_first_backward_call = False
self.is_first_forward_call = True
if unpack_tensor_id not in self.tracker:
raise ValueError(f"untracked tensor with id {unpack_tensor_id}")
maybe_accelerator_tensor, modified = self.tracker[unpack_tensor_id]
if modified:
# Get data on the current autograd node
graph_id = torch._C._current_graph_task_id()
node = torch._C._current_autograd_node()
prev_node_ids = []
# If we're on a new node, mark prev node's tensors to be freed later
if graph_id == self.curr_graph_id and self.curr_autograd_node != node:
self.curr_autograd_node = node
prev_node_ids = list(self.bwd_tensor_stash.keys())
brought_back_from_cpu = True
if unpack_tensor_id in self.fwd_stash:
maybe_accelerator_tensor = self.fwd_stash[unpack_tensor_id][0]
brought_back_from_cpu = False
else:
# Kick off the process to bring tensors back
with self.s1 if self.accelerator_type == "xpu" else torch.cuda.stream(self.s1):
accelerator_tensor = maybe_accelerator_tensor.to(self.accelerator_type, non_blocking=True)
maybe_accelerator_tensor = accelerator_tensor
# Tell comp stream to wait for the info to be loaded before executing
self.s0.wait_stream(self.s1)
# Stash the tensor to keep memory alive until compute stream is complete
self.bwd_tensor_stash[unpack_tensor_id] = maybe_accelerator_tensor
# Note: [Track views of the unpacked]
# Why do we get the use count of the unpacked tensor here? We want an
# initial count to compare to later, during the post-hook of the
# backward node, when we need to decide whether we're allowed to free
# the tensor yet. In what obscure cases must we delay freeing the
# tensor (and thus call record_stream)?
# 1. Any of the outputs of the backward node is a view of the unpacked
# tensor.
# 2. In the case that this unpacked tensor will be used in a
# checkpointed region, if one of the recomputed saved tensors ends
# up as a view of the unpacked tensor.
# 3. The user abuses the system somehow and manually relies on the
# unpacked tensor to exist after the backward node has executed.
storage_refcount = torch._C._storage_Use_Count(maybe_accelerator_tensor.untyped_storage()._cdata)
def hook(outputs, inputs):
# create events for the current node inputs/outputs if they were streamed in
if brought_back_from_cpu:
# See Note: [Track views of the unpacked]
# IF any of the outputs is a view of the tensor, OR if a view of
# the tensor has been saved as a part of checkpoint's recompute
# process, OR the user has abusedly incurred a reference on the
# unpacked tensor, THEN the tensor might be used later and we
# cannot presume to delete it after only the current node is
# done! So we use our frenemy, record_stream, to ensure the
# Tensor stays unmessed with until it's done getting used in the
# compute stream (s0 here). Note that the con here is we introduce
# non-deterministic (thus higher) memory usage, but this case
# should not happen often.
unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id]
if torch._C._storage_Use_Count(unpacked_tensor.untyped_storage()._cdata) > storage_refcount:
unpacked_tensor.record_stream(self.s0)
del self.bwd_tensor_stash[unpack_tensor_id]
else:
event = self.s0.record_event()
self.bwd_ev_stash[unpack_tensor_id] = event
# if there are still things in the fwd_stash, get rid of them as we're in bwd now
for id in list(self.fwd_stash.keys()):
_, ev = self.fwd_stash[id]
self.s0.wait_event(ev)
del self.fwd_stash[id]
# wait on prev node's events and del those
for id in prev_node_ids:
event = self.bwd_ev_stash[id]
self.s1.wait_event(event)
del self.bwd_tensor_stash[id]
return outputs
node.register_hook(hook)
# clear tensor from tracking
del self.tracker[unpack_tensor_id]
return maybe_accelerator_tensor
unpack_tensor = unpack_tensor_with_streams if self.use_streams else unpack_tensor_single_stream
super().__init__(pack_tensor, unpack_tensor)
class NoOpManager(saved_tensors_hooks):
"""
A `saved_tensors_hook` manager used to disable any other `saved_tensors_hook` manager applied before. This relies
on the behavior that only the most recently registered `saved_tensors_hook` will run.
One example usage is to opt a local region of code out of activations offloading, which is usually applied globally
to best track state.
"""
def __init__(self) -> None:
def noop(tensor):
return tensor
super().__init__(noop, noop)
def get_act_offloading_ctx_manager(
model: nn.Module,
use_pin_memory: bool = True,
use_streams: bool = True,
min_offload_size: int = 1024,
max_fwd_stash_size: int = 5,
warn_if_no_head: bool = True,
) -> OffloadActivations:
"""
Returns the activation offloading context manager for the model. All but the last output Linear in every step will
be offloaded.
If activation offloading is enabled, we return the OffloadActivations context manager.
If activation offloading is disabled, we return a NoOpManager context manager.
Args:
model (`nn.Module`):
Model to wrap with the activation offloading context manager.
use_pin_memory (`bool`, *optional*, defaults to `True`):
Whether to offloaded Tensor will be placed in pinned memory on the CPU. Pinned memory allows the Tensor to
be moved back onto GPU more quickly but is a limited resource.
use_streams (`bool`, *optional*, defaults to `True`):
Whether to use streams for performance optimization where the communications get overlapped with the
computation. Requires a torch build after torch-2.5.0.
min_offload_size (`int`, *optional*, defaults to `1024`):
Minimum number of bytes a Tensor must be in order to qualify for offloading. If the tensor is too small, we
do not want to waste bandwidth and resources moving it to CPU and back.
max_fwd_stash_size (`int`, *optional*, defaults to `5`):
Maximum size of the forward stash, or the maximum number of consecutive activations to keep alive during
the forward pass. This number must be at least 1. Keeping alive more activations will potentially allow
more overlap between the communication and compute streams at the cost of increasing memory usage. Keeping
alive fewer activations will conserve memory, but may cause poor overlap between the streams, increasing
runtime.
warn_if_no_head (`bool`, *optional*, defaults to `True`):
Whether to warn if no output head is detected. If set to `False`, no warning will be raised if no output
head is detected.
Returns:
`contextlib.ContextDecorator`:
Activation offloading context manager for the model.
"""
activations_handling_ctx = OffloadActivations(
use_pin_memory=use_pin_memory,
use_streams=use_streams,
min_offload_size=min_offload_size,
max_fwd_stash_size=max_fwd_stash_size,
)
# Below is our hack to disable offloading the last output Linear in every
# step, as the cost for offloading the activation and then soon after bringing
# it back is expensive.
output_head_detected = False
noop_ctx = NoOpManager()
# Try to get the actual model if it's wrapped
unwrapped_model = model
if hasattr(unwrapped_model, "module"):
unwrapped_model = unwrapped_model.module
# check for PEFT models
if hasattr(unwrapped_model, "base_model") and hasattr(unwrapped_model, "peft_config"):
unwrapped_model = unwrapped_model.base_model
# Check for different types of output heads
if hasattr(unwrapped_model, "output"):
if isinstance(unwrapped_model.output, nn.Module):
unwrapped_model.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
unwrapped_model.output.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
output_head_detected = True
elif hasattr(unwrapped_model.output, "linear") and isinstance(unwrapped_model.output.linear, nn.Module):
unwrapped_model.output.linear.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
unwrapped_model.output.linear.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
output_head_detected = True
# Check for HuggingFace model output heads
elif hasattr(unwrapped_model, "lm_head"):
unwrapped_model.lm_head.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
unwrapped_model.lm_head.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
output_head_detected = True
# Check for decoder-based models
elif hasattr(unwrapped_model, "decoder"):
decoder = unwrapped_model.decoder
if hasattr(decoder, "output"):
decoder.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
decoder.output.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
output_head_detected = True
# Some models have lm_head in the decoder
elif hasattr(decoder, "lm_head"):
decoder.lm_head.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
decoder.lm_head.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
output_head_detected = True
# Check for transformer models with final layer norm
elif hasattr(unwrapped_model, "final_layer_norm") or hasattr(unwrapped_model, "ln_f"):
final_norm = getattr(unwrapped_model, "final_layer_norm", None) or unwrapped_model.ln_f
final_norm.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
final_norm.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
output_head_detected = True
# Check for models with head module
elif hasattr(unwrapped_model, "head") and isinstance(unwrapped_model.head, nn.Module):
unwrapped_model.head.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
unwrapped_model.head.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
output_head_detected = True
if not output_head_detected and warn_if_no_head:
warnings.warn(
"During activation offloading, no output head was detected. If your model has an output head, it will be "
"offloaded. This usually greatly slows training, given the large vocabulary size. To change this "
"behavior, set your output head as model.output and make it an nn.Module. You can disable this warning by "
"passing `warn_if_no_head=False`."
)
# Disable offloading for any Liger modules
for name, module in unwrapped_model.named_modules():
if "liger" in name.lower():
module.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
module.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
return activations_handling_ctx

View File

@ -16,9 +16,9 @@ import itertools
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
from accelerate.utils import is_deepspeed_available
import torch.nn as nn
from packaging import version
from transformers import PreTrainedModel, PreTrainedTokenizer
@ -30,12 +30,10 @@ SUPPORTED_ARCHITECTURES = (
AutoModelForSeq2SeqLMWithValueHead,
)
if is_deepspeed_available():
import deepspeed
if TYPE_CHECKING:
from accelerate import Accelerator
from deepspeed.runtime.engine import DeepSpeedEngine
from torch.nn import Module
from torch.nn.parallel.distributed import DistributedDataParallel
@ -167,6 +165,8 @@ def iter_params(module, recurse=False):
def add_hooks(model: "DeepSpeedEngine") -> None:
"""Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
import deepspeed
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
return
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
@ -214,6 +214,8 @@ def unwrap_model_for_generation(
if not gather_deepspeed3_params:
yield accelerator.unwrap_model(model)
else:
import deepspeed
with deepspeed.zero.GatheredParameters(model.parameters()):
remove_hooks(model)
yield accelerator.unwrap_model(model)
@ -222,8 +224,13 @@ def unwrap_model_for_generation(
yield unwrapped_model
def prepare_deepspeed(model, accelerator):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
def prepare_deepspeed(model: "Module", accelerator: "Accelerator"):
"""Prepares the model for DeepSpeed inference or evaluation by initializing it with the appropriate configuration.
Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
"""
import deepspeed # local import (instead of top-level) to avoid DS init interfering with other backends (like vllm): https://github.com/deepspeedai/DeepSpeed/issues/7252
deepspeed_plugin = accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
stage = config_kwargs["zero_optimization"]["stage"]
@ -266,7 +273,7 @@ def prepare_fsdp(model, accelerator):
accelerator.state.fsdp_plugin.set_auto_wrap_policy(model)
fsdp_plugin = accelerator.state.fsdp_plugin
kwargs = {
"sharding_strategy": fsdp_plugin.sharding_strategy,
"sharding_strategy": fsdp_plugin.sharding_strategy or fsdp_plugin.reshard_after_forward,
"cpu_offload": fsdp_plugin.cpu_offload,
"auto_wrap_policy": fsdp_plugin.auto_wrap_policy,
"mixed_precision": fsdp_plugin.mixed_precision_policy,
@ -282,3 +289,53 @@ def prepare_fsdp(model, accelerator):
model = FSDP(model, **kwargs)
model.eval()
return model
class _ForwardRedirection:
"""Implements the `forward-redirection`.
Taken from Pytorch-lightning: https://github.com/Lightning-AI/pytorch-lightning/blob/02311d03fb982560246eead7c08104481fac9579/src/lightning/pytorch/strategies/strategy.py#L602
A method call to a wrapped module gets rerouted through the wrapper's `forward` method instead.
"""
def __call__(
self, wrapper_module: nn.Module, original_module: nn.Module, method: callable, *args: Any, **kwargs: Any
):
"""Reroutes a method call through the `wrapper_module`'s `forward` method.
Args:
wrapper_module: The module that has `original_module` wrapped.
original_module: The module that was wrapped inside `wrapper_module`.
method_name: The name of the method that should be called on the `original_module` after inputs get
redirected through the `wrapper_module`'s `forward` method.
*args: The positional arguments to the method `method_name`. They will get passed to a patched
`forward` method instead.
**kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
`forward` method instead.
"""
original_forward = original_module.forward
def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
# Unpatch ourselves immediately before calling the method `method_name`
# because itself may want to call the real `forward`
original_module.forward = original_forward # type: ignore[method-assign]
# Call the actual method e.g. `.training_step(...)`
out = method(*_args, **_kwargs)
self.on_after_inner_forward(wrapper_module, original_module)
return out
# Patch the original_module's forward so we can redirect the arguments back to the real method
original_module.forward = wrapped_forward # type: ignore[method-assign]
wrapper_output = wrapper_module(*args, **kwargs)
self.on_after_outer_forward(wrapper_module, original_module)
return wrapper_output
def on_after_inner_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None:
pass
def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None:
pass

32
trl/rewards/__init__.py Normal file
View File

@ -0,0 +1,32 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import TYPE_CHECKING
from ..import_utils import _LazyModule
_import_structure = {
"format_rewards": ["think_format_reward"],
}
if TYPE_CHECKING:
from .format_rewards import think_format_reward
else:
sys.modules[__name__] = _LazyModule(__name__, __file__, _import_structure, module_spec=__spec__)

View File

@ -0,0 +1,49 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
def think_format_reward(completions: list[list[dict[str, str]]], **kwargs) -> list[float]:
r"""
Reward function that checks if the reasoning process is enclosed within `"<think>"` and `"</think>"` tags. The
function returns a reward of 1.0 if the format is correct, otherwise 0.0.
Args:
completions (`list[list[dict[str, str]]]`):
List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary
containing the key `"content"` with the value being the text of the completion.
**kwargs:
Additional keyword arguments. This function does not use them, but they are required in the function
signature to ensure compatibility with trainers like [`GRPOTrainer`].
Returns:
`list[float]`:
A list of rewards, where each reward is 1.0 if the completion matches the expected format, otherwise 0.0.
Example:
```python
>>> from trl.rewards import think_format_reward
>>> completions = [
... [{"content": "<think>\nThis is my reasoning.\n</think>\nThis is my answer."}],
... [{"content": "<think>\nThis is my reasoning.\nThis is my answer."}],
... ]
>>> think_format_reward(completions)
[1.0, 0.0]
```
"""
pattern = r"^<think>(?!.*<think>)(.*?)</think>.*$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]

View File

@ -26,15 +26,18 @@ from typing import Optional
import torch
import yaml
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers.utils import is_rich_available
from trl import TrlParser, init_zero_verbose
from trl.trainer.utils import get_quantization_config
if is_rich_available():
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
init_zero_verbose()
HELP_STRING = """\
@ -195,6 +198,9 @@ class ChatArguments:
class RichInterface:
def __init__(self, model_name=None, user_name=None):
if not is_rich_available():
raise ImportError("Rich is not available. Please install it with `pip install rich`.")
self._console = Console()
if model_name is None:
self.model_name = "assistant"

View File

@ -33,8 +33,13 @@ from .utils import get_git_commit_hash
def print_env():
devices = None
if torch.cuda.is_available():
devices = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]
elif torch.backends.mps.is_available():
devices = ["MPS"]
elif torch.xpu.is_available():
devices = [torch.xpu.get_device_name(i) for i in range(torch.xpu.device_count())]
accelerate_config = accelerate_config_str = "not found"
@ -55,7 +60,7 @@ def print_env():
"Python version": platform.python_version(),
"TRL version": f"{__version__}+{commit_hash[:7]}" if commit_hash else __version__,
"PyTorch version": version("torch"),
"CUDA device(s)": ", ".join(devices) if torch.cuda.is_available() else "not available",
"accelerator(s)": ", ".join(devices) if devices is not None else "cpu",
"Transformers version": version("transformers"),
"Accelerate version": version("accelerate"),
"Accelerate config": accelerate_config_str,

View File

@ -13,6 +13,9 @@
# limitations under the License.
import argparse
import importlib
import os
import sys
from dataclasses import dataclass, field
from typing import Optional
@ -20,6 +23,12 @@ from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
from trl.rewards import think_format_reward
reward_funcs_registry = {
"think_format_reward": think_format_reward,
}
@dataclass
@ -28,9 +37,12 @@ class GRPOScriptArguments(ScriptArguments):
Script arguments for the GRPO training script.
Args:
reward_model_name_or_path (`str` or `None`):
reward_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
Reward model id of a pretrained model hosted inside a model repo on huggingface.co or local path to a
directory containing model weights saved using [`~transformers.PreTrainedModel.save_pretrained`].
reward_funcs (`list[str]` or `None`, *optional*, defaults to `None`):
Reward functions to use. It can be either one of `"think_format_reward"`; or a dotted import path "
(e.g., `'my_lib.rewards.custom_reward'`).
"""
reward_model_name_or_path: Optional[str] = field(
@ -40,6 +52,13 @@ class GRPOScriptArguments(ScriptArguments):
"local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`."
},
)
reward_funcs: Optional[list[str]] = field(
default=None,
metadata={
"help": "Reward functions to use. It can be either one of 'think_format_reward'; or a dotted "
"import path. (e.g., 'my_lib.rewards.custom_reward')."
},
)
def main(script_args, training_args, model_args):
@ -50,9 +69,30 @@ def main(script_args, training_args, model_args):
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
reward_model = AutoModelForSequenceClassification.from_pretrained(
script_args.reward_model_name_or_path, trust_remote_code=model_args.trust_remote_code, num_labels=1
)
# Get the reward models and functions
reward_funcs = []
if script_args.reward_model_name_or_path:
reward_model = AutoModelForSequenceClassification.from_pretrained(
script_args.reward_model_name_or_path, trust_remote_code=model_args.trust_remote_code, num_labels=1
)
reward_funcs.append(reward_model)
if script_args.reward_funcs:
for func_name in script_args.reward_funcs:
if func_name in reward_funcs_registry:
reward_funcs.append(reward_funcs_registry[func_name])
elif "." in func_name:
module_path, func_name = func_name.rsplit(".", 1)
sys.path.insert(0, os.getcwd())
module = importlib.import_module(module_path)
reward_func = getattr(module, func_name)
reward_funcs.append(reward_func)
else:
raise ValueError(
f"Could not load reward function '{func_name}'. Expected one of "
f"{list(reward_funcs_registry.keys())} or a valid import path."
)
# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)

View File

@ -142,5 +142,8 @@ def make_parser(subparsers: argparse._SubParsersAction = None):
if __name__ == "__main__":
parser = make_parser()
script_args, training_args, model_args = parser.parse_args_and_config()
# When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments.
# To ensure that their parsing does not interfere with the script arguments, parse the arguments with
# `return_remaining_strings=True`, then ignore the remaining strings.
script_args, training_args, model_args, _ = parser.parse_args_and_config(return_remaining_strings=True)
main(script_args, training_args, model_args)

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import importlib
import inspect
import logging
@ -25,6 +26,7 @@ from typing import Optional, Union
import yaml
from transformers import HfArgumentParser
from transformers.hf_argparser import DataClass, DataClassType
from transformers.utils import is_rich_available
logger = logging.getLogger(__name__)
@ -51,7 +53,7 @@ class ScriptArguments:
type, inplace operation. See https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992.
"""
dataset_name: str = field(metadata={"help": "Dataset name."})
dataset_name: Optional[str] = field(default=None, metadata={"help": "Dataset name."})
dataset_config: Optional[str] = field(
default=None,
metadata={
@ -78,14 +80,21 @@ class ScriptArguments:
def init_zero_verbose():
"""
Perform zero verbose init - use this method on top of the CLI modules to make
logging and warning output cleaner. Uses Rich if available, falls back otherwise.
"""
import logging
import warnings
from rich.logging import RichHandler
FORMAT = "%(message)s"
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.ERROR)
if is_rich_available():
from rich.logging import RichHandler
handler = RichHandler()
else:
handler = logging.StreamHandler()
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[handler], level=logging.ERROR)
# Custom warning handler to redirect warnings to the logging system
def warning_handler(message, category, filename, lineno, file=None, line=None):
@ -207,20 +216,33 @@ class TrlParser(HfArgumentParser):
def set_defaults_with_config(self, **kwargs) -> list[str]:
"""
Overrides the parser's default values with those provided via keyword arguments.
Overrides the parser's default values with those provided via keyword arguments, including for subparsers.
Any argument with an updated default will also be marked as not required
if it was previously required.
Returns a list of strings that were not consumed by the parser.
"""
# If an argument is in the kwargs, update its default and set it as not required
for action in self._actions:
if action.dest in kwargs:
action.default = kwargs.pop(action.dest)
action.required = False
remaining_strings = [item for key, value in kwargs.items() for item in [f"--{key}", str(value)]]
return remaining_strings
def apply_defaults(parser, kw):
used_keys = set()
for action in parser._actions:
# Handle subparsers recursively
if isinstance(action, argparse._SubParsersAction):
for subparser in action.choices.values():
used_keys.update(apply_defaults(subparser, kw))
elif action.dest in kw:
action.default = kw[action.dest]
action.required = False
used_keys.add(action.dest)
return used_keys
used_keys = apply_defaults(self, kwargs)
# Remaining args not consumed by the parser
remaining = [
item for key, value in kwargs.items() if key not in used_keys for item in (f"--{key}", str(value))
]
return remaining
def get_git_commit_hash(package_name):

View File

@ -184,6 +184,8 @@ class ScriptArguments:
enforce_eager (`bool` or `None`, *optional*, defaults to `None`):
Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the
model in eager mode. If `False` (default behavior), we will use CUDA graph and eager execution in hybrid.
kv_cache_dtype (`str`, *optional*, defaults to `"auto"`):
Data type to use for KV cache. If set to `"auto"`, the dtype will default to the model data type.
log_level (`str`, *optional*, defaults to `"info"`):
Log level for uvicorn. Possible choices: `"critical"`, `"error"`, `"warning"`, `"info"`, `"debug"`,
`"trace"`.
@ -251,6 +253,12 @@ class ScriptArguments:
"execution in hybrid."
},
)
kv_cache_dtype: str = field(
default="auto",
metadata={
"help": "Data type to use for KV cache. If set to 'auto', the dtype will default to the model data type."
},
)
log_level: str = field(
default="info",
metadata={
@ -280,6 +288,7 @@ def llm_worker(
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
# This is particularly useful here because we generate completions from the same prompts.
enable_prefix_caching=script_args.enable_prefix_caching,
kv_cache_dtype=script_args.kv_cache_dtype,
max_model_len=script_args.max_model_len,
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
)

View File

@ -38,6 +38,7 @@ _import_structure = {
"gkd_trainer": ["GKDTrainer"],
"grpo_config": ["GRPOConfig"],
"grpo_trainer": ["GRPOTrainer"],
"iterative_sft_config": ["IterativeSFTConfig"],
"iterative_sft_trainer": ["IterativeSFTTrainer"],
"judges": [
"AllTrueJudge",
@ -109,7 +110,7 @@ if TYPE_CHECKING:
from .gkd_trainer import GKDTrainer
from .grpo_config import GRPOConfig
from .grpo_trainer import GRPOTrainer
from .iterative_sft_trainer import IterativeSFTTrainer
from .iterative_sft_trainer import IterativeSFTConfig, IterativeSFTTrainer
from .judges import (
AllTrueJudge,
BaseBinaryJudge,

View File

@ -19,7 +19,6 @@ import textwrap
import warnings
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from operator import itemgetter
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
@ -32,7 +31,7 @@ import torch.nn.functional as F
import transformers
from accelerate import PartialState
from accelerate.logging import get_logger
from accelerate.utils import is_deepspeed_available, tqdm
from accelerate.utils import tqdm
from datasets import Dataset
from packaging import version
from torch.utils.data import DataLoader, SequentialSampler
@ -56,7 +55,7 @@ from transformers.utils import is_peft_available
from ..data_utils import maybe_apply_chat_template
from ..import_utils import is_joblib_available
from ..models import PreTrainedModelWrapper, create_reference_model
from ..models import create_reference_model, prepare_deepspeed
from .bco_config import BCOConfig
from .utils import (
DPODataCollatorWithPadding,
@ -83,9 +82,6 @@ if is_sklearn_available():
if is_joblib_available():
import joblib
if is_deepspeed_available():
import deepspeed
if TYPE_CHECKING:
from transformers import PreTrainedModel, PreTrainedTokenizer
@ -712,7 +708,7 @@ class BCOTrainer(Trainer):
)
else:
if self.is_deepspeed_enabled:
self.ref_model = self._prepare_deepspeed(self.ref_model)
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
@ -846,37 +842,6 @@ class BCOTrainer(Trainer):
return all_embeddings
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
if model is not None:
if hasattr(model, "config"):
hidden_size = (
max(model.config.hidden_sizes)
if getattr(model.config, "hidden_sizes", None)
else getattr(model.config, "hidden_size", None)
)
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
config_kwargs.update(
{
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
}
)
# If ZeRO-3 is used, we shard both the active and reference model.
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
if config_kwargs["zero_optimization"]["stage"] != 3:
config_kwargs["zero_optimization"]["stage"] = 0
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model
def _save_optimizer_and_scheduler(self, output_dir):
output_dir = output_dir if output_dir is not None else self.args.output_dir
super()._save_optimizer_and_scheduler(output_dir)
@ -1310,10 +1275,12 @@ class BCOTrainer(Trainer):
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]:
if dataset is None:
dataset = self.train_dataset
if dataset is None or not has_length(dataset):
return None
return SequentialSampler(self.train_dataset)
return SequentialSampler(dataset)
def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
"""Generate samples from the model and reference model for the given batch of inputs."""

View File

@ -19,11 +19,7 @@ import pandas as pd
import torch
from accelerate import Accelerator
from accelerate.state import AcceleratorState
from accelerate.utils import gather_object, is_comet_ml_available, is_deepspeed_available, is_wandb_available
from rich.console import Console, Group
from rich.live import Live
from rich.panel import Panel
from rich.progress import Progress
from accelerate.utils import gather_object, is_wandb_available
from transformers import (
GenerationConfig,
PreTrainedModel,
@ -35,6 +31,7 @@ from transformers import (
TrainingArguments,
)
from transformers.trainer_utils import has_length
from transformers.utils import is_rich_available
from ..data_utils import maybe_apply_chat_template
from ..import_utils import is_mergekit_available
@ -44,11 +41,11 @@ from .judges import BasePairwiseJudge
from .utils import log_table_to_comet_experiment
if is_deepspeed_available():
import deepspeed
if is_comet_ml_available():
pass
if is_rich_available():
from rich.console import Console, Group
from rich.live import Live
from rich.panel import Panel
from rich.progress import Progress
if is_wandb_available():
import wandb
@ -115,6 +112,8 @@ class SyncRefModelCallback(TrainerCallback):
def sync_target_model(model, target_model, alpha):
deepspeed_plugin = AcceleratorState().deepspeed_plugin
if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3:
import deepspeed
with deepspeed.zero.GatheredParameters(
list(model.parameters()) + list(target_model.parameters()), modifier_rank=0
):
@ -138,6 +137,9 @@ class RichProgressCallback(TrainerCallback):
"""
def __init__(self):
if not is_rich_available():
raise ImportError("RichProgressCallback requires the `rich` extra. To install, run `pip install rich`.")
self.training_bar = None
self.prediction_bar = None

View File

@ -131,14 +131,19 @@ class DPOConfig(TrainingArguments):
Whether to ignore the provided reference model and implicitly use a reference model that assigns equal
probability to all responses.
label_smoothing (`float`, *optional*, defaults to `0.0`):
Robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report and
Robust DPO label smoothing parameter from the [cDPO report](https://ericmitchell.ai/cdpo.pdf) and
[Robust DPO](https://huggingface.co/papers/2403.00409) paper that should be between `0.0` and `0.5`.
use_weighting (`bool`, *optional*, defaults to `False`):
Whether to weight the loss as done in the [WPO](https://huggingface.co/papers/2406.11827) paper.
Whether to weight the loss as done in the [WPO paper](https://huggingface.co/papers/2406.11827).
rpo_alpha (`float`, *optional*, defaults to `None`):
α parameter from the [RPO](https://huggingface.co/papers/2404.19733) paper (v3), which controls the
α parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) (v3), which controls the
weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the
DPO loss. The paper recommends `rpo_alpha=1.0`.
ld_alpha (`float` or `None`, *optional*, defaults to `None`):
α parameter from the [LD-DPO paper](https://huggingface.co/papers/2409.06411), which controls the weighting
of the verbose token log-probabilities in responses. If `None`, no weighting is applied to the verbose
part, and the loss is equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between
`0.0` and `1.0`.
discopop_tau (`float`, *optional*, defaults to `0.05`):
τ/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls
the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`.
@ -346,6 +351,14 @@ class DPOConfig(TrainingArguments):
"`rpo_alpha=1.0`."
},
)
ld_alpha: Optional[float] = field(
default=None,
metadata={
"help": "α parameter from the LD-DPO paper, which controls the weighting of the verbose token "
"log-probabilities in responses. If `None`, no weighting is applied to the verbose part, and the loss is "
"equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between `0.0` and `1.0`.",
},
)
discopop_tau: float = field(
default=0.05,
metadata={

View File

@ -19,7 +19,6 @@ import textwrap
import warnings
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Callable, Literal, Optional, Union
@ -30,7 +29,7 @@ import torch.nn as nn
import torch.nn.functional as F
import transformers
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available, tqdm
from accelerate.utils import tqdm
from datasets import Dataset, IterableDataset
from packaging import version
from torch.utils.data import DataLoader
@ -53,7 +52,7 @@ from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available, is_torch_xpu_available
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from ..models import PreTrainedModelWrapper, create_reference_model
from ..models import create_reference_model, prepare_deepspeed
from ..models.utils import prepare_fsdp
from .callbacks import SyncRefModelCallback
from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType
@ -63,6 +62,7 @@ from .utils import (
disable_dropout_in_model,
empty_cache,
flush_left,
flush_right,
generate_model_card,
get_comet_experiment_url,
log_table_to_comet_experiment,
@ -80,9 +80,6 @@ if is_peft_available():
if is_wandb_available():
import wandb
if is_deepspeed_available():
import deepspeed
@dataclass
class DataCollatorForPreference(DataCollatorMixin):
@ -184,7 +181,6 @@ class DPOTrainer(Trainer):
Processing class used to process the data. If provided, will be used to automatically process the inputs
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
reuse the fine-tuned model.
This supercedes the `tokenizer` argument, which is now deprecated.
model_init (`Callable[[], transformers.PreTrainedModel]`):
The model initializer to use for training. If None is specified, the default model initializer will be used.
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
@ -510,7 +506,7 @@ class DPOTrainer(Trainer):
)
else:
if self.is_deepspeed_enabled:
self.ref_model = self._prepare_deepspeed(self.ref_model)
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
elif self.is_fsdp_enabled:
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
else:
@ -558,7 +554,7 @@ class DPOTrainer(Trainer):
dataset = dataset.map(
self.tokenize_row if not self.is_vision_model else self.process_row,
remove_columns=["prompt", "chosen", "rejected"],
remove_columns=["chosen", "rejected"],
fn_kwargs={
"processing_class": processing_class,
"max_prompt_length": args.max_prompt_length,
@ -676,37 +672,6 @@ class DPOTrainer(Trainer):
return output
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
if model is not None:
if hasattr(model, "config"):
hidden_size = (
max(model.config.hidden_sizes)
if getattr(model.config, "hidden_sizes", None)
else getattr(model.config, "hidden_size", None)
)
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
config_kwargs.update(
{
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
}
)
# If ZeRO-3 is used, we shard both the active and reference model.
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
if config_kwargs["zero_optimization"]["stage"] != 3:
config_kwargs["zero_optimization"]["stage"] = 0
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model
def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
# By default, this method sets `self._signature_columns` to the model's expected inputs.
@ -840,9 +805,9 @@ class DPOTrainer(Trainer):
with torch.no_grad(), compte_ref_context_manager:
if self.ref_model is None:
with self.null_ref_context():
ref_model_output = self.concatenated_forward(self.model, batch)
ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True)
else:
ref_model_output = self.concatenated_forward(self.ref_model, batch)
ref_model_output = self.concatenated_forward(self.ref_model, batch, is_ref_model=True)
return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"]
@staticmethod
@ -1102,16 +1067,28 @@ class DPOTrainer(Trainer):
return losses, chosen_rewards, rejected_rewards
def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]):
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
def concatenated_forward(
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]], is_ref_model: bool = False
):
"""
Runs the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
We do this to avoid doing two forward passes, because it's faster for FSDP.
Args:
model:
Model to run the forward pass on.
batch:
Batch of input data.
is_ref_model:
Whether this method is being called for the reference model. If `True`, length desensitization is not
applied.
"""
num_examples = batch["prompt_input_ids"].shape[0]
concatenated_batch = self.concatenated_inputs(batch, padding_value=self.padding_value)
model_kwargs = {}
model_kwargs = {"use_cache": False}
if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True
@ -1148,26 +1125,35 @@ class DPOTrainer(Trainer):
dim=1,
)
# Flush left to reduce the memory usage
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
# Truncate right
if self.max_length is not None:
if self.truncation_mode == "keep_end":
# Flush and truncate
if self.max_length is not None and self.max_length < attention_mask.size(1):
if self.truncation_mode == "keep_start":
# Flush left to reduce the memory usage
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
attention_mask = attention_mask[:, : self.max_length]
input_ids = input_ids[:, : self.max_length]
loss_mask = loss_mask[:, : self.max_length]
elif self.truncation_mode == "keep_end":
# Flush right before truncating left, then flush left
# [[0, 0, x, x, x, x], -> [[0, 0, x, x],
# [0, x, x, x, 0, 0]] [0, x, x, x]]
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
input_ids = input_ids[:, -self.max_length :]
attention_mask = attention_mask[:, -self.max_length :]
loss_mask = loss_mask[:, -self.max_length :]
elif self.truncation_mode == "keep_start":
input_ids = input_ids[:, : self.max_length]
attention_mask = attention_mask[:, : self.max_length]
loss_mask = loss_mask[:, : self.max_length]
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
else:
raise ValueError(
f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', "
"'keep_start']."
)
else:
# Flush left to reduce the memory usage
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
if self.use_logits_to_keep:
# Compute logits_to_keep based on loss_mask pattern:
@ -1254,6 +1240,28 @@ class DPOTrainer(Trainer):
if self.loss_type == "ipo":
all_logps = all_logps / loss_mask.sum(-1)
if self.args.ld_alpha is not None and not is_ref_model:
# Compute response lengths based on loss_mask
completion_lengths = loss_mask.sum(dim=1)
chosen_lengths = completion_lengths[:num_examples]
rejected_lengths = completion_lengths[num_examples:]
public_lengths = torch.min(chosen_lengths, rejected_lengths) # l_p in the paper
public_lengths = torch.cat([public_lengths, public_lengths], dim=0)
seq_len = per_token_logps.size(1)
position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps)
ld_mask = position_ids < public_lengths.unsqueeze(1)
mask = position_ids < completion_lengths.unsqueeze(1)
front_mask = (ld_mask & mask).float()
rear_mask = (~ld_mask & mask).float()
front_logps = (per_token_logps * front_mask).sum(dim=1)
rear_logps = (per_token_logps * rear_mask).sum(dim=1)
all_logps = front_logps + self.args.ld_alpha * rear_logps
output["chosen_logps"] = all_logps[:num_examples]
output["rejected_logps"] = all_logps[num_examples:]
@ -1488,7 +1496,7 @@ class DPOTrainer(Trainer):
)
],
)
if "wandb" in self.args.report_to:
if "wandb" in self.args.report_to and self.accelerator.is_main_process:
wandb.log({"game_log": wandb.Table(data=table)})
if "comet_ml" in self.args.report_to:

View File

@ -15,13 +15,11 @@
import os
import random
import textwrap
from copy import deepcopy
from typing import Any, Callable, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from accelerate.utils import is_deepspeed_available
from datasets import Dataset
from transformers import (
AutoModelForCausalLM,
@ -38,7 +36,7 @@ from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available
from ..models import PreTrainedModelWrapper
from ..models import prepare_deepspeed
from ..models.utils import unwrap_model_for_generation
from .gkd_config import GKDConfig
from .sft_trainer import SFTTrainer
@ -51,10 +49,6 @@ from .utils import (
)
if is_deepspeed_available():
import deepspeed
if is_peft_available():
from peft import PeftConfig
@ -124,7 +118,7 @@ class GKDTrainer(SFTTrainer):
disable_dropout_in_model(self.model)
if self.is_deepspeed_enabled:
self.teacher_model = self._prepare_deepspeed(teacher_model)
self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator)
else:
self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
@ -311,37 +305,6 @@ class GKDTrainer(SFTTrainer):
loss = super().training_step(model, inputs, num_items_in_batch)
return loss
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
if model is not None:
if hasattr(model, "config"):
hidden_size = (
max(model.config.hidden_sizes)
if getattr(model.config, "hidden_sizes", None)
else getattr(model.config, "hidden_size", None)
)
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
config_kwargs.update(
{
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
}
)
# If ZeRO-3 is used, we shard both the active and reference model.
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
if config_kwargs["zero_optimization"]["stage"] != 3:
config_kwargs["zero_optimization"]["stage"] = 0
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model
def create_model_card(
self,
model_name: Optional[str] = None,

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass, field
from typing import Optional, Union
@ -64,14 +64,19 @@ class GRPOConfig(TrainingArguments):
> Parameters that control generation
temperature (`float`, defaults to `0.9`):
generation_batch_size: (`int` or `None`, *optional*, defaults to `None`):
Batch size to use for generation. If `None`, it defaults to the effective training batch size:
`per_device_train_batch_size * num_processes * gradient_accumulation_steps`.
steps_per_generations: (`int` or `None`, *optional*, defaults to `None`):
Number of optimization steps per generation. If `None`, it defaults to gradient_accumulation_steps.
temperature (`float`, defaults to `1.0`):
Temperature for sampling. The higher the temperature, the more random the completions.
top_p (`float`, *optional*, defaults to `1.0`):
Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to
`1.0` to consider all tokens.
top_k (`int` or `None`, *optional*, defaults to `50`):
top_k (`int` or `None`, *optional*, defaults to `None`):
Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is
disabled.
disabled and all tokens are considered.
min_p (`float` or `None`, *optional*, defaults to `None`):
Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range.
@ -85,18 +90,42 @@ class GRPOConfig(TrainingArguments):
> Parameters that control generation acceleration powered by vLLM
use_vllm (`bool`, *optional*, defaults to `False`):
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`):
Host of the vLLM server to connect to.
vllm_server_port (`int`, *optional*, defaults to `8000`):
Port of the vLLM server to connect to.
vllm_server_timeout (`float`, *optional*, defaults to `120.0`):
Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the
timeout, a `ConnectionError` is raised.
Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for generation
instead of the default model.generate(). Requires `vllm` to be installed.
vllm_mode (`str`, *optional*, defaults to `"server"`):
Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `"server"` or
`"colocate"`.
- `"server"`: The trainer will send generation requests to a separate vLLM server. Make sure a TRL vLLM
server is running (start with `trl vllm-serve`).
- `"colocate"`: vLLM will run in the same process and share the training GPUs. This avoids the need for a
separate server but may cause resource contention with training.
vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
> Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
vllm_server_base_url (`str` or `None`, *optional*, defaults to `None`):
Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `vllm_server_host` and
`vllm_server_port` are ignored.
vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`):
Host of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided.
vllm_server_port (`int`, *optional*, defaults to `8000`):
Port of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided.
vllm_server_timeout (`float`, *optional*, defaults to `240.0`):
Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the
timeout, a `ConnectionError` is raised.
> Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`)
vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.3`):
Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set to
`"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when
launching the vLLM server via the `--vllm_gpu_memory_utilization` flag.
vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`):
Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set to
`"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when
launching the vLLM server via the `--vllm_tensor_parallel_size` flag.
> Parameters that control the training
learning_rate (`float`, *optional*, defaults to `1e-6`):
@ -109,6 +138,10 @@ class GRPOConfig(TrainingArguments):
Number of iterations per batch (denoted as μ in the algorithm).
epsilon (`float`, *optional*, defaults to `0.2`):
Epsilon value for clipping.
delta: (`float` or `None`, *optional*, defaults to `None`):
Enables the upper clipping bound in two-sided GRPO loss when set to a float. If `None` (default), standard
GRPO clipping is used. Recommended to be greater than `1 + ε` when enabled. This method is introduced in
the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291).
epsilon_high (`float` or `None`, *optional*, defaults to `None`):
Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound
specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
@ -139,7 +172,7 @@ class GRPOConfig(TrainingArguments):
[DAPO](https://huggingface.co/papers/2503.14476) paper, this is a good practice for training stability.
sync_ref_model (`bool`, *optional*, defaults to `False`):
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
the `ref_model_mixup_alpha` parameter. This synchronization originates from the
[TR-DPO](https://huggingface.co/papers/2404.09656) paper.
ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`):
α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
@ -230,8 +263,21 @@ class GRPOConfig(TrainingArguments):
)
# Parameters that control generation
generation_batch_size: Optional[int] = field(
default=None,
metadata={
"help": "Batch size to use for generation. If `None`, it defaults to the effective training batch size: "
"`per_device_train_batch_size * num_processes * gradient_accumulation_steps`."
},
)
steps_per_generation: Optional[int] = field(
default=None,
metadata={
"help": "Number of optimization steps per generation. If `None`, it defaults to gradient_accumulation_steps."
},
)
temperature: float = field(
default=0.9,
default=1.0,
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
)
top_p: float = field(
@ -242,10 +288,10 @@ class GRPOConfig(TrainingArguments):
},
)
top_k: Optional[int] = field(
default=50,
default=None,
metadata={
"help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, "
"top-k-filtering is disabled."
"top-k-filtering is disabled and all tokens are considered."
},
)
min_p: Optional[float] = field(
@ -272,17 +318,40 @@ class GRPOConfig(TrainingArguments):
use_vllm: bool = field(
default=False,
metadata={
"help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a vLLM server is "
"running. To run the server, install vLLM (`pip install vllm`) and run `trl vllm-serve`."
"help": "Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for "
"generation instead of the default model.generate(). Requires `vllm` to be installed."
},
)
vllm_server_base_url: Optional[str] = field(
default=None,
metadata={
"help": "Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` "
"and `vllm_server_port` are ignored."
},
)
vllm_mode: str = field(
default="server",
metadata={
"help": "Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `server` or "
"`'colocate'`. `'server'`: The trainer will send generation requests to a separate vLLM server. Make sure a "
"TRL vLLM server is running (start with `trl vllm-serve`). `'colocate'`: vLLM will run in the same "
"process and share the training GPUs. This avoids the need for a separate server but may cause resource "
"contention with training."
},
)
vllm_guided_decoding_regex: Optional[str] = field(
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
)
# Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
vllm_server_host: str = field(
default="0.0.0.0",
metadata={"help": "Host of the vLLM server to connect to."},
metadata={"help": "Host of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."},
)
vllm_server_port: int = field(
default=8000,
metadata={"help": "Port of the vLLM server to connect to."},
metadata={"help": "Port of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."},
)
vllm_server_timeout: float = field(
default=240.0,
@ -291,9 +360,23 @@ class GRPOConfig(TrainingArguments):
"after the timeout, a `ConnectionError` is raised."
},
)
vllm_guided_decoding_regex: Optional[str] = field(
default=None,
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
# Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`)
vllm_gpu_memory_utilization: float = field(
default=0.3,
metadata={
"help": "Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set "
"to `'colocate'`. If you are using `vllm_mode='server'`, this parameter must be passed separately when "
"launching the vLLM server via the `--vllm_gpu_memory_utilization` flag."
},
)
vllm_tensor_parallel_size: int = field(
default=1,
metadata={
"help": "Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set "
"to `'colocate'`. If you are using `vllm_mode='server'`, this parameter must be passed separately when "
"launching the vLLM server via the `--vllm_tensor_parallel_size` flag."
},
)
# Parameters that control the training
@ -319,6 +402,14 @@ class GRPOConfig(TrainingArguments):
default=0.2,
metadata={"help": "Epsilon value for clipping."},
)
delta: Optional[float] = field(
default=None,
metadata={
"help": "Enables the upper clipping bound in two-sided GRPO loss when set to a float. If `None` "
"(default), standard GRPO clipping is used. Recommended to be greater than `1 + ε` when enabled. This "
"method is introduced in the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291)."
},
)
epsilon_high: Optional[float] = field(
default=None,
metadata={
@ -413,82 +504,58 @@ class GRPOConfig(TrainingArguments):
},
)
# Deprecated parameters
vllm_device: Optional[str] = field(
default=None,
metadata={
"help": "This parameter is deprecated and will be removed in version 0.18.0. To use vLLM, start a vLLM "
"server with the `trl vllm-serve` command."
},
)
vllm_gpu_memory_utilization: Optional[float] = field(
default=None,
metadata={
"help": "This parameter is deprecated and will be removed in version 0.18.0. To control the GPU memory "
"utilization for vLLM, you should now use the `gpu_memory_utilization` parameter in the vLLM server "
"configuration."
},
)
vllm_dtype: Optional[str] = field(
default=None,
metadata={
"help": "This parameter is deprecated and will be removed in version 0.18.0. To control the data type for "
"vLLM generation, you should now use the `dtype` parameter in the vLLM server configuration."
},
)
vllm_max_model_len: Optional[int] = field(
default=None,
metadata={
"help": "This parameter is deprecated and will be removed in version 0.18.0. To control the "
"`max_model_len` for vLLM, you should now use the `max_model_len` parameter in the vLLM server "
"configuration."
},
)
vllm_enable_prefix_caching: Optional[bool] = field(
default=None,
metadata={
"help": "This parameter is deprecated and will be removed in version 0.18.0. To control prefix caching in "
"vLLM, you should now use the `enable_prefix_caching` parameter in the vLLM server configuration."
},
)
def __post_init__(self):
super().__post_init__()
if self.vllm_device is not None:
warnings.warn(
"`vllm_device` is deprecated and will be removed in version 0.18.0. To use vLLM, start a vLLM server "
"with the `trl vllm-serve` command.",
DeprecationWarning,
num_processes = self.world_size
# The current default effective batch size
if self.generation_batch_size is not None and self.steps_per_generation is not None:
raise ValueError(
"'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time"
)
if self.vllm_gpu_memory_utilization is not None:
warnings.warn(
"`vllm_gpu_memory_utilization` is deprecated and will be removed in v0.18. To control the GPU memory "
"utilization for vLLM, you should now use the `gpu_memory_utilization` parameter in the vLLM server "
"configuration.",
DeprecationWarning,
if self.steps_per_generation is None:
self.steps_per_generation = self.gradient_accumulation_steps
if self.generation_batch_size is None:
self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation
if self.generation_batch_size % self.per_device_train_batch_size * num_processes != 0:
raise ValueError(
f"generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size "
f"({self.per_device_train_batch_size * num_processes})."
)
if self.vllm_dtype is not None:
warnings.warn(
"`vllm_dtype` is deprecated and will be removed in version 0.18.0. To control the data type for vLLM "
"generation, you should now use the `dtype` parameter in the vLLM server configuration.",
DeprecationWarning,
)
self.steps_per_generation = self.generation_batch_size // (self.per_device_train_batch_size * num_processes)
if self.vllm_max_model_len is not None:
warnings.warn(
"`vllm_max_model_len` is deprecated and will be removed in version 0.18.0. To control the "
"`max_model_len` for vLLM, you should now use the `max_model_len` parameter in the vLLM server "
"configuration.",
DeprecationWarning,
# Check if the effective batch size can be divided by the number of generations
if self.num_generations < 2:
raise ValueError(
"GRPO requires at least 2 generations per prompt to calculate the advantages. You provided "
f"{self.num_generations}, which is less than the minimum required."
)
possible_values = [
n_gen for n_gen in range(2, self.generation_batch_size + 1) if (self.generation_batch_size) % n_gen == 0
]
if self.vllm_enable_prefix_caching is not None:
warnings.warn(
"`vllm_enable_prefix_caching` is deprecated and will be removed in version 0.18.0. To control prefix "
"caching in vLLM, you should now use the `enable_prefix_caching` parameter in the vLLM server "
"configuration.",
DeprecationWarning,
if self.num_generations not in possible_values:
raise ValueError(
f"The effective train batch size ({num_processes} x {self.per_device_train_batch_size} x "
f"{self.steps_per_generation}) must be evenly divisible by the number of generations per "
f"prompt ({self.num_generations}). Given the current effective train batch size, the valid values for "
f"the number of generations are: {possible_values}."
)
if self.eval_strategy != "no":
global_eval_batch_size = self.per_device_eval_batch_size * num_processes
possible_values = [
n_gen for n_gen in range(2, global_eval_batch_size + 1) if (global_eval_batch_size) % n_gen == 0
]
if self.num_generations not in possible_values:
raise ValueError(
f"The global eval batch size ({num_processes} x {self.per_device_eval_batch_size}) must be "
f"evenly divisible by the number of generations per prompt ({self.num_generations}). Given the "
"current global eval batch size, the valid values for the number of generations are: "
f"{possible_values}."
)
if self.delta is not None and self.use_liger_loss:
raise ValueError("Liger loss does not support two-sided GRPO loss yet.")

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,79 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Any, Optional
from transformers import TrainingArguments
@dataclass
class IterativeSFTConfig(TrainingArguments):
r"""
Configuration class for the [`IterativeSFTTrainer`].
Only the parameters specific to iterative SFT training are listed here. For details on other parameters, refer to the
[`~transformers.TrainingArguments`] documentation.
Using [`~transformers.HfArgumentParser`] we can turn this class into
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
command line.
Parameters:
> Parameters that control the model
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
argument of the [`IterativeSFTTrainer`] is provided as a string.
> Parameters that control the data preprocessing
max_length (`int` or `None`, *optional*, defaults to `None`):
Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated.
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
The truncation mode to use, either `"keep_end"` or `"keep_start"`.
optimize_device_cache (`bool`, *optional*, defaults to `False`):
Whether to optimize CUDA cache for slightly more memory-efficient training.
"""
# Parameters that control the model
model_init_kwargs: Optional[dict[str, Any]] = field(
default=None,
metadata={
"help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of "
"the `IterativeSFTTrainer` is provided as a string."
},
)
# Parameters that control the data preprocessing
max_length: Optional[int] = field(
default=None,
metadata={
"help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated."
},
)
truncation_mode: str = field(
default="keep_end",
metadata={"help": "The truncation mode to use, either 'keep_end' or 'keep_start'."},
)
optimize_device_cache: bool = field(
default=False,
metadata={"help": "Whether to optimize CUDA cache for slightly more memory-efficient training."},
)
def __post_init__(self):
super().__post_init__()
if self.truncation_mode not in ["keep_end", "keep_start"]:
raise ValueError(f"truncation_mode must be either 'keep_end' or 'keep_start', got {self.truncation_mode}")

View File

@ -20,6 +20,8 @@ import torch
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BaseImageProcessor,
DataCollator,
DataCollatorForLanguageModeling,
@ -36,6 +38,7 @@ from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available
from ..core import PPODecorators
from .iterative_sft_config import IterativeSFTConfig
from .utils import generate_model_card, get_comet_experiment_url
@ -52,39 +55,49 @@ class IterativeSFTTrainer(Trainer):
The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization.
Args:
model (`PreTrainedModel`):
Model to be optimized, either an 'AutoModelForCausalLM' or an 'AutoModelForSeq2SeqLM'.
Check the documentation of `PreTrainedModel` for more details.
args (`transformers.TrainingArguments`):
The arguments to use for training.
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
Processing class used to process the data. If provided, will be used to automatically process the inputs
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
reuse the fine-tuned model.
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
The optimizer and scheduler to use for training.
data_collator (Union[DataCollatorForLanguageModeling, DataCollatorForSeq2Seq], *optional*):
Data collator to be used for training and passed along the dataloader.
model (`Union[str, PreTrainedModel]`):
Model to be trained. Can be either:
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
a path to a *directory* containing model weights saved using
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
in `args.model_init_kwargs`.
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
args ([`IterativeSFTConfig`], *optional*, defaults to `None`):
Configuration for this trainer. If `None`, a default configuration is used.
data_collator (`DataCollator`, *optional*):
Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
Will default to [`~transformers.default_data_collator`] if no `processing_class` is provided, an instance
of [`~transformers.DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or
tokenizer.
eval_dataset (`datasets.Dataset`):
The dataset to use for evaluation.
max_length (`int`, defaults to `None`):
The maximum length of the input.
truncation_mode (`str`, defaults to `keep_end`):
The truncation mode to use, either `keep_end` or `keep_start`.
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
Processing class used to process the data. If `None`, the processing class is loaded from the model's name
with [`~transformers.AutoTokenizer.from_pretrained`].
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
The optimizer and scheduler to use for training.
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
The function to use to preprocess the logits before computing the metrics.
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to metric values.
optimize_device_cache (`bool`, *optional*, defaults to `False`):
Optimize CUDA cache for slightly more memory-efficient training.
max_length (`int`, *optional*, deprecated):
Maximum length of the tokenized sequence. Use `args.max_length` instead.
truncation_mode (`str`, *optional*, deprecated):
The truncation mode to use. Use `args.truncation_mode` instead.
optimize_device_cache (`bool`, *optional*, deprecated):
Whether to optimize CUDA cache. Use `args.optimize_device_cache` instead.
"""
_tag_names = ["trl", "iterative-sft"]
def __init__(
self,
model: Optional[PreTrainedModel] = None,
args: Optional[TrainingArguments] = None,
model: Union[str, PreTrainedModel],
args: Optional[Union[IterativeSFTConfig, TrainingArguments]] = None,
data_collator: Optional[DataCollator] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
@ -92,35 +105,74 @@ class IterativeSFTTrainer(Trainer):
None,
None,
),
data_collator: Optional[DataCollator] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
max_length: Optional[int] = None,
truncation_mode: Optional[str] = "keep_end",
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
optimize_device_cache: Optional[bool] = False,
# Deprecated parameters
max_length: Optional[int] = None,
truncation_mode: Optional[str] = None,
optimize_device_cache: Optional[bool] = None,
):
# Step 0: check positional arguments validity
if not isinstance(processing_class, (PreTrainedTokenizerBase)):
raise ValueError(
f"processing_class must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(processing_class)}"
)
if not isinstance(model, PreTrainedModel):
raise ValueError(f"model must be a PreTrainedModel, got {type(model)}")
if not model.can_generate():
# Handle deprecated parameters
deprecated_params = {}
if max_length is not None:
deprecated_params["max_length"] = max_length
warnings.warn(
f"The current model class {type(model)} is not compatible with `.generate()`"
"Please make sure that this is intended."
"The `max_length` parameter is deprecated and will be removed in version 0.20. "
"Pass it through the `args` parameter using `IterativeSFTConfig(max_length=...)` instead.",
DeprecationWarning,
)
if optimizers[1] is None and args.max_steps == -1:
raise ValueError(
"When no scheduler is provided, you need to set the total number of training steps to perform `max_steps`"
if truncation_mode is not None:
deprecated_params["truncation_mode"] = truncation_mode
warnings.warn(
"The `truncation_mode` parameter is deprecated and will be removed in version 0.20. "
"Pass it through the `args` parameter using `IterativeSFTConfig(truncation_mode=...)` instead.",
DeprecationWarning,
)
if optimize_device_cache is not None:
deprecated_params["optimize_device_cache"] = optimize_device_cache
warnings.warn(
"The `optimize_device_cache` parameter is deprecated and will be removed in version 0.20 "
"Pass it through the `args` parameter using `IterativeSFTConfig(optimize_device_cache=...)` instead.",
DeprecationWarning,
)
self.is_encoder_decoder = getattr(model.config, "is_encoder_decoder", False)
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
# Args
model_id = model if isinstance(model, str) else model.config._name_or_path
if args is None:
model_name = model_id.split("/")[-1]
args = IterativeSFTConfig(f"{model_name}-IterativeSFT")
elif isinstance(args, TrainingArguments) and not isinstance(args, IterativeSFTConfig):
dict_args = args.to_dict()
dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token
dict_args.pop("push_to_hub_token")
args = IterativeSFTConfig(**dict_args)
# Update args with deprecated parameters if provided
if deprecated_params:
for key, value in deprecated_params.items():
setattr(args, key, value)
# Handle the tokenizer
if processing_class is None:
processing_class = AutoTokenizer.from_pretrained(model_id)
# Model
if args.model_init_kwargs is not None and not isinstance(model, str):
warnings.warn(
"You passed model_init_kwargs to the `IterativeSFTConfig`, but your model is already instantiated. "
"The `model_init_kwargs` will be ignored."
)
if isinstance(model, str):
model = self._create_model_from_path(model, args)
# PEFT configuration and model wrapping
if is_peft_available() and isinstance(model, PeftModel):
self.is_peft_model = True
else:
self.is_peft_model = False
self.processing_class = processing_class
self.is_encoder_decoder = getattr(model.config, "is_encoder_decoder", False)
if data_collator is None:
if self.is_encoder_decoder:
@ -132,9 +184,9 @@ class IterativeSFTTrainer(Trainer):
else:
self.data_collator = data_collator
self.max_length = max_length
self.truncation_mode = truncation_mode
self.optimize_device_cache = optimize_device_cache
self.max_length = args.max_length
self.truncation_mode = args.truncation_mode
self.optimize_device_cache = args.optimize_device_cache
super().__init__(
model=model,
@ -167,6 +219,11 @@ class IterativeSFTTrainer(Trainer):
PPODecorators.optimize_device_cache = self.optimize_device_cache
def _create_model_from_path(self, model_path: str, args: IterativeSFTConfig) -> PreTrainedModel:
"""Creates a model from a path or model identifier."""
model_init_kwargs = args.model_init_kwargs or {}
return AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
def prepare_model_inputs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor):
if attention_mask is None:
attention_mask = [torch.ones_like(ids) for ids in input_ids]

View File

@ -19,7 +19,6 @@ import textwrap
import warnings
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from copy import deepcopy
from operator import itemgetter
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
@ -31,7 +30,7 @@ import torch.nn as nn
import torch.nn.functional as F
import transformers
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available, tqdm
from accelerate.utils import tqdm
from datasets import Dataset, concatenate_datasets
from packaging import version
from torch.utils.data import DataLoader, SequentialSampler
@ -54,7 +53,7 @@ from transformers.utils import is_peft_available
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset
from ..import_utils import is_liger_kernel_available
from ..models import PreTrainedModelWrapper, create_reference_model
from ..models import create_reference_model, prepare_deepspeed
from .kto_config import KTOConfig
from .utils import (
DPODataCollatorWithPadding,
@ -68,9 +67,6 @@ from .utils import (
)
if is_deepspeed_available():
import deepspeed
if is_liger_kernel_available():
from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss
@ -779,7 +775,7 @@ class KTOTrainer(Trainer):
)
else:
if self.is_deepspeed_enabled:
self.ref_model = self._prepare_deepspeed(self.ref_model)
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
@ -808,37 +804,6 @@ class KTOTrainer(Trainer):
ignore_index=self.label_pad_token_id, beta=self.beta, use_ref_model=(self.ref_model is not None)
)
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
if model is not None:
if hasattr(model, "config"):
hidden_size = (
max(model.config.hidden_sizes)
if getattr(model.config, "hidden_sizes", None)
else getattr(model.config, "hidden_size", None)
)
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
config_kwargs.update(
{
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
}
)
# If ZeRO-3 is used, we shard both the active and reference model.
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
if config_kwargs["zero_optimization"]["stage"] != 3:
config_kwargs["zero_optimization"]["stage"] = 0
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model
@contextmanager
def null_ref_context(self):
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
@ -1497,10 +1462,12 @@ class KTOTrainer(Trainer):
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]:
if dataset is None:
dataset = self.train_dataset
if dataset is None or not has_length(dataset):
return None
return SequentialSampler(self.train_dataset)
return SequentialSampler(dataset)
def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
"""Generate samples from the model and reference model for the given batch of inputs."""

View File

@ -32,7 +32,7 @@ from transformers import (
)
from transformers.trainer_utils import EvalPrediction
from transformers.training_args import OptimizerNames
from transformers.utils import is_apex_available
from transformers.utils import is_apex_available, is_peft_available
from ..data_utils import is_conversational, maybe_apply_chat_template
from ..models.modeling_base import GeometricMixtureWrapper
@ -59,6 +59,10 @@ if is_wandb_available():
import wandb
if is_peft_available():
from peft import PeftModel
class NashMDTrainer(OnlineDPOTrainer):
r"""
Initialize NashMDTrainer as a subclass of [`OnlineDPOConfig`].
@ -170,28 +174,50 @@ class NashMDTrainer(OnlineDPOTrainer):
return self._mixture_coef
def _generate_completions(self, model, prompts):
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
model_output = unwrapped_model.generate(
# Generate completions from the policy model.
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_for_gen_ctx:
model_output = unwrapped_policy_for_gen_ctx.generate(
input_ids=prompts["input_ids"],
attention_mask=prompts["attention_mask"],
generation_config=self.generation_config,
)
ref_model = model if self.ref_model is None else self.ref_model
with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
mixture_model = GeometricMixtureWrapper(
model=unwrapped_model,
ref_model=unwrapped_ref_model,
generation_config=self.generation_config,
mixture_coef=self.mixture_coef,
device=self.accelerator.device,
)
# Get the DDP/FSDP unwrapped version of the main model.
# This will be the policy model for GeometricMixtureWrapper (PEFT adapters active if PEFT is used).
policy_model_for_gmw = self.accelerator.unwrap_model(model)
mixture_output = mixture_model.generate(
input_ids=prompts["input_ids"],
attention_mask=prompts["attention_mask"],
generation_config=self.generation_config,
)
# Determine the correct reference model for GeometricMixtureWrapper.
# This also needs to be DDP/FSDP unwrapped.
ref_model_for_gmw: torch.nn.Module
if self.ref_model is None:
# No explicit ref_model is provided.
# Use the base of the main `model` if it's a PEFT model.
# policy_model_for_gmw is already DDP-unwrapped.
if is_peft_available() and isinstance(policy_model_for_gmw, PeftModel):
ref_model_for_gmw = policy_model_for_gmw.get_base_model()
else:
# Not a PEFT model (or PEFT not available), or already a base model.
# Use the DDP-unwrapped policy model itself as the reference.
ref_model_for_gmw = policy_model_for_gmw
else:
# An explicit ref_model is provided. Unwrap it for DDP/FSDP.
ref_model_for_gmw = self.accelerator.unwrap_model(self.ref_model)
# Both models given to GeometricMixtureWrapper (policy_model_for_gmw and ref_model_for_gmw) are DDP-unwrapped.
with torch.no_grad(): # Ensure no_grad context for mixture model generation
mixture_model = GeometricMixtureWrapper(
model=policy_model_for_gmw,
ref_model=ref_model_for_gmw,
generation_config=self.generation_config,
mixture_coef=self.mixture_coef,
device=self.accelerator.device,
)
mixture_output = mixture_model.generate(
input_ids=prompts["input_ids"],
attention_mask=prompts["attention_mask"],
generation_config=self.generation_config,
)
return model_output, mixture_output

View File

@ -19,7 +19,6 @@ import textwrap
import warnings
from collections import defaultdict
from contextlib import nullcontext
from copy import deepcopy
from typing import Any, Callable, Literal, Optional, Union
import numpy as np
@ -30,7 +29,6 @@ import torch.nn as nn
import torch.nn.functional as F
import transformers
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available
from datasets import Dataset
from packaging import version
from torch.utils.data import DataLoader
@ -52,7 +50,6 @@ from transformers.trainer_utils import EvalLoopOutput
from transformers.utils import is_peft_available, is_torch_fx_proxy
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
from ..models import PreTrainedModelWrapper
from .orpo_config import ORPOConfig
from .utils import (
DPODataCollatorWithPadding,
@ -75,9 +72,6 @@ if is_peft_available():
if is_wandb_available():
import wandb
if is_deepspeed_available():
import deepspeed
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
@ -358,37 +352,6 @@ class ORPOTrainer(Trainer):
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
)
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
if model is not None:
if hasattr(model, "config"):
hidden_size = (
max(model.config.hidden_sizes)
if getattr(model.config, "hidden_sizes", None)
else getattr(model.config, "hidden_size", None)
)
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
config_kwargs.update(
{
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
}
)
# If ZeRO-3 is used, we shard both the active and reference model.
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
if config_kwargs["zero_optimization"]["stage"] != 3:
config_kwargs["zero_optimization"]["stage"] = 0
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
model.eval()
return model
def build_tokenized_answer(self, prompt, answer):
"""
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.

View File

@ -44,7 +44,7 @@ from transformers import (
from transformers.integrations import get_reporting_integration_callbacks
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
from transformers.utils import is_peft_available
from transformers.utils import is_peft_available, is_rich_available
from ..core import masked_mean, masked_whiten
from ..models import create_reference_model
@ -54,6 +54,7 @@ from .utils import (
OnlineTrainerState,
batch_generation,
disable_dropout_in_model,
empty_cache,
exact_div,
first_true_indices,
forward,
@ -107,7 +108,7 @@ class PPOTrainer(Trainer):
ref_model: Optional[nn.Module],
reward_model: nn.Module,
train_dataset: Dataset,
value_model: Optional[nn.Module] = None,
value_model: nn.Module,
data_collator: Optional[DataCollatorWithPadding] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
# less commonly used
@ -437,7 +438,7 @@ class PPOTrainer(Trainer):
logits = logitss[i : i + args.local_rollout_forward_batch_size]
logprob = selective_log_softmax(logits, response)
del logits
torch.cuda.empty_cache()
empty_cache()
if ref_policy is None:
with self.null_ref_context():
@ -448,7 +449,7 @@ class PPOTrainer(Trainer):
ref_logits /= args.temperature + 1e-7
ref_logprob = selective_log_softmax(ref_logits, response)
del ref_output, ref_logits
torch.cuda.empty_cache()
empty_cache()
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
postprocessed_response = response
@ -484,7 +485,7 @@ class PPOTrainer(Trainer):
scores = torch.cat(scores, 0)
values = torch.cat(values, 0)
del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
torch.cuda.empty_cache()
empty_cache()
gc.collect()
# Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
@ -531,7 +532,7 @@ class PPOTrainer(Trainer):
returns = advantages + values
advantages = masked_whiten(advantages, ~padding_mask)
advantages = torch.masked_fill(advantages, padding_mask, 0)
torch.cuda.empty_cache()
empty_cache()
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
for ppo_epoch_idx in range(args.num_ppo_epochs):
@ -612,7 +613,7 @@ class PPOTrainer(Trainer):
mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
)
# fmt: on
torch.cuda.empty_cache()
empty_cache()
with torch.no_grad():
mean_kl = kl.sum(1).mean()
mean_entropy = (-logprobs).sum(1).mean()
@ -649,12 +650,12 @@ class PPOTrainer(Trainer):
self._save_checkpoint(model, trial=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
torch.cuda.empty_cache()
empty_cache()
gc.collect()
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
self.generate_completions(sampling=True)
torch.cuda.empty_cache()
empty_cache()
del (
query_responses,
responses,
@ -674,7 +675,7 @@ class PPOTrainer(Trainer):
advantages,
returns,
)
torch.cuda.empty_cache()
empty_cache()
# HF trainer specifics
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
@ -732,7 +733,8 @@ class PPOTrainer(Trainer):
df = pd.DataFrame(table)
if self.accelerator.is_main_process:
print_rich_table(df.iloc[0 : 0 + 5])
if is_rich_available():
print_rich_table(df.iloc[0 : 0 + 5])
if "wandb" in args.report_to:
import wandb

View File

@ -38,7 +38,7 @@ from transformers import (
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_pt_utils import nested_detach
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available
from transformers.utils import is_peft_available, is_rich_available
from ..data_utils import maybe_apply_chat_template
from .reward_config import RewardConfig
@ -350,7 +350,8 @@ class RewardTrainer(Trainer):
break
df = pd.DataFrame(table)
if self.accelerator.process_index == 0:
print_rich_table(df[:num_print_samples])
if is_rich_available():
print_rich_table(df[:num_print_samples])
if "wandb" in self.args.report_to:
import wandb

View File

@ -43,6 +43,7 @@ from transformers import (
from transformers.integrations import get_reporting_integration_callbacks
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
from transformers.utils import is_rich_available
from ..models.utils import unwrap_model_for_generation
from ..trainer.utils import (
@ -59,7 +60,7 @@ from ..trainer.utils import (
truncate_response,
)
from .rloo_config import RLOOConfig
from .utils import generate_model_card, get_comet_experiment_url, log_table_to_comet_experiment
from .utils import empty_cache, generate_model_card, get_comet_experiment_url, log_table_to_comet_experiment
if is_wandb_available():
@ -332,14 +333,14 @@ class RLOOTrainer(Trainer):
logits = logitss[i : i + args.local_rollout_forward_batch_size]
logprob = selective_log_softmax(logits, response)
del logits
torch.cuda.empty_cache()
empty_cache()
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
ref_logits = ref_output.logits[:, context_length - 1 : -1]
ref_logits /= args.temperature + 1e-7
ref_logprob = selective_log_softmax(ref_logits, response)
del ref_output, ref_logits
torch.cuda.empty_cache()
empty_cache()
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
postprocessed_response = response
@ -380,7 +381,7 @@ class RLOOTrainer(Trainer):
sequence_lengths = torch.cat(sequence_lengths, 0)
scores = torch.cat(scores, 0)
del (logprob, ref_logprob, score)
torch.cuda.empty_cache()
empty_cache()
gc.collect()
# Response Processing 3. filter response. Ensure that the sample contains stop_token_id
@ -438,7 +439,7 @@ class RLOOTrainer(Trainer):
if args.normalize_advantage:
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
torch.cuda.empty_cache()
empty_cache()
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
for ppo_epoch_idx in range(args.num_ppo_epochs):
@ -514,7 +515,7 @@ class RLOOTrainer(Trainer):
mb_advantage, mb_responses, mb_query_responses, mb_logprobs,
)
# fmt: on
torch.cuda.empty_cache()
empty_cache()
# Compute metrics
with torch.no_grad():
@ -551,7 +552,7 @@ class RLOOTrainer(Trainer):
if self.control.should_save:
self._save_checkpoint(model, trial=None)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
torch.cuda.empty_cache()
empty_cache()
gc.collect()
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
@ -625,7 +626,8 @@ class RLOOTrainer(Trainer):
df = pd.DataFrame(table)
if self.accelerator.is_main_process:
print_rich_table(df.iloc[0 : 0 + 5])
if is_rich_available():
print_rich_table(df.iloc[0 : 0 + 5])
if "wandb" in args.report_to:
import wandb

View File

@ -62,6 +62,8 @@ class SFTConfig(TrainingArguments):
continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only
supported with the `flash_attention_2` attention implementation, which can efficiently handle the flattened
batch structure.
pad_to_multiple_of (`int` or `None`, *optional*, defaults to `None`):
If set, the sequences will be padded to a multiple of this value.
eval_packing (`bool` or `None`, *optional*, defaults to `None`):
Whether to pack the eval dataset. If `None`, uses the same value as `packing`.
@ -76,6 +78,9 @@ class SFTConfig(TrainingArguments):
`False`, loss is computed on the entire sequence. If `None` (default), the behavior depends on the dataset:
loss is computed on the completion for [prompt-completion](#prompt-completion) datasets, and on
the full sequence for [language modeling](#language-modeling) datasets.
activation_offloading (`bool`, *optional*, defaults to `False`):
Whether to offload the activations to the CPU.
"""
# Parameters that control the model
@ -140,6 +145,10 @@ class SFTConfig(TrainingArguments):
"handle the flattened batch structure."
},
)
pad_to_multiple_of: Optional[int] = field(
default=None,
metadata={"help": "If set, the sequences will be padded to a multiple of this value."},
)
eval_packing: Optional[bool] = field(
default=None,
metadata={"help": "Whether to pack the eval dataset. If `None`, uses the same value as `packing`."},
@ -165,79 +174,25 @@ class SFTConfig(TrainingArguments):
)
},
)
activation_offloading: bool = field(
default=False,
metadata={"help": "Whether to offload the activations to the CPU."},
)
# Deprecated parameters
dataset_batch_size: Optional[int] = field(
default=None,
metadata={
"help": "This parameter is deprecated and will be removed in version 0.18.0. You can safely remove this "
"parameter from your configuration."
},
)
num_of_sequences: Optional[int] = field(
default=None,
metadata={
"help": "This parameter is deprecated and will be removed in version 0.18.0. Use `max_length` instead, "
"which specifies the maximum length of the tokenized sequence, unlike `num_of_sequences`, which referred "
"to string sequences."
},
)
chars_per_token: Optional[float] = field(
default=None,
metadata={
"help": "This parameter is deprecated and will be removed in version 0.18.0. If you want to customize the "
"packing length, use `max_length`."
},
)
max_seq_length: Optional[int] = field(
default=None,
metadata={
"help": "This parameter is deprecated and will be removed in version 0.20.0. Use `max_length` instead."
},
)
use_liger: Optional[bool] = field(
default=None,
metadata={
"help": "This parameter is deprecated and will be removed in version 0.18.0. Use `use_liger_kernel` "
"instead."
},
)
def __post_init__(self):
super().__post_init__()
if self.dataset_batch_size is not None:
warnings.warn(
"`dataset_batch_size` is deprecated and will be removed in version 0.18.0. You can safely remove this "
"parameter from your configuration.",
DeprecationWarning,
)
if self.num_of_sequences is not None:
warnings.warn(
"`num_of_sequences` is deprecated and will be removed in version 0.18.0. Use `max_length` instead, "
"which specifies the maximum length of the tokenized sequence, unlike `num_of_sequences`, which "
"referred to string sequences.",
DeprecationWarning,
)
if self.chars_per_token is not None:
warnings.warn(
"`chars_per_token` is deprecated and will be removed in version 0.18.0. If you want to customize the "
"packing length, use `max_length`.",
DeprecationWarning,
)
if self.max_seq_length is not None:
warnings.warn(
"`max_seq_length` is deprecated and will be removed in version 0.20.0. Use `max_length` instead.",
DeprecationWarning,
)
self.max_length = self.max_seq_length
if self.use_liger is not None:
warnings.warn(
"`use_liger` is deprecated and will be removed in version 0.18.0. Use `use_liger_kernel` instead.",
DeprecationWarning,
)
self.use_liger_kernel = self.use_liger

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import dataclasses
import os
import warnings
@ -51,6 +52,7 @@ from ..data_utils import (
pack_dataset,
truncate_dataset,
)
from ..models import get_act_offloading_ctx_manager
from .sft_config import SFTConfig
from .utils import (
ConstantLengthDataset,
@ -80,7 +82,9 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
Token ID to use for padding.
completion_only_loss (`bool`, *optional*, defaults to `True`):
When the input contains a completion mask (`completion_mask`), the labels are set to -100 for the tokens
that are not in the completion.
that are no in the completion.
pad_to_multiple_of (`int` or `None`, *optional*, defaults to `None`):
If set, the sequences will be padded to a multiple of this value.
return_tensors (`str`, *optional*, defaults to `"pt"`):
Type of Tensor to return. Only `"pt"` is currently supported.
@ -116,6 +120,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
pad_token_id: int
completion_only_loss: bool = True
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
@ -128,11 +133,22 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
# Pad
output = {}
output["input_ids"] = pad(input_ids, padding_value=self.pad_token_id, padding_side="right")
output["attention_mask"] = pad(attention_mask, padding_value=0, padding_side="right")
output["labels"] = pad(labels, padding_value=-100, padding_side="right")
output["input_ids"] = pad(
input_ids,
padding_value=self.pad_token_id,
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
)
output["attention_mask"] = pad(
attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"] = pad(
labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
if self.completion_only_loss and "completion_mask" in examples[0]:
completion_mask = pad(completion_mask, padding_value=0, padding_side="right")
completion_mask = pad(
completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion
return output
@ -210,7 +226,8 @@ class SFTTrainer(Trainer):
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
formatting_func (`Optional[Callable]`):
Formatting function applied to the dataset before tokenization.
Formatting function applied to the dataset before tokenization. Applying the formatting function explicitly
converts the dataset into a [language modeling](#language-modeling) type.
"""
_tag_names = ["trl", "sft"]
@ -315,11 +332,21 @@ class SFTTrainer(Trainer):
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
"in the vocabulary before using it as a padding token."
)
data_collator = DataCollatorForLanguageModeling(pad_token_id, self.completion_only_loss)
data_collator = DataCollatorForLanguageModeling(
pad_token_id, self.completion_only_loss, args.pad_to_multiple_of
)
# Dataset
preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
if preprocess_dataset:
if self.completion_only_loss and formatting_func:
raise ValueError(
"A formatting function was provided while `completion_only_loss=True`, which is incompatible. "
"Using a formatter converts the dataset to a language modeling type, conflicting with "
"completion-only loss. To resolve this, apply your formatting function before passing the "
"dataset, or disable `completion_only_loss` in `SFTConfig`."
)
train_dataset = self._prepare_dataset(
train_dataset, processing_class, args, args.packing, formatting_func, "train"
)
@ -370,6 +397,12 @@ class SFTTrainer(Trainer):
**super_init_kwargs,
)
# Initialize activation offloading context
if self.args.activation_offloading:
self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model)
else:
self.maybe_activation_offload_context = contextlib.nullcontext()
# Add tags for models that have been loaded with the correct transformers version
if hasattr(self.model, "add_model_tags"):
self.model.add_model_tags(self._tag_names)
@ -650,7 +683,7 @@ class SFTTrainer(Trainer):
"""
Compute training loss and additionally compute token accuracies
"""
mode = "eval" if self.control.should_evaluate else "train"
mode = "train" if self.model.training else "eval"
(loss, outputs) = super().compute_loss(
model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
)
@ -694,8 +727,13 @@ class SFTTrainer(Trainer):
return (loss, outputs) if return_outputs else loss
# Override training step to add activation offloading context.
def training_step(self, *args, **kwargs):
with self.maybe_activation_offload_context:
return super().training_step(*args, **kwargs)
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
mode = "eval" if self.control.should_evaluate else "train"
mode = "train" if self.model.training else "eval"
metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`

View File

@ -45,12 +45,12 @@ from transformers import (
)
from transformers.utils import (
is_peft_available,
is_rich_available,
is_torch_mlu_available,
is_torch_npu_available,
is_torch_xpu_available,
)
from ..import_utils import is_rich_available
from ..trainer.model_config import ModelConfig
@ -415,7 +415,12 @@ class RewardDataCollatorWithPadding:
return batch
def pad(tensors: list[torch.Tensor], padding_value: int = 0, padding_side: str = "right") -> torch.Tensor:
def pad(
tensors: list[torch.Tensor],
padding_value: int = 0,
padding_side: str = "right",
pad_to_multiple_of: Optional[int] = None,
) -> torch.Tensor:
"""
Pads a list of tensors to the same shape along the first dimension.
@ -426,6 +431,8 @@ def pad(tensors: list[torch.Tensor], padding_value: int = 0, padding_side: str =
Value to use for padding. Default is 0.
padding_side (`str`):
Side on which to add padding. Must be 'left' or 'right'. Default is 'right'.
pad_to_multiple_of (`int`, *optional*, defaults to `None`):
If set will pad the sequence to a multiple of the provided value.
Returns:
`torch.Tensor`:
@ -446,18 +453,25 @@ def pad(tensors: list[torch.Tensor], padding_value: int = 0, padding_side: str =
# Determine the maximum shape for each dimension
output_shape = np.max([t.shape for t in tensors], 0).tolist()
# Apply pad_to_multiple_of to the first (sequence) dimension
if pad_to_multiple_of is not None:
remainder = output_shape[0] % pad_to_multiple_of
if remainder != 0:
output_shape[0] += pad_to_multiple_of - remainder
# Create an output tensor filled with the padding value
output = torch.full((len(tensors), *output_shape), padding_value, dtype=tensors[0].dtype, device=tensors[0].device)
for i, t in enumerate(tensors):
# Determine the slice for the sequence dimension
if padding_side == "left":
seq_slice = slice(output_shape[0] - t.shape[0], output_shape[0])
seq_start = output_shape[0] - t.shape[0]
elif padding_side == "right":
seq_slice = slice(0, t.shape[0])
seq_start = 0
else:
raise ValueError("padding_side must be 'left' or 'right'")
# Define the slices
seq_slice = slice(seq_start, seq_start + t.shape[0])
slices = (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:])
output[i][slices] = t
@ -964,7 +978,11 @@ def cap_exp(value, cap=-1):
return torch.exp(torch.clamp(value, max=cap))
def print_rich_table(df: pd.DataFrame) -> Table:
def print_rich_table(df: pd.DataFrame) -> None:
if not is_rich_available():
raise ImportError(
"The function `print_rich_table` requires the `rich` library. Please install it with `pip install rich`."
)
console = Console()
table = Table(show_lines=True)
for column in df.columns:
@ -1600,7 +1618,7 @@ def log_table_to_comet_experiment(name: str, table: pd.DataFrame) -> None:
experiment.log_table(tabular_data=table, filename=name)
def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
"""
Shift non-zero elements in the mask and corresponding tensors to the left.
@ -1642,28 +1660,59 @@ def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor
[5, 6, 0]])
```
"""
_, M = mask.shape
# Create copy of mask and tensors
mask = mask.clone()
mask_copy = mask.clone()
tensors = [t.clone() for t in tensors]
# Shift non-zero values to the left
for i in range(mask.size(0)):
first_one_idx = torch.nonzero(mask[i])[0].item()
mask[i] = torch.roll(mask[i], shifts=-first_one_idx)
for tensor in tensors:
tensor[i] = torch.roll(tensor[i], shifts=-first_one_idx)
first_non_zero = mask_copy.argmax(dim=1)
pos = torch.arange(M, device=mask_copy.device).unsqueeze(0)
idx_roll = (pos + first_non_zero.unsqueeze(1)) % M
mask_roll = mask_copy.gather(1, idx_roll)
rolled_tensors = [t.gather(1, idx_roll) for t in tensors]
# Get the first column idx that is all zeros and remove every column after that
empty_cols = torch.sum(mask, dim=0) == 0
first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else mask.size(1)
mask = mask[:, :first_empty_col]
for i, tensor in enumerate(tensors):
tensors[i] = tensor[:, :first_empty_col]
# Truncate trailing columns that are all zeros in mask_roll
col_sums = mask_roll.sum(dim=0)
empty_cols = col_sums == 0
first_empty_col = int(empty_cols.to(torch.int8).argmax()) if empty_cols.any() else M
flushed_mask = mask_roll[:, :first_empty_col]
flushed_tensors = [t[:, :first_empty_col] for t in rolled_tensors]
if not tensors:
return mask
else:
return mask, *tensors
if not flushed_tensors:
return flushed_mask
return flushed_mask, *flushed_tensors
def flush_right(mask: torch.Tensor, *tensors: torch.Tensor) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
"""
Shift non-zero elements in the mask and corresponding tensors to the right. See `flush_left` for details.
"""
_, M = mask.shape
# Create copy of mask and tensors
mask_copy = mask.clone()
tensors = [t.clone() for t in tensors]
# Shift non-zero values to the right
flipped_mask = torch.fliplr(mask_copy)
first_non_zero = flipped_mask.argmax(dim=1)
pos = torch.arange(M, device=mask_copy.device).unsqueeze(0)
idx_roll = (pos - first_non_zero.unsqueeze(1)) % M
mask_roll = mask_copy.gather(1, idx_roll)
rolled_tensors = [t.gather(1, idx_roll) for t in tensors]
# Truncate leading columns that are all zeros in mask_roll
col_sums = mask_roll.sum(dim=0)
non_empty_cols = col_sums != 0
first_non_empty_col = int(non_empty_cols.to(torch.int8).argmax()) if non_empty_cols.any() else M
flushed_mask = mask_roll[:, first_non_empty_col:]
flushed_tensors = [t[:, first_non_empty_col:] for t in rolled_tensors]
if not flushed_tensors:
return flushed_mask
return flushed_mask, *flushed_tensors
def selective_log_softmax(logits, index):
@ -1702,7 +1751,12 @@ def selective_log_softmax(logits, index):
def print_prompt_completions_sample(
prompts: list[str], completions: list[str], rewards: dict[str, list[float]], step: int, num_samples: int = None
prompts: list[str],
completions: list[str],
rewards: dict[str, list[float]],
advantages: list[float],
step: int,
num_samples: int = None,
) -> None:
"""
Print out a sample of model completions to the console with multiple reward metrics.
@ -1717,6 +1771,8 @@ def print_prompt_completions_sample(
List of completions corresponding to the prompts.
rewards (`dict[str, list[float]]`):
Dictionary where keys are reward names and values are lists of rewards.
advantages (`list[float]`):
List of advantages corresponding to the prompts and completions.
step (`int`):
Current training step number, used in the output title.
num_samples (`int` or `None`, *optional*, defaults to `None`):
@ -1728,18 +1784,24 @@ def print_prompt_completions_sample(
>>> prompts = ["The sky is", "The sun is"]
>>> completions = [" blue.", " in the sky."]
>>> rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]}
>>> print_prompt_completions_sample(prompts, completions, rewards, 42)
╭────────────────────── Step 42 ───────────────────────╮
│ ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┓ │
┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃
┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━┩
│ The sky is │ blue. │ 0.12 │ 0.79 │
├────────────┼──────────────┼─────────────┼────────┤
│ The sun is │ in the sky. │ 0.46 │ 0.10 │
└────────────┴──────────────┴─────────────┴────────┘
──────────────────────────────────────────────────────
>>> advantages = [0.987, 0.654]
>>> print_prompt_completions_sample(prompts, completions, rewards, advantages, 42)
╭──────────────────────────── Step 42 ─────────────────────────────╮
┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┓
┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ Advantage ┃
┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━┩
│ The sky is │ blue. │ 0.12 │ 0.79 │ 0.99 │
├────────────┼──────────────┼─────────────┼────────┼───────────┤
│ The sun is │ in the sky. │ 0.46 │ 0.10 │ 0.65 │
│ └──────────────────────────────────────────────────────────┘ │
╰──────────────────────────────────────────────────────────────────╯
```
"""
if not is_rich_available():
raise ImportError(
"The function `print_prompt_completions_sample` requires the `rich` library. Please install it with "
"`pip install rich`."
)
console = Console()
table = Table(show_header=True, header_style="bold white", expand=True)
@ -1748,6 +1810,7 @@ def print_prompt_completions_sample(
table.add_column("Completion", style="bright_green")
for reward_name in rewards.keys():
table.add_column(reward_name, style="bold cyan", justify="right")
table.add_column("Advantage", style="bold magenta", justify="right")
# Some basic input validation
if num_samples is not None:
@ -1762,10 +1825,11 @@ def print_prompt_completions_sample(
prompts = [prompts[i] for i in indices]
completions = [completions[i] for i in indices]
rewards = {key: [val[i] for i in indices] for key, val in rewards.items()}
advantages = [advantages[i] for i in indices]
for i in range(len(prompts)):
reward_values = [f"{rewards[key][i]:.2f}" for key in rewards.keys()] # 2 decimals
table.add_row(Text(prompts[i]), Text(completions[i]), *reward_values)
table.add_row(Text(prompts[i]), Text(completions[i]), *reward_values, f"{advantages[i]:.2f}")
table.add_section() # Adds a separator between rows
panel = Panel(table, expand=False, title=f"Step {step}", border_style="bold white")

View File

@ -33,6 +33,7 @@ from transformers import (
)
from transformers.trainer_utils import EvalPrediction
from transformers.training_args import OptimizerNames
from transformers.utils import is_peft_available
from ..data_utils import is_conversational, maybe_apply_chat_template
from ..models.utils import unwrap_model_for_generation
@ -58,6 +59,10 @@ if is_wandb_available():
import wandb
if is_peft_available():
from peft import PeftModel
class XPOTrainer(OnlineDPOTrainer):
r"""
Initialize XPOTrainer as a subclass of [`OnlineDPOConfig`].
@ -174,16 +179,26 @@ class XPOTrainer(OnlineDPOTrainer):
return self._alpha
def _generate_completions(self, prompts, model):
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
model_output = unwrapped_model.generate(
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_model_for_gen:
model_output = unwrapped_policy_model_for_gen.generate(
input_ids=prompts["input_ids"],
attention_mask=prompts["attention_mask"],
generation_config=self.generation_config,
)
ref_model = model if self.ref_model is None else self.ref_model
with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
ref_output = unwrapped_ref_model.generate(
actual_model_for_ref_generation: torch.nn.Module
if self.ref_model is None:
unwrapped_main_model_for_ref_logic = self.accelerator.unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_main_model_for_ref_logic, PeftModel):
actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic.get_base_model()
else:
actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic
else:
actual_model_for_ref_generation = self.accelerator.unwrap_model(self.ref_model)
with unwrap_model_for_generation(actual_model_for_ref_generation, self.accelerator) as final_ref_model_for_gen:
ref_output = final_ref_model_for_gen.generate(
input_ids=prompts["input_ids"],
attention_mask=prompts["attention_mask"],
generation_config=self.generation_config,