mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
64 Commits
Author | SHA1 | Date | |
---|---|---|---|
2c49300910 | |||
e530486c26 | |||
1bae58c292 | |||
ef4b0b225c | |||
8e8e62b380 | |||
824100ce25 | |||
4e7f0a5eb9 | |||
17a9069710 | |||
cb07c44920 | |||
0b6a1874f1 | |||
ac18c9d532 | |||
d1174adc5b | |||
cd838417e4 | |||
c7e3f096a5 | |||
5c08897570 | |||
3ef9faf257 | |||
9ac614fb08 | |||
29401e790e | |||
31bf3f9244 | |||
7f32792c07 | |||
3d8727918a | |||
65245f6be8 | |||
a528b9c465 | |||
e0dd525021 | |||
64aa06499b | |||
be93a0c30c | |||
f9fbd91ea9 | |||
54d4f6b13a | |||
05bc43e960 | |||
d3dc8ff654 | |||
21738c3732 | |||
eab175d434 | |||
4da4dc9117 | |||
6b3a02385d | |||
abbbb93d6a | |||
cafa663c84 | |||
fd04a5461a | |||
56e5766205 | |||
89d44caece | |||
adfa7fd59a | |||
cf5183db7f | |||
1954c02d86 | |||
45f4c58832 | |||
cc044e35b2 | |||
999acd53ec | |||
8606b1ad09 | |||
a673da5773 | |||
00b8e311aa | |||
c163cf5081 | |||
bc9c019c43 | |||
18596cf232 | |||
280d35301b | |||
13fa8402a3 | |||
09b669fbf7 | |||
01d0be15cb | |||
3a42af1c78 | |||
aaf39604ba | |||
2bf48478e8 | |||
a8cfca6d01 | |||
1bca49515e | |||
39e96394a9 | |||
8e6ed93dfd | |||
29c5e05e3a | |||
a9b27f82d6 |
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@ -36,7 +36,7 @@ jobs:
|
||||
name: Tests
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.9', '3.10', '3.11', '3.12']
|
||||
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
|
||||
fail-fast: false
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
|
2
.github/workflows/tests_latest.yml
vendored
2
.github/workflows/tests_latest.yml
vendored
@ -24,7 +24,7 @@ jobs:
|
||||
steps:
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
with: { ref: v0.17-release }
|
||||
with: { ref: v0.18-release }
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
|
@ -1,8 +1,8 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.3
|
||||
rev: v0.11.10
|
||||
hooks:
|
||||
- id: ruff
|
||||
- id: ruff-check
|
||||
types_or: [ python, pyi ]
|
||||
args: [ --fix ]
|
||||
- id: ruff-format
|
||||
|
@ -31,4 +31,4 @@ keywords:
|
||||
- pytorch
|
||||
- transformers
|
||||
license: Apache-2.0
|
||||
version: 0.17
|
||||
version: 0.18
|
||||
|
190
CONTRIBUTING.md
190
CONTRIBUTING.md
@ -456,3 +456,193 @@ Warnings play a critical role in guiding users toward resolving potential issues
|
||||
```
|
||||
|
||||
By following this classification, you ensure that warnings, information, and exceptions are used appropriately, providing clear guidance to the user without cluttering the system with unnecessary messages.
|
||||
|
||||
|
||||
## Making a release
|
||||
|
||||
> [!NOTE]
|
||||
> VERSION needs to be formatted following the `v{major}.{minor}.{patch}` convention. We need to follow this convention to be able to retrieve versioned scripts.
|
||||
|
||||
To create the package for PyPI.
|
||||
|
||||
#### 0. Prerequisites
|
||||
|
||||
- Dependencies:
|
||||
- twine: `pip install build twine`
|
||||
- Create an account in (and join the `trl` project):
|
||||
- PyPI: https://pypi.org/
|
||||
- Test PyPI: https://test.pypi.org/
|
||||
|
||||
#### 1. Ensure your local repository is up to date with the upstream repository
|
||||
|
||||
```bash
|
||||
git checkout main
|
||||
git pull origin main
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Do not merge other pull requests into `main` until the release is done. This is to ensure that the release is stable and does not include any untested changes. Announce internally (#trl-internal) to other maintainers that you are doing a release and that they must not merge PRs until the release is done.
|
||||
|
||||
#### 2. Create a release branch from main
|
||||
|
||||
```bash
|
||||
git checkout -b release-v{major}.{minor}
|
||||
```
|
||||
|
||||
#### 3. Change the version in the following files
|
||||
|
||||
- `.github/workflows/tests_latest.yml`:
|
||||
```diff
|
||||
- with: { ref: v{major}.{minor-1}-release }
|
||||
+ with: { ref: v{major}.{minor}-release }
|
||||
```
|
||||
- `CITATION.cff`
|
||||
```diff
|
||||
- version: {major}.{minor-1}
|
||||
+ version: {major}.{minor}
|
||||
```
|
||||
- `__init__.py`
|
||||
```diff
|
||||
- __version__ = "{major}.{minor}.0.dev0"
|
||||
+ __version__ = "{major}.{minor}.0"
|
||||
```
|
||||
- `setup.cfg`
|
||||
```diff
|
||||
- version = {major}.{minor}.0.dev0
|
||||
+ version = {major}.{minor}.0
|
||||
```
|
||||
|
||||
#### 4. Commit and push these changes
|
||||
|
||||
```shell
|
||||
git commit -m 'Release: {major}.{minor}'
|
||||
git push origin release-v{major}.{minor}
|
||||
```
|
||||
|
||||
#### 5. Create a pull request
|
||||
|
||||
from `release-v{major}.{minor}` to `main`, named `Release: v{major}.{minor}`, wait for tests to pass, and request a review.
|
||||
|
||||
#### 6. Once the pull request is approved, merge it into `main`
|
||||
|
||||
#### 7. Add a tag in git to mark the release
|
||||
|
||||
```shell
|
||||
git checkout main
|
||||
git pull origin main
|
||||
git tag -a v{major}.{minor}.0 -m 'Adds tag v{major}.{minor}.0 for PyPI'
|
||||
git push origin v{major}.{minor}.0
|
||||
```
|
||||
|
||||
#### 8. Create a branch `v{major}.{minor}-release` for future patch releases.
|
||||
|
||||
```shell
|
||||
git checkout -b v{major}.{minor}-release
|
||||
git push origin v{major}.{minor}-release
|
||||
```
|
||||
|
||||
This ensures that future patch releases (`v{major}.{minor}.1`, `v{major}.{minor}.2`, etc.) can be made separately from `main`.
|
||||
|
||||
#### 9. Create the wheels for your release
|
||||
|
||||
These are the artifacts that will be uploaded to PyPI and installed by users via `pip install trl`.
|
||||
|
||||
Clean previous builds:
|
||||
|
||||
```shell
|
||||
rm -rf build dist
|
||||
```
|
||||
|
||||
At the root of your repo, run
|
||||
|
||||
```bash
|
||||
python -m build .
|
||||
```
|
||||
|
||||
This will create a folders named `dist` with the new versions of your package.
|
||||
|
||||
#### 10. Upload the package to PyPI Test
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Do not skip this step. It is important to test the package before uploading it to the main PyPI server.
|
||||
|
||||
```shell
|
||||
twine upload dist/* -r testpypi
|
||||
```
|
||||
|
||||
Then in a fresh environment containing all dependencies you need, try to install your new package from the PyPI test server.
|
||||
|
||||
```bash
|
||||
pip install -i https://test.pypi.org/simple/ trl
|
||||
```
|
||||
|
||||
You might get errors for missing dependencies since the PyPI test server does not contain all packages like PyPI does. To make sure you have everything you can do:
|
||||
|
||||
```bash
|
||||
pip install trl
|
||||
pip uninstall trl
|
||||
```
|
||||
|
||||
(the second line will remove trl but keep all its dependencies).
|
||||
|
||||
Also make sure you can actually use the package! Run the following line:
|
||||
|
||||
```bash
|
||||
python -c "from trl import *"
|
||||
```
|
||||
|
||||
along with anything that tests:
|
||||
|
||||
- the core feature of your package
|
||||
- the new features you’re adding in the release
|
||||
|
||||
#### 11. Publish on PyPI
|
||||
|
||||
> [!WARNING]
|
||||
> This can't be reverted. Make sure you have tested everything before doing this step.
|
||||
|
||||
```shell
|
||||
twine upload dist/*
|
||||
```
|
||||
|
||||
#### 12. Create a GitHub Release
|
||||
|
||||
1. Go to the repo’s [releases section](https://github.com/huggingface/trl/releases) on GitHub.
|
||||
2. Click **Draft a new release**.
|
||||
3. Select the `v{major}.{minor}.0` tag you just created in step 7.
|
||||
4. Add a title (`v{major}.{minor}.0`) and a short description of what’s new.
|
||||
5. Click **Publish Release**.
|
||||
|
||||
#### 13. Bump to dev version
|
||||
|
||||
1. Create a branch `bump-dev-version-{major}.{minor+1}` from `main` and checkout to it.
|
||||
|
||||
```shell
|
||||
git checkout -b bump-dev-version-{major}.{minor+1}
|
||||
```
|
||||
|
||||
2. Change the version in the following files:
|
||||
1. `__init__.py`
|
||||
```diff
|
||||
- __version__ = "{major}.{minor}.0"
|
||||
+ __version__ = "{major}.{minor+1}.0.dev0"
|
||||
```
|
||||
2. `setup.cfg`
|
||||
```diff
|
||||
- version = {major}.{minor}.0
|
||||
+ version = {major}.{minor+1}.0.dev0
|
||||
```
|
||||
|
||||
3. Commit and push these changes
|
||||
|
||||
```shell
|
||||
git add trl/__init__.py setup.cfg
|
||||
git commit -m '⬆️ Bump dev version'
|
||||
git push origin bump-dev-version-{major}.{minor+1}
|
||||
```
|
||||
|
||||
4. Create a pull request from `bump-dev-version-{major}.{minor+1}` to `main`, named `⬆️ Bump dev version`, and request urgent review.
|
||||
|
||||
5. Once the pull request is approved, merge it into `main`.
|
||||
|
||||
6. The codebase is now ready for the next development cycle, inform the team in the #trl-internal channel.
|
||||
|
@ -1,6 +1,6 @@
|
||||
include settings.ini
|
||||
include LICENSE
|
||||
include CONTRIBUTING.md
|
||||
include README.md
|
||||
recursive-exclude * __pycache__
|
||||
include trl/templates/*.md
|
||||
include trl/templates/*.md
|
||||
include trl/accelerate_configs/*.yaml
|
@ -12,8 +12,9 @@
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/huggingface/trl/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/huggingface/trl.svg?color=blue"></a>
|
||||
<a href="https://huggingface.co/docs/trl/index"><img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/trl/index.svg?down_color=red&down_message=offline&up_color=blue&up_message=online"></a>
|
||||
<a href="https://huggingface.co/docs/trl/index"><img alt="Documentation" src="https://img.shields.io/website?label=documentation&url=https%3A%2F%2Fhuggingface.co%2Fdocs%2Ftrl%2Findex&down_color=red&down_message=offline&up_color=blue&up_message=online"></a>
|
||||
<a href="https://github.com/huggingface/trl/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/trl.svg"></a>
|
||||
<a href="https://huggingface.co/trl-lib"><img alt="Hugging Face Hub" src="https://img.shields.io/badge/🤗%20Hub-trl--lib-yellow"></a>
|
||||
</p>
|
||||
|
||||
## Overview
|
||||
|
@ -51,8 +51,6 @@
|
||||
title: Training StackLlama
|
||||
- local: detoxifying_a_lm
|
||||
title: Detoxifying a Language Model
|
||||
- local: learning_tools
|
||||
title: Learning to Use Tools
|
||||
- local: multi_adapter_rl
|
||||
title: Multi Adapter RLHF
|
||||
- local: training_vlm_sft
|
||||
@ -99,6 +97,8 @@
|
||||
title: Trainers
|
||||
- local: models
|
||||
title: Model Classes
|
||||
- local: model_utils
|
||||
title: Model Utilities
|
||||
- local: best_of_n
|
||||
title: Best of N Sampling
|
||||
- local: judges
|
||||
@ -107,8 +107,8 @@
|
||||
title: Callbacks
|
||||
- local: data_utils
|
||||
title: Data Utilities
|
||||
- local: text_environments
|
||||
title: Text Environments
|
||||
- local: rewards
|
||||
title: Reward Functions
|
||||
- local: script_utils
|
||||
title: Script Utilities
|
||||
- local: others
|
||||
|
@ -1,105 +1,225 @@
|
||||
# Command Line Interfaces (CLIs)
|
||||
|
||||
You can use TRL to fine-tune your language model with methods like Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) using the command line interface (CLI).
|
||||
TRL provides a powerful command-line interface (CLI) to fine-tune large language models (LLMs) using methods like Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and more. The CLI abstracts away much of the boilerplate, letting you launch training jobs quickly and reproducibly.
|
||||
|
||||
Currently supported CLIs are:
|
||||
Currently supported commands are:
|
||||
|
||||
#### Training commands
|
||||
#### Training Commands
|
||||
|
||||
- `trl dpo`: fine-tune a LLM with DPO
|
||||
- `trl grpo`: fine-tune a LLM with GRPO
|
||||
- `trl kto`: fine-tune a LLM with KTO
|
||||
- `trl sft`: fine-tune a LLM with SFT
|
||||
|
||||
#### Other commands
|
||||
#### Other Commands
|
||||
|
||||
- `trl env`: get the system information
|
||||
- `trl vllm-serve`: serve a model with vLLM
|
||||
|
||||
## Fine-tuning with the CLI
|
||||
## Fine-Tuning with the TRL CLI
|
||||
|
||||
Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task.
|
||||
### Basic Usage
|
||||
|
||||
You can launch training directly from the CLI by specifying required arguments like the model and dataset:
|
||||
|
||||
<hfoptions id="command_line">
|
||||
<hfoption id="SFT">
|
||||
|
||||
Before using the `sft` or `dpo` commands make sure to run:
|
||||
```bash
|
||||
accelerate config
|
||||
trl sft \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name stanfordnlp/imdb
|
||||
```
|
||||
and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command.
|
||||
|
||||
We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command.
|
||||
</hfoption>
|
||||
<hfoption id="DPO">
|
||||
|
||||
```bash
|
||||
trl dpo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name anthropic/hh-rlhf
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Using Configuration Files
|
||||
|
||||
To keep your CLI commands clean and reproducible, you can define all training arguments in a YAML configuration file:
|
||||
|
||||
<hfoptions id="config_file">
|
||||
<hfoption id="SFT">
|
||||
|
||||
```yaml
|
||||
model_name_or_path:
|
||||
Qwen/Qwen2.5-0.5B
|
||||
dataset_name:
|
||||
stanfordnlp/imdb
|
||||
report_to:
|
||||
none
|
||||
learning_rate:
|
||||
0.0001
|
||||
lr_scheduler_type:
|
||||
cosine
|
||||
# sft_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: stanfordnlp/imdb
|
||||
```
|
||||
|
||||
Save that config in a `.yaml` and get started immediately! An example CLI config is available as `examples/cli_configs/example_config.yaml`. Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g. from the root folder:
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl sft --config examples/cli_configs/example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts
|
||||
trl sft --config sft_config.yaml
|
||||
```
|
||||
|
||||
Will force-use `cosine_with_restarts` for `lr_scheduler_type`.
|
||||
</hfoption>
|
||||
<hfoption id="DPO">
|
||||
|
||||
### Supported Arguments
|
||||
```yaml
|
||||
# dpo_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: anthropic/hh-rlhf
|
||||
```
|
||||
|
||||
We do support all arguments from `transformers.TrainingArguments`, for loading your model, we support all arguments from `~trl.ModelConfig`:
|
||||
|
||||
[[autodoc]] ModelConfig
|
||||
|
||||
You can pass any of these arguments either to the CLI or the YAML file.
|
||||
|
||||
### Supervised Fine-tuning (SFT)
|
||||
|
||||
Follow the basic instructions above and run `trl sft --output_dir <output_dir> <*args>`:
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl sft --model_name_or_path facebook/opt-125m --dataset_name stanfordnlp/imdb --output_dir opt-sft-imdb
|
||||
trl dpo --config dpo_config.yaml
|
||||
```
|
||||
|
||||
The SFT CLI is based on the `trl/scripts/sft.py` script.
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Direct Policy Optimization (DPO)
|
||||
### Scaling Up with Accelerate
|
||||
|
||||
To use the DPO CLI, you need to have a dataset in the TRL format such as
|
||||
TRL CLI natively supports [🤗 Accelerate](https://huggingface.co/docs/accelerate), making it easy to scale training across multiple GPUs, machines, or use advanced setups like DeepSpeed — all from the same CLI.
|
||||
|
||||
* TRL's Anthropic HH dataset: https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-helpful-base-trl-style
|
||||
* TRL's OpenAI TL;DR summarization dataset: https://huggingface.co/datasets/trl-internal-testing/tldr-preference-trl-style
|
||||
You can pass any `accelerate launch` arguments directly to `trl`, such as `--num_processes`. For more information see [Using accelerate launch](https://huggingface.co/docs/accelerate/en/basic_tutorials/launch#using-accelerate-launch).
|
||||
|
||||
These datasets always have at least three columns `prompt, chosen, rejected`:
|
||||
|
||||
* `prompt` is a list of strings.
|
||||
* `chosen` is the chosen response in [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
|
||||
* `rejected` is the rejected response [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
|
||||
|
||||
|
||||
To do a quick start, you can run the following command:
|
||||
<hfoptions id="launch_args">
|
||||
<hfoption id="SFT inline">
|
||||
|
||||
```bash
|
||||
trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style
|
||||
trl sft \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name stanfordnlp/imdb \
|
||||
--num_processes 4
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="SFT w/ config file">
|
||||
|
||||
The DPO CLI is based on the `trl/scripts/dpo.py` script.
|
||||
```yaml
|
||||
# sft_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: stanfordnlp/imdb
|
||||
num_processes: 4
|
||||
```
|
||||
|
||||
|
||||
#### Custom preference dataset
|
||||
|
||||
Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`):
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
|
||||
trl sft --config sft_config.yaml
|
||||
```
|
||||
|
||||
## Chat interface
|
||||
</hfoption>
|
||||
<hfoption id="DPO inline">
|
||||
|
||||
```bash
|
||||
trl dpo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name anthropic/hh-rlhf \
|
||||
--num_processes 4
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPO w/ config file">
|
||||
|
||||
```yaml
|
||||
# dpo_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: anthropic/hh-rlhf
|
||||
num_processes: 4
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl dpo --config dpo_config.yaml
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Using `--accelerate_config` for Accelerate Configuration
|
||||
|
||||
The `--accelerate_config` flag lets you easily configure distributed training with [🤗 Accelerate](https://github.com/huggingface/accelerate). This flag accepts either:
|
||||
|
||||
* the name of a predefined config profile (built into TRL), or
|
||||
* a path to a custom Accelerate YAML config file.
|
||||
|
||||
#### Predefined Config Profiles
|
||||
|
||||
TRL provides several ready-to-use Accelerate configs to simplify common training setups:
|
||||
|
||||
| Name | Description |
|
||||
| ------------ | ----------------------------------- |
|
||||
| `fsdp1` | Fully Sharded Data Parallel Stage 1 |
|
||||
| `fsdp2` | Fully Sharded Data Parallel Stage 2 |
|
||||
| `zero1` | DeepSpeed ZeRO Stage 1 |
|
||||
| `zero2` | DeepSpeed ZeRO Stage 2 |
|
||||
| `zero3` | DeepSpeed ZeRO Stage 3 |
|
||||
| `multi_gpu` | Multi-GPU training |
|
||||
| `single_gpu` | Single-GPU training |
|
||||
|
||||
To use one of these, just pass the name to `--accelerate_config`. TRL will automatically load the corresponding config file from `trl/accelerate_config/`.
|
||||
|
||||
#### Example Usage
|
||||
|
||||
<hfoptions id="accelerate_config">
|
||||
<hfoption id="SFT inline">
|
||||
|
||||
```bash
|
||||
trl sft \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name stanfordnlp/imdb \
|
||||
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="SFT w/ config file">
|
||||
|
||||
```yaml
|
||||
# sft_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: stanfordnlp/imdb
|
||||
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl sft --config sft_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPO inline">
|
||||
|
||||
```bash
|
||||
trl dpo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name anthropic/hh-rlhf \
|
||||
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPO w/ config file">
|
||||
|
||||
```yaml
|
||||
# dpo_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: anthropic/hh-rlhf
|
||||
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl dpo --config dpo_config.yaml
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Chat Interface
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
@ -130,7 +250,7 @@ Besides talking to the model there are a few commands you can use:
|
||||
- `save` or `save {SAVE_NAME}`: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
|
||||
- `exit`: closes the interface
|
||||
|
||||
## Getting the system information
|
||||
## Getting the System Information
|
||||
|
||||
You can get the system information by running the following command:
|
||||
|
||||
@ -138,7 +258,7 @@ You can get the system information by running the following command:
|
||||
trl env
|
||||
```
|
||||
|
||||
This will print out the system information including the GPU information, the CUDA version, the PyTorch version, the transformers version, and the TRL version, and any optional dependencies that are installed.
|
||||
This will print out the system information, including the GPU information, the CUDA version, the PyTorch version, the transformers version, the TRL version, and any optional dependencies that are installed.
|
||||
|
||||
```txt
|
||||
Copy-paste the following information when reporting an issue:
|
||||
@ -146,7 +266,7 @@ Copy-paste the following information when reporting an issue:
|
||||
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
|
||||
- Python version: 3.11.9
|
||||
- PyTorch version: 2.4.1
|
||||
- CUDA device: NVIDIA H100 80GB HBM3
|
||||
- accelerator(s): NVIDIA H100 80GB HBM3
|
||||
- Transformers version: 4.45.0.dev0
|
||||
- Accelerate version: 0.34.2
|
||||
- Accelerate config:
|
||||
@ -177,6 +297,7 @@ Copy-paste the following information when reporting an issue:
|
||||
- LLM-Blender version: 0.0.2
|
||||
- OpenAI version: 1.46.0
|
||||
- PEFT version: 0.12.0
|
||||
- vLLM version: not installed
|
||||
```
|
||||
|
||||
This information are required when reporting an issue.
|
||||
This information is required when reporting an issue.
|
||||
|
@ -168,6 +168,10 @@ The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterativ
|
||||
|
||||
The [WPO](https://huggingface.co/papers/2406.11827) paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the `use_weighting` flag to `True` in the [`DPOConfig`].
|
||||
|
||||
### LD-DPO loss
|
||||
|
||||
The [LD-DPO](https://huggingface.co/papers/2409.06411) paper decomposes the portion of the response that exceeds the desired length into two components — human-like preferences and verbosity preference — based on a mixing coefficient \\( \alpha \\). To use this method, set the `ld_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this value between `0.0` and `1.0`.
|
||||
|
||||
### For Mixture of Experts Models: Enabling the auxiliary loss
|
||||
|
||||
MOEs are the most efficient if the load is about equally distributed between experts.
|
||||
|
@ -155,6 +155,7 @@ This constant is recommended to be the maximum completion length. To use this fo
|
||||
- `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function.
|
||||
- `reward`: The overall average reward after applying reward weights.
|
||||
- `reward_std`: The standard deviation of the overall reward within each batch after applying reward weights.
|
||||
- `frac_reward_zero_std`: The fraction of samples in the generation batch with a reward std of zero, implying there is little diversity for that prompt (all answers are correct or incorrect).
|
||||
- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero.
|
||||
- `clip_ratio/region_mean`: The ratio of token probabilities where the GRPO objective is clipped to stay within the trust region:
|
||||
$$
|
||||
@ -170,26 +171,59 @@ A higher value means more tokens are clipped, which constrains how much the poli
|
||||
|
||||
### Speed up training with vLLM-powered generation
|
||||
|
||||
Generation is often the main bottleneck that makes training slow with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation. To enable it, first install the package with
|
||||
|
||||
Generation is often the main bottleneck when training with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a high-throughput, low-latency inference engine for LLMs. To enable it, first install the package with
|
||||
```shell
|
||||
pip install trl[vllm]
|
||||
```
|
||||
|
||||
Then, start the vLLM server with the desired model:
|
||||
We support two ways of using vLLM during training: **server mode** and **colocate mode**.
|
||||
|
||||
```bash
|
||||
trl vllm-serve --model <model_name>
|
||||
```
|
||||
#### 🔌 Option 1: Server mode
|
||||
|
||||
Then, pass `use_vllm=True` in the training arguments and run the training script:
|
||||
In this mode, vLLM runs in a separate process (and using separate GPUs) and communicates with the trainer via HTTP. This is ideal if you have dedicated GPUs for inference.
|
||||
|
||||
1. **Start the vLLM server**:
|
||||
```bash
|
||||
trl vllm-serve --model <model_name>
|
||||
```
|
||||
|
||||
2. **Enable server mode in your training script**:
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
...,
|
||||
use_vllm=True,
|
||||
vllm_mode="server", # default value, can be omitted
|
||||
)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure that the server is using different GPUs than the trainer, otherwise you may run into NCCL errors. You can specify the GPUs to use with the `CUDA_VISIBLE_DEVICES` environment variable.
|
||||
|
||||
</Tip>
|
||||
|
||||
#### 🧩 Option 2: Colocate mode
|
||||
|
||||
In this mode, vLLM runs inside the trainer process and shares GPU memory with the training model. This avoids launching a separate server and can improve GPU utilization, but may lead to memory contention on the training GPUs.
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(..., use_vllm=True)
|
||||
training_args = GRPOConfig(
|
||||
...,
|
||||
use_vllm=True,
|
||||
vllm_mode="colocate",
|
||||
)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
Depending on the model size and the overall GPU memory requirements for training, you may need to adjust the `vllm_gpu_memory_utilization` parameter in [`GRPOConfig`] to avoid underutilization or out-of-memory errors.
|
||||
|
||||
</Tip>
|
||||
|
||||
For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).
|
||||
|
||||
### GRPO at scale: train a 70B+ Model on multiple nodes
|
||||
@ -274,6 +308,7 @@ The [`GRPOTrainer`] supports using custom reward functions instead of dense rewa
|
||||
- The function must accept the following as keyword arguments:
|
||||
- `prompts` (contains the prompts),
|
||||
- `completions` (contains the generated completions),
|
||||
- `completions_ids` (contains the tokenized completions),
|
||||
- All columns names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument.
|
||||
|
||||
The easiest way to comply with this requirement is to use `**kwargs` in the function signature.
|
||||
@ -287,9 +322,29 @@ The [`GRPOTrainer`] supports using custom reward functions instead of dense rewa
|
||||
|
||||
Below is an example of a reward function for a standard format that rewards longer completions:
|
||||
|
||||
```python
|
||||
def reward_func(completions_ids, **kwargs):
|
||||
"""Reward function that assigns higher scores to longer completions (in terms of token count)."""
|
||||
return [float(len(ids)) for ids in completions_ids]
|
||||
```
|
||||
|
||||
You can test it as follows:
|
||||
|
||||
```python
|
||||
>>> prompts = ["The sky is", "The sun is"] # not used in the reward function, but the trainer will pass it
|
||||
>>> completions = [" blue.", " in the sky."] # not used in the reward function, but the trainer will pass it
|
||||
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]]
|
||||
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
|
||||
[2.0, 4.0]
|
||||
```
|
||||
|
||||
#### Example 1.1: Reward longer completions (based in the number of characters)
|
||||
|
||||
Same as the previous example, but this time the reward function is based on the number of characters instead of tokens.
|
||||
|
||||
```python
|
||||
def reward_func(completions, **kwargs):
|
||||
"""Reward function that gives higher scores to longer completions."""
|
||||
"""Reward function that assigns higher scores to longer completions (in terms of character count)."""
|
||||
return [float(len(completion)) for completion in completions]
|
||||
```
|
||||
|
||||
@ -298,7 +353,8 @@ You can test it as follows:
|
||||
```python
|
||||
>>> prompts = ["The sky is", "The sun is"]
|
||||
>>> completions = [" blue.", " in the sky."]
|
||||
>>> print(reward_func(prompts=prompts, completions=completions))
|
||||
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]] # not used in the reward function, but the trainer will pass it
|
||||
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
|
||||
[6.0, 12.0]
|
||||
```
|
||||
|
||||
|
@ -2,56 +2,138 @@
|
||||
|
||||
[](https://huggingface.co/models?other=iterative-sft,trl)
|
||||
|
||||
|
||||
Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.
|
||||
|
||||
## Usage
|
||||
## Quickstart
|
||||
|
||||
To get started quickly, instantiate an instance a model, and a tokenizer.
|
||||
To get started quickly, you can either pass a model identifier or a pre-instantiated model to the trainer:
|
||||
|
||||
```python
|
||||
from trl import IterativeSFTConfig, IterativeSFTTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
# Using a model identifier
|
||||
trainer = IterativeSFTTrainer(
|
||||
"facebook/opt-350m",
|
||||
args=IterativeSFTConfig(
|
||||
max_length=512,
|
||||
output_dir="./output",
|
||||
),
|
||||
)
|
||||
|
||||
# Or using a pre-instantiated model
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
|
||||
trainer = IterativeSFTTrainer(
|
||||
model,
|
||||
tokenizer
|
||||
args=IterativeSFTConfig(
|
||||
max_length=512,
|
||||
output_dir="./output",
|
||||
),
|
||||
processing_class=tokenizer,
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
You have the choice to either provide a list of strings or a list of tensors to the step function.
|
||||
## Usage
|
||||
|
||||
#### Using a list of tensors as input:
|
||||
The [`IterativeSFTTrainer`] supports two ways of providing input data to the `step` function:
|
||||
|
||||
### Using a list of tensors as input:
|
||||
|
||||
```python
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
trainer.step(**inputs)
|
||||
|
||||
```
|
||||
|
||||
#### Using a list of strings as input:
|
||||
### Using a list of strings as input:
|
||||
|
||||
```python
|
||||
|
||||
inputs = {
|
||||
"texts": texts
|
||||
"texts": texts,
|
||||
"texts_labels": texts_labels, # Optional, defaults to texts
|
||||
}
|
||||
|
||||
trainer.step(**inputs)
|
||||
|
||||
```
|
||||
|
||||
For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels.
|
||||
For causal language models, labels will automatically be created from `input_ids` or from `texts`. When using sequence to sequence models you will have to provide your own labels or `text_labels`.
|
||||
|
||||
## IterativeTrainer
|
||||
## Configuration
|
||||
|
||||
The [`IterativeSFTConfig`] class provides several parameters to customize the training:
|
||||
|
||||
```python
|
||||
from trl import IterativeSFTConfig
|
||||
|
||||
config = IterativeSFTConfig(
|
||||
# Model initialization parameters
|
||||
model_init_kwargs={"torch_dtype": "bfloat16"},
|
||||
|
||||
# Data preprocessing parameters
|
||||
max_length=512,
|
||||
truncation_mode="keep_end",
|
||||
|
||||
# Training parameters
|
||||
output_dir="./output",
|
||||
learning_rate=2e-5,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
max_steps=1000,
|
||||
logging_steps=10,
|
||||
save_steps=100,
|
||||
optim="adamw_torch",
|
||||
report_to="wandb",
|
||||
)
|
||||
```
|
||||
|
||||
### Model Initialization
|
||||
|
||||
You can control how the model is initialized by passing keyword arguments to `model_init_kwargs`:
|
||||
|
||||
```python
|
||||
config = IterativeSFTConfig(
|
||||
model_init_kwargs={
|
||||
"torch_dtype": "bfloat16",
|
||||
"device_map": "auto",
|
||||
"trust_remote_code": True,
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### Data Preprocessing
|
||||
|
||||
The trainer supports two truncation modes:
|
||||
|
||||
- `keep_end`: Truncates from the start of the sequence
|
||||
- `keep_start`: Truncates from the end of the sequence
|
||||
|
||||
```python
|
||||
config = IterativeSFTConfig(
|
||||
max_length=512,
|
||||
truncation_mode="keep_end", # or "keep_start"
|
||||
)
|
||||
```
|
||||
|
||||
### Training Optimization
|
||||
|
||||
You can optimize CUDA cache usage for more memory-efficient training:
|
||||
|
||||
```python
|
||||
config = IterativeSFTConfig(
|
||||
optimize_device_cache=True,
|
||||
)
|
||||
```
|
||||
|
||||
## IterativeSFTTrainer
|
||||
|
||||
[[autodoc]] IterativeSFTTrainer
|
||||
|
||||
## IterativeSFTConfig
|
||||
|
||||
[[autodoc]] IterativeSFTConfig
|
||||
|
@ -1,233 +0,0 @@
|
||||
# Learning Tools (Experimental 🧪)
|
||||
|
||||
Using Large Language Models (LLMs) with tools has been a popular topic recently with awesome works such as [ToolFormer](https://huggingface.co/papers/2302.04761) and [ToolBench](https://huggingface.co/papers/2305.16504). In TRL, we provide a simple example of how to teach LLM to use tools with reinforcement learning.
|
||||
|
||||
|
||||
Here's an overview of the scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples/research_projects/tools):
|
||||
|
||||
| File | Description |
|
||||
|---|---|
|
||||
| [`calculator.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/calculator.py) | Script to train LLM to use a calculator with reinforcement learning. |
|
||||
| [`triviaqa.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/triviaqa.py) | Script to train LLM to use a wiki tool to answer questions. |
|
||||
| [`python_interpreter.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/python_interpreter.py) | Script to train LLM to use python interpreter to solve math puzzles. |
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Note that the scripts above rely heavily on the `TextEnvironment` API which is still under active development. The API may change in the future. Please see [`TextEnvironment`](text_environment) for the related docs.
|
||||
</Tip>
|
||||
|
||||
|
||||
## Learning to Use a Calculator
|
||||
|
||||
|
||||
The rough idea is as follows:
|
||||
|
||||
1. Load a tool such as [ybelkada/simple-calculator](https://huggingface.co/spaces/ybelkada/simple-calculator) that parse a text calculation like `"14 + 34"` and return the calculated number:
|
||||
```python
|
||||
from transformers import AutoTokenizer, load_tool
|
||||
tool = load_tool("ybelkada/simple-calculator")
|
||||
tool_fn = lambda text: str(round(float(tool(text)), 2)) # rounding to 2 decimal places
|
||||
```
|
||||
1. Define a reward function that returns a positive reward if the tool returns the correct answer. In the script we create a dummy reward function like `reward_fn = lambda x: 1`, but we override the rewards directly later.
|
||||
1. Create a prompt on how to use the tools
|
||||
```python
|
||||
# system prompt
|
||||
prompt = """\
|
||||
What is 13.1-3?
|
||||
|
||||
<request><SimpleCalculatorTool>13.1-3<call>10.1<response>
|
||||
|
||||
Result=10.1<submit>
|
||||
|
||||
What is 4*3?
|
||||
|
||||
<request><SimpleCalculatorTool>4*3<call>12<response>
|
||||
|
||||
Result=12<submit>
|
||||
|
||||
What is 12.1+1?
|
||||
|
||||
<request><SimpleCalculatorTool>12.1+1<call>13.1<response>
|
||||
|
||||
Result=13.1<submit>
|
||||
|
||||
What is 12.1-20?
|
||||
|
||||
<request><SimpleCalculatorTool>12.1-20<call>-7.9<response>
|
||||
|
||||
Result=-7.9<submit>"""
|
||||
```
|
||||
3. Create a `trl.TextEnvironment` with the model
|
||||
```python
|
||||
env = TextEnvironment(
|
||||
model,
|
||||
tokenizer,
|
||||
{"SimpleCalculatorTool": tool_fn},
|
||||
reward_fn,
|
||||
prompt,
|
||||
generation_kwargs=generation_kwargs,
|
||||
)
|
||||
```
|
||||
4. Then generate some data such as `tasks = ["\n\nWhat is 13.1-3?", "\n\nWhat is 4*3?"]` and run the environment with `queries, responses, masks, rewards, histories = env.run(tasks)`. The environment will look for the `<call>` token in the prompt and append the tool output to the response; it will also return the mask associated with the response. You can further use the `histories` to visualize the interaction between the model and the tool; `histories[0].show_text()` will show the text with color-coded tool output and `histories[0].show_tokens(tokenizer)` will show visualize the tokens.
|
||||

|
||||
1. Finally, we can train the model with `train_stats = ppo_trainer.step(queries, responses, rewards, masks)`. The trainer will use the mask to ignore the tool output when computing the loss, make sure to pass that argument to `step`.
|
||||
|
||||
## Experiment results
|
||||
|
||||
We trained a model with the above script for 10 random seeds. You can reproduce the run with the following command. Feel free to remove the `--slurm-*` arguments if you don't have access to a slurm cluster.
|
||||
|
||||
```
|
||||
WANDB_TAGS="calculator_final" python benchmark/benchmark.py \
|
||||
--command "python examples/research_projects/tools/calculator.py" \
|
||||
--num-seeds 10 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 8 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
```
|
||||
|
||||
We can then use [`openrlbenchmark`](https://github.com/openrlbenchmark/openrlbenchmark) which generates the following plot.
|
||||
```
|
||||
# pip install openrlbenchmark==0.2.1a5
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=openrlbenchmark&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.tracker_project_name&cen=trl_ppo_trainer_config.value.log_with&metrics=env/reward_mean&metrics=objective/kl' \
|
||||
'wandb?tag=calculator_final&cl=calculator_mask' \
|
||||
--env-ids trl \
|
||||
--check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename static/0compare \
|
||||
--scan-history
|
||||
```
|
||||
|
||||

|
||||
|
||||
As we can see, while 1-2 experiments crashed for some reason, most of the runs obtained near perfect proficiency in the calculator task.
|
||||
|
||||
|
||||
## (Early Experiments 🧪): learning to use a wiki tool for question answering
|
||||
|
||||
In the [ToolFormer](https://huggingface.co/papers/2302.04761) paper, it shows an interesting use case that utilizes a Wikipedia Search tool to help answer questions. In this section, we attempt to perform similar experiments but uses RL instead to teach the model to use a wiki tool on the [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) dataset.
|
||||
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
**Note that many settings are different so the results are not directly comparable.**
|
||||
</Tip>
|
||||
|
||||
|
||||
|
||||
|
||||
### Building a search index
|
||||
|
||||
Since [ToolFormer](https://huggingface.co/papers/2302.04761) did not open source, we needed to first replicate the search index. It is mentioned in their paper that the authors built the search index using a BM25 retriever that indexes the Wikipedia dump from [KILT](https://github.com/facebookresearch/KILT)
|
||||
|
||||
Fortunately, [`pyserini`](https://github.com/castorini/pyserini) already implements the BM25 retriever and provides a prebuilt index for the KILT Wikipedia dump. We can use the following code to search the index.
|
||||
|
||||
```python
|
||||
from pyserini.search.lucene import LuceneSearcher
|
||||
import json
|
||||
searcher = LuceneSearcher.from_prebuilt_index('wikipedia-kilt-doc')
|
||||
def search(query):
|
||||
hits = searcher.search(query, k=1)
|
||||
hit = hits[0]
|
||||
contents = json.loads(hit.raw)['contents']
|
||||
return contents
|
||||
print(search("tennis racket"))
|
||||
```
|
||||
```
|
||||
Racket (sports equipment)
|
||||
A racket or racquet is a sports implement consisting of a handled frame with an open hoop across which a network of strings or catgut is stretched tightly. It is used for striking a ball or shuttlecock in games such as squash, tennis, racquetball, and badminton. Collectively, these games are known as racket sports. Racket design and manufacturing has changed considerably over the centuries.
|
||||
|
||||
The frame of rackets for all sports was traditionally made of solid wood (later laminated wood) and the strings of animal intestine known as catgut. The traditional racket size was limited by the strength and weight of the wooden frame which had to be strong enough to hold the strings and stiff enough to hit the ball or shuttle. Manufacturers started adding non-wood laminates to wood rackets to improve stiffness. Non-wood rackets were made first of steel, then of aluminum, and then carbon fiber composites. Wood is still used for real tennis, rackets, and xare. Most rackets are now made of composite materials including carbon fiber or fiberglass, metals such as titanium alloys, or ceramics.
|
||||
...
|
||||
```
|
||||
|
||||
We then basically deployed this snippet as a Hugging Face space [here](https://huggingface.co/spaces/vwxyzjn/pyserini-wikipedia-kilt-doc), so that we can use the space as a `transformers.Tool` later.
|
||||
|
||||

|
||||
|
||||
### Experiment settings
|
||||
|
||||
We use the following settings:
|
||||
|
||||
* use the `bigcode/starcoderbase` model as the base model
|
||||
* use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragraphs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool.
|
||||
* test if the response contain the answer string, if so, give a reward of 1, otherwise, give a reward of 0.
|
||||
* notice this is a simplified evaluation criteria. In [ToolFormer](https://huggingface.co/papers/2302.04761), the authors checks if the first 20 words of the response contain the correct answer.
|
||||
* used the following prompt that demonstrates the usage of the wiki tool.
|
||||
```python
|
||||
prompt = """\
|
||||
Answer the following question:
|
||||
|
||||
Q: In which branch of the arts is Patricia Neary famous?
|
||||
A: Ballets
|
||||
A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
|
||||
Result=Ballets<submit>
|
||||
|
||||
Q: Who won Super Bowl XX?
|
||||
A: Chicago Bears
|
||||
A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
|
||||
Result=Chicago Bears<submit>
|
||||
|
||||
Q: """
|
||||
```
|
||||
|
||||
|
||||
### Result and Discussion
|
||||
|
||||
|
||||
Our experiments show that the agent can learn to use the wiki tool to answer questions. The learning curves would go up mostly, but one of the experiment did crash.
|
||||
|
||||

|
||||
|
||||
Wandb report is [here](https://wandb.ai/costa-huang/cleanRL/reports/TriviaQA-Final-Experiments--Vmlldzo1MjY0ODk5) for further inspection.
|
||||
|
||||
|
||||
Note that the correct rate of the trained model is on the low end, which could be due to the following reasons:
|
||||
|
||||
* **incorrect searches:** When given the question `"What is Bruce Willis' real first name?"` if the model searches for `Bruce Willis`, our wiki tool returns "Patrick Poivey (born 18 February 1948) is a French actor. He is especially known for his voice: he is the French dub voice of Bruce Willis since 1988.` But a correct search should be `Walter Bruce Willis (born March 19, 1955) is an American former actor. He achieved fame with a leading role on the comedy-drama series Moonlighting (1985–1989) and appeared in over a hundred films, gaining recognition as an action hero after his portrayal of John McClane in the Die Hard franchise (1988–2013) and other roles.[1][2]"
|
||||
|
||||
|
||||

|
||||
|
||||
* **unnecessarily long response**: The wiki tool by default sometimes output very long sequences. E.g., when the wiki tool searches for "Brown Act"
|
||||
* Our wiki tool returns "The Ralph M. Brown Act, located at California Government Code 54950 "et seq.", is an act of the California State Legislature, authored by Assemblymember Ralph M. Brown and passed in 1953, that guarantees the public's right to attend and participate in meetings of local legislative bodies."
|
||||
* [ToolFormer](https://huggingface.co/papers/2302.04761)'s wiki tool returns "The Ralph M. Brown Act is an act of the California State Legislature that guarantees the public's right to attend and participate in meetings of local legislative bodies." which is more succinct.
|
||||
|
||||

|
||||
|
||||
|
||||
## (Early Experiments 🧪): solving math puzzles with python interpreter
|
||||
|
||||
In this section, we attempt to teach the model to use a python interpreter to solve math puzzles. The rough idea is to give the agent a prompt like the following:
|
||||
|
||||
```python
|
||||
prompt = """\
|
||||
Example of using a Python API to solve math questions.
|
||||
|
||||
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
|
||||
|
||||
<request><PythonInterpreter>
|
||||
def solution():
|
||||
money_initial = 23
|
||||
bagels = 5
|
||||
bagel_cost = 3
|
||||
money_spent = bagels * bagel_cost
|
||||
money_left = money_initial - money_spent
|
||||
result = money_left
|
||||
return result
|
||||
print(solution())
|
||||
<call>8<response>
|
||||
|
||||
Result = 8 <submit>
|
||||
|
||||
Q: """
|
||||
```
|
||||
|
||||
|
||||
Training experiment can be found at https://wandb.ai/lvwerra/trl-gsm8k/runs/a5odv01y
|
||||
|
||||

|
@ -1,74 +1,99 @@
|
||||
# Logging
|
||||
|
||||
As reinforcement learning algorithms are historically challenging to debug, it's important to pay careful attention to logging.
|
||||
By default, the TRL [`PPOTrainer`] saves a lot of relevant information to wandb or tensorboard.
|
||||
By default, TRL trainers like [`PPOTrainer`] and [`GRPOTrainer`] save a lot of relevant information to supported experiment trackers like Weights & Biases (wandb) or TensorBoard.
|
||||
|
||||
Upon initialization, pass one of these two options to the [`PPOConfig`]:
|
||||
Upon initialization, pass the `report_to` argument to the respective configuration object (e.g., [`PPOConfig`] for `PPOTrainer`, or [`GRPOConfig`] for `GRPOTrainer`):
|
||||
|
||||
```
|
||||
training_args = PPOConfig(..., report_to="wandb") # or "tensorboard"
|
||||
```python
|
||||
# For PPOTrainer
|
||||
ppo_config = PPOConfig(
|
||||
# ...,
|
||||
report_to="wandb" # or "tensorboard"
|
||||
)
|
||||
|
||||
# For GRPOTrainer
|
||||
grpc_config = GRPOConfig(
|
||||
# ...,
|
||||
report_to="wandb" # or "tensorboard"
|
||||
)
|
||||
```
|
||||
|
||||
If you want to log with tensorboard, add the kwarg `project_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig.
|
||||
If you want to log with TensorBoard, you might also need to specify logging directories, for example, by adding `logging_dir=PATH_TO_LOGS` to the configuration object (e.g., `PPOConfig` or `GRPOConfig`).
|
||||
|
||||
## PPO Logging
|
||||
|
||||
Here's a brief explanation for the logged metrics provided in the data:
|
||||
|
||||
Key metrics to monitor. We want to maximize the reward, maintain a low KL divergence, and maximize entropy:
|
||||
1. `env/reward_mean`: The average reward obtained from the environment. Alias `ppo/mean_scores`, which is used to specifically monitor the reward model.
|
||||
1. `env/reward_std`: The standard deviation of the reward obtained from the environment. Alias ``ppo/std_scores`, which is used to specifically monitor the reward model.
|
||||
1. `env/reward_dist`: The histogram distribution of the reward obtained from the environment.
|
||||
1. `objective/kl`: The mean Kullback-Leibler (KL) divergence between the old and new policies. It measures how much the new policy deviates from the old policy. The KL divergence is used to compute the KL penalty in the objective function.
|
||||
1. `objective/kl_dist`: The histogram distribution of the `objective/kl`.
|
||||
1. `objective/kl_coef`: The coefficient for Kullback-Leibler (KL) divergence in the objective function.
|
||||
1. `ppo/mean_non_score_reward`: The **KL penalty** calculated by `objective/kl * objective/kl_coef` as the total reward for optimization to prevent the new policy from deviating too far from the old policy.
|
||||
1. `objective/entropy`: The entropy of the model's policy, calculated by `-logprobs.sum(-1).mean()`. High entropy means the model's actions are more random, which can be beneficial for exploration.
|
||||
|
||||
Training stats:
|
||||
1. `ppo/learning_rate`: The learning rate for the PPO algorithm.
|
||||
1. `ppo/policy/entropy`: The entropy of the model's policy, calculated by `pd = torch.nn.functional.softmax(logits, dim=-1); entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)`. It measures the randomness of the policy.
|
||||
1. `ppo/policy/clipfrac`: The fraction of probability ratios (old policy / new policy) that fell outside the clipping range in the PPO objective. This can be used to monitor the optimization process.
|
||||
1. `ppo/policy/approxkl`: The approximate KL divergence between the old and new policies, measured by `0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)`, corresponding to the `k2` estimator in http://joschu.net/blog/kl-approx.html
|
||||
1. `ppo/policy/policykl`: Similar to `ppo/policy/approxkl`, but measured by `masked_mean(old_logprobs - logprobs, mask)`, corresponding to the `k1` estimator in http://joschu.net/blog/kl-approx.html
|
||||
1. `ppo/policy/ratio`: The histogram distribution of the ratio between the new and old policies, used to compute the PPO objective.
|
||||
1. `ppo/policy/advantages_mean`: The average of the GAE (Generalized Advantage Estimation) advantage estimates. The advantage function measures how much better an action is compared to the average action at a state.
|
||||
1. `ppo/policy/advantages`: The histogram distribution of `ppo/policy/advantages_mean`.
|
||||
1. `ppo/returns/mean`: The mean of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance. See https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ for more details.
|
||||
1. `ppo/returns/var`: The variance of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance.
|
||||
1. `ppo/val/mean`: The mean of the values, used to monitor the value function's performance.
|
||||
1. `ppo/val/var` : The variance of the values, used to monitor the value function's performance.
|
||||
1. `ppo/val/var_explained`: The explained variance for the value function, used to monitor the value function's performance.
|
||||
1. `ppo/val/clipfrac`: The fraction of the value function's predicted values that are clipped.
|
||||
1. `ppo/val/vpred`: The predicted values from the value function.
|
||||
1. `ppo/val/error`: The mean squared error between the `ppo/val/vpred` and returns, used to monitor the value function's performance.
|
||||
1. `ppo/loss/policy`: The policy loss for the Proximal Policy Optimization (PPO) algorithm.
|
||||
1. `ppo/loss/value`: The loss for the value function in the PPO algorithm. This value quantifies how well the function estimates the expected future rewards.
|
||||
1. `ppo/loss/total`: The total loss for the PPO algorithm. It is the sum of the policy loss and the value function loss.
|
||||
|
||||
|
||||
Stats on queries, responses, and logprobs:
|
||||
1. `tokens/queries_len_mean`: The average length of the queries tokens.
|
||||
1. `tokens/queries_len_std`: The standard deviation of the length of the queries tokens.
|
||||
1. `tokens/queries_dist`: The histogram distribution of the length of the queries tokens.
|
||||
1. `tokens/responses_len_mean`: The average length of the responses tokens.
|
||||
1. `tokens/responses_len_std`: The standard deviation of the length of the responses tokens.
|
||||
1. `tokens/responses_dist`: The histogram distribution of the length of the responses tokens. (Costa: inconsistent naming, should be `tokens/responses_len_dist`)
|
||||
1. `objective/logprobs`: The histogram distribution of the log probabilities of the actions taken by the model.
|
||||
1. `objective/ref_logprobs`: The histogram distribution of the log probabilities of the actions taken by the reference model.
|
||||
|
||||
|
||||
* `eps`: Tracks the number of episodes per second.
|
||||
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
|
||||
* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
|
||||
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
|
||||
* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
|
||||
* `objective/scores`: The mean scores returned by the reward model / environment.
|
||||
* `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
|
||||
* `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
|
||||
* `loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
|
||||
* `loss/value_avg`: The average value loss, indicating the difference between the predicted value and the actual reward.
|
||||
* `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to `policy/clipfrac_avg` but for the value function.
|
||||
* `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
|
||||
* `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
|
||||
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
|
||||
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
|
||||
* `lr`: The current learning rate used by the optimizer.
|
||||
* `episode`: The current episode count in the training process.
|
||||
|
||||
### Crucial values
|
||||
During training, many values are logged, here are the most important ones:
|
||||
|
||||
1. `env/reward_mean`,`env/reward_std`, `env/reward_dist`: the properties of the reward distribution from the "environment" / reward model
|
||||
1. `ppo/mean_non_score_reward`: The mean negated KL penalty during training (shows the delta between the reference model and the new policy over the batch in the step)
|
||||
1. `objective/scores`: The mean scores returned by the reward model / environment.
|
||||
1. `objective/rlhf_reward`: The mean RLHF reward. This is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
|
||||
1. `objective/non_score_reward`: The mean reward from non-score-related sources (e.g., KL penalty).
|
||||
|
||||
Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables):
|
||||
|
||||
1. `ppo/loss/value`: it will spike / NaN when not going well.
|
||||
1. `ppo/policy/ratio`: `ratio` being 1 is a baseline value, meaning that the probability of sampling a token is the same under the new and old policy. If the ratio is too high like 200, it means the probability of sampling a token is 200 times higher under the new policy than the old policy. This is a sign that the new policy is too different from the old policy, which will likely cause overoptimization and collapse training later on.
|
||||
1. `ppo/policy/clipfrac` and `ppo/policy/approxkl`: if `ratio` is too high, the `ratio` is going to get clipped, resulting in high `clipfrac` and high `approxkl` as well.
|
||||
1. `objective/kl`: it should stay positive so that the policy is not too far away from the reference policy.
|
||||
1. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities.
|
||||
1. `loss/value_avg`: The average value loss. It will spike / NaN when not going well.
|
||||
1. `val/ratio`: The mean ratio of the current policy probability to the old policy probability. This number should float around 1.0. If this `ratio` is too high (e.g., 2.0 or 1000.0) or too small (e.g., 0.1), it means the updates between consecutive policies are too drastic.
|
||||
1. `policy/clipfrac_avg` and `policy/approxkl_avg`: If `val/ratio` is too high, the `ratio` is going to get clipped, resulting in high `policy/clipfrac_avg` and high `policy/approxkl_avg` as well.
|
||||
1. `objective/kl`: The mean KL divergence. It should stay positive and ideally not too large, so that the policy is not too far away from the reference policy.
|
||||
|
||||
## GRPO Logging
|
||||
|
||||
Here's a brief explanation for the logged metrics provided in the data for the GRPO trainer:
|
||||
|
||||
* `num_tokens`: Total number of input tokens processed during training so far.
|
||||
|
||||
**Completions:**
|
||||
* `completions/mean_length`: Mean length of all generated completions (including those not ending with an EOS token).
|
||||
* `completions/min_length`: Minimum length among all generated completions.
|
||||
* `completions/max_length`: Maximum length among all generated completions.
|
||||
* `completions/clipped_ratio`: The ratio of completions that did not end with an EOS token before reaching the maximum generation length (i.e., they were truncated).
|
||||
* `completions/mean_terminated_length`: Mean length of only those completions that successfully ended with an EOS token.
|
||||
* `completions/min_terminated_length`: Minimum length among completions that ended with an EOS token.
|
||||
* `completions/max_terminated_length`: Maximum length among completions that ended with an EOS token.
|
||||
|
||||
**Rewards:**
|
||||
* `rewards/{reward_func_name}/mean`: The mean reward obtained from a specific, named reward function (e.g., `rewards/my_custom_reward/mean`). This is logged for each reward function used.
|
||||
* `rewards/{reward_func_name}/std`: The standard deviation of rewards from a specific, named reward function.
|
||||
* `reward`: The overall mean of the (potentially weighted and, if `args.scale_rewards` is true, normalized) rewards, after group-wise normalization (advantages).
|
||||
* `reward_std`: The standard deviation of the (potentially weighted) rewards *before* group-wise normalization for advantages.
|
||||
|
||||
**Policy and Loss Metrics:**
|
||||
* `kl`: The mean Kullback-Leibler (KL) divergence between the current policy and the reference policy. This is logged only if `beta` (the KL coefficient in `GRPOConfig`) is non-zero.
|
||||
* If Liger GRPOLoss is used (`use_liger_loss: True` in `GRPOConfig`):
|
||||
* `clip_ratio`: The fraction of policy updates where the probability ratio was clipped according to the GRPO loss's epsilon bounds.
|
||||
* If standard GRPOLoss is used (`use_liger_loss: False`):
|
||||
* `clip_ratio/low_mean`: The mean fraction of instances where the probability ratio `r_t(θ)` was clipped at the lower bound `1 - epsilon_low` (occurs when advantage is negative and ratio is below the bound).
|
||||
* `clip_ratio/low_min`: The minimum observed fraction for `clip_ratio/low_mean` across batches/processes.
|
||||
* `clip_ratio/high_mean`: The mean fraction of instances where the probability ratio `r_t(θ)` was clipped at the upper bound `1 + epsilon_high` (occurs when advantage is positive and ratio is above the bound).
|
||||
* `clip_ratio/high_max`: The maximum observed fraction for `clip_ratio/high_mean` across batches/processes.
|
||||
* `clip_ratio/region_mean`: The mean fraction of instances where the probability ratio was clipped at either the lower or upper bound.
|
||||
|
||||
### Crucial GRPO values
|
||||
During GRPO training, monitor these values for insights into performance and stability:
|
||||
|
||||
1. `reward`: This is the primary objective. It reflects the (group-wise normalized) rewards the policy is achieving. It should generally increase during successful training.
|
||||
1. `kl`: If `beta > 0`, this tracks the divergence from the reference model. Keep an eye on it to ensure the policy doesn't stray too far, which can lead to instability.
|
||||
1. `clip_ratio/*` (either `clip_ratio` for Liger loss or the more detailed `clip_ratio/...` metrics for standard loss): These indicate how often the policy updates are being constrained by the GRPO clipping mechanism. Very high values might suggest that the policy is trying to change too drastically (potentially due to large advantages or a learning rate that's too high) or that the epsilon clipping range is too restrictive.
|
||||
1. `completions/clipped_ratio`: A high ratio here indicates that the model is frequently generating completions that are cut off by `max_completion_length` rather than naturally ending with an EOS token. This might suggest issues with learning sequence termination or that `max_completion_length` is too short.
|
||||
1. `rewards/{reward_func_name}/mean`: Monitoring the mean of individual reward functions can help diagnose which aspects of the desired behavior the model is learning or struggling with, especially when using multiple reward sources.
|
||||
|
5
docs/source/model_utils.md
Normal file
5
docs/source/model_utils.md
Normal file
@ -0,0 +1,5 @@
|
||||
# Model Utilities
|
||||
|
||||
## get_act_offloading_ctx_manager
|
||||
|
||||
[[autodoc]] models.get_act_offloading_ctx_manager
|
@ -14,7 +14,7 @@ Sequence lengths in the dataset can vary widely. When data is batched, sequences
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/why_you_should_truncate.png" alt="Truncation prompt completion" width="600"/>
|
||||
</div>
|
||||
|
||||
To reduce memory usage, it’s important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case.
|
||||
To reduce memory usage, it's important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case.
|
||||
|
||||
<hfoptions id="dpo">
|
||||
<hfoption id="DPO">
|
||||
@ -129,6 +129,42 @@ training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_imple
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Activation offloading
|
||||
|
||||
Activation offloading is a memory efficiency technique that reduces GPU VRAM usage by temporarily moving activation tensors to CPU RAM during the forward pass and bringing them back only when needed for the backward pass. This significantly reduces peak memory usage at the cost of slightly increased training time.
|
||||
|
||||
To enable activation offloading in your SFT training configuration:
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="SFT">
|
||||
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(..., activation_offloading=True)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
When using activation offloading with models that use Liger kernels, you must disable Liger cross entropy due to compatibility issues. The issue occurs specifically with `use_liger_kernel=True` because Liger cross entropy performs in-place operations which conflict with activation offloading. The default setting (`use_liger_kernel=False`) works:
|
||||
|
||||
```python
|
||||
# When using activation offloading with a model that uses Liger kernels:
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(
|
||||
activation_offloading=True,
|
||||
use_liger_kernel=False, # Disable Liger cross entropy
|
||||
# Other parameters...
|
||||
)
|
||||
```
|
||||
</Tip>
|
||||
|
||||
Under the hood, activation offloading implements PyTorch's [`saved_tensors_hooks`](https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html#hooks-for-autograd-saved-tensors) to intercept activations during the forward pass. It intelligently manages which tensors to offload based on size and context, avoiding offloading output tensors which would be inefficient. For performance optimization, it can optionally use CUDA streams to overlap computation with CPU-GPU transfers.
|
||||
|
||||
## Disabling model gathering for generation in online methods
|
||||
|
||||
When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Online methods involve generating completions from the model as part of the training process. During this step, the model weights are temporarily gathered on a single GPU for generation. For very large models, this gathering can lead to out-of-memory (OOM) errors, as described in this issue: [#2250](https://github.com/huggingface/trl/issues/2250#issue-2598304204).
|
||||
|
9
docs/source/rewards.md
Normal file
9
docs/source/rewards.md
Normal file
@ -0,0 +1,9 @@
|
||||
# Reward Functions
|
||||
|
||||
This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`].
|
||||
|
||||
## Format rewards
|
||||
|
||||
### think_format_reward
|
||||
|
||||
[[autodoc]] rewards.think_format_reward
|
@ -50,24 +50,24 @@ trainer = SFTTrainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
The above snippets will use the default training arguments from the [`SFTConfig`] class. If you want to modify the defaults pass in your modification to the `SFTConfig` constructor and pass them to the trainer via the `args` argument.
|
||||
The above snippets will use the default training arguments from the [`SFTConfig`] class. If you want to modify the defaults, pass in your modification to the `SFTConfig` constructor and pass it to the trainer via the `args` argument.
|
||||
|
||||
## Advanced usage
|
||||
|
||||
### Train on completions only
|
||||
|
||||
To train on completions only, simply use a [prompt-completion](#prompt-completion) dataset. In this mode, loss is computed solely on the completion part.
|
||||
To train on completions only, simply use a [prompt-completion](dataset_formats#prompt-completion) dataset. In this mode, loss is computed solely on the completion part.
|
||||
|
||||
If you’d like to compute loss on both the prompt **and** the completion while still using a prompt-completion dataset, set `completion_only_loss=False` in the [`SFTConfig`]. This is equivalent to [converting the dataset to a language modeling](#from-prompt-completion-to-language-modeling-dataset) format.
|
||||
If you’d like to compute loss on both the prompt **and** the completion while still using a prompt-completion dataset, set `completion_only_loss=False` in the [`SFTConfig`]. This is equivalent to [converting the dataset to a language modeling](dataset_formats#from-prompt-completion-to-language-modeling-dataset) format.
|
||||
|
||||
### Add Special Tokens for Chat Format
|
||||
|
||||
Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment.
|
||||
Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system, and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment.
|
||||
The [`setup_chat_format`] function in `trl` easily sets up a model and tokenizer for conversational AI tasks. This function:
|
||||
- Adds special tokens to the tokenizer, e.g. `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
|
||||
- Adds special tokens to the tokenizer, e.g., `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
|
||||
- Resizes the model’s embedding layer to accommodate the new tokens.
|
||||
- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI.
|
||||
- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g. 64. If you want to see more formats being supported in the future, please open a GitHub issue on [trl](https://github.com/huggingface/trl)
|
||||
- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g., `64`. If you want to see more formats being supported in the future, please open a GitHub issue on [trl](https://github.com/huggingface/trl)
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
@ -77,12 +77,12 @@ from trl import setup_chat_format
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
|
||||
# Set up the chat format with default 'chatml' format
|
||||
# Set up the chat format with the default 'chatml' format
|
||||
model, tokenizer = setup_chat_format(model, tokenizer)
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Some base models, like those from Qwen, have a predefined chat template in the model's tokenizer. In these cases it is not necessary to apply `setup_chat_format()`, as the tokenizer already handles the formatting. However, it is necessary to align the EOS token with the chat template to ensure the model's responses terminate correctly. In these cases, specify `eos_token` in `SFTConfig`; for example, for `Qwen/Qwen2.5-1.5B` one should set `eos_token="<|im_end|>"`.
|
||||
> Some base models, like those from Qwen, have a predefined chat template in the model's tokenizer. In these cases, it is not necessary to apply `setup_chat_format()`, as the tokenizer already handles the formatting. However, it is necessary to align the EOS token with the chat template to ensure the model's responses terminate correctly. In these cases, specify `eos_token` in `SFTConfig`; for example, for `Qwen/Qwen2.5-1.5B`, one should set `eos_token="<|im_end|>"`.
|
||||
|
||||
With our model and tokenizer set up, we can now fine-tune our model on a conversational dataset. Below is an example of how a dataset can be formatted for fine-tuning.
|
||||
|
||||
@ -126,7 +126,7 @@ trainer = SFTTrainer(
|
||||
)
|
||||
```
|
||||
|
||||
If the dataset is not in one of those format you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look.
|
||||
If the dataset is not in one of those formats, you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look.
|
||||
|
||||
|
||||
### Format your input prompts
|
||||
@ -158,7 +158,7 @@ trainer = SFTTrainer(
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
To properly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example of how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
|
||||
To properly format your input, make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example of how to use SFTTrainer on the alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
|
||||
|
||||
### Packing dataset
|
||||
|
||||
@ -177,12 +177,12 @@ trainer = SFTTrainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Note that if you use a packed dataset and if you pass `max_steps` in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing.
|
||||
Note that if you use a packed dataset and if you pass `max_steps` in the training arguments, you will probably train your models for more than a few epochs, depending on the way you have configured the packed dataset and the training protocol. Double-check that you know and understand what you are doing.
|
||||
If you don't want to pack your `eval_dataset`, you can pass `eval_packing=False` to the `SFTConfig` init method.
|
||||
|
||||
#### Customize your prompts using packed dataset
|
||||
|
||||
If your dataset has several fields that you want to combine, for example if the dataset has `question` and `answer` fields and you want to combine them, you can pass a formatting function to the trainer that will take care of that. For example:
|
||||
If your dataset has several fields that you want to combine, for example, if the dataset has `question` and `answer` fields and you want to combine them, you can pass a formatting function to the trainer that will take care of that. For example:
|
||||
|
||||
```python
|
||||
def formatting_func(example):
|
||||
@ -256,7 +256,7 @@ trainer.train()
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> If the chat template contains special tokens like `<|im_start|>` (ChatML) or `<|eot_id|>` (Llama), the embedding layer and LM head must be included in the trainable parameters via the `modules_to_save` argument. Without this, the fine-tuned model will produce unbounded or nonsense generations. If the chat template doesn't contain special tokens (e.g. Alpaca), then the `modules_to_save` argument can be ignored or set to `None`.
|
||||
> If the chat template contains special tokens like `<|im_start|>` (ChatML) or `<|eot_id|>` (Llama), the embedding layer and LM head must be included in the trainable parameters via the `modules_to_save` argument. Without this, the fine-tuned model will produce unbounded or nonsensical generations. If the chat template doesn't contain special tokens (e.g., Alpaca), then the `modules_to_save` argument can be ignored or set to `None`.
|
||||
|
||||
|
||||
You can also continue training your `PeftModel`. For that, first load a `PeftModel` outside `SFTTrainer` and pass it directly to the trainer without the `peft_config` argument being passed.
|
||||
@ -321,7 +321,7 @@ Once you have loaded your model, wrap the `trainer.train()` call under the `with
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Note that you cannot train your model using Flash Attention 1 on an arbitrary dataset as `torch.scaled_dot_product_attention` does not support training with padding tokens if you use Flash Attention kernels. Therefore you can only use that feature with `packing=True`. If your dataset contains padding tokens, consider switching to Flash Attention 2 integration.
|
||||
Note that you cannot train your model using Flash Attention 1 on an arbitrary dataset as `torch.scaled_dot_product_attention` does not support training with padding tokens if you use Flash Attention kernels. Therefore, you can only use that feature with `packing=True`. If your dataset contains padding tokens, consider switching to Flash Attention 2 integration.
|
||||
|
||||
Below are some numbers you can get in terms of speedup and memory efficiency, using Flash Attention 1, on a single NVIDIA-T4 16GB.
|
||||
|
||||
@ -351,12 +351,12 @@ model = AutoModelForCausalLM.from_pretrained(
|
||||
```
|
||||
|
||||
If you don't use quantization, make sure your model is loaded in half-precision and dispatch your model on a supported GPU device.
|
||||
After loading your model, you can either train it as it is, or attach adapters and train adapters on it in case your model is quantized.
|
||||
After loading your model, you can either train it as it is or attach adapters and train adapters on it in case your model is quantized.
|
||||
|
||||
In contrast to Flash Attention 1, the integration makes it possible to train your model on an arbitrary dataset that also includes padding tokens.
|
||||
|
||||
|
||||
### Using model creation utility
|
||||
### Using the model creation utility
|
||||
|
||||
We included a utility function to create your model.
|
||||
|
||||
@ -391,17 +391,17 @@ trainer = SFTTrainer(
|
||||
)
|
||||
```
|
||||
|
||||
### Enhance the model's performances using NEFTune
|
||||
### Enhance the model's performance using NEFTune
|
||||
|
||||
NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://huggingface.co/papers/2310.05914) from Jain et al. it consists of adding noise to the embedding vectors during training. According to the abstract of the paper:
|
||||
NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://huggingface.co/papers/2310.05914) from Jain et al. It consists of adding noise to the embedding vectors during training. According to the abstract of the paper:
|
||||
|
||||
> Standard finetuning of LLaMA-2-7B using Alpaca achieves 29.79% on AlpacaEval, which rises to 64.69% using noisy embeddings. NEFTune also improves over strong baselines on modern instruction datasets. Models trained with Evol-Instruct see a 10% improvement, with ShareGPT an 8% improvement, and with OpenPlatypus an 8% improvement. Even powerful models further refined with RLHF such as LLaMA-2-Chat benefit from additional training with NEFTune.
|
||||
> Standard finetuning of LLaMA-2-7B using Alpaca achieves 29.79% on AlpacaEval, which rises to 64.69% using noisy embeddings. NEFTune also improves over strong baselines on modern instruction datasets. Models trained with Evol-Instruct see a 10% improvement, with ShareGPT an 8% improvement, and with OpenPlatypus an 8% improvement. Even powerful models further refined with RLHF, such as LLaMA-2-Chat, benefit from additional training with NEFTune.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/neft-screenshot.png">
|
||||
</div>
|
||||
|
||||
To use it in `SFTTrainer` simply pass `neftune_noise_alpha` when creating your `SFTConfig` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to retrieve back the original behaviour of the embedding layer.
|
||||
To use it in `SFTTrainer`, simply pass `neftune_noise_alpha` when creating your `SFTConfig` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to revert to the original behaviour of the embedding layer.
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
@ -430,7 +430,7 @@ Note however, that the amount of performance gain is _dataset dependent_ and in
|
||||
|
||||
### Accelerate fine-tuning 2x using `unsloth`
|
||||
|
||||
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks on 1x A100 listed below:
|
||||
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently, `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek, etc) and Mistral architectures. Some benchmarks on 1x A100 listed below:
|
||||
|
||||
| 1 A100 40GB | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
|
||||
| --------------- | --------- | --- | --------------------- | --------- | ------------ |
|
||||
@ -439,7 +439,7 @@ You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [
|
||||
| Mistral 7b | Slim Orca | 1x | 1.17x | **1.88x** | -65.9% |
|
||||
| Tiny Llama 1.1b | Alpaca | 1x | 1.55x | **2.74x** | -57.8% |
|
||||
|
||||
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
|
||||
First, install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
|
||||
|
||||
```python
|
||||
import torch
|
||||
@ -489,18 +489,18 @@ trainer.train()
|
||||
|
||||
The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).
|
||||
|
||||
## Liger-Kernel: Increase 20% throughput and reduces 60% memory for multi-GPU training
|
||||
## Liger-Kernel: Increase 20% throughput and reduce 60% memory for multi-GPU training
|
||||
|
||||
[Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. That way, we can **4x** our context length, as described in the benchmark below. They have implemented Hugging Face Compatible `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed).
|
||||
[Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduce memory usage by 60%. That way, we can **4x** our context length, as described in the benchmark below. They have implemented Hugging Face Compatible `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed).
|
||||
|
||||
With great memory reduction, you can potentially turn off cpu_offloading or gradient checkpointing to further boost the performance.
|
||||
With this memory reduction, you can potentially turn off `cpu_offloading` or gradient checkpointing to further boost the performance.
|
||||
|
||||
| Speed Up | Memory Reduction |
|
||||
|--------------------------|-------------------------|
|
||||
|  |  |
|
||||
|
||||
|
||||
1. To use Liger-Kernel in [`SFTTrainer`], first install by
|
||||
1. To use Liger-Kernel in [`SFTTrainer`], first install it by:
|
||||
|
||||
```bash
|
||||
pip install liger-kernel
|
||||
@ -510,7 +510,8 @@ pip install liger-kernel
|
||||
|
||||
```python
|
||||
training_args = SFTConfig(
|
||||
use_liger_kernel=True
|
||||
use_liger_kernel=True,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
@ -523,11 +524,11 @@ Pay attention to the following best practices when training a model with that tr
|
||||
- [`SFTTrainer`] always truncates by default the sequences to the `max_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training.
|
||||
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
|
||||
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
|
||||
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
|
||||
- If you create a model outside the trainer, make sure not to pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
|
||||
|
||||
## Multi-GPU Training
|
||||
|
||||
Trainer (and thus SFTTrainer) supports multi-GPU training. If you run your script with `python script.py` it will default to using DP as the strategy, which may be [slower than expected](https://github.com/huggingface/trl/issues/1303). To use DDP (which is generally recommended, see [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many?select-gpu=Accelerate#data-parallelism) for more info) you must launch the script with `python -m torch.distributed.launch script.py` or `accelerate launch script.py`. For DDP to work you must also check the following:
|
||||
Trainer (and thus SFTTrainer) supports multi-GPU training. If you run your script with `python script.py` it will default to using DP as the strategy, which may be [slower than expected](https://github.com/huggingface/trl/issues/1303). To use DDP (which is generally recommended, see [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many?select-gpu=Accelerate#data-parallelism) for more info) you must launch the script with `python -m torch.distributed.launch script.py` or `accelerate launch script.py`. For DDP to work, you must also check the following:
|
||||
- If you're using gradient_checkpointing, add the following to the TrainingArguments: `gradient_checkpointing_kwargs={'use_reentrant':False}` (more info [here](https://github.com/huggingface/transformers/issues/26969)
|
||||
- Ensure that the model is placed on the correct device:
|
||||
```python
|
||||
@ -545,7 +546,7 @@ You may experience some issues with GPTQ Quantization after completing training.
|
||||
|
||||
## Extending `SFTTrainer` for Vision Language Models
|
||||
|
||||
`SFTTrainer` does not inherently support vision-language data. However, we provide a guide on how to tweak the trainer to support vision-language data. Specifically, you need to use a custom data collator that is compatible with vision-language data. This guide outlines the steps to make these adjustments. For a concrete example, refer to the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) which demonstrates how to fine-tune the LLaVA 1.5 model on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset.
|
||||
`SFTTrainer` does not inherently support vision-language data. However, we provide a guide on how to tweak the trainer to support vision-language data. Specifically, you need to use a custom data collator that is compatible with vision-language data. This guide outlines the steps to make these adjustments. For a concrete example, refer to the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py), which demonstrates how to fine-tune the LLaVA 1.5 model on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset.
|
||||
|
||||
### Preparing the Data
|
||||
|
||||
@ -664,6 +665,6 @@ A full example of training LLaVa 1.5 on the [HuggingFaceH4/llava-instruct-mix-vs
|
||||
|
||||
## Datasets
|
||||
|
||||
In the SFTTrainer we smartly support `datasets.IterableDataset` in addition to other style datasets. This is useful if you are using large corpora that you do not want to save all to disk. The data will be tokenized and processed on the fly, even when packing is enabled.
|
||||
In the SFTTrainer, we smartly support `datasets.IterableDataset` in addition to other style datasets. This is useful if you are using large corpora that you do not want to save all to disk. The data will be tokenized and processed on the fly, even when packing is enabled.
|
||||
|
||||
Additionally, in the SFTTrainer, we support pre-tokenized datasets if they are `datasets.Dataset` or `datasets.IterableDataset`. In other words, if such a dataset has a column of `input_ids`, no further processing (tokenization or packing) will be done, and the dataset will be used as-is. This can be useful if you have pretokenized your dataset outside of this script and want to re-use it directly.
|
||||
Additionally, in the SFTTrainer, we support pre-tokenized datasets if they are `datasets.Dataset` or `datasets.IterableDataset`. In other words, if such a dataset has a column of `input_ids`, no further processing (tokenization or packing) will be done, and the dataset will be used as-is. This can be useful if you have pretokenized your dataset outside of this script and want to reuse it directly.
|
||||
|
@ -1,197 +0,0 @@
|
||||
# Text Environments
|
||||
|
||||
Text environments provide a learning ground for language agents. It allows a language model to use tools to accomplish a task such as using a Python interpreter to answer math questions or using a search index for trivia questions. Having access to tools allows language models to solve tasks that would be very hard for the models itself but can be trivial for the appropriate tools. A good example is arithmetics of large numbers that become a simple copy-paste task once you have access to a calculator.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/textenv.png">
|
||||
</div>
|
||||
|
||||
Let's dive into how text environments work and start with tools!
|
||||
|
||||
## Tools
|
||||
|
||||
One of the core building blocks of text environments are tools that the model can use to solve tasks. In general tools can be any Python function that takes a string as input and returns string. The `TextEnvironment` offers two options for tools: either go with predefined tools from `transformers.Tool` or define your own function or class with `__call__` method. Let's have a look at both!
|
||||
|
||||
### `transformers.Tool`
|
||||
|
||||
Text environments fully support tools of the class `transformers.Tool`. The advantage of building tools in that framework is that they can easily be shared
|
||||
|
||||
```Python
|
||||
from transformers import load_tool
|
||||
|
||||
# simple calculator tool that runs +-/* operations
|
||||
calc_tool = load_tool("ybelkada/simple-calculator")
|
||||
|
||||
# python interpreter that executes program and returns outputs
|
||||
py_tool = load_tool("lvwerra/python-interpreter")
|
||||
|
||||
# wikipedia search index that returns best search match
|
||||
wiki_tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
|
||||
```
|
||||
|
||||
These tools are either loaded from the hub or from a local folder. Using the tool is as simple as calling them with a text query:
|
||||
|
||||
```Python
|
||||
calc_tool("1/2")
|
||||
>>> "0.5"
|
||||
```
|
||||
|
||||
Note that both input and return values are strings to enable easy usage with a language model.
|
||||
|
||||
### Custom Tools
|
||||
|
||||
The following is an example of a tool that adds two integers:
|
||||
|
||||
```Python
|
||||
def add(text):
|
||||
int_1, int_2 = text.split("+")
|
||||
result = int(int_1) + int(int_2)
|
||||
return str(result)
|
||||
|
||||
print(add("1+1"))
|
||||
>>> "2"
|
||||
```
|
||||
|
||||
We looked at basic examples such as a calculator but the principle holds for more complex tools as well such as a web search tool where you input the query and get the search results in return. Now let's look at how the model can use the tools with the call syntax.
|
||||
|
||||
### Call syntax
|
||||
|
||||
In order to have a unified way for the model to call a tool we created a simple syntax that looks as follows:
|
||||
|
||||
```python
|
||||
"<request><TOOL_NAME>QUERY<call>TOOL_RESPONSE<response>"
|
||||
```
|
||||
|
||||
There are a few special tokens involved so let's decompose it: First the model can signal that it wants to use a tool by emitting the `<request>` token. After that we want to know the name of the tool to call which is done by enclosing the tool name with `<>` brackets. Once we know which tool to call the tool query follows which is in free text form. The `<call>` tokens signifies the end of the query and stops the model generation. At this point the model output is parsed and the query sent to the tool. The environment appends the tool response to the string followed by the `<response>` token to show the end the tool output.
|
||||
|
||||
Let's look at the concrete example of the calculator and assume its name is `Calculator` (more on how the name of a tool is inferred later):
|
||||
|
||||
```python
|
||||
"<request><Calculator>1/2<call>0.5<response>"
|
||||
```
|
||||
|
||||
Finally, the episode is ended and generation stops when the model generates `<submit>` which marks the interaction as completed.
|
||||
|
||||
Now let's have a look how we can create a new text environment!
|
||||
|
||||
## Create a `TextEnvironment`
|
||||
|
||||
|
||||
```python
|
||||
prompt = """\
|
||||
What is 13-3?
|
||||
<request><SimpleCalculatorTool>13-3<call>10.0<response>
|
||||
Result=10<submit>
|
||||
"""
|
||||
|
||||
def reward_fn(result, answer):
|
||||
"""Simplified reward function returning 1 if result matches answer and 0 otherwise."""
|
||||
result_parsed = result.split("=")[1].split("<")[0]
|
||||
return int(result_parsed==answer)
|
||||
|
||||
text_env = TextEnvironemnt(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
tools= {"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")},
|
||||
reward_fn=exact_match_reward,
|
||||
prompt=prompt,
|
||||
max_turns=1
|
||||
max_tool_response=100
|
||||
generation_kwargs={"do_sample": "true"}
|
||||
)
|
||||
```
|
||||
|
||||
Let's decompose the settings:
|
||||
|
||||
| Argument | Description |
|
||||
|:-------------------|:----------------|
|
||||
| `model` | Language model to interact with the environment and generate requests. |
|
||||
| `tokenizer` | Tokenizer of language model handling tokenization of strings. |
|
||||
| `tools` | `list` of `dict` of tools. If former the name of the tool is inferred from class name and otherwise it's the keys of the dictionary.|
|
||||
| `reward_fn` | A function that takes a string as input and returns. Can have extra arguments that are passed to `.run()` such as ground truth.|
|
||||
| `prompt` | Prompt to prepend to every task. Usually a few examples to demonstrate to the model how to use the tools in a few-shot fashion. |
|
||||
| `max_turns` | Maximum number of interactions between model and tools before episode ends.|
|
||||
| `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.|
|
||||
| `max_length` | The maximum number of tokens to allow in an episode. |
|
||||
| `generation_kwargs`| Generation settings used by the language model. |
|
||||
|
||||
You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools!
|
||||
|
||||
|
||||
## Run an Episode
|
||||
|
||||
To run a set of queries through the text environment one can simply use the `run` method.
|
||||
|
||||
```python
|
||||
queries = ["What is 1/2?"]
|
||||
answers = ["0.5"]
|
||||
|
||||
queries, responses, masks, rewards, histories = text_env.run(queries, answers=answers)
|
||||
```
|
||||
|
||||
This will execute the model/tool feedback loop for each query until either no tool is called anymore, the maximum number of turns is reached or to maximum number of tokens in an episode is exceeded. The extra `kwargs` (e.g. `answers=answers` above) passed to `run` will be passed on to the reward function.
|
||||
|
||||
There are five objects that are returned by `run`:
|
||||
|
||||
- `queries`: a list of the tokenized queries
|
||||
- `responses`: all tokens that have been generated withing the environment including model and tool tokens
|
||||
- `masks`: mask that indicates which tokens have been generated by the model and which tokens are generated by the tool
|
||||
- `rewards`: a list of reward for each query/response
|
||||
- `histories`: list of `TextHistory` objects, which are useful objects containing all the above and also the text equivalents
|
||||
|
||||
The masks are crucial for training as we don't want to optimize tokens that the model has not generated which are tokens produced by the tools.
|
||||
|
||||
Next, we'll train a PPO step with the generated responses!
|
||||
|
||||
|
||||
### Train
|
||||
Training on episodes from the `TextEnvironment` is straight forward and simply requires forwarding all the returned variables except the `TextHistory` objects to the `step` method:
|
||||
|
||||
```python
|
||||
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
|
||||
```
|
||||
|
||||
## `TextHistory`
|
||||
|
||||
The `TextHistory` object stores the interactions between the model and the text environment. It stores tokens and text generated in each turn and their source in each turn (model or system) as well as rewards. Let's go through the class attributes and methods.
|
||||
|
||||
### Attributes
|
||||
|
||||
The following table summarises the available attributes of the `TextHistory` class:
|
||||
|
||||
| Attribute | Description |
|
||||
|:-------------------|:----------------|
|
||||
| `text` | The full string of the text generated in the text environment with both model and system generated text. |
|
||||
| `text_spans` | A list of tuples with the spans for each model or system generated text segment. |
|
||||
| `system_spans` | A list of boolean values indicating if the segment is model or system generated. |
|
||||
| `tokens` | All tokens generated in text environment with both model and system generated tokens. |
|
||||
| `token_spans` | Similar to `text_spans` the `token_spans` indicate the boundaries of model andsystem generated tokens. |
|
||||
| `token_masks` | The token masks can be used to ignore system generated tokens by masking them. |
|
||||
| `completed` | Indicates if the interaction with the environment has completed. |
|
||||
| `truncated` | Indicates if the interaction with the environment has completed because max length was reached. |
|
||||
|
||||
With these attributes you can reconstruct every interaction of the model with the `TextEnvironment`. The `TextHistory` also lets you visualize the text history. Let's have a look!
|
||||
|
||||
### Visualization
|
||||
|
||||
When the model interacts inside the `TextEnvironment` it can be useful to visualize and separate which parts of the text outputs were generated by the model and which parts come from the system and tools. For that purpose there are the two methods [`TextHistory.show_text`] and [`TextHistory.show_tokens`]. They print the text and tokens respectively and highlight the various segments using the [`rich` library](https://github.com/Textualize/rich) (make sure to install it before using these methods).
|
||||
|
||||
You can see that the prompt is highlighted in gray, whereas system segments such as query and tool responses are highlighted in green. All segments generated by the model are highlighted in blue and in addition to the pure text output the reward is displayed as additional text in plum. Here an example of `show_text`:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/textenv_show_text.png" width=600>
|
||||
</div>
|
||||
|
||||
Sometimes there can be tricky tokenization related issues that are hidden when showing the decoded text. Thus `TextHistory` also offers an option to display the same highlighting on the tokens directly with `show_tokens`:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/textenv_show_tokens.png" width=800>
|
||||
</div>
|
||||
|
||||
Note that you can turn on the colour legend by passing `show_legend=True`.
|
||||
|
||||
## API Documentation
|
||||
|
||||
[[autodoc]] TextEnvironment
|
||||
|
||||
[[autodoc]] TextHistory
|
@ -1,20 +1,125 @@
|
||||
# vLLM Integration
|
||||
|
||||
<Tip warning={true}>
|
||||
This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood. Let's go! 🔥
|
||||
|
||||
Section under construction. Feel free to contribute!
|
||||
## 🚀 How can I use vLLM with TRL to speed up training?
|
||||
|
||||
</Tip>
|
||||
💡 **Note**: Resources required for this specific example: a single node with 8 GPUs.
|
||||
|
||||
## TRL vLLM server
|
||||
First, install vLLM using the following command:
|
||||
|
||||
TRL provides a way to speedup generation using a dedicated vLLM server.
|
||||
```bash
|
||||
pip install "trl[vllm]"
|
||||
```
|
||||
|
||||
Then run the server:
|
||||
|
||||
```sh
|
||||
trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 2 --data-parallel-size 2
|
||||
```
|
||||
|
||||
Once the server is running, you can use it to generate completions for training. In the example below, we are using the `GRPOTrainer` to train a model using the vLLM server for generation. The `--tensor-parallel-size` and `--data-parallel-size` arguments control how the model and data are sharded across GPUs.
|
||||
|
||||
In this example, we are sharding two copies of the model across 4 GPUs. Increasing data parallelism increases throughput, while increasing tensor parallelism allows for serving larger models. Then, run the training script by passing `use_vllm=True` in the training arguments as follows:
|
||||
|
||||
Sample of a simple `train.py` script:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
# Dummy reward function: count the number of unique characters in the completions
|
||||
def reward_num_unique_chars(completions, **kwargs):
|
||||
return [len(set(c)) for c in completions]
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir="my_test",
|
||||
use_vllm=True,
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
logging_steps=10,
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2.5-7B",
|
||||
args=training_args,
|
||||
reward_funcs=reward_num_unique_chars,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
And the train command:
|
||||
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
|
||||
```
|
||||
|
||||
## 🎬 Flashback: Why do we need to use vLLM in online methods?
|
||||
|
||||
Online methods like GRPO or Online DPO require the model to generate completions during training, which are then used to compute reward signals. However, generation can be extremely time-consuming, especially with large or reasoning models. In the default setup (without vLLM), completions are generated using the [(unwrapped) model's `generate` method](https://github.com/huggingface/trl/blob/f3e8c2304428ef16e9ae5de9e5741ed84d533b7b/trl/trainer/grpo_trainer.py#L965C39-L965C66). This approach quickly becomes a major bottleneck — generation is slow and inefficient, particularly for large batches or models. As a result, training times increase significantly, and overall efficiency drops. To address this, we turn to vLLM, which enables much faster and more scalable generation, helping eliminate this bottleneck in online methods.
|
||||
|
||||
## 🤔 How does vLLM solve the slow generation issue?
|
||||
|
||||
If you've ever done autoregressive decoder training, you know all the input tokens to the LLM produce their attention key and value tensors, and these tensors are kept in GPU memory to later generate subsequent tokens based on them. These cached key and value tensors are often referred to as the KV cache. However, storing the KV cache occupies a lot of memory, so vLLM uses a technique called **PagedAttention** to solve this problem. PagedAttention, which is inspired by the OS’s virtual memory concept, stores continuous keys and values in **non-contiguous memory space**, which is much more efficient. The details of this are beyond the scope of this document, but in short, it allows the model to store the keys and values in a more efficient way, reducing the memory footprint and speeding up the generation process. If you are interested, make sure to check out the [vLLM PagedAttention](https://blog.vllm.ai/2023/06/20/vllm.html) for more details.
|
||||
|
||||
## 🤔 What exactly happens when you run `trl vllm-serve --model <model_name>`?
|
||||
|
||||
When you run for example
|
||||
|
||||
```sh
|
||||
trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 1 --data-parallel-size 4
|
||||
```
|
||||
|
||||
the following happens:
|
||||
|
||||

|
||||
|
||||
1. vLLM first spawns multiple workers to handle incoming requests in parallel. The number of workers is determined by multiplying the `--tensor-parallel-size` and `--data-parallel-size` values. In this example, it spawns 4 workers (1 × 4).
|
||||
Each worker operates independently and processes a chunk of the incoming requests — which are basically the prompts sent to the server for generation. A key point to understand is that these 4 workers are running in parallel, and each one is responsible for handling a subset of the total incoming load.
|
||||
|
||||
2. Once the incoming requests (prompts) are distributed across the workers, the model starts generating completions. Internally, the model’s weights are split across multiple GPUs based on the `--tensor-parallel-size` argument — this is how tensor parallelism is handled. Meanwhile, data parallelism (controlled by `--data-parallel-size`) ensures that different sets of requests are processed independently across the workers. In short: tensor parallelism splits the model across GPUs, and data parallelism splits the batch of requests across different model replicas.
|
||||
|
||||
3. Although the GPUs process requests independently and in parallel, they still need to communicate with each other. Remember that each GPU handles only a slice of the incoming prompts (for example, with 4 GPUs and 8 prompts using `--data-parallel-size=4`, each GPU processes 2 prompts).
|
||||
This GPU-to-GPU communication is managed efficiently by NVIDIA’s NCCL library. The communication mainly ensures that each GPU gets its correct portion of the incoming requests — it’s lightweight and doesn’t interfere with generation itself.
|
||||
Separately, the number of completions to generate per prompt is controlled by the `num_generations` setting in the GRPO config. For instance, if you set `num_generations=2` (like in the picture above), each prompt will have 2 completions. So, with 8 prompts and `num_generations=2`, you would end up with 16 completions total — regardless of the number of GPUs or parallelism settings.
|
||||
|
||||
## 🥸 More detail on what happens under the hood when running the server
|
||||
|
||||
* The vLLM server starts by running the command: `trl vllm-serve --model Qwen/Qwen2.5-7B`.
|
||||
* Once the server is running, it generates completions based on requests from the client (trainer) using `vllm_client.generate` [here](https://github.com/huggingface/trl/blob/cc044e35b285be7dc062764b3364e1e684db4c7c/trl/trainer/grpo_trainer.py#L1025-L1035).
|
||||
* The client (trainer) then requests these completions from the server.
|
||||
* These completions are used to compute the reward signal.
|
||||
* Based on the reward signal and the model’s output, the loss is computed, and the backward pass is performed to update the model’s weights.
|
||||
* **Note**: The server only handles completion generation — it doesn’t train the model. Therefore, the model’s weights aren’t updated on the server. Once the backward pass is complete, the client sends the updated weights to the server using `vllm_client.update_named_param(name, param.data)`.
|
||||
|
||||
When using vLLM, ensure the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation for training using `CUDA_VISIBLE_DEVICES`. See the example below:
|
||||
|
||||
* **Set GPUs *0–3* for vLLM generation:** Assume `CUDA_VISIBLE_DEVICES=0,1,2,3` are allocated for vLLM generation.
|
||||
|
||||
```sh
|
||||
trl vllm-serve --model <model_name> --tensor-parallel-size 1 --data-parallel-size 4
|
||||
```
|
||||
|
||||
* **And GPUs *4–7* for training:** If you do not set the `CUDA_VISIBLE_DEVICES` environment variable, the training script will use all available GPUs by default, which may lead to resource conflicts. To avoid this, you can specify which GPUs to use for training. For example, if you want to use GPUs 4–7 for training, set the environment variable as follows:
|
||||
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
|
||||
```
|
||||
|
||||
## 🍷 More customization options with vLLM?
|
||||
|
||||
You can customize the server configuration by passing additional arguments.
|
||||
|
||||
```
|
||||
$ trl vllm-serve --help
|
||||
usage: trl vllm-serve [-h] --model MODEL [--revision REVISION] [--tensor_parallel_size TENSOR_PARALLEL_SIZE] [--host HOST]
|
||||
[--port PORT] [--gpu_memory_utilization GPU_MEMORY_UTILIZATION] [--dtype DTYPE]
|
||||
[--max_model_len MAX_MODEL_LEN] [--enable_prefix_caching ENABLE_PREFIX_CACHING]
|
||||
usage: trl vllm-serve [-h] --model MODEL [--revision REVISION] [--tensor_parallel_size TENSOR_PARALLEL_SIZE]
|
||||
[--data_parallel_size DATA_PARALLEL_SIZE] [--host HOST] [--port PORT]
|
||||
[--gpu_memory_utilization GPU_MEMORY_UTILIZATION] [--dtype DTYPE] [--max_model_len MAX_MODEL_LEN]
|
||||
[--enable_prefix_caching ENABLE_PREFIX_CACHING] [--enforce_eager ENFORCE_EAGER] [--log_level LOG_LEVEL]
|
||||
|
||||
options:
|
||||
-h, --help Show this help message and exit
|
||||
@ -22,6 +127,8 @@ options:
|
||||
--revision REVISION Revision to use for the model. If not specified, the default branch will be used. (default: None)
|
||||
--tensor_parallel_size TENSOR_PARALLEL_SIZE, --tensor-parallel-size TENSOR_PARALLEL_SIZE
|
||||
Number of tensor parallel workers to use. (default: 1)
|
||||
--data_parallel_size DATA_PARALLEL_SIZE, --data-parallel-size DATA_PARALLEL_SIZE
|
||||
Number of data parallel workers to use. (default: 1)
|
||||
--host HOST Host address to run the server on. (default: 0.0.0.0)
|
||||
--port PORT Port to run the server on. (default: 8000)
|
||||
--gpu_memory_utilization GPU_MEMORY_UTILIZATION, --gpu-memory-utilization GPU_MEMORY_UTILIZATION
|
||||
@ -38,10 +145,41 @@ options:
|
||||
--enable_prefix_caching ENABLE_PREFIX_CACHING, --enable-prefix-caching ENABLE_PREFIX_CACHING
|
||||
Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support this
|
||||
feature. (default: None)
|
||||
--enforce_eager ENFORCE_EAGER, --enforce-eager ENFORCE_EAGER
|
||||
Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the model
|
||||
in eager mode. If `False` (default behavior), we will use CUDA graph and eager execution in hybrid. (default:
|
||||
None)
|
||||
--log_level LOG_LEVEL, --log-level LOG_LEVEL
|
||||
Log level for uvicorn. Possible choices: 'critical', 'error', 'warning', 'info', 'debug', 'trace'. (default:
|
||||
info)
|
||||
```
|
||||
|
||||
### Find the best distributed setup
|
||||
## 🥳 Okay, now that we have the server running, how can we use it to generate completions?
|
||||
|
||||
Run the training script and pass `use_vllm=True` in the training arguments:
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(..., use_vllm=True)
|
||||
```
|
||||
|
||||
## 💆🏻♀️ What's the best distributed setup?
|
||||
|
||||

|
||||

|
||||
|
||||

|
||||
First and foremost, always remember that the optimal setup depends on:
|
||||
|
||||
* The model size
|
||||
* The number of GPUs you have
|
||||
* The GPU memory size
|
||||
* The batch size you are using
|
||||
* The number of requests you are sending to the server (prompts)
|
||||
* The `max_model_len` you are using (this is the max length of the input sequence that the model can process, a.k.a. the context window size)
|
||||
* The number of completions you are generating for each request (`num_generations`)
|
||||
|
||||
Given these factors, our experiments on the Qwen model family (3B, 7B, 14B, 32B) using 8 H100 GPUs show that:
|
||||
|
||||
* For reasonable-sized models (3B–14B) and a moderate context window (`max_len < 8k`), using full capacity for data parallelism gives better throughput. The setup `(tp=1, dp=8)` yields the best results.
|
||||
* For larger models (32B) and longer context windows (`max_len > 8k`), a smaller DP size combined with some model-side parallelism performs better. For example, `(tp=2, dp=4)` is a good setup for 32B models with a larger context window.
|
||||
|
28
examples/accelerate_configs/fsdp1.yaml
Normal file
28
examples/accelerate_configs/fsdp1.yaml
Normal file
@ -0,0 +1,28 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
fsdp_config:
|
||||
fsdp_activation_checkpointing: false
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_forward_prefetch: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: FULL_SHARD
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_use_orig_params: true
|
||||
fsdp_version: 1
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
25
examples/accelerate_configs/fsdp2.yaml
Normal file
25
examples/accelerate_configs/fsdp2.yaml
Normal file
@ -0,0 +1,25 @@
|
||||
# Requires accelerate 1.7.0 or higher
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
fsdp_config:
|
||||
fsdp_activation_checkpointing: false
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_version: 2
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
@ -1,25 +0,0 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_forward_prefetch: false
|
||||
fsdp_offload_params: true
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_use_orig_params: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
@ -4,4 +4,4 @@ This directory contains a collection of Jupyter notebooks that demonstrate how t
|
||||
|
||||
- [`best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb): This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO.
|
||||
- [`gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb): This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook.
|
||||
- [`gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook.
|
||||
- [`gpt2-sentiment-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook.
|
||||
|
@ -1,118 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, load_tool
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment
|
||||
|
||||
|
||||
def generate_data(n):
|
||||
"""Generate random arithmetic tasks and answers."""
|
||||
tasks, answers = [], []
|
||||
for _ in range(n):
|
||||
a = np.random.randint(0, 50)
|
||||
b = np.random.randint(0, 50)
|
||||
op = np.random.choice(["-", "+", "*"])
|
||||
tasks.append(f"\n\nWhat is {a} {op} {b}?")
|
||||
if op == "-":
|
||||
answers.append(a - b)
|
||||
elif op == "+":
|
||||
answers.append(a + b)
|
||||
else:
|
||||
answers.append(a * b)
|
||||
return tasks, answers
|
||||
|
||||
|
||||
def exact_match_reward(responses, answers=None):
|
||||
"""Reward if generated response contains correct answer."""
|
||||
rewards = []
|
||||
pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*<submit>" # generated by chatGPT
|
||||
for response, answer in zip(responses, answers):
|
||||
reward = 0.0
|
||||
predicted_number = None
|
||||
match_pattern = re.findall(pattern, response)
|
||||
if match_pattern:
|
||||
predicted_number = float(match_pattern[0])
|
||||
if predicted_number is not None:
|
||||
if np.abs(predicted_number - answer) < 0.01:
|
||||
reward += 1.0
|
||||
rewards.append(torch.tensor(reward))
|
||||
return rewards
|
||||
|
||||
|
||||
# set up models
|
||||
model_id = "gpt2"
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id)
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# system prompt
|
||||
prompt = """\
|
||||
What is 13-3?
|
||||
|
||||
<request><SimpleCalculatorTool>13-3<call>10.0<response>
|
||||
|
||||
Result=10<submit>
|
||||
|
||||
What is 4*3?
|
||||
|
||||
<request><SimpleCalculatorTool>4*3<call>12.0<response>
|
||||
|
||||
Result=12<submit>"""
|
||||
|
||||
generation_kwargs = {
|
||||
"min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.eos_token_id,
|
||||
"eos_token_id": -1,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
# trainer
|
||||
ppo_config = PPOConfig(
|
||||
batch_size=256,
|
||||
learning_rate=1.41e-5,
|
||||
mini_batch_size=64,
|
||||
log_with="wandb",
|
||||
)
|
||||
ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer)
|
||||
|
||||
# text env
|
||||
text_env = TextEnvironment(
|
||||
model,
|
||||
tokenizer,
|
||||
{"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")},
|
||||
exact_match_reward,
|
||||
prompt,
|
||||
generation_kwargs=generation_kwargs,
|
||||
)
|
||||
|
||||
# main training loop
|
||||
for _step in range(100):
|
||||
tasks, answers = generate_data(ppo_config.batch_size)
|
||||
queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers)
|
||||
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
|
||||
|
||||
response_texts = [tokenizer.decode(response) for response in responses]
|
||||
query_texts = [tokenizer.decode(query) for query in queries]
|
||||
texts = {"query": [qt.split("<submit>")[-1].strip() for qt in query_texts], "response": response_texts}
|
||||
ppo_trainer.log_stats(train_stats, texts, rewards, columns_to_log=["query", "response", "answer"])
|
||||
ppo_trainer.save_pretrained(model_id + "-calculator")
|
@ -1,193 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoTokenizer, HfArgumentParser, load_tool
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment
|
||||
|
||||
|
||||
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name: Optional[str] = field(default="bigcode/starcoderbase", metadata={"help": "the model name"})
|
||||
learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
|
||||
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
|
||||
batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=16, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
max_new_tokens: Optional[int] = field(default=256, metadata={"help": "max number of generated tokens per turn"})
|
||||
ppo_epochs: Optional[int] = field(default=1, metadata={"help": "max number of ppo epochs"})
|
||||
n_epochs: Optional[int] = field(default=32, metadata={"help": "max number of ppo epochs"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
|
||||
def exact_match_reward(responses, answers=None):
|
||||
"""Reward if generated response contains correct answer."""
|
||||
rewards = []
|
||||
pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*<submit>" # generated by chatGPT
|
||||
for response, answer in zip(responses, answers):
|
||||
reward = 0.0
|
||||
try:
|
||||
predicted_number = None
|
||||
match_pattern = re.findall(pattern, response)
|
||||
if match_pattern:
|
||||
predicted_number = float(match_pattern[0])
|
||||
if predicted_number is not None:
|
||||
if np.abs(predicted_number - float(answer)) < 0.1:
|
||||
reward += 1.0
|
||||
except Exception:
|
||||
pass
|
||||
rewards.append(torch.tensor(reward))
|
||||
return rewards
|
||||
|
||||
|
||||
def evaluate(test_dataloader, text_env, ppo_trainer):
|
||||
test_rewards = []
|
||||
for test_batch in test_dataloader:
|
||||
_, _, _, rewards, _ = text_env.run(test_batch["query"], answers=test_batch["answer"])
|
||||
test_rewards.extend(rewards)
|
||||
test_rewards = ppo_trainer.accelerator.gather_for_metrics(
|
||||
torch.stack(test_rewards).to(ppo_trainer.accelerator.device)
|
||||
)
|
||||
return test_rewards.mean()
|
||||
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
target_modules=["c_proj", "c_attn", "q_attn"],
|
||||
)
|
||||
|
||||
# set up models
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
script_args.model_name,
|
||||
use_auth_token=True,
|
||||
load_in_4bit=True,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, use_auth_token=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
ds = load_dataset("openai/gsm8k", "main", split="train")
|
||||
ds = ds.rename_columns({"question": "query"})
|
||||
ds = ds.map(lambda x: {"answer": x["answer"].split("#### ")[1]})
|
||||
ds = ds.select(range(1, len(ds))) # skip the first sample which is used in prompt
|
||||
|
||||
ds_test = load_dataset("openai/gsm8k", "main", split="test")
|
||||
ds_test = ds_test.rename_columns({"question": "query"})
|
||||
ds_test = ds_test.map(lambda x: {"answer": x["answer"].split("#### ")[1]})
|
||||
|
||||
test_dataloader = torch.utils.data.DataLoader(ds_test, batch_size=script_args.batch_size)
|
||||
|
||||
# prompt
|
||||
prompt = """\
|
||||
Example of using a Python API to solve math questions.
|
||||
|
||||
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
|
||||
|
||||
<request><PythonInterpreter>
|
||||
def solution():
|
||||
money_initial = 23
|
||||
bagels = 5
|
||||
bagel_cost = 3
|
||||
money_spent = bagels * bagel_cost
|
||||
money_left = money_initial - money_spent
|
||||
result = money_left
|
||||
return result
|
||||
print(solution())
|
||||
<call>72<response>
|
||||
|
||||
Result = 72 <submit>
|
||||
|
||||
Q: """
|
||||
|
||||
generation_kwargs = {
|
||||
"min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.eos_token_id,
|
||||
"eos_token_id": -1,
|
||||
"max_new_tokens": script_args.max_new_tokens,
|
||||
}
|
||||
|
||||
# trainer
|
||||
ppo_config = PPOConfig(
|
||||
batch_size=script_args.batch_size,
|
||||
learning_rate=script_args.learning_rate,
|
||||
mini_batch_size=script_args.mini_batch_size,
|
||||
ppo_epochs=script_args.ppo_epochs,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
log_with="wandb",
|
||||
tracker_project_name="trl-gsm8k",
|
||||
remove_unused_columns=False,
|
||||
optimize_cuda_cache=True,
|
||||
)
|
||||
|
||||
ppo_trainer = PPOTrainer(args=ppo_config, model=model, tokenizer=tokenizer, dataset=ds)
|
||||
test_dataloader = ppo_trainer.accelerator.prepare(test_dataloader)
|
||||
|
||||
# text env
|
||||
text_env = TextEnvironment(
|
||||
model,
|
||||
tokenizer,
|
||||
[load_tool("lvwerra/python-interpreter")],
|
||||
exact_match_reward,
|
||||
prompt,
|
||||
max_turns=2,
|
||||
generation_kwargs=generation_kwargs,
|
||||
)
|
||||
|
||||
# main training loop
|
||||
for epoch in range(script_args.n_epochs):
|
||||
for step, batch in enumerate(ppo_trainer.dataloader):
|
||||
if (step == 0) and (epoch % 4 == 0): # evaluate every 4 epochs
|
||||
reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer)
|
||||
else:
|
||||
reward_mean_test = None
|
||||
|
||||
queries, responses, masks, rewards, histories = text_env.run(batch["query"], answers=batch["answer"])
|
||||
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
|
||||
|
||||
# logging
|
||||
if reward_mean_test is not None:
|
||||
train_stats["env/reward_mean_test"] = reward_mean_test
|
||||
texts = {
|
||||
"query": batch["query"],
|
||||
"response": [tokenizer.decode(response) for response in responses],
|
||||
"answer": batch["answer"],
|
||||
}
|
||||
ppo_trainer.log_stats(train_stats, texts, rewards, columns_to_log=["query", "response", "answer"])
|
||||
|
||||
reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer)
|
||||
ppo_trainer.save_pretrained(f"model/{script_args.model_name}-gsm8k")
|
@ -1,192 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoTokenizer, HfArgumentParser, load_tool
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment
|
||||
|
||||
|
||||
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name: Optional[str] = field(default="bigcode/starcoderbase", metadata={"help": "the model name"})
|
||||
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
|
||||
learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
|
||||
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
|
||||
batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=16, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
max_new_tokens: Optional[int] = field(default=256, metadata={"help": "max number of generated tokens per turn"})
|
||||
ppo_epochs: Optional[int] = field(default=1, metadata={"help": "max number of ppo epochs"})
|
||||
iterations: Optional[int] = field(default=1000, metadata={"help": "the number of iterations"})
|
||||
seed: Optional[int] = field(default=0, metadata={"help": "the random seed"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
target_modules=["c_proj", "c_attn", "q_attn"],
|
||||
)
|
||||
|
||||
# set up models
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
script_args.model_name,
|
||||
use_auth_token=True,
|
||||
trust_remote_code=True,
|
||||
load_in_4bit=True,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, use_auth_token=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# system prompt
|
||||
prompt = """\
|
||||
Answer the following question:
|
||||
|
||||
Q: In which branch of the arts is Patricia Neary famous?
|
||||
A: Ballets
|
||||
A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
|
||||
Result=Ballets<submit>
|
||||
|
||||
Q: Who won Super Bowl XX?
|
||||
A: Chicago Bears
|
||||
A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
|
||||
Result=Chicago Bears<submit>
|
||||
|
||||
Q: """
|
||||
|
||||
generation_kwargs = {
|
||||
"min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.eos_token_id,
|
||||
"eos_token_id": -1,
|
||||
"max_new_tokens": script_args.max_new_tokens,
|
||||
}
|
||||
|
||||
# trainer
|
||||
config = PPOConfig(
|
||||
batch_size=script_args.batch_size,
|
||||
model_name=script_args.model_name,
|
||||
learning_rate=script_args.learning_rate,
|
||||
log_with=script_args.log_with,
|
||||
mini_batch_size=script_args.mini_batch_size,
|
||||
ppo_epochs=script_args.ppo_epochs,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
seed=script_args.seed,
|
||||
optimize_cuda_cache=True,
|
||||
)
|
||||
ppo_trainer = PPOTrainer(args=config, model=model, tokenizer=tokenizer)
|
||||
dataset = load_dataset("mandarjoshi/trivia_qa", "rc", split="train")
|
||||
local_seed = script_args.seed + ppo_trainer.accelerator.process_index * 100003 # Prime
|
||||
dataset = dataset.shuffle(local_seed)
|
||||
|
||||
|
||||
def data_generator():
|
||||
for i in range(len(dataset)):
|
||||
yield dataset[i]["question"], list(dataset[i]["answer"]["normalized_aliases"])
|
||||
|
||||
|
||||
gen = data_generator()
|
||||
gen = iter(gen)
|
||||
|
||||
|
||||
def generate_data(n):
|
||||
tasks, answers = [], []
|
||||
for _i in range(n):
|
||||
q, a = next(gen)
|
||||
tasks.append(q)
|
||||
answers.append(a)
|
||||
return tasks, answers
|
||||
|
||||
|
||||
def exact_match_reward(responses, answers=None):
|
||||
"""Reward if generated response contains correct answer."""
|
||||
rewards = []
|
||||
for response, answer in zip(responses, answers):
|
||||
reward = 0.0
|
||||
for a in answer:
|
||||
if a.lower() in response.lower():
|
||||
reward += 1.0
|
||||
break
|
||||
rewards.append(torch.tensor(reward))
|
||||
return rewards
|
||||
|
||||
|
||||
def tool_fn(x):
|
||||
# limit the amount of tokens
|
||||
return tool(x).split("\n")[1][:600]
|
||||
|
||||
|
||||
# text env
|
||||
tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
|
||||
|
||||
text_env = TextEnvironment(
|
||||
model,
|
||||
tokenizer,
|
||||
{"Wiki": tool_fn},
|
||||
exact_match_reward,
|
||||
prompt,
|
||||
generation_kwargs=generation_kwargs,
|
||||
max_tool_response=400,
|
||||
)
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
trainable_params = 0
|
||||
all_param = 0
|
||||
for _, param in model.named_parameters():
|
||||
all_param += param.numel()
|
||||
if param.requires_grad:
|
||||
trainable_params += param.numel()
|
||||
print(
|
||||
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
||||
)
|
||||
|
||||
|
||||
print_trainable_parameters(model)
|
||||
# main training loop
|
||||
for i in range(script_args.iterations):
|
||||
tasks, answers = generate_data(config.batch_size)
|
||||
queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers)
|
||||
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
|
||||
response_texts = [tokenizer.decode(response) for response in responses]
|
||||
query_texts = [tokenizer.decode(query) for query in queries]
|
||||
texts = {
|
||||
"query": [qt.split("<submit>")[-1].strip() for qt in query_texts],
|
||||
"response": response_texts,
|
||||
"answer": [", ".join(item) for item in answers],
|
||||
}
|
||||
all_rewards = ppo_trainer.accelerator.gather(torch.tensor(rewards, device=ppo_trainer.accelerator.device))
|
||||
ppo_trainer.log_stats(train_stats, texts, list(all_rewards), columns_to_log=["query", "response", "answer"])
|
||||
if i % 100 == 0:
|
||||
ppo_trainer.save_pretrained(f"models/{script_args.model_name}_{script_args.seed}_{i}_triviaqa")
|
@ -45,7 +45,6 @@ python examples/scripts/gkd.py \
|
||||
--lora_alpha 16
|
||||
"""
|
||||
|
||||
from accelerate import PartialState
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer, GenerationConfig
|
||||
|
||||
@ -106,14 +105,6 @@ if __name__ == "__main__":
|
||||
################
|
||||
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
||||
|
||||
with PartialState().local_main_process_first():
|
||||
dataset = dataset.map(
|
||||
lambda x: {
|
||||
"prompt": tokenizer.apply_chat_template(x["prompt"], tokenize=False, add_generation_prompt=True)
|
||||
},
|
||||
num_proc=training_args.dataset_num_proc,
|
||||
)
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
|
@ -59,6 +59,9 @@ from transformers import (
|
||||
Qwen2Config,
|
||||
Qwen2ForCausalLM,
|
||||
Qwen2ForSequenceClassification,
|
||||
Qwen3Config,
|
||||
Qwen3ForCausalLM,
|
||||
Qwen3ForSequenceClassification,
|
||||
SiglipVisionConfig,
|
||||
T5Config,
|
||||
T5ForConditionalGeneration,
|
||||
@ -120,6 +123,7 @@ for model_id, config_class, model_class, suffix in [
|
||||
("microsoft/Phi-3.5-mini-instruct", Phi3Config, Phi3ForCausalLM, None),
|
||||
("Qwen/Qwen2.5-32B-Instruct", Qwen2Config, Qwen2ForCausalLM, "2.5"),
|
||||
("Qwen/Qwen2.5-Coder-0.5B", Qwen2Config, Qwen2ForCausalLM, "2.5-Coder"),
|
||||
("Qwen/Qwen3-4B", Qwen3Config, Qwen3ForCausalLM, None),
|
||||
]:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
config = config_class(
|
||||
@ -134,7 +138,7 @@ for model_id, config_class, model_class, suffix in [
|
||||
push_to_hub(model, tokenizer, "tiny", suffix)
|
||||
|
||||
|
||||
# A slightly bigger model, required for vLLM testing
|
||||
# Two slightly bigger models, required for vLLM testing
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-32B-Instruct")
|
||||
config = Qwen2Config(
|
||||
vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()),
|
||||
@ -147,11 +151,23 @@ config = Qwen2Config(
|
||||
model = Qwen2ForCausalLM(config)
|
||||
push_to_hub(model, tokenizer, "small", "2.5")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")
|
||||
config = Qwen3Config(
|
||||
vocab_size=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()),
|
||||
hidden_size=128, # increase hidden size so that hidden_size // num_attention_heads = 32, required for vLLM
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
num_hidden_layers=2,
|
||||
intermediate_size=32,
|
||||
)
|
||||
model = Qwen3ForCausalLM(config)
|
||||
push_to_hub(model, tokenizer, "small")
|
||||
|
||||
# Reward models
|
||||
for model_id, config_class, model_class, suffix in [
|
||||
("meta-llama/Llama-3.2-1B-Instruct", LlamaConfig, LlamaForSequenceClassification, "3.2"),
|
||||
("Qwen/Qwen2.5-32B-Instruct", Qwen2Config, Qwen2ForSequenceClassification, "2.5"),
|
||||
("Qwen/Qwen3-4B", Qwen3Config, Qwen3ForSequenceClassification, None),
|
||||
]:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
config = config_class(
|
||||
|
99
setup.cfg
99
setup.cfg
@ -1,2 +1,99 @@
|
||||
[metadata]
|
||||
license_file = LICENSE
|
||||
name = trl
|
||||
version = 0.18.1
|
||||
description = Train transformer language models with reinforcement learning.
|
||||
long_description = file: README.md
|
||||
long_description_content_type = text/markdown
|
||||
author = Leandro von Werra
|
||||
author_email = leandro.vonwerra@gmail.com
|
||||
url = https://github.com/huggingface/trl
|
||||
keywords = transformers, huggingface, language modeling, post-training, rlhf, sft, dpo, grpo
|
||||
license_file = LICENSE
|
||||
classifiers =
|
||||
Development Status :: 2 - Pre-Alpha
|
||||
Intended Audience :: Developers
|
||||
Intended Audience :: Science/Research
|
||||
Natural Language :: English
|
||||
Operating System :: OS Independent
|
||||
Programming Language :: Python :: 3
|
||||
Programming Language :: Python :: 3.9
|
||||
Programming Language :: Python :: 3.10
|
||||
Programming Language :: Python :: 3.11
|
||||
Programming Language :: Python :: 3.12
|
||||
Programming Language :: Python :: 3.13
|
||||
|
||||
[options]
|
||||
packages = find:
|
||||
python_requires = >=3.9
|
||||
include_package_data = True
|
||||
install_requires =
|
||||
accelerate>=0.34.0
|
||||
datasets>=3.0.0
|
||||
transformers>=4.50.0
|
||||
|
||||
[options.packages.find]
|
||||
exclude =
|
||||
tests*
|
||||
|
||||
[options.package_data]
|
||||
trl =
|
||||
templates/*.md
|
||||
accelerate_configs/*.yaml
|
||||
|
||||
[options.extras_require]
|
||||
bco =
|
||||
scikit-learn
|
||||
joblib
|
||||
deepspeed =
|
||||
deepspeed>=0.14.4
|
||||
diffusers =
|
||||
diffusers>=0.18.0
|
||||
judges =
|
||||
openai>=1.23.2
|
||||
llm-blender>=0.0.2
|
||||
liger =
|
||||
liger-kernel>=0.5.9
|
||||
mergekit =
|
||||
mergekit>=0.0.5.1
|
||||
peft =
|
||||
peft>=0.8.0
|
||||
quantization =
|
||||
bitsandbytes
|
||||
scikit =
|
||||
scikit-learn
|
||||
test =
|
||||
parameterized
|
||||
pytest-cov
|
||||
pytest-rerunfailures
|
||||
pytest-xdist
|
||||
pytest
|
||||
vllm =
|
||||
# vLLM package does not yet support Python 3.13. These constraints can be lifted once support is added:
|
||||
# see https://github.com/vllm-project/vllm/pull/13164
|
||||
vllm>=0.8.3; python_version < "3.13"
|
||||
fastapi; python_version < "3.13"
|
||||
pydantic; python_version < "3.13"
|
||||
requests; python_version < "3.13"
|
||||
uvicorn; python_version < "3.13"
|
||||
|
||||
vlm =
|
||||
Pillow
|
||||
dev =
|
||||
%(bco)s
|
||||
%(deepspeed)s
|
||||
%(diffusers)s
|
||||
%(judges)s
|
||||
%(liger)s
|
||||
%(mergekit)s
|
||||
%(peft)s
|
||||
%(quantization)s
|
||||
%(scikit)s
|
||||
%(test)s
|
||||
%(vlm)s
|
||||
|
||||
[options.entry_points]
|
||||
console_scripts =
|
||||
trl = trl.cli:main
|
||||
|
||||
[coverage:run]
|
||||
branch = True
|
||||
|
121
setup.py
121
setup.py
@ -12,124 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""trl is an open library for RL with transformer models.
|
||||
|
||||
Note:
|
||||
|
||||
VERSION needs to be formatted following the MAJOR.MINOR.PATCH convention
|
||||
(we need to follow this convention to be able to retrieve versioned scripts)
|
||||
|
||||
Simple check list for release from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py
|
||||
|
||||
To create the package for PyPI.
|
||||
|
||||
0. Prerequisites:
|
||||
- Dependencies:
|
||||
- twine: "pip install twine"
|
||||
- Create an account in (and join the 'trl' project):
|
||||
- PyPI: https://pypi.org/
|
||||
- Test PyPI: https://test.pypi.org/
|
||||
|
||||
1. Change the version in:
|
||||
- __init__.py
|
||||
- setup.py
|
||||
|
||||
2. Commit these changes: "git commit -m 'Release: VERSION'"
|
||||
|
||||
3. Add a tag in git to mark the release: "git tag VERSION -m 'Add tag VERSION for PyPI'"
|
||||
Push the tag to remote: git push --tags origin main
|
||||
|
||||
4. Build both the sources and the wheel. Do not change anything in setup.py between
|
||||
creating the wheel and the source distribution (obviously).
|
||||
|
||||
First, delete any "build" directory that may exist from previous builds.
|
||||
|
||||
For the wheel, run: "python setup.py bdist_wheel" in the top level directory.
|
||||
(this will build a wheel for the python version you use to build it).
|
||||
|
||||
For the sources, run: "python setup.py sdist"
|
||||
You should now have a /dist directory with both .whl and .tar.gz source versions.
|
||||
|
||||
5. Check that everything looks correct by uploading the package to the PyPI test server:
|
||||
|
||||
twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
|
||||
|
||||
Check that you can install it in a virtualenv/notebook by running:
|
||||
pip install -i https://testpypi.python.org/pypi trl
|
||||
|
||||
6. Upload the final version to actual PyPI:
|
||||
twine upload dist/* -r pypi
|
||||
|
||||
7. Fill release notes in the tag in github once everything is looking hunky-dory.
|
||||
|
||||
8. Change the version in __init__.py and setup.py to X.X.X+1.dev0 (e.g. VERSION=1.18.3 -> 1.18.4.dev0).
|
||||
Then push the change with a message 'set dev version'
|
||||
"""
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
from setuptools import setup
|
||||
|
||||
|
||||
__version__ = "0.17.0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
|
||||
REQUIRED_PKGS = [
|
||||
"accelerate>=0.34.0",
|
||||
"datasets>=3.0.0",
|
||||
"rich", # rich shouldn't be a required package for trl, we should remove it from here
|
||||
"transformers>=4.46.0",
|
||||
]
|
||||
EXTRAS = {
|
||||
"deepspeed": ["deepspeed>=0.14.4"],
|
||||
"diffusers": ["diffusers>=0.18.0,<0.33.0"], # Temp set <0.33.0 due to ftfy optional dep issue breaking doc builds
|
||||
"judges": ["openai>=1.23.2", "llm-blender>=0.0.2"],
|
||||
"liger": ["liger-kernel>=0.5.6"],
|
||||
"mergekit": ["mergekit>=0.0.5.1"],
|
||||
"peft": ["peft>=0.8.0"],
|
||||
"quantization": ["bitsandbytes"],
|
||||
"scikit": ["scikit-learn"],
|
||||
"bco": ["scikit-learn", "joblib"],
|
||||
"test": ["parameterized", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "pytest"],
|
||||
"vllm": ["vllm>=0.8.3", "fastapi", "pydantic", "requests", "uvicorn"],
|
||||
"vlm": ["Pillow"],
|
||||
}
|
||||
EXTRAS["dev"] = []
|
||||
for reqs in EXTRAS.values():
|
||||
EXTRAS["dev"].extend(reqs)
|
||||
|
||||
|
||||
setup(
|
||||
name="trl",
|
||||
license="Apache 2.0",
|
||||
classifiers=[
|
||||
"Development Status :: 2 - Pre-Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Natural Language :: English",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
],
|
||||
url="https://github.com/huggingface/trl",
|
||||
entry_points={
|
||||
"console_scripts": ["trl=trl.cli:main"],
|
||||
},
|
||||
include_package_data=True,
|
||||
package_data={
|
||||
"trl": ["templates/*.md"],
|
||||
},
|
||||
packages=find_packages(exclude={"tests", "tests.slow", "trl.templates"}),
|
||||
install_requires=REQUIRED_PKGS,
|
||||
extras_require=EXTRAS,
|
||||
python_requires=">=3.9",
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
zip_safe=False,
|
||||
version=__version__,
|
||||
description="Train transformer language models with reinforcement learning.",
|
||||
keywords="transformers, huggingface, language modeling, post-training, rlhf, sft, dpo, grpo",
|
||||
author="Leandro von Werra",
|
||||
author_email="leandro.vonwerra@gmail.com",
|
||||
)
|
||||
setup()
|
||||
|
@ -22,7 +22,7 @@ from accelerate.utils.memory import release_memory
|
||||
from datasets import load_dataset
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.testing_utils import require_liger_kernel, require_torch_accelerator
|
||||
from transformers.testing_utils import require_liger_kernel, require_peft, require_torch_accelerator
|
||||
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
|
||||
@ -81,3 +81,67 @@ class GRPOTrainerSlowTester(unittest.TestCase):
|
||||
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
||||
|
||||
release_memory(model, trainer)
|
||||
|
||||
@parameterized.expand(MODELS_TO_TEST)
|
||||
@require_liger_kernel
|
||||
@require_peft
|
||||
def test_training_with_liger_grpo_loss_and_peft(self, model_name):
|
||||
from peft import LoraConfig, TaskType
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = GRPOConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=3,
|
||||
num_generations=3,
|
||||
use_liger_loss=True,
|
||||
max_completion_length=self.max_length,
|
||||
report_to="none",
|
||||
logging_strategy="no",
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
|
||||
|
||||
# Configure PEFT with LoRA
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
r=8,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.1,
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model=model,
|
||||
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||
args=training_args,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss
|
||||
|
||||
assert isinstance(trainer.liger_grpo_loss, LigerFusedLinearGRPOLoss)
|
||||
|
||||
# Verify PEFT adapter is properly initialized
|
||||
from peft import PeftModel
|
||||
|
||||
self.assertTrue(isinstance(trainer.model, PeftModel), "Model should be wrapped with PEFT")
|
||||
|
||||
# Store adapter weights before training
|
||||
previous_trainable_params = {
|
||||
n: param.clone() for n, param in trainer.model.named_parameters() if param.requires_grad
|
||||
}
|
||||
self.assertTrue(len(previous_trainable_params) > 0, "No trainable parameters found in PEFT model")
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Verify adapter weights have changed after training
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
||||
|
||||
release_memory(model, trainer)
|
||||
|
@ -420,3 +420,39 @@ class SFTTrainerSlowTester(unittest.TestCase):
|
||||
trainer.train()
|
||||
|
||||
release_memory(trainer.model, trainer)
|
||||
|
||||
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS)))
|
||||
@require_torch_accelerator
|
||||
def test_train_offloading(self, model_name, packing):
|
||||
"""Test that activation offloading works with SFTTrainer."""
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Initialize the trainer
|
||||
training_args = SFTConfig(
|
||||
output_dir=tmp_dir,
|
||||
activation_offloading=True,
|
||||
report_to="none",
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=2,
|
||||
packing=packing,
|
||||
max_length=self.max_length,
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
model=model_name, args=training_args, train_dataset=self.train_dataset, eval_dataset=self.eval_dataset
|
||||
)
|
||||
|
||||
# Save the initial parameters to compare them later
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
# Train the model
|
||||
trainer.train()
|
||||
|
||||
# Check that the training loss is not None
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# Check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||
|
||||
release_memory(trainer.model, trainer)
|
||||
|
156
tests/test_activation_offloading.py
Normal file
156
tests/test_activation_offloading.py
Normal file
@ -0,0 +1,156 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.testing_utils import require_peft, require_torch_accelerator, torch_device
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from trl.models.activation_offloading import NoOpManager, OffloadActivations
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
|
||||
class TestActivationOffloading(unittest.TestCase):
|
||||
@require_torch_accelerator
|
||||
@require_peft
|
||||
def test_offloading_with_peft_models(self) -> None:
|
||||
"""Test that activation offloading works with PEFT models."""
|
||||
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)
|
||||
peft_config = LoraConfig(
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
r=8,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
model = get_peft_model(model, peft_config)
|
||||
inp = torch.randint(0, 100, (2, 10), device=torch_device)
|
||||
|
||||
# First forward-backward pass without offloading
|
||||
torch.manual_seed(42)
|
||||
loss = model(inp, labels=inp).loss
|
||||
loss.backward()
|
||||
|
||||
# Store gradients - only from trainable parameters
|
||||
grads_original = []
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad and param.grad is not None:
|
||||
grads_original.append((name, param.grad.clone()))
|
||||
|
||||
# Reset gradients
|
||||
for p in model.parameters():
|
||||
if p.grad is not None:
|
||||
p.grad = None
|
||||
|
||||
# Second forward-backward pass with offloading
|
||||
torch.manual_seed(42)
|
||||
with OffloadActivations():
|
||||
loss_c = model(inp, labels=inp).loss
|
||||
loss_c.backward()
|
||||
|
||||
# Compare gradients - only trainable parameters
|
||||
for name_orig, grad_orig in grads_original:
|
||||
for name_param, param in model.named_parameters():
|
||||
if name_param == name_orig and param.requires_grad and param.grad is not None:
|
||||
self.assertTrue(
|
||||
torch.allclose(grad_orig, param.grad, rtol=1e-4, atol=1e-5),
|
||||
f"Gradient mismatch for {name_orig}",
|
||||
)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_noop_manager_with_offloading(self):
|
||||
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)
|
||||
inp = torch.randint(0, 100, (2, 10), device=torch_device)
|
||||
|
||||
# Run with offloading but disable for specific section
|
||||
with OffloadActivations():
|
||||
# First forward-backward with normal offloading
|
||||
torch.manual_seed(42)
|
||||
out1 = model(inp, labels=inp)
|
||||
out1.loss.backward()
|
||||
grads1 = [p.grad.clone() for p in model.parameters()]
|
||||
|
||||
# Reset grads
|
||||
for p in model.parameters():
|
||||
p.grad = None
|
||||
|
||||
# Second forward-backward with NoOpManager
|
||||
with NoOpManager():
|
||||
torch.manual_seed(42)
|
||||
out2 = model(inp, labels=inp)
|
||||
out2.loss.backward()
|
||||
|
||||
grads2 = [p.grad.clone() for p in model.parameters()]
|
||||
|
||||
# Gradients should match as NoOpManager should have prevented offloading
|
||||
for g1, g2 in zip(grads1, grads2):
|
||||
self.assertTrue(torch.allclose(g1, g2, rtol=1e-4, atol=1e-5))
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_min_offload_size(self):
|
||||
"""Test that tensors smaller than min_offload_size aren't offloaded"""
|
||||
model = nn.Sequential(
|
||||
nn.Linear(5, 5), # Small layer that shouldn't be offloaded
|
||||
nn.Linear(5, 1000), # Large layer that should be offloaded
|
||||
).to(torch_device)
|
||||
|
||||
inp = torch.randn(2, 5, device=torch_device)
|
||||
|
||||
with OffloadActivations(min_offload_size=1000):
|
||||
out = model(inp)
|
||||
out.sum().backward()
|
||||
|
||||
# The test passes if no errors occur, as we're mainly testing
|
||||
# that the logic handles both offloaded and non-offloaded tensors
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_real_hf_model(self):
|
||||
"""Test with an actual HuggingFace model"""
|
||||
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)
|
||||
|
||||
# Create small input
|
||||
inp = torch.randint(0, 100, (2, 10), device=torch_device)
|
||||
|
||||
# Baseline without offloading
|
||||
torch.manual_seed(42)
|
||||
out1 = model(inp, labels=inp).loss
|
||||
out1.backward()
|
||||
grads1 = [p.grad.clone() for p in model.parameters()]
|
||||
|
||||
# Reset grads
|
||||
for p in model.parameters():
|
||||
p.grad = None
|
||||
|
||||
# With offloading
|
||||
with OffloadActivations():
|
||||
torch.manual_seed(42)
|
||||
out2 = model(inp, labels=inp).loss
|
||||
out2.backward()
|
||||
|
||||
grads2 = [p.grad.clone() for p in model.parameters()]
|
||||
|
||||
# Check outputs and gradients match
|
||||
self.assertTrue(torch.allclose(out1, out2, rtol=1e-5))
|
||||
for g1, g2 in zip(grads1, grads2):
|
||||
self.assertTrue(torch.allclose(g1, g2, rtol=1e-5))
|
@ -13,12 +13,15 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from io import StringIO
|
||||
from unittest.mock import patch
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info < (3, 10),
|
||||
@ -67,6 +70,33 @@ class TestCLI(unittest.TestCase):
|
||||
with patch("sys.argv", command.split(" ")):
|
||||
main()
|
||||
|
||||
def test_sft_config_file(self):
|
||||
from trl.cli import main
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir: # Create a temporary directory
|
||||
output_dir = os.path.join(tmp_dir, "output")
|
||||
|
||||
# Create a temporary config file
|
||||
config_path = os.path.join(tmp_dir, "config.yaml")
|
||||
config_content = {
|
||||
"model_name_or_path": "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
||||
"dataset_name": "trl-internal-testing/zen",
|
||||
"dataset_config": "standard_language_modeling",
|
||||
"report_to": "none",
|
||||
"output_dir": output_dir,
|
||||
"lr_scheduler_type": "cosine_with_restarts",
|
||||
}
|
||||
with open(config_path, "w") as config_file:
|
||||
yaml.dump(config_content, config_file)
|
||||
|
||||
# Test the CLI with config file
|
||||
command = f"trl sft --config {config_path}"
|
||||
with patch("sys.argv", command.split(" ")):
|
||||
main()
|
||||
|
||||
# Verify that output directory was created
|
||||
self.assertTrue(os.path.exists(output_dir))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -163,3 +163,80 @@ class TestTrlParser(unittest.TestCase):
|
||||
self.assertIsInstance(result_args[0], MyDataclass)
|
||||
self.assertEqual(result_args[0].arg1, 2)
|
||||
self.assertEqual(result_args[1], ["--remaining_string_in_config", "abc", "--remaining_string_in_args", "def"])
|
||||
|
||||
@patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value"))
|
||||
@patch("yaml.safe_load")
|
||||
def test_subparsers_with_config_defaults(self, mock_yaml_load):
|
||||
"""Test that config defaults are applied to all subparsers."""
|
||||
mock_yaml_load.return_value = {"arg1": 2, "arg2": "config_value"}
|
||||
|
||||
# Create the main parser
|
||||
parser = TrlParser()
|
||||
|
||||
# Add subparsers
|
||||
subparsers = parser.add_subparsers(dest="command", parser_class=TrlParser)
|
||||
|
||||
# Create a subparser for a specific command
|
||||
subparsers.add_parser("subcommand", dataclass_types=[MyDataclass])
|
||||
|
||||
# Parse with config file
|
||||
args = ["subcommand", "--config", "config.yaml"]
|
||||
result_args = parser.parse_args_and_config(args)
|
||||
|
||||
# Check main parser arguments
|
||||
self.assertEqual(len(result_args), 1)
|
||||
|
||||
# Check that config values were applied to the subparser
|
||||
self.assertEqual(result_args[0].arg1, 2) # Default from config
|
||||
self.assertEqual(result_args[0].arg2, "config_value") # Default from config
|
||||
|
||||
@patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value"))
|
||||
@patch("yaml.safe_load")
|
||||
def test_subparsers_with_config_defaults_and_arg_override(self, mock_yaml_load):
|
||||
"""Test that config defaults are applied to all subparsers."""
|
||||
mock_yaml_load.return_value = {"arg1": 2, "arg2": "config_value"}
|
||||
|
||||
# Create the main parser
|
||||
parser = TrlParser()
|
||||
|
||||
# Add subparsers
|
||||
subparsers = parser.add_subparsers(dest="command", parser_class=TrlParser)
|
||||
|
||||
# Create a subparser for a specific command
|
||||
subparsers.add_parser("subcommand", dataclass_types=[MyDataclass])
|
||||
|
||||
# Test with command line arguments overriding config
|
||||
args = ["subcommand", "--arg1", "3", "--config", "config.yaml"]
|
||||
result_args = parser.parse_args_and_config(args)
|
||||
|
||||
# Command line arguments should override config
|
||||
self.assertEqual(result_args[0].arg1, 3)
|
||||
self.assertEqual(result_args[0].arg2, "config_value") # Still from config
|
||||
|
||||
@patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value"))
|
||||
@patch("yaml.safe_load")
|
||||
def test_subparsers_multiple_with_config_defaults(self, mock_yaml_load):
|
||||
"""Test that config defaults are applied to all subparsers."""
|
||||
mock_yaml_load.return_value = {"arg1": 2, "arg2": "config_value"}
|
||||
|
||||
# Create the main parser
|
||||
parser = TrlParser()
|
||||
|
||||
# Add subparsers
|
||||
subparsers = parser.add_subparsers(dest="command", parser_class=TrlParser)
|
||||
|
||||
# Create a subparser for a specific command
|
||||
subparsers.add_parser("subcommand0", dataclass_types=[MyDataclass])
|
||||
subparsers.add_parser("subcommand1", dataclass_types=[MyDataclass])
|
||||
|
||||
for idx in range(2):
|
||||
# Parse with config file
|
||||
args = [f"subcommand{idx}", "--config", "config.yaml"]
|
||||
result_args = parser.parse_args_and_config(args)
|
||||
|
||||
# Check main parser arguments
|
||||
self.assertEqual(len(result_args), 1)
|
||||
|
||||
# Check that config values were applied to the subparser
|
||||
self.assertEqual(result_args[0].arg1, 2) # Default from config
|
||||
self.assertEqual(result_args[0].arg2, "config_value") # Default from config
|
||||
|
@ -102,6 +102,7 @@ class ApplyChatTemplateTester(unittest.TestCase):
|
||||
"trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"trl-internal-testing/tiny-Phi3ForCausalLM",
|
||||
"trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
||||
"trl-internal-testing/tiny-Qwen3ForCausalLM",
|
||||
]
|
||||
|
||||
conversational_examples = [
|
||||
|
@ -1258,6 +1258,37 @@ class DPOTrainerTester(unittest.TestCase):
|
||||
|
||||
self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0)
|
||||
|
||||
def test_train_with_length_desensitization(self):
|
||||
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = DPOConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
learning_rate=9e-1,
|
||||
ld_alpha=0.5,
|
||||
report_to="none",
|
||||
)
|
||||
trainer = DPOTrainer(
|
||||
model=model_id,
|
||||
args=training_args,
|
||||
processing_class=tokenizer,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# Check that the parameters have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
if param.sum() != 0: # ignore 0 biases
|
||||
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
|
||||
|
||||
|
||||
@require_vision
|
||||
class DPOVisionTrainerTester(unittest.TestCase):
|
||||
@ -1344,13 +1375,15 @@ class DPOVisionTrainerTester(unittest.TestCase):
|
||||
"trl-internal-testing/tiny-LlavaForConditionalGeneration",
|
||||
"trl-internal-testing/tiny-LlavaNextForConditionalGeneration",
|
||||
] and (
|
||||
n.startswith("vision_tower.vision_model.encoder.layers.1")
|
||||
or n == "vision_tower.vision_model.post_layernorm.weight"
|
||||
"vision_tower.vision_model.encoder.layers.1" in n
|
||||
or "vision_tower.vision_model.post_layernorm.weight" in n
|
||||
):
|
||||
# For some reason, these params are not updated. This is probably not related to TRL, but to
|
||||
# the model itself. We should investigate this further, but for now we just skip these params.
|
||||
continue
|
||||
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
|
||||
self.assertFalse(
|
||||
torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -24,7 +24,7 @@ from transformers.testing_utils import require_peft
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
from trl.trainer.grpo_trainer import RepeatSampler
|
||||
from trl.trainer.grpo_trainer import RepeatSampler, shuffle_tensor_dict, split_tensor_dict
|
||||
|
||||
from .testing_utils import require_vllm
|
||||
|
||||
@ -33,6 +33,77 @@ if is_peft_available():
|
||||
from peft import LoraConfig, PeftModel
|
||||
|
||||
|
||||
class SplitTensorDictTester(unittest.TestCase):
|
||||
def test_split_equal_chunks(self):
|
||||
x = torch.arange(12).reshape(6, 2)
|
||||
y = torch.arange(6).reshape(6, 1)
|
||||
tensor_dict = {"x": x, "y": y}
|
||||
|
||||
result = split_tensor_dict(tensor_dict, 3)
|
||||
|
||||
expected_x_chunks = torch.chunk(x, 3, dim=0)
|
||||
expected_y_chunks = torch.chunk(y, 3, dim=0)
|
||||
self.assertEqual(len(result), 3)
|
||||
for i in range(3):
|
||||
self.assertTrue(torch.equal(result[i]["x"], expected_x_chunks[i]))
|
||||
self.assertTrue(torch.equal(result[i]["y"], expected_y_chunks[i]))
|
||||
|
||||
def test_with_none_tensor(self):
|
||||
x = torch.arange(12).reshape(6, 2)
|
||||
tensor_dict = {"x": x, "y": None}
|
||||
|
||||
result = split_tensor_dict(tensor_dict, 2)
|
||||
|
||||
expected_x_chunks = torch.chunk(x, 2, dim=0)
|
||||
self.assertEqual(len(result), 2)
|
||||
for i in range(2):
|
||||
self.assertTrue(torch.equal(result[i]["x"], expected_x_chunks[i]))
|
||||
self.assertIsNone(result[i]["y"])
|
||||
|
||||
|
||||
class ShuffleTensorDictTester(unittest.TestCase):
|
||||
def test_shuffle_preserves_shape(self):
|
||||
x = torch.arange(6).reshape(3, 2)
|
||||
y = torch.arange(3).reshape(3, 1)
|
||||
tensor_dict = {"x": x.clone(), "y": y.clone()}
|
||||
|
||||
shuffled = shuffle_tensor_dict(tensor_dict)
|
||||
|
||||
self.assertEqual(shuffled["x"].shape, x.shape)
|
||||
self.assertEqual(shuffled["y"].shape, y.shape)
|
||||
|
||||
def test_shuffle_consistent_across_tensors(self):
|
||||
# Use known patterns to check alignment
|
||||
x = torch.tensor([[10, 11], [20, 21], [30, 31]])
|
||||
y = torch.tensor([[1], [2], [3]])
|
||||
tensor_dict = {"x": x.clone(), "y": y.clone()}
|
||||
|
||||
shuffled = shuffle_tensor_dict(tensor_dict)
|
||||
|
||||
# Build a reverse map from shuffled x rows to y values
|
||||
for i in range(3):
|
||||
x_row = shuffled["x"][i]
|
||||
y_val = shuffled["y"][i].item()
|
||||
|
||||
if torch.equal(x_row, torch.tensor([10, 11])):
|
||||
self.assertEqual(y_val, 1)
|
||||
elif torch.equal(x_row, torch.tensor([20, 21])):
|
||||
self.assertEqual(y_val, 2)
|
||||
elif torch.equal(x_row, torch.tensor([30, 31])):
|
||||
self.assertEqual(y_val, 3)
|
||||
else:
|
||||
self.fail("Unexpected x row in shuffled output.")
|
||||
|
||||
def test_none_tensor_remains_none(self):
|
||||
x = torch.arange(6).reshape(3, 2)
|
||||
tensor_dict = {"x": x.clone(), "y": None}
|
||||
|
||||
shuffled = shuffle_tensor_dict(tensor_dict)
|
||||
|
||||
self.assertIsNone(shuffled["y"])
|
||||
self.assertEqual(shuffled["x"].shape, x.shape)
|
||||
|
||||
|
||||
class RepeatRandomSamplerTester(unittest.TestCase):
|
||||
def test_sampler(self):
|
||||
dataset = ["a", "b", "c", "d", "e", "f", "g"]
|
||||
@ -1076,3 +1147,34 @@ class GRPOTrainerTester(unittest.TestCase):
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
||||
|
||||
def test_training_delta_clipping(self):
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = GRPOConfig(
|
||||
output_dir=tmp_dir,
|
||||
learning_rate=0.1, # increase the learning rate to speed up the test
|
||||
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
||||
num_generations=3, # reduce the number of generations to reduce memory usage
|
||||
max_completion_length=8, # reduce the completion length to reduce memory usage
|
||||
delta=2.0, # set delta to a non-None value
|
||||
report_to="none",
|
||||
)
|
||||
trainer = GRPOTrainer(
|
||||
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
||||
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# Check that the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
||||
|
@ -160,6 +160,38 @@ class TestNashMDTrainer(unittest.TestCase):
|
||||
# Check if training loss is available
|
||||
self.assertIn("train_loss", trainer.state.log_history[-1])
|
||||
|
||||
@require_peft
|
||||
def test_training_pre_pefted_model_implicit_ref_with_reward_model(self):
|
||||
lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM")
|
||||
# self.model from setUp is a base AutoModelForCausalLM
|
||||
peft_model_instance = get_peft_model(self.model, lora_config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = NashMDConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=1, # Keep small for quick test
|
||||
max_steps=2, # Few steps
|
||||
learning_rate=5.0e-7,
|
||||
eval_strategy="no",
|
||||
report_to="none",
|
||||
remove_unused_columns=False, # Important for the dummy dataset
|
||||
)
|
||||
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")["train"]
|
||||
|
||||
trainer = NashMDTrainer(
|
||||
model=peft_model_instance, # Pass the already PEFT model
|
||||
ref_model=None, # Implicit reference from peft_model_instance's base
|
||||
reward_model=self.reward_model, # To trigger GeometricMixtureWrapper path
|
||||
args=training_args,
|
||||
processing_class=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
# peft_config is not passed, as model is already PEFT
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIn("train_loss", trainer.state.log_history[-1])
|
||||
|
||||
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)])
|
||||
@require_llm_blender
|
||||
def test_nash_md_trainer_judge_training(self, config_name):
|
||||
|
65
tests/test_rewards.py
Normal file
65
tests/test_rewards.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from trl.rewards import think_format_reward
|
||||
|
||||
|
||||
class ThinkFormatRewardTester(unittest.TestCase):
|
||||
def test_valid_format(self):
|
||||
completions = [
|
||||
"<think>This is my reasoning.</think>This is my answer.", # Simple, one-line reasoning
|
||||
"<think>\nThis is my reasoning.\n</think>\nThis is my answer.", # Multiline reasoning
|
||||
"<think>\nThis is\nmy reasoning.\n</think>\nThis is my answer.", # Multiline reasoning
|
||||
"<think>\nThis is <some tag> my reasoning.</think>\nThis is my answer.", # Reasoning including other tags
|
||||
"<think></think>\nThis is my answer.", # Empty reasoning
|
||||
]
|
||||
completions = [[{"content": completion}] for completion in completions]
|
||||
expected_rewards = [1.0, 1.0, 1.0, 1.0, 1.0] # All should be valid
|
||||
rewards = think_format_reward(completions)
|
||||
self.assertEqual(rewards, expected_rewards)
|
||||
|
||||
def test_invalid_format(self):
|
||||
completions = [
|
||||
"<think>\nThis is my reasoning.\nThis is my answer.", # No closing </think>
|
||||
"<think>This is my reasoning.\nThis is my answer.", # No closing </think>
|
||||
"This is my reasoning. This is my answer.", # No <think> tags
|
||||
"This is my reasoning.\nThis is my answer.", # No <think> tags
|
||||
"This is my reasoning.</think>\nThis is my answer.", # No opening <think>
|
||||
"This is my reasoning.</think>This is my answer.", # No opening <think>
|
||||
"This<think>is my reasoning.</think>\nThis is my answer.", # <think> tag in the middle
|
||||
"<think>This is<think>my reasoning.</think></think>This is my answer.", # Nested <think> tags
|
||||
"<think>This is</think>\nmy\n<think>reasoning.</think>\nThis is my answer.", # Multiline <think>
|
||||
]
|
||||
completions = [[{"content": completion}] for completion in completions]
|
||||
expected_rewards = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] # All should be invalid
|
||||
rewards = think_format_reward(completions)
|
||||
self.assertEqual(rewards, expected_rewards)
|
||||
|
||||
def test_mixed_format(self):
|
||||
completions = [
|
||||
"<think>This is my reasoning.</think>This is my answer.", # Valid
|
||||
"<think>\nThis is my reasoning.\n</think>\nThis is my answer.", # Valid
|
||||
"<think>This is my reasoning.\nThis is my answer.", # Invalid
|
||||
"This is my reasoning. This is my answer.", # Invalid
|
||||
]
|
||||
completions = [[{"content": completion}] for completion in completions]
|
||||
expected_rewards = [1.0, 1.0, 0.0, 0.0]
|
||||
rewards = think_format_reward(completions)
|
||||
self.assertEqual(rewards, expected_rewards)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -22,6 +22,8 @@ from transformers import Trainer, TrainingArguments
|
||||
|
||||
from trl.trainer.callbacks import RichProgressCallback
|
||||
|
||||
from .testing_utils import require_rich
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self):
|
||||
@ -32,6 +34,7 @@ class DummyModel(nn.Module):
|
||||
return self.a * x
|
||||
|
||||
|
||||
@require_rich
|
||||
class TestRichProgressCallback(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.dummy_model = DummyModel()
|
||||
|
@ -370,7 +370,6 @@ class TrainerArgTester(unittest.TestCase):
|
||||
packing=True,
|
||||
max_length=256,
|
||||
dataset_num_proc=4,
|
||||
dataset_batch_size=512,
|
||||
neftune_noise_alpha=0.1,
|
||||
model_init_kwargs={"trust_remote_code": True},
|
||||
dataset_kwargs={"append_concat_token": True, "skip_prepare_dataset": True},
|
||||
@ -381,7 +380,6 @@ class TrainerArgTester(unittest.TestCase):
|
||||
self.assertEqual(trainer.args.packing, True)
|
||||
self.assertEqual(trainer.args.max_length, 256)
|
||||
self.assertEqual(trainer.args.dataset_num_proc, 4)
|
||||
self.assertEqual(trainer.args.dataset_batch_size, 512)
|
||||
self.assertEqual(trainer.args.neftune_noise_alpha, 0.1)
|
||||
self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True})
|
||||
self.assertIn("append_concat_token", trainer.args.dataset_kwargs)
|
||||
|
@ -32,6 +32,7 @@ from trl.trainer.utils import (
|
||||
batch_generation,
|
||||
decode_and_strip_padding,
|
||||
flush_left,
|
||||
flush_right,
|
||||
generate_model_card,
|
||||
get_peft_config,
|
||||
pad,
|
||||
@ -39,6 +40,8 @@ from trl.trainer.utils import (
|
||||
selective_log_softmax,
|
||||
)
|
||||
|
||||
from .testing_utils import require_rich
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig
|
||||
@ -95,6 +98,38 @@ class TestPad(unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(torch.equal(output, expected))
|
||||
|
||||
def test_pad_to_multiple_of_1(self):
|
||||
x = torch.tensor([1, 2, 3])
|
||||
y = torch.tensor([4, 5])
|
||||
# Max length is 3, pad to multiple of 4
|
||||
output = pad((x, y), padding_value=0, padding_side="right", pad_to_multiple_of=4)
|
||||
expected = torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]])
|
||||
self.assertTrue(torch.equal(output, expected))
|
||||
|
||||
def test_pad_to_multiple_of_2(self):
|
||||
x = torch.tensor([1, 2, 3, 4, 5])
|
||||
y = torch.tensor([6, 7, 8])
|
||||
# Max length is 3, pad to multiple of 4
|
||||
output = pad((x, y), padding_value=0, padding_side="right", pad_to_multiple_of=4)
|
||||
expected = torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0], [6, 7, 8, 0, 0, 0, 0, 0]])
|
||||
self.assertTrue(torch.equal(output, expected))
|
||||
|
||||
def test_pad_to_multiple_of_side_left(self):
|
||||
x = torch.tensor([1, 2, 3, 4, 5])
|
||||
y = torch.tensor([6, 7, 8])
|
||||
# Max length is 3, pad to multiple of 4
|
||||
output = pad((x, y), padding_value=0, padding_side="left", pad_to_multiple_of=4)
|
||||
expected = torch.tensor([[0, 0, 0, 1, 2, 3, 4, 5], [0, 0, 0, 0, 0, 6, 7, 8]])
|
||||
self.assertTrue(torch.equal(output, expected))
|
||||
|
||||
def test_pad_to_multiple_of_no_extra_padding(self):
|
||||
x = torch.tensor([1, 2, 3, 4])
|
||||
y = torch.tensor([5, 6, 7, 8])
|
||||
# Already multiple of 4
|
||||
output = pad((x, y), padding_value=0, padding_side="left", pad_to_multiple_of=4)
|
||||
expected = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
|
||||
self.assertTrue(torch.equal(output, expected))
|
||||
|
||||
|
||||
@require_peft
|
||||
class TestGetPEFTConfig(unittest.TestCase):
|
||||
@ -440,12 +475,12 @@ class TestFlushLeft(unittest.TestCase):
|
||||
self.assertTrue(torch.equal(new_tensor1, expected_tensor1))
|
||||
|
||||
def test_no_shift_needed(self):
|
||||
mask = torch.tensor([[1, 1, 0, 0], [1, 1, 0, 0]])
|
||||
tensor1 = torch.tensor([[5, 6, 0, 0], [7, 8, 0, 0]])
|
||||
mask = torch.tensor([[1, 1, 0, 0], [1, 0, 0, 0]])
|
||||
tensor1 = torch.tensor([[5, 6, 0, 0], [7, 0, 0, 0]])
|
||||
new_mask, new_tensor1 = flush_left(mask, tensor1)
|
||||
|
||||
expected_mask = torch.tensor([[1, 1], [1, 1]])
|
||||
expected_tensor1 = torch.tensor([[5, 6], [7, 8]])
|
||||
expected_mask = torch.tensor([[1, 1], [1, 0]])
|
||||
expected_tensor1 = torch.tensor([[5, 6], [7, 0]])
|
||||
|
||||
self.assertTrue(torch.equal(new_mask, expected_mask))
|
||||
self.assertTrue(torch.equal(new_tensor1, expected_tensor1))
|
||||
@ -453,9 +488,51 @@ class TestFlushLeft(unittest.TestCase):
|
||||
def test_no_tensors(self):
|
||||
mask = torch.tensor([[0, 0, 1, 1, 1], [0, 1, 1, 0, 0]])
|
||||
new_mask = flush_left(mask)
|
||||
|
||||
expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]])
|
||||
self.assertTrue(torch.equal(new_mask, expected_mask))
|
||||
|
||||
|
||||
class TestFlushRight(unittest.TestCase):
|
||||
def test_basic_case(self):
|
||||
mask = torch.tensor([[1, 1, 1, 0, 0], [0, 0, 1, 1, 0]])
|
||||
tensor1 = torch.tensor([[2, 3, 4, 0, 0], [0, 0, 5, 6, 0]])
|
||||
tensor2 = torch.tensor([[7, 8, 9, 0, 0], [0, 0, 10, 11, 0]])
|
||||
new_mask, new_tensor1, new_tensor2 = flush_right(mask, tensor1, tensor2)
|
||||
|
||||
expected_mask = torch.tensor([[1, 1, 1], [0, 1, 1]])
|
||||
expected_tensor1 = torch.tensor([[2, 3, 4], [0, 5, 6]])
|
||||
expected_tensor2 = torch.tensor([[7, 8, 9], [0, 10, 11]])
|
||||
|
||||
self.assertTrue(torch.equal(new_mask, expected_mask))
|
||||
self.assertTrue(torch.equal(new_tensor1, expected_tensor1))
|
||||
self.assertTrue(torch.equal(new_tensor2, expected_tensor2))
|
||||
|
||||
def test_single_row(self):
|
||||
mask = torch.tensor([[1, 1, 0, 0]])
|
||||
tensor1 = torch.tensor([[2, 3, 0, 0]])
|
||||
new_mask, new_tensor1 = flush_right(mask, tensor1)
|
||||
|
||||
expected_mask = torch.tensor([[1, 1]])
|
||||
expected_tensor1 = torch.tensor([[2, 3]])
|
||||
|
||||
self.assertTrue(torch.equal(new_mask, expected_mask))
|
||||
self.assertTrue(torch.equal(new_tensor1, expected_tensor1))
|
||||
|
||||
def test_no_shift_needed(self):
|
||||
mask = torch.tensor([[0, 0, 1, 1], [0, 0, 0, 1]])
|
||||
tensor1 = torch.tensor([[0, 0, 5, 6], [0, 0, 0, 7]])
|
||||
new_mask, new_tensor1 = flush_right(mask, tensor1)
|
||||
|
||||
expected_mask = torch.tensor([[1, 1], [0, 1]])
|
||||
expected_tensor1 = torch.tensor([[5, 6], [0, 7]])
|
||||
|
||||
self.assertTrue(torch.equal(new_mask, expected_mask))
|
||||
self.assertTrue(torch.equal(new_tensor1, expected_tensor1))
|
||||
|
||||
def test_no_tensors(self):
|
||||
mask = torch.tensor([[1, 1, 1, 0, 0], [0, 0, 1, 1, 0]])
|
||||
new_mask = flush_right(mask)
|
||||
expected_mask = torch.tensor([[1, 1, 1], [0, 1, 1]])
|
||||
self.assertTrue(torch.equal(new_mask, expected_mask))
|
||||
|
||||
|
||||
@ -480,28 +557,30 @@ class TestSelectiveLogSoftmax(unittest.TestCase):
|
||||
torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5)
|
||||
|
||||
|
||||
@require_rich
|
||||
class TestPrintPromptCompletionsSample(unittest.TestCase):
|
||||
@patch("sys.stdout", new_callable=StringIO)
|
||||
def test_print_output(self, mock_stdout):
|
||||
prompts = ["The sky is", "The sun is"]
|
||||
completions = [" blue.", " in the sky."]
|
||||
rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]}
|
||||
advantages = [0.987, 0.654]
|
||||
step = 42
|
||||
|
||||
print_prompt_completions_sample(prompts, completions, rewards, step)
|
||||
print_prompt_completions_sample(prompts, completions, rewards, advantages, step)
|
||||
|
||||
output = mock_stdout.getvalue()
|
||||
|
||||
expected_output = textwrap.dedent("""\
|
||||
╭────────────────────── Step 42 ───────────────────────╮
|
||||
│ ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┓ │
|
||||
│ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ │
|
||||
│ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━┩ │
|
||||
│ │ The sky is │ blue. │ 0.12 │ 0.79 │ │
|
||||
│ ├────────────┼──────────────┼─────────────┼────────┤ │
|
||||
│ │ The sun is │ in the sky. │ 0.46 │ 0.10 │ │
|
||||
│ └────────────┴──────────────┴─────────────┴────────┘ │
|
||||
╰──────────────────────────────────────────────────────╯
|
||||
╭──────────────────────────── Step 42 ─────────────────────────────╮
|
||||
│ ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┓ │
|
||||
│ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ Advantage ┃ │
|
||||
│ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━┩ │
|
||||
│ │ The sky is │ blue. │ 0.12 │ 0.79 │ 0.99 │ │
|
||||
│ ├────────────┼──────────────┼─────────────┼────────┼───────────┤ │
|
||||
│ │ The sun is │ in the sky. │ 0.46 │ 0.10 │ 0.65 │ │
|
||||
│ └────────────┴──────────────┴─────────────┴────────┴───────────┘ │
|
||||
╰──────────────────────────────────────────────────────────────────╯
|
||||
""")
|
||||
self.assertEqual(output, expected_output)
|
||||
|
||||
@ -510,29 +589,30 @@ class TestPrintPromptCompletionsSample(unittest.TestCase):
|
||||
prompts = ["A", "B"]
|
||||
completions = ["1", "2"]
|
||||
rewards = {"Score": [0.1, 0.2]}
|
||||
advantages = [0.3, 0.4]
|
||||
step = 10
|
||||
|
||||
print_prompt_completions_sample(prompts, completions, rewards, step, num_samples=1)
|
||||
print_prompt_completions_sample(prompts, completions, rewards, advantages, step, num_samples=1)
|
||||
output = mock_stdout.getvalue()
|
||||
|
||||
possible_outputs = [
|
||||
textwrap.dedent("""\
|
||||
╭──────────── Step 10 ────────────╮
|
||||
│ ┏━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┓ │
|
||||
│ ┃ Prompt ┃ Completion ┃ Score ┃ │
|
||||
│ ┡━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━┩ │
|
||||
│ │ A │ 1 │ 0.10 │ │
|
||||
│ └────────┴────────────┴───────┘ │
|
||||
╰─────────────────────────────────╯
|
||||
╭────────────────── Step 10 ──────────────────╮
|
||||
│ ┏━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┓ │
|
||||
│ ┃ Prompt ┃ Completion ┃ Score ┃ Advantage ┃ │
|
||||
│ ┡━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━┩ │
|
||||
│ │ A │ 1 │ 0.10 │ 0.30 │ │
|
||||
│ └────────┴────────────┴───────┴───────────┘ │
|
||||
╰─────────────────────────────────────────────╯
|
||||
"""),
|
||||
textwrap.dedent("""\
|
||||
╭──────────── Step 10 ────────────╮
|
||||
│ ┏━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┓ │
|
||||
│ ┃ Prompt ┃ Completion ┃ Score ┃ │
|
||||
│ ┡━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━┩ │
|
||||
│ │ B │ 2 │ 0.20 │ │
|
||||
│ └────────┴────────────┴───────┘ │
|
||||
╰─────────────────────────────────╯
|
||||
╭────────────────── Step 10 ──────────────────╮
|
||||
│ ┏━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┓ │
|
||||
│ ┃ Prompt ┃ Completion ┃ Score ┃ Advantage ┃ │
|
||||
│ ┡━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━┩ │
|
||||
│ │ B │ 2 │ 0.20 │ 0.40 │ │
|
||||
│ └────────┴────────────┴───────┴───────────┘ │
|
||||
╰─────────────────────────────────────────────╯
|
||||
"""),
|
||||
]
|
||||
self.assertIn(output, possible_outputs)
|
||||
|
@ -20,12 +20,12 @@ import unittest
|
||||
import psutil
|
||||
import pytest
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.testing_utils import require_torch_multi_gpu
|
||||
from transformers.testing_utils import require_torch_multi_accelerator, torch_device
|
||||
|
||||
from trl.extras.vllm_client import VLLMClient
|
||||
from trl.scripts.vllm_serve import chunk_list
|
||||
|
||||
from .testing_utils import require_3_gpus
|
||||
from .testing_utils import require_3_accelerators
|
||||
|
||||
|
||||
class TestChunkList(unittest.TestCase):
|
||||
@ -55,15 +55,16 @@ class TestChunkList(unittest.TestCase):
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@require_torch_multi_gpu
|
||||
@require_torch_multi_accelerator
|
||||
class TestVLLMClientServer(unittest.TestCase):
|
||||
model_id = "Qwen/Qwen2.5-1.5B"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# We want the server to run on GPU 1, so we set CUDA_VISIBLE_DEVICES to "1"
|
||||
# We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1"
|
||||
env = os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = "1" # Restrict to GPU 1
|
||||
VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
|
||||
env[VISIBLE_DEVICES] = "1" # Restrict to accelerator 1
|
||||
|
||||
# Start the server process
|
||||
cls.server_process = subprocess.Popen(
|
||||
@ -107,7 +108,86 @@ class TestVLLMClientServer(unittest.TestCase):
|
||||
self.assertLessEqual(len(seq), 32)
|
||||
|
||||
def test_update_model_params(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda")
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
|
||||
self.client.update_model_params(model)
|
||||
|
||||
def test_reset_prefix_cache(self):
|
||||
# Test resetting the prefix cache
|
||||
self.client.reset_prefix_cache()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super().tearDownClass()
|
||||
|
||||
# Close the client
|
||||
cls.client.close_communicator()
|
||||
|
||||
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
|
||||
# kill the server process and its children explicitly.
|
||||
parent = psutil.Process(cls.server_process.pid)
|
||||
children = parent.children(recursive=True)
|
||||
for child in children:
|
||||
child.send_signal(signal.SIGTERM)
|
||||
cls.server_process.terminate()
|
||||
cls.server_process.wait()
|
||||
|
||||
|
||||
# Same as above but using base_url to instantiate the client.
|
||||
@pytest.mark.slow
|
||||
@require_torch_multi_accelerator
|
||||
class TestVLLMClientServerBaseURL(unittest.TestCase):
|
||||
model_id = "Qwen/Qwen2.5-1.5B"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1"
|
||||
env = os.environ.copy()
|
||||
VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
|
||||
env[VISIBLE_DEVICES] = "1" # Restrict to accelerator 1
|
||||
|
||||
# Start the server process
|
||||
cls.server_process = subprocess.Popen(
|
||||
["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
|
||||
)
|
||||
|
||||
# Initialize the client
|
||||
cls.client = VLLMClient(base_url="http://localhost:8000", connection_timeout=240)
|
||||
cls.client.init_communicator()
|
||||
|
||||
def test_generate(self):
|
||||
prompts = ["Hello, AI!", "Tell me a joke"]
|
||||
outputs = self.client.generate(prompts)
|
||||
|
||||
# Check that the output is a list
|
||||
self.assertIsInstance(outputs, list)
|
||||
|
||||
# Check that the number of generated sequences is equal to the number of prompts
|
||||
self.assertEqual(len(outputs), len(prompts))
|
||||
|
||||
# Check that the generated sequences are lists of integers
|
||||
for seq in outputs:
|
||||
self.assertTrue(all(isinstance(tok, int) for tok in seq))
|
||||
|
||||
def test_generate_with_params(self):
|
||||
prompts = ["Hello, AI!", "Tell me a joke"]
|
||||
outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)
|
||||
|
||||
# Check that the output is a list
|
||||
self.assertIsInstance(outputs, list)
|
||||
|
||||
# Check that the number of generated sequences is 2 times the number of prompts
|
||||
self.assertEqual(len(outputs), 2 * len(prompts))
|
||||
|
||||
# Check that the generated sequences are lists of integers
|
||||
for seq in outputs:
|
||||
self.assertTrue(all(isinstance(tok, int) for tok in seq))
|
||||
|
||||
# Check that the length of the generated sequences is less than or equal to 32
|
||||
for seq in outputs:
|
||||
self.assertLessEqual(len(seq), 32)
|
||||
|
||||
def test_update_model_params(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
|
||||
self.client.update_model_params(model)
|
||||
|
||||
def test_reset_prefix_cache(self):
|
||||
@ -132,15 +212,16 @@ class TestVLLMClientServer(unittest.TestCase):
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@require_3_gpus
|
||||
@require_3_accelerators
|
||||
class TestVLLMClientServerTP(unittest.TestCase):
|
||||
model_id = "Qwen/Qwen2.5-1.5B"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# We want the server to run on GPU 1 and 2, so we set CUDA_VISIBLE_DEVICES to "1,2"
|
||||
# We want the server to run on accelerator 1 and 2, so we set VISIBLE_DEVICES to "1,2"
|
||||
env = os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = "1,2" # Restrict to GPU 1 and 2
|
||||
VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
|
||||
env[VISIBLE_DEVICES] = "1,2" # Restrict to accelerator 1 and 2
|
||||
|
||||
# Start the server process
|
||||
cls.server_process = subprocess.Popen(
|
||||
@ -169,7 +250,7 @@ class TestVLLMClientServerTP(unittest.TestCase):
|
||||
self.assertTrue(all(isinstance(tok, int) for tok in seq))
|
||||
|
||||
def test_update_model_params(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda")
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
|
||||
self.client.update_model_params(model)
|
||||
|
||||
def test_reset_prefix_cache(self):
|
||||
@ -194,15 +275,16 @@ class TestVLLMClientServerTP(unittest.TestCase):
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@require_3_gpus
|
||||
@require_3_accelerators
|
||||
class TestVLLMClientServerDP(unittest.TestCase):
|
||||
model_id = "Qwen/Qwen2.5-1.5B"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# We want the server to run on GPU 1 and 2, so we set CUDA_VISIBLE_DEVICES to "1,2"
|
||||
# We want the server to run on accelerator 1 and 2, so we set VISIBLE_DEVICES to "1,2"
|
||||
env = os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = "1,2" # Restrict to GPU 1 and 2
|
||||
VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES"
|
||||
env[VISIBLE_DEVICES] = "1,2" # Restrict to accelerator 1 and 2
|
||||
|
||||
# Start the server process
|
||||
cls.server_process = subprocess.Popen(
|
||||
@ -230,7 +312,7 @@ class TestVLLMClientServerDP(unittest.TestCase):
|
||||
self.assertTrue(all(isinstance(tok, int) for tok in seq))
|
||||
|
||||
def test_update_model_params(self):
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map="cuda")
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device)
|
||||
self.client.update_model_params(model)
|
||||
|
||||
def test_reset_prefix_cache(self):
|
||||
|
@ -160,6 +160,36 @@ class TestXPOTrainer(unittest.TestCase):
|
||||
# Check if training loss is available
|
||||
self.assertIn("train_loss", trainer.state.log_history[-1])
|
||||
|
||||
@require_peft
|
||||
def test_training_pre_pefted_model_implicit_ref(self):
|
||||
lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM")
|
||||
peft_model_instance = get_peft_model(self.model, lora_config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = XPOConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=1,
|
||||
max_steps=2,
|
||||
learning_rate=5.0e-7,
|
||||
eval_strategy="no",
|
||||
report_to="none",
|
||||
remove_unused_columns=False,
|
||||
)
|
||||
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")["train"]
|
||||
|
||||
trainer = XPOTrainer(
|
||||
model=peft_model_instance,
|
||||
ref_model=None,
|
||||
reward_model=self.reward_model, # Using reward_model to ensure _generate_completions is used as expected
|
||||
args=training_args,
|
||||
processing_class=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIn("train_loss", trainer.state.log_history[-1])
|
||||
|
||||
@require_llm_blender
|
||||
@parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)])
|
||||
def test_xpo_trainer_judge_training(self, config_name):
|
||||
|
@ -17,6 +17,8 @@ import unittest
|
||||
|
||||
import torch
|
||||
from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available
|
||||
from transformers.testing_utils import torch_device
|
||||
from transformers.utils import is_rich_available
|
||||
|
||||
from trl import BaseBinaryJudge, BasePairwiseJudge
|
||||
from trl.import_utils import (
|
||||
@ -65,6 +67,13 @@ def require_mergekit(test_case):
|
||||
return unittest.skipUnless(is_mergekit_available(), "test requires mergekit")(test_case)
|
||||
|
||||
|
||||
def require_rich(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires rich. Skips the test if rich is not available.
|
||||
"""
|
||||
return unittest.skipUnless(is_rich_available(), "test requires rich")(test_case)
|
||||
|
||||
|
||||
def require_sklearn(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires sklearn. Skips the test if sklearn is not available.
|
||||
@ -86,11 +95,14 @@ def require_no_wandb(test_case):
|
||||
return unittest.skipUnless(not is_wandb_available(), "test requires no wandb")(test_case)
|
||||
|
||||
|
||||
def require_3_gpus(test_case):
|
||||
def require_3_accelerators(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires at least num_gpus GPUs. Skips the test if num_gpus is not available.
|
||||
Decorator marking a test that requires at least 3 accelerators. Skips the test if 3 accelerators are not available.
|
||||
"""
|
||||
return unittest.skipUnless(torch.cuda.device_count() > 3, "test requires at least 3 GPUs")(test_case)
|
||||
torch_accelerator_module = getattr(torch, torch_device, torch.cuda)
|
||||
return unittest.skipUnless(
|
||||
torch_accelerator_module.device_count() > 3, f"test requires at least 3 {torch_device}s"
|
||||
)(test_case)
|
||||
|
||||
|
||||
class RandomBinaryJudge(BaseBinaryJudge):
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__version__ = "0.17.0"
|
||||
__version__ = "0.18.1"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@ -66,6 +66,7 @@ _import_structure = {
|
||||
"GRPOConfig",
|
||||
"GRPOTrainer",
|
||||
"HfPairwiseJudge",
|
||||
"IterativeSFTConfig",
|
||||
"IterativeSFTTrainer",
|
||||
"KTOConfig",
|
||||
"KTOTrainer",
|
||||
@ -161,6 +162,7 @@ if TYPE_CHECKING:
|
||||
GRPOConfig,
|
||||
GRPOTrainer,
|
||||
HfPairwiseJudge,
|
||||
IterativeSFTConfig,
|
||||
IterativeSFTTrainer,
|
||||
KTOConfig,
|
||||
KTOTrainer,
|
||||
|
28
trl/accelerate_configs/fsdp1.yaml
Normal file
28
trl/accelerate_configs/fsdp1.yaml
Normal file
@ -0,0 +1,28 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
fsdp_config:
|
||||
fsdp_activation_checkpointing: false
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_forward_prefetch: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: FULL_SHARD
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_use_orig_params: true
|
||||
fsdp_version: 1
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
25
trl/accelerate_configs/fsdp2.yaml
Normal file
25
trl/accelerate_configs/fsdp2.yaml
Normal file
@ -0,0 +1,25 @@
|
||||
# Requires accelerate 1.7.0 or higher
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
fsdp_config:
|
||||
fsdp_activation_checkpointing: false
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_version: 2
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
16
trl/accelerate_configs/multi_gpu.yaml
Normal file
16
trl/accelerate_configs/multi_gpu.yaml
Normal file
@ -0,0 +1,16 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
16
trl/accelerate_configs/single_gpu.yaml
Normal file
16
trl/accelerate_configs/single_gpu.yaml
Normal file
@ -0,0 +1,16 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: "NO"
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
20
trl/accelerate_configs/zero1.yaml
Normal file
20
trl/accelerate_configs/zero1.yaml
Normal file
@ -0,0 +1,20 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
gradient_accumulation_steps: 1
|
||||
zero3_init_flag: false
|
||||
zero_stage: 1
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
21
trl/accelerate_configs/zero2.yaml
Normal file
21
trl/accelerate_configs/zero2.yaml
Normal file
@ -0,0 +1,21 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
22
trl/accelerate_configs/zero3.yaml
Normal file
22
trl/accelerate_configs/zero3.yaml
Normal file
@ -0,0 +1,22 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: true
|
||||
zero3_save_16bit_model: true
|
||||
zero_stage: 3
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
58
trl/cli.py
58
trl/cli.py
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib.resources as resources
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
@ -45,8 +46,37 @@ def main():
|
||||
make_sft_parser(subparsers)
|
||||
make_vllm_serve_parser(subparsers)
|
||||
|
||||
# Parse the arguments
|
||||
args = parser.parse_args()
|
||||
# Parse the arguments; the remaining ones (`launch_args`) are passed to the 'accelerate launch' subparser.
|
||||
# Duplicates may occur if the same argument is provided in both the config file and CLI.
|
||||
# For example: launch_args = `["--num_processes", "4", "--num_processes", "8"]`.
|
||||
# Deduplication and precedence (CLI over config) are handled later by launch_command_parser.
|
||||
args, launch_args = parser.parse_args_and_config(return_remaining_strings=True)
|
||||
|
||||
# Replace `--accelerate_config foo` with `--config_file trl/accelerate_configs/foo.yaml` if it is present in the
|
||||
# launch_args. It allows the user to use predefined accelerate configs from the `trl` package.
|
||||
if "--accelerate_config" in launch_args:
|
||||
# Get the index of the '--accelerate_config' argument and the corresponding config name
|
||||
config_index = launch_args.index("--accelerate_config")
|
||||
config_name = launch_args[config_index + 1]
|
||||
|
||||
# If the config_name correspond to a path in the filesystem, we don't want to override it
|
||||
if os.path.isfile(config_name):
|
||||
accelerate_config_path = config_name
|
||||
elif resources.files("trl.accelerate_configs").joinpath(f"{config_name}.yaml").exists():
|
||||
# Get the predefined accelerate config path from the package resources
|
||||
accelerate_config_path = resources.files("trl.accelerate_configs").joinpath(f"{config_name}.yaml")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Accelerate config {config_name} is neither a file nor a valid config in the `trl` package. "
|
||||
"Please provide a valid config name or a path to a config file."
|
||||
)
|
||||
|
||||
# Remove '--accelerate_config' and its corresponding config name
|
||||
launch_args.pop(config_index)
|
||||
launch_args.pop(config_index)
|
||||
|
||||
# Insert '--config_file' and the absolute path to the front of the list
|
||||
launch_args = ["--config_file", str(accelerate_config_path)] + launch_args
|
||||
|
||||
if args.command == "chat":
|
||||
(chat_args,) = parser.parse_args_and_config()
|
||||
@ -54,8 +84,8 @@ def main():
|
||||
|
||||
if args.command == "dpo":
|
||||
# Get the default args for the launch command
|
||||
dpo_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "dpo.py")
|
||||
args = launch_command_parser().parse_args([dpo_training_script])
|
||||
dpo_training_script = resources.files("trl.scripts").joinpath("dpo.py")
|
||||
args = launch_command_parser().parse_args([str(dpo_training_script)])
|
||||
|
||||
# Feed the args to the launch command
|
||||
args.training_script_args = sys.argv[2:] # remove "trl" and "dpo"
|
||||
@ -66,8 +96,8 @@ def main():
|
||||
|
||||
elif args.command == "grpo":
|
||||
# Get the default args for the launch command
|
||||
grpo_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "grpo.py")
|
||||
args = launch_command_parser().parse_args([grpo_training_script])
|
||||
grpo_training_script = resources.files("trl.scripts").joinpath("grpo.py")
|
||||
args = launch_command_parser().parse_args([str(grpo_training_script)])
|
||||
|
||||
# Feed the args to the launch command
|
||||
args.training_script_args = sys.argv[2:] # remove "trl" and "grpo"
|
||||
@ -75,20 +105,22 @@ def main():
|
||||
|
||||
elif args.command == "kto":
|
||||
# Get the default args for the launch command
|
||||
kto_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "kto.py")
|
||||
args = launch_command_parser().parse_args([kto_training_script])
|
||||
kto_training_script = resources.files("trl.scripts").joinpath("kto.py")
|
||||
args = launch_command_parser().parse_args([str(kto_training_script)])
|
||||
|
||||
# Feed the args to the launch command
|
||||
args.training_script_args = sys.argv[2:] # remove "trl" and "kto"
|
||||
launch_command(args) # launch training
|
||||
|
||||
elif args.command == "sft":
|
||||
# Get the default args for the launch command
|
||||
sft_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "sft.py")
|
||||
args = launch_command_parser().parse_args([sft_training_script])
|
||||
# Get the path to the training script
|
||||
sft_training_script = resources.files("trl.scripts").joinpath("sft.py")
|
||||
|
||||
# Feed the args to the launch command
|
||||
args.training_script_args = sys.argv[2:] # remove "trl" and "sft"
|
||||
# This simulates running: `accelerate launch <launch args> sft.py <training script args>`.
|
||||
# Note that the training script args may include launch-related arguments (e.g., `--num_processes`),
|
||||
# but we rely on the script to ignore any that don't apply to it.
|
||||
training_script_args = sys.argv[2:] # Remove "trl" and "sft"
|
||||
args = launch_command_parser().parse_args(launch_args + [str(sft_training_script)] + training_script_args)
|
||||
launch_command(args) # launch training
|
||||
|
||||
elif args.command == "vllm-serve":
|
||||
|
@ -13,13 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate.utils import extract_model_from_parallel
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
from ..import_utils import is_rich_available
|
||||
from transformers.utils import is_rich_available
|
||||
|
||||
|
||||
if is_rich_available():
|
||||
@ -241,6 +241,10 @@ class TextEnvironment:
|
||||
max_length (Optional[int]): The maximum number of tokens to allow in an episode.
|
||||
generation_kwargs (Optional[dict]): A dictionary of keyword arguments to pass to the model's generate method.
|
||||
"""
|
||||
warnings.warn(
|
||||
"This class is deprecated and will be removed in version 0.21.0. To enable tool use with LLMs, check out smolagents (https://huggingface.co/docs/smolagents/index)",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.prompt = prompt
|
||||
|
@ -17,17 +17,22 @@ import functools
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
||||
from transformers import Trainer, is_wandb_available
|
||||
from transformers import Trainer
|
||||
from transformers.integrations import is_mlflow_available, is_wandb_available
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
if is_mlflow_available():
|
||||
import mlflow
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None]:
|
||||
"""
|
||||
A context manager function for profiling a block of code. Results are logged to Weights & Biases if enabled.
|
||||
A context manager function for profiling a block of code. Results are logged to Weights & Biases or MLflow
|
||||
depending on the trainer's configuration.
|
||||
|
||||
Args:
|
||||
trainer (`~transformers.Trainer`):
|
||||
@ -54,8 +59,12 @@ def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
profiling_metrics = {f"profiling/Time taken: {trainer.__class__.__name__}.{name}": duration}
|
||||
if "wandb" in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process:
|
||||
wandb.log({f"profiling/Time taken: {trainer.__class__.__name__}.{name}": duration})
|
||||
wandb.log(profiling_metrics)
|
||||
|
||||
if "mlflow" in trainer.args.report_to and mlflow.run is not None and trainer.accelerator.is_main_process:
|
||||
mlflow.log_metrics(profiling_metrics, step=trainer.state.global_step)
|
||||
|
||||
|
||||
def profiling_decorator(func: callable) -> callable:
|
||||
|
@ -14,8 +14,10 @@
|
||||
|
||||
import atexit
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -47,10 +49,13 @@ class VLLMClient:
|
||||
weights in a distributed setting. Before using it, start the vLLM server with `trl vllm-serve`.
|
||||
|
||||
Args:
|
||||
base_url (`str` or `None`, *optional*, defaults to `None`):
|
||||
Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `host` and `server_port` are
|
||||
ignored.
|
||||
host (`str`, *optional*, defaults to `"0.0.0.0"`):
|
||||
IP address of the vLLM server.
|
||||
IP address of the vLLM server. Ignored if `base_url` is provided.
|
||||
server_port (`int`, *optional*, defaults to `8000`):
|
||||
Port number of the vLLM server.
|
||||
Port number of the vLLM server. Ignored if `base_url` is provided.
|
||||
group_port (`int`, *optional*, defaults to `51216`):
|
||||
Port number for the weight update group.
|
||||
connection_timeout (`float`, *optional*, defaults to `0.0`):
|
||||
@ -81,10 +86,24 @@ class VLLMClient:
|
||||
>>> client.init_communicator()
|
||||
>>> client.update_model_params(model)
|
||||
```
|
||||
|
||||
There are several ways to initialize the client:
|
||||
|
||||
```python
|
||||
VLLMClient(base_url="http://localhost:8000")
|
||||
VLLMClient(base_url="http://192.168.1.100:8000")
|
||||
VLLMClient(host="localhost", server_port=8000)
|
||||
VLLMClient(host="192.168.1.100", server_port=8000)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, host: str = "0.0.0.0", server_port: int = 8000, group_port: int = 51216, connection_timeout: float = 0.0
|
||||
self,
|
||||
base_url: Optional[str] = None,
|
||||
host: str = "0.0.0.0",
|
||||
server_port: int = 8000,
|
||||
group_port: int = 51216,
|
||||
connection_timeout: float = 0.0,
|
||||
):
|
||||
if not is_requests_available():
|
||||
raise ImportError("requests is not installed. Please install it with `pip install requests`.")
|
||||
@ -92,8 +111,17 @@ class VLLMClient:
|
||||
raise ImportError("vLLM is not installed. Please install it with `pip install vllm`.")
|
||||
|
||||
self.session = requests.Session()
|
||||
self.host = host
|
||||
self.server_port = server_port
|
||||
|
||||
if base_url is not None:
|
||||
# Parse the base_url to extract host and port
|
||||
parsed_url = urlparse(base_url)
|
||||
self.host = socket.gethostbyname(parsed_url.hostname)
|
||||
scheme = parsed_url.scheme or "http"
|
||||
self.base_url = f"{scheme}://{parsed_url.netloc}{parsed_url.path}"
|
||||
else:
|
||||
self.host = host
|
||||
self.server_port = server_port
|
||||
self.base_url = f"http://{self.host}:{self.server_port}"
|
||||
self.group_port = group_port
|
||||
self.check_server(connection_timeout) # check server and fail after timeout
|
||||
|
||||
@ -108,7 +136,7 @@ class VLLMClient:
|
||||
total_timeout (`float`, *optional*, defaults to `0.0`):
|
||||
Total timeout duration in seconds.
|
||||
"""
|
||||
url = f"http://{self.host}:{self.server_port}/health/"
|
||||
url = f"{self.base_url}/health/"
|
||||
start_time = time.time() # Record the start time
|
||||
|
||||
while True:
|
||||
@ -119,11 +147,13 @@ class VLLMClient:
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time >= total_timeout:
|
||||
raise ConnectionError(
|
||||
f"The vLLM server can't be reached at {self.host}:{self.server_port} after {total_timeout} "
|
||||
"seconds. Make sure the server is running by running `trl vllm-serve`."
|
||||
f"The vLLM server can't be reached at {self.base_url} after {total_timeout} seconds. Make "
|
||||
"sure the server is running by running `trl vllm-serve`."
|
||||
) from exc
|
||||
else:
|
||||
if response.status_code == 200:
|
||||
if "X-Forwarded-For" in response.headers:
|
||||
self.host = response.headers["X-Forwarded-For"]
|
||||
logger.info("Server is up!")
|
||||
return None
|
||||
|
||||
@ -170,7 +200,7 @@ class VLLMClient:
|
||||
`list[list[int]]`:
|
||||
List of lists of token IDs representing the model-generated completions for each prompt.
|
||||
"""
|
||||
url = f"http://{self.host}:{self.server_port}/generate/"
|
||||
url = f"{self.base_url}/generate/"
|
||||
response = self.session.post(
|
||||
url,
|
||||
json={
|
||||
@ -195,7 +225,7 @@ class VLLMClient:
|
||||
Initializes the weight update group in a distributed setup for model synchronization.
|
||||
"""
|
||||
# Get the world size from the server
|
||||
url = f"http://{self.host}:{self.server_port}/get_world_size/"
|
||||
url = f"{self.base_url}/get_world_size/"
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
vllm_world_size = response.json()["world_size"]
|
||||
@ -206,7 +236,7 @@ class VLLMClient:
|
||||
self.rank = vllm_world_size # the client's rank is the last process
|
||||
|
||||
# Initialize weight update group
|
||||
url = f"http://{self.host}:{self.server_port}/init_communicator/"
|
||||
url = f"{self.base_url}/init_communicator/"
|
||||
# In the server side, the host is set to 0.0.0.0
|
||||
response = self.session.post(url, json={"host": "0.0.0.0", "port": self.group_port, "world_size": world_size})
|
||||
if response.status_code != 200:
|
||||
@ -235,7 +265,7 @@ class VLLMClient:
|
||||
Tensor containing the updated weights.
|
||||
"""
|
||||
dtype, shape = str(weights.dtype), tuple(weights.shape)
|
||||
url = f"http://{self.host}:{self.server_port}/update_named_param/"
|
||||
url = f"{self.base_url}/update_named_param/"
|
||||
response = self.session.post(url, json={"name": name, "dtype": dtype, "shape": shape})
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
||||
@ -260,7 +290,7 @@ class VLLMClient:
|
||||
"""
|
||||
Resets the prefix cache for the model.
|
||||
"""
|
||||
url = f"http://{self.host}:{self.server_port}/reset_prefix_cache/"
|
||||
url = f"{self.base_url}/reset_prefix_cache/"
|
||||
response = self.session.post(url)
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Request failed: {response.status_code}, {response.text}")
|
||||
@ -269,7 +299,7 @@ class VLLMClient:
|
||||
"""
|
||||
Closes the weight update group and cleans up the communication group.
|
||||
"""
|
||||
url = f"http://{self.host}:{self.server_port}/close_communicator/"
|
||||
url = f"{self.base_url}/close_communicator/"
|
||||
|
||||
try:
|
||||
response = self.session.post(url)
|
||||
|
@ -22,7 +22,7 @@ from packaging import version
|
||||
from transformers.utils.import_utils import _is_package_available
|
||||
|
||||
|
||||
LIGER_KERNEL_MIN_VERSION = "0.5.6"
|
||||
LIGER_KERNEL_MIN_VERSION = "0.5.8"
|
||||
|
||||
# Use same as transformers.utils.import_utils
|
||||
_deepspeed_available = _is_package_available("deepspeed")
|
||||
@ -33,7 +33,6 @@ _llm_blender_available = _is_package_available("llm_blender")
|
||||
_mergekit_available = _is_package_available("mergekit")
|
||||
_pydantic_available = _is_package_available("pydantic")
|
||||
_requests_available = _is_package_available("requests")
|
||||
_rich_available = _is_package_available("rich")
|
||||
_unsloth_available = _is_package_available("unsloth")
|
||||
_uvicorn_available = _is_package_available("uvicorn")
|
||||
_vllm_available = _is_package_available("vllm")
|
||||
@ -73,10 +72,6 @@ def is_requests_available() -> bool:
|
||||
return _requests_available
|
||||
|
||||
|
||||
def is_rich_available() -> bool:
|
||||
return _rich_available
|
||||
|
||||
|
||||
def is_unsloth_available() -> bool:
|
||||
return _unsloth_available
|
||||
|
||||
|
@ -18,6 +18,7 @@ from ..import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffu
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"activation_offloading": ["get_act_offloading_ctx_manager"],
|
||||
"modeling_base": ["GeometricMixtureWrapper", "PreTrainedModelWrapper", "create_reference_model"],
|
||||
"modeling_value_head": ["AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead"],
|
||||
"utils": [
|
||||
@ -43,6 +44,7 @@ else:
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .activation_offloading import get_act_offloading_ctx_manager
|
||||
from .modeling_base import GeometricMixtureWrapper, PreTrainedModelWrapper, create_reference_model
|
||||
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
|
||||
from .utils import (
|
||||
|
462
trl/models/activation_offloading.py
Normal file
462
trl/models/activation_offloading.py
Normal file
@ -0,0 +1,462 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of https://github.com/pytorch/torchtune.
|
||||
|
||||
import warnings
|
||||
|
||||
import psutil
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.autograd.graph import saved_tensors_hooks
|
||||
|
||||
|
||||
class OffloadActivations(saved_tensors_hooks):
|
||||
"""
|
||||
Context manager under which activation tensors created in the forward pass will be offloaded.
|
||||
|
||||
Enable the memory efficiency technique of activation offloading, where activations bigger than `min_offload_size`
|
||||
bytes will be offloaded to CPU in the forward and brought back in the backward. This is in contrast to maintaining
|
||||
the activation on GPU VRAM throughout the program.
|
||||
|
||||
This manager contains the option of using one additional CUDA stream to handle the communication between CUDA and
|
||||
CPU, which is intended to overlap with the default computation stream to improve runtime. We designed
|
||||
synchronization with a few heuristics for optimizing the tradeoff between runtime vs memory usage.
|
||||
|
||||
Args:
|
||||
use_pin_memory (`bool`, *optional*, defaults to `True`):
|
||||
Whether to offloaded Tensor will be placed in pinned memory on the CPU. Pinned memory allows the Tensor to
|
||||
be moved back onto GPU more quickly but is a limited resource.
|
||||
use_streams (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use streams for performance optimization where the communications get overlapped with the
|
||||
computation. Requires a torch build after torch-2.5.0.
|
||||
min_offload_size (`int`, *optional*, defaults to `1024`):
|
||||
Minimum number of bytes a Tensor must be in order to qualify for offloading. If the tensor is too small, we
|
||||
do not want to waste bandwidth and resources moving it to CPU and back.
|
||||
max_fwd_stash_size (`int`, *optional*, defaults to `5`):
|
||||
Maximum size of the forward stash, or the maximum number of consecutive activations to keep alive during
|
||||
the forward pass. This number must be at least 1. Keeping alive more activations will potentially allow
|
||||
more overlap between the communication and compute streams at the cost of increasing memory usage. Keeping
|
||||
alive fewer activations will conserve memory, but may cause poor overlap between the streams, increasing
|
||||
runtime.
|
||||
|
||||
Raises:
|
||||
ValueError: if `max_fwd_stash_size` is not at least `1`.
|
||||
|
||||
Example:
|
||||
>>> with OffloadActivations():
|
||||
>>> outputs = model(inputs, labels=labels)
|
||||
>>> loss = outputs.loss
|
||||
>>> loss.backward()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_pin_memory: bool = True,
|
||||
use_streams: bool = True,
|
||||
min_offload_size: int = 1024,
|
||||
max_fwd_stash_size: int = 5,
|
||||
) -> None:
|
||||
self.use_streams = use_streams
|
||||
|
||||
self.min_tensor_size_bytes = min_offload_size # we don't want to bother with small tensors
|
||||
self.tracker = {} # tensor_id => (new_tensor, if_modified) ---> track what saved/offloaded tensors are where
|
||||
self.tensor_id = 0
|
||||
self.is_first_forward_call = True
|
||||
self.is_first_backward_call = True
|
||||
self.is_first_forward_pass = True
|
||||
|
||||
# Managing cpu memory
|
||||
self.use_pin_memory = use_pin_memory
|
||||
self.virtual_memory_safe_pct = 60 # we should not exceed this percentage of memory
|
||||
|
||||
self.accelerator_type = (
|
||||
torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
|
||||
)
|
||||
# NOTE: xpu doesn't have `default_stream` API, use `current_stream` instead
|
||||
self.s0 = (
|
||||
torch.xpu.current_stream() if self.accelerator_type == "xpu" else torch.cuda.default_stream()
|
||||
) # comp stream
|
||||
|
||||
# For streaming
|
||||
if self.use_streams:
|
||||
self.s1 = torch.Stream() if self.accelerator_type == "xpu" else torch.cuda.Stream() # comms stream
|
||||
self.fwd_stash = {} # tensor_id => (activation, ev1)
|
||||
if max_fwd_stash_size < 1:
|
||||
raise ValueError(f"max_fwd_stash_size should be at least 1 but is {max_fwd_stash_size}")
|
||||
self.max_fwd_stash_size = max_fwd_stash_size
|
||||
self.bwd_tensor_stash = {} # tensor_id => activation
|
||||
self.bwd_ev_stash = {} # tensor_id => ev0
|
||||
self.curr_graph_id = None
|
||||
self.curr_autograd_node = None
|
||||
|
||||
# -------- platform util functions -------- #
|
||||
def verify_sufficient_virtual_memory():
|
||||
curr_pct = get_cpu_ram_pct()
|
||||
if curr_pct > self.virtual_memory_safe_pct:
|
||||
warnings.warn(f"{curr_pct=}% > {self.virtual_memory_safe_pct=}% of virtual memory used")
|
||||
|
||||
def get_cpu_ram_pct() -> float:
|
||||
# get the percentage of memory used by the system
|
||||
return psutil.virtual_memory().percent
|
||||
|
||||
def get_tensor_id() -> int:
|
||||
# create a unique id for each tensor we are managing
|
||||
self.tensor_id += 1
|
||||
return self.tensor_id
|
||||
|
||||
def get_num_bytes_tensor(x: torch.Tensor) -> int:
|
||||
# get the number of bytes in a tensor, for memory management purposes
|
||||
return x.element_size() * x.nelement() # x.element_size() * x._base_storage().nbytes()
|
||||
|
||||
# -------- core pack / unpack work -------- #
|
||||
def pack_tensor(activation: torch.Tensor) -> int:
|
||||
# activations are passed in during forward pass - from here we take over and return a unique id
|
||||
if self.is_first_forward_call:
|
||||
if len(self.tracker) != 0:
|
||||
raise ValueError("Backward pass should have cleared tracker of all tensors")
|
||||
|
||||
# set training phase trackers
|
||||
self.is_first_forward_call = False
|
||||
self.is_first_backward_call = True
|
||||
|
||||
# query for basic tensor info
|
||||
num_bytes = get_num_bytes_tensor(activation)
|
||||
tensor_id = get_tensor_id()
|
||||
|
||||
# only offload hefty bois if they're activations on CUDA (our heuristic
|
||||
# for that is to check if they're not params or buffers)!
|
||||
if (
|
||||
activation.device.type in ["cuda", "xpu"]
|
||||
and num_bytes >= self.min_tensor_size_bytes
|
||||
and (
|
||||
not isinstance(activation, torch.nn.Parameter)
|
||||
and not (hasattr(torch.nn, "Buffer") and isinstance(activation, torch.nn.Buffer))
|
||||
)
|
||||
):
|
||||
if self.use_streams:
|
||||
# First, sync back and dereference previously offloaded tensors
|
||||
# as the offloading should be done sufficiently long ago.
|
||||
for id in list(self.fwd_stash.keys()):
|
||||
if id <= tensor_id - self.max_fwd_stash_size:
|
||||
_, ev = self.fwd_stash[id]
|
||||
self.s0.wait_event(ev)
|
||||
del self.fwd_stash[id]
|
||||
else:
|
||||
break
|
||||
|
||||
# Sync in, offload, and add an event to sync back later
|
||||
self.s1.wait_stream(self.s0)
|
||||
|
||||
stream = self.s1 if self.use_streams else self.s0
|
||||
with stream if self.accelerator_type == "xpu" else torch.cuda.stream(stream):
|
||||
cpu_tensor = torch.empty_like(activation, pin_memory=self.use_pin_memory, device="cpu")
|
||||
cpu_tensor.copy_(activation, non_blocking=True)
|
||||
self.tracker[tensor_id] = (
|
||||
cpu_tensor,
|
||||
True, # True = (in future) modified
|
||||
)
|
||||
|
||||
if self.use_streams:
|
||||
event = self.s1.record_event()
|
||||
|
||||
# Stash to keep activation alive til s1 is done
|
||||
self.fwd_stash[tensor_id] = (activation, event)
|
||||
else:
|
||||
self.tracker[tensor_id] = (
|
||||
activation,
|
||||
False,
|
||||
) # False = not modified, tensor is as is
|
||||
|
||||
return tensor_id
|
||||
|
||||
def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor:
|
||||
# backward pass - we are called with the tensor_id, which
|
||||
# we will use to retrieve the saved/offloaded tensor
|
||||
if self.is_first_backward_call:
|
||||
if self.is_first_forward_pass:
|
||||
self.is_first_forward_pass = False
|
||||
if self.use_pin_memory:
|
||||
verify_sufficient_virtual_memory()
|
||||
|
||||
self.is_first_backward_call = False
|
||||
self.is_first_forward_call = True
|
||||
|
||||
if unpack_tensor_id not in self.tracker:
|
||||
raise ValueError(f"Untracked tensor with id {unpack_tensor_id}")
|
||||
|
||||
maybe_accelerator_tensor, modified = self.tracker[unpack_tensor_id]
|
||||
if modified:
|
||||
accelerator_tensor = maybe_accelerator_tensor.to(self.accelerator_type, non_blocking=True)
|
||||
maybe_accelerator_tensor = accelerator_tensor
|
||||
|
||||
# clear tensor from tracking
|
||||
del self.tracker[unpack_tensor_id]
|
||||
return maybe_accelerator_tensor
|
||||
|
||||
def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor:
|
||||
# backward pass - we are called with the tensor_id, which
|
||||
# we will use to retrieve the saved/offloaded tensor
|
||||
if self.is_first_backward_call:
|
||||
self.curr_graph_id = torch._C._current_graph_task_id()
|
||||
|
||||
def wait_and_del_remaining_references() -> None:
|
||||
for id in list(self.bwd_tensor_stash.keys()):
|
||||
event = self.bwd_ev_stash[id]
|
||||
self.s1.wait_event(event)
|
||||
del self.bwd_tensor_stash[id]
|
||||
|
||||
# Register a callback to the end of autograd to clean everything up
|
||||
torch.autograd.variable.Variable._execution_engine.queue_callback(wait_and_del_remaining_references)
|
||||
|
||||
if self.is_first_forward_pass:
|
||||
self.is_first_forward_pass = False
|
||||
if self.use_pin_memory:
|
||||
verify_sufficient_virtual_memory()
|
||||
|
||||
self.is_first_backward_call = False
|
||||
self.is_first_forward_call = True
|
||||
|
||||
if unpack_tensor_id not in self.tracker:
|
||||
raise ValueError(f"untracked tensor with id {unpack_tensor_id}")
|
||||
|
||||
maybe_accelerator_tensor, modified = self.tracker[unpack_tensor_id]
|
||||
if modified:
|
||||
# Get data on the current autograd node
|
||||
graph_id = torch._C._current_graph_task_id()
|
||||
node = torch._C._current_autograd_node()
|
||||
prev_node_ids = []
|
||||
|
||||
# If we're on a new node, mark prev node's tensors to be freed later
|
||||
if graph_id == self.curr_graph_id and self.curr_autograd_node != node:
|
||||
self.curr_autograd_node = node
|
||||
prev_node_ids = list(self.bwd_tensor_stash.keys())
|
||||
|
||||
brought_back_from_cpu = True
|
||||
if unpack_tensor_id in self.fwd_stash:
|
||||
maybe_accelerator_tensor = self.fwd_stash[unpack_tensor_id][0]
|
||||
brought_back_from_cpu = False
|
||||
else:
|
||||
# Kick off the process to bring tensors back
|
||||
with self.s1 if self.accelerator_type == "xpu" else torch.cuda.stream(self.s1):
|
||||
accelerator_tensor = maybe_accelerator_tensor.to(self.accelerator_type, non_blocking=True)
|
||||
maybe_accelerator_tensor = accelerator_tensor
|
||||
|
||||
# Tell comp stream to wait for the info to be loaded before executing
|
||||
self.s0.wait_stream(self.s1)
|
||||
|
||||
# Stash the tensor to keep memory alive until compute stream is complete
|
||||
self.bwd_tensor_stash[unpack_tensor_id] = maybe_accelerator_tensor
|
||||
|
||||
# Note: [Track views of the unpacked]
|
||||
# Why do we get the use count of the unpacked tensor here? We want an
|
||||
# initial count to compare to later, during the post-hook of the
|
||||
# backward node, when we need to decide whether we're allowed to free
|
||||
# the tensor yet. In what obscure cases must we delay freeing the
|
||||
# tensor (and thus call record_stream)?
|
||||
# 1. Any of the outputs of the backward node is a view of the unpacked
|
||||
# tensor.
|
||||
# 2. In the case that this unpacked tensor will be used in a
|
||||
# checkpointed region, if one of the recomputed saved tensors ends
|
||||
# up as a view of the unpacked tensor.
|
||||
# 3. The user abuses the system somehow and manually relies on the
|
||||
# unpacked tensor to exist after the backward node has executed.
|
||||
storage_refcount = torch._C._storage_Use_Count(maybe_accelerator_tensor.untyped_storage()._cdata)
|
||||
|
||||
def hook(outputs, inputs):
|
||||
# create events for the current node inputs/outputs if they were streamed in
|
||||
if brought_back_from_cpu:
|
||||
# See Note: [Track views of the unpacked]
|
||||
# IF any of the outputs is a view of the tensor, OR if a view of
|
||||
# the tensor has been saved as a part of checkpoint's recompute
|
||||
# process, OR the user has abusedly incurred a reference on the
|
||||
# unpacked tensor, THEN the tensor might be used later and we
|
||||
# cannot presume to delete it after only the current node is
|
||||
# done! So we use our frenemy, record_stream, to ensure the
|
||||
# Tensor stays unmessed with until it's done getting used in the
|
||||
# compute stream (s0 here). Note that the con here is we introduce
|
||||
# non-deterministic (thus higher) memory usage, but this case
|
||||
# should not happen often.
|
||||
unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id]
|
||||
if torch._C._storage_Use_Count(unpacked_tensor.untyped_storage()._cdata) > storage_refcount:
|
||||
unpacked_tensor.record_stream(self.s0)
|
||||
del self.bwd_tensor_stash[unpack_tensor_id]
|
||||
else:
|
||||
event = self.s0.record_event()
|
||||
self.bwd_ev_stash[unpack_tensor_id] = event
|
||||
|
||||
# if there are still things in the fwd_stash, get rid of them as we're in bwd now
|
||||
for id in list(self.fwd_stash.keys()):
|
||||
_, ev = self.fwd_stash[id]
|
||||
self.s0.wait_event(ev)
|
||||
del self.fwd_stash[id]
|
||||
|
||||
# wait on prev node's events and del those
|
||||
for id in prev_node_ids:
|
||||
event = self.bwd_ev_stash[id]
|
||||
self.s1.wait_event(event)
|
||||
del self.bwd_tensor_stash[id]
|
||||
|
||||
return outputs
|
||||
|
||||
node.register_hook(hook)
|
||||
|
||||
# clear tensor from tracking
|
||||
del self.tracker[unpack_tensor_id]
|
||||
return maybe_accelerator_tensor
|
||||
|
||||
unpack_tensor = unpack_tensor_with_streams if self.use_streams else unpack_tensor_single_stream
|
||||
super().__init__(pack_tensor, unpack_tensor)
|
||||
|
||||
|
||||
class NoOpManager(saved_tensors_hooks):
|
||||
"""
|
||||
A `saved_tensors_hook` manager used to disable any other `saved_tensors_hook` manager applied before. This relies
|
||||
on the behavior that only the most recently registered `saved_tensors_hook` will run.
|
||||
|
||||
One example usage is to opt a local region of code out of activations offloading, which is usually applied globally
|
||||
to best track state.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def noop(tensor):
|
||||
return tensor
|
||||
|
||||
super().__init__(noop, noop)
|
||||
|
||||
|
||||
def get_act_offloading_ctx_manager(
|
||||
model: nn.Module,
|
||||
use_pin_memory: bool = True,
|
||||
use_streams: bool = True,
|
||||
min_offload_size: int = 1024,
|
||||
max_fwd_stash_size: int = 5,
|
||||
warn_if_no_head: bool = True,
|
||||
) -> OffloadActivations:
|
||||
"""
|
||||
Returns the activation offloading context manager for the model. All but the last output Linear in every step will
|
||||
be offloaded.
|
||||
|
||||
If activation offloading is enabled, we return the OffloadActivations context manager.
|
||||
If activation offloading is disabled, we return a NoOpManager context manager.
|
||||
|
||||
Args:
|
||||
model (`nn.Module`):
|
||||
Model to wrap with the activation offloading context manager.
|
||||
use_pin_memory (`bool`, *optional*, defaults to `True`):
|
||||
Whether to offloaded Tensor will be placed in pinned memory on the CPU. Pinned memory allows the Tensor to
|
||||
be moved back onto GPU more quickly but is a limited resource.
|
||||
use_streams (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use streams for performance optimization where the communications get overlapped with the
|
||||
computation. Requires a torch build after torch-2.5.0.
|
||||
min_offload_size (`int`, *optional*, defaults to `1024`):
|
||||
Minimum number of bytes a Tensor must be in order to qualify for offloading. If the tensor is too small, we
|
||||
do not want to waste bandwidth and resources moving it to CPU and back.
|
||||
max_fwd_stash_size (`int`, *optional*, defaults to `5`):
|
||||
Maximum size of the forward stash, or the maximum number of consecutive activations to keep alive during
|
||||
the forward pass. This number must be at least 1. Keeping alive more activations will potentially allow
|
||||
more overlap between the communication and compute streams at the cost of increasing memory usage. Keeping
|
||||
alive fewer activations will conserve memory, but may cause poor overlap between the streams, increasing
|
||||
runtime.
|
||||
warn_if_no_head (`bool`, *optional*, defaults to `True`):
|
||||
Whether to warn if no output head is detected. If set to `False`, no warning will be raised if no output
|
||||
head is detected.
|
||||
|
||||
Returns:
|
||||
`contextlib.ContextDecorator`:
|
||||
Activation offloading context manager for the model.
|
||||
"""
|
||||
activations_handling_ctx = OffloadActivations(
|
||||
use_pin_memory=use_pin_memory,
|
||||
use_streams=use_streams,
|
||||
min_offload_size=min_offload_size,
|
||||
max_fwd_stash_size=max_fwd_stash_size,
|
||||
)
|
||||
|
||||
# Below is our hack to disable offloading the last output Linear in every
|
||||
# step, as the cost for offloading the activation and then soon after bringing
|
||||
# it back is expensive.
|
||||
output_head_detected = False
|
||||
noop_ctx = NoOpManager()
|
||||
|
||||
# Try to get the actual model if it's wrapped
|
||||
unwrapped_model = model
|
||||
if hasattr(unwrapped_model, "module"):
|
||||
unwrapped_model = unwrapped_model.module
|
||||
# check for PEFT models
|
||||
if hasattr(unwrapped_model, "base_model") and hasattr(unwrapped_model, "peft_config"):
|
||||
unwrapped_model = unwrapped_model.base_model
|
||||
|
||||
# Check for different types of output heads
|
||||
if hasattr(unwrapped_model, "output"):
|
||||
if isinstance(unwrapped_model.output, nn.Module):
|
||||
unwrapped_model.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
||||
unwrapped_model.output.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
|
||||
output_head_detected = True
|
||||
elif hasattr(unwrapped_model.output, "linear") and isinstance(unwrapped_model.output.linear, nn.Module):
|
||||
unwrapped_model.output.linear.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
||||
unwrapped_model.output.linear.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
|
||||
output_head_detected = True
|
||||
|
||||
# Check for HuggingFace model output heads
|
||||
elif hasattr(unwrapped_model, "lm_head"):
|
||||
unwrapped_model.lm_head.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
||||
unwrapped_model.lm_head.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
|
||||
output_head_detected = True
|
||||
|
||||
# Check for decoder-based models
|
||||
elif hasattr(unwrapped_model, "decoder"):
|
||||
decoder = unwrapped_model.decoder
|
||||
if hasattr(decoder, "output"):
|
||||
decoder.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
||||
decoder.output.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
|
||||
output_head_detected = True
|
||||
# Some models have lm_head in the decoder
|
||||
elif hasattr(decoder, "lm_head"):
|
||||
decoder.lm_head.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
||||
decoder.lm_head.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
|
||||
output_head_detected = True
|
||||
|
||||
# Check for transformer models with final layer norm
|
||||
elif hasattr(unwrapped_model, "final_layer_norm") or hasattr(unwrapped_model, "ln_f"):
|
||||
final_norm = getattr(unwrapped_model, "final_layer_norm", None) or unwrapped_model.ln_f
|
||||
final_norm.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
||||
final_norm.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
|
||||
output_head_detected = True
|
||||
|
||||
# Check for models with head module
|
||||
elif hasattr(unwrapped_model, "head") and isinstance(unwrapped_model.head, nn.Module):
|
||||
unwrapped_model.head.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
||||
unwrapped_model.head.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
|
||||
output_head_detected = True
|
||||
|
||||
if not output_head_detected and warn_if_no_head:
|
||||
warnings.warn(
|
||||
"During activation offloading, no output head was detected. If your model has an output head, it will be "
|
||||
"offloaded. This usually greatly slows training, given the large vocabulary size. To change this "
|
||||
"behavior, set your output head as model.output and make it an nn.Module. You can disable this warning by "
|
||||
"passing `warn_if_no_head=False`."
|
||||
)
|
||||
|
||||
# Disable offloading for any Liger modules
|
||||
for name, module in unwrapped_model.named_modules():
|
||||
if "liger" in name.lower():
|
||||
module.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
|
||||
module.register_forward_hook(lambda *args: noop_ctx.__exit__(), always_call=True)
|
||||
|
||||
return activations_handling_ctx
|
@ -16,9 +16,9 @@ import itertools
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
||||
|
||||
from accelerate.utils import is_deepspeed_available
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
@ -30,12 +30,10 @@ SUPPORTED_ARCHITECTURES = (
|
||||
AutoModelForSeq2SeqLMWithValueHead,
|
||||
)
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from accelerate import Accelerator
|
||||
from deepspeed.runtime.engine import DeepSpeedEngine
|
||||
from torch.nn import Module
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
|
||||
@ -167,6 +165,8 @@ def iter_params(module, recurse=False):
|
||||
|
||||
def add_hooks(model: "DeepSpeedEngine") -> None:
|
||||
"""Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
|
||||
import deepspeed
|
||||
|
||||
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
|
||||
return
|
||||
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
|
||||
@ -214,6 +214,8 @@ def unwrap_model_for_generation(
|
||||
if not gather_deepspeed3_params:
|
||||
yield accelerator.unwrap_model(model)
|
||||
else:
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(model.parameters()):
|
||||
remove_hooks(model)
|
||||
yield accelerator.unwrap_model(model)
|
||||
@ -222,8 +224,13 @@ def unwrap_model_for_generation(
|
||||
yield unwrapped_model
|
||||
|
||||
|
||||
def prepare_deepspeed(model, accelerator):
|
||||
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
||||
def prepare_deepspeed(model: "Module", accelerator: "Accelerator"):
|
||||
"""Prepares the model for DeepSpeed inference or evaluation by initializing it with the appropriate configuration.
|
||||
|
||||
Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
||||
"""
|
||||
import deepspeed # local import (instead of top-level) to avoid DS init interfering with other backends (like vllm): https://github.com/deepspeedai/DeepSpeed/issues/7252
|
||||
|
||||
deepspeed_plugin = accelerator.state.deepspeed_plugin
|
||||
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
||||
stage = config_kwargs["zero_optimization"]["stage"]
|
||||
@ -266,7 +273,7 @@ def prepare_fsdp(model, accelerator):
|
||||
accelerator.state.fsdp_plugin.set_auto_wrap_policy(model)
|
||||
fsdp_plugin = accelerator.state.fsdp_plugin
|
||||
kwargs = {
|
||||
"sharding_strategy": fsdp_plugin.sharding_strategy,
|
||||
"sharding_strategy": fsdp_plugin.sharding_strategy or fsdp_plugin.reshard_after_forward,
|
||||
"cpu_offload": fsdp_plugin.cpu_offload,
|
||||
"auto_wrap_policy": fsdp_plugin.auto_wrap_policy,
|
||||
"mixed_precision": fsdp_plugin.mixed_precision_policy,
|
||||
@ -282,3 +289,53 @@ def prepare_fsdp(model, accelerator):
|
||||
model = FSDP(model, **kwargs)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
class _ForwardRedirection:
|
||||
"""Implements the `forward-redirection`.
|
||||
|
||||
Taken from Pytorch-lightning: https://github.com/Lightning-AI/pytorch-lightning/blob/02311d03fb982560246eead7c08104481fac9579/src/lightning/pytorch/strategies/strategy.py#L602
|
||||
|
||||
A method call to a wrapped module gets rerouted through the wrapper's `forward` method instead.
|
||||
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self, wrapper_module: nn.Module, original_module: nn.Module, method: callable, *args: Any, **kwargs: Any
|
||||
):
|
||||
"""Reroutes a method call through the `wrapper_module`'s `forward` method.
|
||||
|
||||
Args:
|
||||
wrapper_module: The module that has `original_module` wrapped.
|
||||
original_module: The module that was wrapped inside `wrapper_module`.
|
||||
method_name: The name of the method that should be called on the `original_module` after inputs get
|
||||
redirected through the `wrapper_module`'s `forward` method.
|
||||
*args: The positional arguments to the method `method_name`. They will get passed to a patched
|
||||
`forward` method instead.
|
||||
**kwargs: The keyword arguments to the method `method_name`. They will get passed to a patched
|
||||
`forward` method instead.
|
||||
|
||||
"""
|
||||
original_forward = original_module.forward
|
||||
|
||||
def wrapped_forward(*_args: Any, **_kwargs: Any) -> Any:
|
||||
# Unpatch ourselves immediately before calling the method `method_name`
|
||||
# because itself may want to call the real `forward`
|
||||
original_module.forward = original_forward # type: ignore[method-assign]
|
||||
# Call the actual method e.g. `.training_step(...)`
|
||||
out = method(*_args, **_kwargs)
|
||||
self.on_after_inner_forward(wrapper_module, original_module)
|
||||
return out
|
||||
|
||||
# Patch the original_module's forward so we can redirect the arguments back to the real method
|
||||
original_module.forward = wrapped_forward # type: ignore[method-assign]
|
||||
|
||||
wrapper_output = wrapper_module(*args, **kwargs)
|
||||
self.on_after_outer_forward(wrapper_module, original_module)
|
||||
return wrapper_output
|
||||
|
||||
def on_after_inner_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def on_after_outer_forward(self, wrapper_module: nn.Module, original_module: nn.Module) -> None:
|
||||
pass
|
||||
|
32
trl/rewards/__init__.py
Normal file
32
trl/rewards/__init__.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..import_utils import _LazyModule
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"format_rewards": ["think_format_reward"],
|
||||
}
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .format_rewards import think_format_reward
|
||||
|
||||
|
||||
else:
|
||||
sys.modules[__name__] = _LazyModule(__name__, __file__, _import_structure, module_spec=__spec__)
|
49
trl/rewards/format_rewards.py
Normal file
49
trl/rewards/format_rewards.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def think_format_reward(completions: list[list[dict[str, str]]], **kwargs) -> list[float]:
|
||||
r"""
|
||||
Reward function that checks if the reasoning process is enclosed within `"<think>"` and `"</think>"` tags. The
|
||||
function returns a reward of 1.0 if the format is correct, otherwise 0.0.
|
||||
|
||||
Args:
|
||||
completions (`list[list[dict[str, str]]]`):
|
||||
List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary
|
||||
containing the key `"content"` with the value being the text of the completion.
|
||||
**kwargs:
|
||||
Additional keyword arguments. This function does not use them, but they are required in the function
|
||||
signature to ensure compatibility with trainers like [`GRPOTrainer`].
|
||||
|
||||
Returns:
|
||||
`list[float]`:
|
||||
A list of rewards, where each reward is 1.0 if the completion matches the expected format, otherwise 0.0.
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from trl.rewards import think_format_reward
|
||||
>>> completions = [
|
||||
... [{"content": "<think>\nThis is my reasoning.\n</think>\nThis is my answer."}],
|
||||
... [{"content": "<think>\nThis is my reasoning.\nThis is my answer."}],
|
||||
... ]
|
||||
>>> think_format_reward(completions)
|
||||
[1.0, 0.0]
|
||||
```
|
||||
"""
|
||||
pattern = r"^<think>(?!.*<think>)(.*?)</think>.*$"
|
||||
completion_contents = [completion[0]["content"] for completion in completions]
|
||||
matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
|
||||
return [1.0 if match else 0.0 for match in matches]
|
@ -26,15 +26,18 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.markdown import Markdown
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||
from transformers.utils import is_rich_available
|
||||
|
||||
from trl import TrlParser, init_zero_verbose
|
||||
from trl.trainer.utils import get_quantization_config
|
||||
|
||||
|
||||
if is_rich_available():
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.markdown import Markdown
|
||||
|
||||
init_zero_verbose()
|
||||
|
||||
HELP_STRING = """\
|
||||
@ -195,6 +198,9 @@ class ChatArguments:
|
||||
|
||||
class RichInterface:
|
||||
def __init__(self, model_name=None, user_name=None):
|
||||
if not is_rich_available():
|
||||
raise ImportError("Rich is not available. Please install it with `pip install rich`.")
|
||||
|
||||
self._console = Console()
|
||||
if model_name is None:
|
||||
self.model_name = "assistant"
|
||||
|
@ -33,8 +33,13 @@ from .utils import get_git_commit_hash
|
||||
|
||||
|
||||
def print_env():
|
||||
devices = None
|
||||
if torch.cuda.is_available():
|
||||
devices = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]
|
||||
elif torch.backends.mps.is_available():
|
||||
devices = ["MPS"]
|
||||
elif torch.xpu.is_available():
|
||||
devices = [torch.xpu.get_device_name(i) for i in range(torch.xpu.device_count())]
|
||||
|
||||
accelerate_config = accelerate_config_str = "not found"
|
||||
|
||||
@ -55,7 +60,7 @@ def print_env():
|
||||
"Python version": platform.python_version(),
|
||||
"TRL version": f"{__version__}+{commit_hash[:7]}" if commit_hash else __version__,
|
||||
"PyTorch version": version("torch"),
|
||||
"CUDA device(s)": ", ".join(devices) if torch.cuda.is_available() else "not available",
|
||||
"accelerator(s)": ", ".join(devices) if devices is not None else "cpu",
|
||||
"Transformers version": version("transformers"),
|
||||
"Accelerate version": version("accelerate"),
|
||||
"Accelerate config": accelerate_config_str,
|
||||
|
@ -13,6 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
@ -20,6 +23,12 @@ from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
|
||||
from trl.rewards import think_format_reward
|
||||
|
||||
|
||||
reward_funcs_registry = {
|
||||
"think_format_reward": think_format_reward,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -28,9 +37,12 @@ class GRPOScriptArguments(ScriptArguments):
|
||||
Script arguments for the GRPO training script.
|
||||
|
||||
Args:
|
||||
reward_model_name_or_path (`str` or `None`):
|
||||
reward_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
|
||||
Reward model id of a pretrained model hosted inside a model repo on huggingface.co or local path to a
|
||||
directory containing model weights saved using [`~transformers.PreTrainedModel.save_pretrained`].
|
||||
reward_funcs (`list[str]` or `None`, *optional*, defaults to `None`):
|
||||
Reward functions to use. It can be either one of `"think_format_reward"`; or a dotted import path "
|
||||
(e.g., `'my_lib.rewards.custom_reward'`).
|
||||
"""
|
||||
|
||||
reward_model_name_or_path: Optional[str] = field(
|
||||
@ -40,6 +52,13 @@ class GRPOScriptArguments(ScriptArguments):
|
||||
"local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`."
|
||||
},
|
||||
)
|
||||
reward_funcs: Optional[list[str]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Reward functions to use. It can be either one of 'think_format_reward'; or a dotted "
|
||||
"import path. (e.g., 'my_lib.rewards.custom_reward')."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def main(script_args, training_args, model_args):
|
||||
@ -50,9 +69,30 @@ def main(script_args, training_args, model_args):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
||||
)
|
||||
reward_model = AutoModelForSequenceClassification.from_pretrained(
|
||||
script_args.reward_model_name_or_path, trust_remote_code=model_args.trust_remote_code, num_labels=1
|
||||
)
|
||||
|
||||
# Get the reward models and functions
|
||||
reward_funcs = []
|
||||
if script_args.reward_model_name_or_path:
|
||||
reward_model = AutoModelForSequenceClassification.from_pretrained(
|
||||
script_args.reward_model_name_or_path, trust_remote_code=model_args.trust_remote_code, num_labels=1
|
||||
)
|
||||
reward_funcs.append(reward_model)
|
||||
|
||||
if script_args.reward_funcs:
|
||||
for func_name in script_args.reward_funcs:
|
||||
if func_name in reward_funcs_registry:
|
||||
reward_funcs.append(reward_funcs_registry[func_name])
|
||||
elif "." in func_name:
|
||||
module_path, func_name = func_name.rsplit(".", 1)
|
||||
sys.path.insert(0, os.getcwd())
|
||||
module = importlib.import_module(module_path)
|
||||
reward_func = getattr(module, func_name)
|
||||
reward_funcs.append(reward_func)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Could not load reward function '{func_name}'. Expected one of "
|
||||
f"{list(reward_funcs_registry.keys())} or a valid import path."
|
||||
)
|
||||
|
||||
# Load the dataset
|
||||
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
||||
|
@ -142,5 +142,8 @@ def make_parser(subparsers: argparse._SubParsersAction = None):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = make_parser()
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
# When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments.
|
||||
# To ensure that their parsing does not interfere with the script arguments, parse the arguments with
|
||||
# `return_remaining_strings=True`, then ignore the remaining strings.
|
||||
script_args, training_args, model_args, _ = parser.parse_args_and_config(return_remaining_strings=True)
|
||||
main(script_args, training_args, model_args)
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
@ -25,6 +26,7 @@ from typing import Optional, Union
|
||||
import yaml
|
||||
from transformers import HfArgumentParser
|
||||
from transformers.hf_argparser import DataClass, DataClassType
|
||||
from transformers.utils import is_rich_available
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -51,7 +53,7 @@ class ScriptArguments:
|
||||
type, inplace operation. See https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992.
|
||||
"""
|
||||
|
||||
dataset_name: str = field(metadata={"help": "Dataset name."})
|
||||
dataset_name: Optional[str] = field(default=None, metadata={"help": "Dataset name."})
|
||||
dataset_config: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
@ -78,14 +80,21 @@ class ScriptArguments:
|
||||
def init_zero_verbose():
|
||||
"""
|
||||
Perform zero verbose init - use this method on top of the CLI modules to make
|
||||
logging and warning output cleaner. Uses Rich if available, falls back otherwise.
|
||||
"""
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
from rich.logging import RichHandler
|
||||
|
||||
FORMAT = "%(message)s"
|
||||
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.ERROR)
|
||||
|
||||
if is_rich_available():
|
||||
from rich.logging import RichHandler
|
||||
|
||||
handler = RichHandler()
|
||||
else:
|
||||
handler = logging.StreamHandler()
|
||||
|
||||
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[handler], level=logging.ERROR)
|
||||
|
||||
# Custom warning handler to redirect warnings to the logging system
|
||||
def warning_handler(message, category, filename, lineno, file=None, line=None):
|
||||
@ -207,20 +216,33 @@ class TrlParser(HfArgumentParser):
|
||||
|
||||
def set_defaults_with_config(self, **kwargs) -> list[str]:
|
||||
"""
|
||||
Overrides the parser's default values with those provided via keyword arguments.
|
||||
Overrides the parser's default values with those provided via keyword arguments, including for subparsers.
|
||||
|
||||
Any argument with an updated default will also be marked as not required
|
||||
if it was previously required.
|
||||
|
||||
Returns a list of strings that were not consumed by the parser.
|
||||
"""
|
||||
# If an argument is in the kwargs, update its default and set it as not required
|
||||
for action in self._actions:
|
||||
if action.dest in kwargs:
|
||||
action.default = kwargs.pop(action.dest)
|
||||
action.required = False
|
||||
remaining_strings = [item for key, value in kwargs.items() for item in [f"--{key}", str(value)]]
|
||||
return remaining_strings
|
||||
|
||||
def apply_defaults(parser, kw):
|
||||
used_keys = set()
|
||||
for action in parser._actions:
|
||||
# Handle subparsers recursively
|
||||
if isinstance(action, argparse._SubParsersAction):
|
||||
for subparser in action.choices.values():
|
||||
used_keys.update(apply_defaults(subparser, kw))
|
||||
elif action.dest in kw:
|
||||
action.default = kw[action.dest]
|
||||
action.required = False
|
||||
used_keys.add(action.dest)
|
||||
return used_keys
|
||||
|
||||
used_keys = apply_defaults(self, kwargs)
|
||||
# Remaining args not consumed by the parser
|
||||
remaining = [
|
||||
item for key, value in kwargs.items() if key not in used_keys for item in (f"--{key}", str(value))
|
||||
]
|
||||
return remaining
|
||||
|
||||
|
||||
def get_git_commit_hash(package_name):
|
||||
|
@ -184,6 +184,8 @@ class ScriptArguments:
|
||||
enforce_eager (`bool` or `None`, *optional*, defaults to `None`):
|
||||
Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the
|
||||
model in eager mode. If `False` (default behavior), we will use CUDA graph and eager execution in hybrid.
|
||||
kv_cache_dtype (`str`, *optional*, defaults to `"auto"`):
|
||||
Data type to use for KV cache. If set to `"auto"`, the dtype will default to the model data type.
|
||||
log_level (`str`, *optional*, defaults to `"info"`):
|
||||
Log level for uvicorn. Possible choices: `"critical"`, `"error"`, `"warning"`, `"info"`, `"debug"`,
|
||||
`"trace"`.
|
||||
@ -251,6 +253,12 @@ class ScriptArguments:
|
||||
"execution in hybrid."
|
||||
},
|
||||
)
|
||||
kv_cache_dtype: str = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": "Data type to use for KV cache. If set to 'auto', the dtype will default to the model data type."
|
||||
},
|
||||
)
|
||||
log_level: str = field(
|
||||
default="info",
|
||||
metadata={
|
||||
@ -280,6 +288,7 @@ def llm_worker(
|
||||
# directly reuse the KV cache if it shares the same prefix with one of the existing queries.
|
||||
# This is particularly useful here because we generate completions from the same prompts.
|
||||
enable_prefix_caching=script_args.enable_prefix_caching,
|
||||
kv_cache_dtype=script_args.kv_cache_dtype,
|
||||
max_model_len=script_args.max_model_len,
|
||||
worker_extension_cls="trl.scripts.vllm_serve.WeightSyncWorkerExtension",
|
||||
)
|
||||
|
@ -38,6 +38,7 @@ _import_structure = {
|
||||
"gkd_trainer": ["GKDTrainer"],
|
||||
"grpo_config": ["GRPOConfig"],
|
||||
"grpo_trainer": ["GRPOTrainer"],
|
||||
"iterative_sft_config": ["IterativeSFTConfig"],
|
||||
"iterative_sft_trainer": ["IterativeSFTTrainer"],
|
||||
"judges": [
|
||||
"AllTrueJudge",
|
||||
@ -109,7 +110,7 @@ if TYPE_CHECKING:
|
||||
from .gkd_trainer import GKDTrainer
|
||||
from .grpo_config import GRPOConfig
|
||||
from .grpo_trainer import GRPOTrainer
|
||||
from .iterative_sft_trainer import IterativeSFTTrainer
|
||||
from .iterative_sft_trainer import IterativeSFTConfig, IterativeSFTTrainer
|
||||
from .judges import (
|
||||
AllTrueJudge,
|
||||
BaseBinaryJudge,
|
||||
|
@ -19,7 +19,6 @@ import textwrap
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from copy import deepcopy
|
||||
from operator import itemgetter
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
|
||||
|
||||
@ -32,7 +31,7 @@ import torch.nn.functional as F
|
||||
import transformers
|
||||
from accelerate import PartialState
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import is_deepspeed_available, tqdm
|
||||
from accelerate.utils import tqdm
|
||||
from datasets import Dataset
|
||||
from packaging import version
|
||||
from torch.utils.data import DataLoader, SequentialSampler
|
||||
@ -56,7 +55,7 @@ from transformers.utils import is_peft_available
|
||||
|
||||
from ..data_utils import maybe_apply_chat_template
|
||||
from ..import_utils import is_joblib_available
|
||||
from ..models import PreTrainedModelWrapper, create_reference_model
|
||||
from ..models import create_reference_model, prepare_deepspeed
|
||||
from .bco_config import BCOConfig
|
||||
from .utils import (
|
||||
DPODataCollatorWithPadding,
|
||||
@ -83,9 +82,6 @@ if is_sklearn_available():
|
||||
if is_joblib_available():
|
||||
import joblib
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
@ -712,7 +708,7 @@ class BCOTrainer(Trainer):
|
||||
)
|
||||
else:
|
||||
if self.is_deepspeed_enabled:
|
||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
@ -846,37 +842,6 @@ class BCOTrainer(Trainer):
|
||||
|
||||
return all_embeddings
|
||||
|
||||
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
||||
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
||||
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
||||
|
||||
if model is not None:
|
||||
if hasattr(model, "config"):
|
||||
hidden_size = (
|
||||
max(model.config.hidden_sizes)
|
||||
if getattr(model.config, "hidden_sizes", None)
|
||||
else getattr(model.config, "hidden_size", None)
|
||||
)
|
||||
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
||||
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
||||
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
||||
config_kwargs.update(
|
||||
{
|
||||
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
||||
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
||||
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
||||
}
|
||||
)
|
||||
|
||||
# If ZeRO-3 is used, we shard both the active and reference model.
|
||||
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
||||
if config_kwargs["zero_optimization"]["stage"] != 3:
|
||||
config_kwargs["zero_optimization"]["stage"] = 0
|
||||
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def _save_optimizer_and_scheduler(self, output_dir):
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
super()._save_optimizer_and_scheduler(output_dir)
|
||||
@ -1310,10 +1275,12 @@ class BCOTrainer(Trainer):
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.train_dataset is None or not has_length(self.train_dataset):
|
||||
def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]:
|
||||
if dataset is None:
|
||||
dataset = self.train_dataset
|
||||
if dataset is None or not has_length(dataset):
|
||||
return None
|
||||
return SequentialSampler(self.train_dataset)
|
||||
return SequentialSampler(dataset)
|
||||
|
||||
def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
|
||||
"""Generate samples from the model and reference model for the given batch of inputs."""
|
||||
|
@ -19,11 +19,7 @@ import pandas as pd
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.state import AcceleratorState
|
||||
from accelerate.utils import gather_object, is_comet_ml_available, is_deepspeed_available, is_wandb_available
|
||||
from rich.console import Console, Group
|
||||
from rich.live import Live
|
||||
from rich.panel import Panel
|
||||
from rich.progress import Progress
|
||||
from accelerate.utils import gather_object, is_wandb_available
|
||||
from transformers import (
|
||||
GenerationConfig,
|
||||
PreTrainedModel,
|
||||
@ -35,6 +31,7 @@ from transformers import (
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import has_length
|
||||
from transformers.utils import is_rich_available
|
||||
|
||||
from ..data_utils import maybe_apply_chat_template
|
||||
from ..import_utils import is_mergekit_available
|
||||
@ -44,11 +41,11 @@ from .judges import BasePairwiseJudge
|
||||
from .utils import log_table_to_comet_experiment
|
||||
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
if is_comet_ml_available():
|
||||
pass
|
||||
if is_rich_available():
|
||||
from rich.console import Console, Group
|
||||
from rich.live import Live
|
||||
from rich.panel import Panel
|
||||
from rich.progress import Progress
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
@ -115,6 +112,8 @@ class SyncRefModelCallback(TrainerCallback):
|
||||
def sync_target_model(model, target_model, alpha):
|
||||
deepspeed_plugin = AcceleratorState().deepspeed_plugin
|
||||
if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3:
|
||||
import deepspeed
|
||||
|
||||
with deepspeed.zero.GatheredParameters(
|
||||
list(model.parameters()) + list(target_model.parameters()), modifier_rank=0
|
||||
):
|
||||
@ -138,6 +137,9 @@ class RichProgressCallback(TrainerCallback):
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not is_rich_available():
|
||||
raise ImportError("RichProgressCallback requires the `rich` extra. To install, run `pip install rich`.")
|
||||
|
||||
self.training_bar = None
|
||||
self.prediction_bar = None
|
||||
|
||||
|
@ -131,14 +131,19 @@ class DPOConfig(TrainingArguments):
|
||||
Whether to ignore the provided reference model and implicitly use a reference model that assigns equal
|
||||
probability to all responses.
|
||||
label_smoothing (`float`, *optional*, defaults to `0.0`):
|
||||
Robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report and
|
||||
Robust DPO label smoothing parameter from the [cDPO report](https://ericmitchell.ai/cdpo.pdf) and
|
||||
[Robust DPO](https://huggingface.co/papers/2403.00409) paper that should be between `0.0` and `0.5`.
|
||||
use_weighting (`bool`, *optional*, defaults to `False`):
|
||||
Whether to weight the loss as done in the [WPO](https://huggingface.co/papers/2406.11827) paper.
|
||||
Whether to weight the loss as done in the [WPO paper](https://huggingface.co/papers/2406.11827).
|
||||
rpo_alpha (`float`, *optional*, defaults to `None`):
|
||||
α parameter from the [RPO](https://huggingface.co/papers/2404.19733) paper (v3), which controls the
|
||||
α parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) (v3), which controls the
|
||||
weighting of the NLL term in the loss. If `None`, no weighting is applied and the loss is the same as the
|
||||
DPO loss. The paper recommends `rpo_alpha=1.0`.
|
||||
ld_alpha (`float` or `None`, *optional*, defaults to `None`):
|
||||
α parameter from the [LD-DPO paper](https://huggingface.co/papers/2409.06411), which controls the weighting
|
||||
of the verbose token log-probabilities in responses. If `None`, no weighting is applied to the verbose
|
||||
part, and the loss is equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between
|
||||
`0.0` and `1.0`.
|
||||
discopop_tau (`float`, *optional*, defaults to `0.05`):
|
||||
τ/temperature parameter from the [DiscoPOP](https://huggingface.co/papers/2406.08414) paper, which controls
|
||||
the shape of log ratio modulated loss. The paper recommends the default value `discopop_tau=0.05`.
|
||||
@ -346,6 +351,14 @@ class DPOConfig(TrainingArguments):
|
||||
"`rpo_alpha=1.0`."
|
||||
},
|
||||
)
|
||||
ld_alpha: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "α parameter from the LD-DPO paper, which controls the weighting of the verbose token "
|
||||
"log-probabilities in responses. If `None`, no weighting is applied to the verbose part, and the loss is "
|
||||
"equivalent to the standard DPO loss. The paper recommends setting `ld_alpha` between `0.0` and `1.0`.",
|
||||
},
|
||||
)
|
||||
discopop_tau: float = field(
|
||||
default=0.05,
|
||||
metadata={
|
||||
|
@ -19,7 +19,6 @@ import textwrap
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
|
||||
@ -30,7 +29,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from accelerate import PartialState
|
||||
from accelerate.utils import is_deepspeed_available, tqdm
|
||||
from accelerate.utils import tqdm
|
||||
from datasets import Dataset, IterableDataset
|
||||
from packaging import version
|
||||
from torch.utils.data import DataLoader
|
||||
@ -53,7 +52,7 @@ from transformers.trainer_utils import EvalLoopOutput
|
||||
from transformers.utils import is_peft_available, is_torch_xpu_available
|
||||
|
||||
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
|
||||
from ..models import PreTrainedModelWrapper, create_reference_model
|
||||
from ..models import create_reference_model, prepare_deepspeed
|
||||
from ..models.utils import prepare_fsdp
|
||||
from .callbacks import SyncRefModelCallback
|
||||
from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType
|
||||
@ -63,6 +62,7 @@ from .utils import (
|
||||
disable_dropout_in_model,
|
||||
empty_cache,
|
||||
flush_left,
|
||||
flush_right,
|
||||
generate_model_card,
|
||||
get_comet_experiment_url,
|
||||
log_table_to_comet_experiment,
|
||||
@ -80,9 +80,6 @@ if is_peft_available():
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForPreference(DataCollatorMixin):
|
||||
@ -184,7 +181,6 @@ class DPOTrainer(Trainer):
|
||||
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
||||
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
||||
reuse the fine-tuned model.
|
||||
This supercedes the `tokenizer` argument, which is now deprecated.
|
||||
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
||||
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
||||
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
||||
@ -510,7 +506,7 @@ class DPOTrainer(Trainer):
|
||||
)
|
||||
else:
|
||||
if self.is_deepspeed_enabled:
|
||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||
elif self.is_fsdp_enabled:
|
||||
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
|
||||
else:
|
||||
@ -558,7 +554,7 @@ class DPOTrainer(Trainer):
|
||||
|
||||
dataset = dataset.map(
|
||||
self.tokenize_row if not self.is_vision_model else self.process_row,
|
||||
remove_columns=["prompt", "chosen", "rejected"],
|
||||
remove_columns=["chosen", "rejected"],
|
||||
fn_kwargs={
|
||||
"processing_class": processing_class,
|
||||
"max_prompt_length": args.max_prompt_length,
|
||||
@ -676,37 +672,6 @@ class DPOTrainer(Trainer):
|
||||
|
||||
return output
|
||||
|
||||
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
||||
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
||||
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
||||
|
||||
if model is not None:
|
||||
if hasattr(model, "config"):
|
||||
hidden_size = (
|
||||
max(model.config.hidden_sizes)
|
||||
if getattr(model.config, "hidden_sizes", None)
|
||||
else getattr(model.config, "hidden_size", None)
|
||||
)
|
||||
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
||||
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
||||
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
||||
config_kwargs.update(
|
||||
{
|
||||
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
||||
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
||||
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
||||
}
|
||||
)
|
||||
|
||||
# If ZeRO-3 is used, we shard both the active and reference model.
|
||||
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
||||
if config_kwargs["zero_optimization"]["stage"] != 3:
|
||||
config_kwargs["zero_optimization"]["stage"] = 0
|
||||
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
||||
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
||||
@ -840,9 +805,9 @@ class DPOTrainer(Trainer):
|
||||
with torch.no_grad(), compte_ref_context_manager:
|
||||
if self.ref_model is None:
|
||||
with self.null_ref_context():
|
||||
ref_model_output = self.concatenated_forward(self.model, batch)
|
||||
ref_model_output = self.concatenated_forward(self.model, batch, is_ref_model=True)
|
||||
else:
|
||||
ref_model_output = self.concatenated_forward(self.ref_model, batch)
|
||||
ref_model_output = self.concatenated_forward(self.ref_model, batch, is_ref_model=True)
|
||||
return ref_model_output["chosen_logps"], ref_model_output["rejected_logps"]
|
||||
|
||||
@staticmethod
|
||||
@ -1102,16 +1067,28 @@ class DPOTrainer(Trainer):
|
||||
|
||||
return losses, chosen_rewards, rejected_rewards
|
||||
|
||||
def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]]):
|
||||
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
||||
def concatenated_forward(
|
||||
self, model: nn.Module, batch: dict[str, Union[list, torch.LongTensor]], is_ref_model: bool = False
|
||||
):
|
||||
"""
|
||||
Runs the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
||||
|
||||
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
||||
|
||||
Args:
|
||||
model:
|
||||
Model to run the forward pass on.
|
||||
batch:
|
||||
Batch of input data.
|
||||
is_ref_model:
|
||||
Whether this method is being called for the reference model. If `True`, length desensitization is not
|
||||
applied.
|
||||
"""
|
||||
num_examples = batch["prompt_input_ids"].shape[0]
|
||||
|
||||
concatenated_batch = self.concatenated_inputs(batch, padding_value=self.padding_value)
|
||||
|
||||
model_kwargs = {}
|
||||
model_kwargs = {"use_cache": False}
|
||||
if self.aux_loss_enabled:
|
||||
model_kwargs["output_router_logits"] = True
|
||||
|
||||
@ -1148,26 +1125,35 @@ class DPOTrainer(Trainer):
|
||||
dim=1,
|
||||
)
|
||||
|
||||
# Flush left to reduce the memory usage
|
||||
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
|
||||
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
|
||||
# Truncate right
|
||||
if self.max_length is not None:
|
||||
if self.truncation_mode == "keep_end":
|
||||
# Flush and truncate
|
||||
if self.max_length is not None and self.max_length < attention_mask.size(1):
|
||||
if self.truncation_mode == "keep_start":
|
||||
# Flush left to reduce the memory usage
|
||||
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
|
||||
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
attention_mask = attention_mask[:, : self.max_length]
|
||||
input_ids = input_ids[:, : self.max_length]
|
||||
loss_mask = loss_mask[:, : self.max_length]
|
||||
elif self.truncation_mode == "keep_end":
|
||||
# Flush right before truncating left, then flush left
|
||||
# [[0, 0, x, x, x, x], -> [[0, 0, x, x],
|
||||
# [0, x, x, x, 0, 0]] [0, x, x, x]]
|
||||
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
|
||||
input_ids = input_ids[:, -self.max_length :]
|
||||
attention_mask = attention_mask[:, -self.max_length :]
|
||||
loss_mask = loss_mask[:, -self.max_length :]
|
||||
elif self.truncation_mode == "keep_start":
|
||||
input_ids = input_ids[:, : self.max_length]
|
||||
attention_mask = attention_mask[:, : self.max_length]
|
||||
loss_mask = loss_mask[:, : self.max_length]
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', "
|
||||
"'keep_start']."
|
||||
)
|
||||
else:
|
||||
# Flush left to reduce the memory usage
|
||||
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
|
||||
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
|
||||
if self.use_logits_to_keep:
|
||||
# Compute logits_to_keep based on loss_mask pattern:
|
||||
@ -1254,6 +1240,28 @@ class DPOTrainer(Trainer):
|
||||
if self.loss_type == "ipo":
|
||||
all_logps = all_logps / loss_mask.sum(-1)
|
||||
|
||||
if self.args.ld_alpha is not None and not is_ref_model:
|
||||
# Compute response lengths based on loss_mask
|
||||
completion_lengths = loss_mask.sum(dim=1)
|
||||
|
||||
chosen_lengths = completion_lengths[:num_examples]
|
||||
rejected_lengths = completion_lengths[num_examples:]
|
||||
public_lengths = torch.min(chosen_lengths, rejected_lengths) # l_p in the paper
|
||||
public_lengths = torch.cat([public_lengths, public_lengths], dim=0)
|
||||
|
||||
seq_len = per_token_logps.size(1)
|
||||
position_ids = torch.arange(seq_len, device=per_token_logps.device).expand_as(per_token_logps)
|
||||
|
||||
ld_mask = position_ids < public_lengths.unsqueeze(1)
|
||||
mask = position_ids < completion_lengths.unsqueeze(1)
|
||||
|
||||
front_mask = (ld_mask & mask).float()
|
||||
rear_mask = (~ld_mask & mask).float()
|
||||
front_logps = (per_token_logps * front_mask).sum(dim=1)
|
||||
rear_logps = (per_token_logps * rear_mask).sum(dim=1)
|
||||
|
||||
all_logps = front_logps + self.args.ld_alpha * rear_logps
|
||||
|
||||
output["chosen_logps"] = all_logps[:num_examples]
|
||||
output["rejected_logps"] = all_logps[num_examples:]
|
||||
|
||||
@ -1488,7 +1496,7 @@ class DPOTrainer(Trainer):
|
||||
)
|
||||
],
|
||||
)
|
||||
if "wandb" in self.args.report_to:
|
||||
if "wandb" in self.args.report_to and self.accelerator.is_main_process:
|
||||
wandb.log({"game_log": wandb.Table(data=table)})
|
||||
|
||||
if "comet_ml" in self.args.report_to:
|
||||
|
@ -15,13 +15,11 @@
|
||||
import os
|
||||
import random
|
||||
import textwrap
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from accelerate.utils import is_deepspeed_available
|
||||
from datasets import Dataset
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
@ -38,7 +36,7 @@ from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalPrediction
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from ..models import PreTrainedModelWrapper
|
||||
from ..models import prepare_deepspeed
|
||||
from ..models.utils import unwrap_model_for_generation
|
||||
from .gkd_config import GKDConfig
|
||||
from .sft_trainer import SFTTrainer
|
||||
@ -51,10 +49,6 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import PeftConfig
|
||||
|
||||
@ -124,7 +118,7 @@ class GKDTrainer(SFTTrainer):
|
||||
disable_dropout_in_model(self.model)
|
||||
|
||||
if self.is_deepspeed_enabled:
|
||||
self.teacher_model = self._prepare_deepspeed(teacher_model)
|
||||
self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator)
|
||||
else:
|
||||
self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
|
||||
|
||||
@ -311,37 +305,6 @@ class GKDTrainer(SFTTrainer):
|
||||
loss = super().training_step(model, inputs, num_items_in_batch)
|
||||
return loss
|
||||
|
||||
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
||||
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
||||
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
||||
|
||||
if model is not None:
|
||||
if hasattr(model, "config"):
|
||||
hidden_size = (
|
||||
max(model.config.hidden_sizes)
|
||||
if getattr(model.config, "hidden_sizes", None)
|
||||
else getattr(model.config, "hidden_size", None)
|
||||
)
|
||||
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
||||
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
||||
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
||||
config_kwargs.update(
|
||||
{
|
||||
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
||||
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
||||
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
||||
}
|
||||
)
|
||||
|
||||
# If ZeRO-3 is used, we shard both the active and reference model.
|
||||
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
||||
if config_kwargs["zero_optimization"]["stage"] != 3:
|
||||
config_kwargs["zero_optimization"]["stage"] = 0
|
||||
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def create_model_card(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
|
@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Union
|
||||
|
||||
@ -64,14 +64,19 @@ class GRPOConfig(TrainingArguments):
|
||||
|
||||
> Parameters that control generation
|
||||
|
||||
temperature (`float`, defaults to `0.9`):
|
||||
generation_batch_size: (`int` or `None`, *optional*, defaults to `None`):
|
||||
Batch size to use for generation. If `None`, it defaults to the effective training batch size:
|
||||
`per_device_train_batch_size * num_processes * gradient_accumulation_steps`.
|
||||
steps_per_generations: (`int` or `None`, *optional*, defaults to `None`):
|
||||
Number of optimization steps per generation. If `None`, it defaults to gradient_accumulation_steps.
|
||||
temperature (`float`, defaults to `1.0`):
|
||||
Temperature for sampling. The higher the temperature, the more random the completions.
|
||||
top_p (`float`, *optional*, defaults to `1.0`):
|
||||
Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. Set to
|
||||
`1.0` to consider all tokens.
|
||||
top_k (`int` or `None`, *optional*, defaults to `50`):
|
||||
top_k (`int` or `None`, *optional*, defaults to `None`):
|
||||
Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, top-k-filtering is
|
||||
disabled.
|
||||
disabled and all tokens are considered.
|
||||
min_p (`float` or `None`, *optional*, defaults to `None`):
|
||||
Minimum token probability, which will be scaled by the probability of the most likely token. It must be a
|
||||
value between `0.0` and `1.0`. Typical values are in the `0.01-0.2` range.
|
||||
@ -85,18 +90,42 @@ class GRPOConfig(TrainingArguments):
|
||||
> Parameters that control generation acceleration powered by vLLM
|
||||
|
||||
use_vllm (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use vLLM for generating completions. If set to `True`, ensure that a GPU is kept unused for
|
||||
training, as vLLM will require one for generation. vLLM must be installed (`pip install vllm`).
|
||||
vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`):
|
||||
Host of the vLLM server to connect to.
|
||||
vllm_server_port (`int`, *optional*, defaults to `8000`):
|
||||
Port of the vLLM server to connect to.
|
||||
vllm_server_timeout (`float`, *optional*, defaults to `120.0`):
|
||||
Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the
|
||||
timeout, a `ConnectionError` is raised.
|
||||
Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for generation
|
||||
instead of the default model.generate(). Requires `vllm` to be installed.
|
||||
vllm_mode (`str`, *optional*, defaults to `"server"`):
|
||||
Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `"server"` or
|
||||
`"colocate"`.
|
||||
|
||||
- `"server"`: The trainer will send generation requests to a separate vLLM server. Make sure a TRL vLLM
|
||||
server is running (start with `trl vllm-serve`).
|
||||
- `"colocate"`: vLLM will run in the same process and share the training GPUs. This avoids the need for a
|
||||
separate server but may cause resource contention with training.
|
||||
vllm_guided_decoding_regex (`str` or `None`, *optional*, defaults to `None`):
|
||||
Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled.
|
||||
|
||||
> Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
|
||||
vllm_server_base_url (`str` or `None`, *optional*, defaults to `None`):
|
||||
Base URL for the vLLM server (e.g., `"http://localhost:8000"`). If provided, `vllm_server_host` and
|
||||
`vllm_server_port` are ignored.
|
||||
vllm_server_host (`str`, *optional*, defaults to `"0.0.0.0"`):
|
||||
Host of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided.
|
||||
vllm_server_port (`int`, *optional*, defaults to `8000`):
|
||||
Port of the vLLM server to connect to. Ignored if `vllm_server_base_url` is provided.
|
||||
vllm_server_timeout (`float`, *optional*, defaults to `240.0`):
|
||||
Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up after the
|
||||
timeout, a `ConnectionError` is raised.
|
||||
|
||||
> Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`)
|
||||
|
||||
vllm_gpu_memory_utilization (`float`, *optional*, defaults to `0.3`):
|
||||
Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set to
|
||||
`"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when
|
||||
launching the vLLM server via the `--vllm_gpu_memory_utilization` flag.
|
||||
vllm_tensor_parallel_size (`int`, *optional*, defaults to `1`):
|
||||
Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set to
|
||||
`"colocate"`. If you are using `vllm_mode="server"`, this parameter must be passed separately when
|
||||
launching the vLLM server via the `--vllm_tensor_parallel_size` flag.
|
||||
|
||||
> Parameters that control the training
|
||||
|
||||
learning_rate (`float`, *optional*, defaults to `1e-6`):
|
||||
@ -109,6 +138,10 @@ class GRPOConfig(TrainingArguments):
|
||||
Number of iterations per batch (denoted as μ in the algorithm).
|
||||
epsilon (`float`, *optional*, defaults to `0.2`):
|
||||
Epsilon value for clipping.
|
||||
delta: (`float` or `None`, *optional*, defaults to `None`):
|
||||
Enables the upper clipping bound in two-sided GRPO loss when set to a float. If `None` (default), standard
|
||||
GRPO clipping is used. Recommended to be greater than `1 + ε` when enabled. This method is introduced in
|
||||
the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291).
|
||||
epsilon_high (`float` or `None`, *optional*, defaults to `None`):
|
||||
Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound
|
||||
specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
|
||||
@ -139,7 +172,7 @@ class GRPOConfig(TrainingArguments):
|
||||
[DAPO](https://huggingface.co/papers/2503.14476) paper, this is a good practice for training stability.
|
||||
sync_ref_model (`bool`, *optional*, defaults to `False`):
|
||||
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
|
||||
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
|
||||
the `ref_model_mixup_alpha` parameter. This synchronization originates from the
|
||||
[TR-DPO](https://huggingface.co/papers/2404.09656) paper.
|
||||
ref_model_mixup_alpha (`float`, *optional*, defaults to `0.6`):
|
||||
α parameter from the [TR-DPO](https://huggingface.co/papers/2404.09656) paper, which controls the mix
|
||||
@ -230,8 +263,21 @@ class GRPOConfig(TrainingArguments):
|
||||
)
|
||||
|
||||
# Parameters that control generation
|
||||
generation_batch_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Batch size to use for generation. If `None`, it defaults to the effective training batch size: "
|
||||
"`per_device_train_batch_size * num_processes * gradient_accumulation_steps`."
|
||||
},
|
||||
)
|
||||
steps_per_generation: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Number of optimization steps per generation. If `None`, it defaults to gradient_accumulation_steps."
|
||||
},
|
||||
)
|
||||
temperature: float = field(
|
||||
default=0.9,
|
||||
default=1.0,
|
||||
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
|
||||
)
|
||||
top_p: float = field(
|
||||
@ -242,10 +288,10 @@ class GRPOConfig(TrainingArguments):
|
||||
},
|
||||
)
|
||||
top_k: Optional[int] = field(
|
||||
default=50,
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, "
|
||||
"top-k-filtering is disabled."
|
||||
"top-k-filtering is disabled and all tokens are considered."
|
||||
},
|
||||
)
|
||||
min_p: Optional[float] = field(
|
||||
@ -272,17 +318,40 @@ class GRPOConfig(TrainingArguments):
|
||||
use_vllm: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a vLLM server is "
|
||||
"running. To run the server, install vLLM (`pip install vllm`) and run `trl vllm-serve`."
|
||||
"help": "Whether to use vLLM for generating completions. If set to `True`, the trainer will use vLLM for "
|
||||
"generation instead of the default model.generate(). Requires `vllm` to be installed."
|
||||
},
|
||||
)
|
||||
vllm_server_base_url: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Base URL for the vLLM server (e.g., 'http://localhost:8000'). If provided, `vllm_server_host` "
|
||||
"and `vllm_server_port` are ignored."
|
||||
},
|
||||
)
|
||||
vllm_mode: str = field(
|
||||
default="server",
|
||||
metadata={
|
||||
"help": "Mode to use for vLLM integration when `use_vllm` is set to `True`. Must be one of `server` or "
|
||||
"`'colocate'`. `'server'`: The trainer will send generation requests to a separate vLLM server. Make sure a "
|
||||
"TRL vLLM server is running (start with `trl vllm-serve`). `'colocate'`: vLLM will run in the same "
|
||||
"process and share the training GPUs. This avoids the need for a separate server but may cause resource "
|
||||
"contention with training."
|
||||
},
|
||||
)
|
||||
vllm_guided_decoding_regex: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
|
||||
)
|
||||
|
||||
# Parameters that control the vLLM server (only used when `vllm_mode` is `"server"`)
|
||||
vllm_server_host: str = field(
|
||||
default="0.0.0.0",
|
||||
metadata={"help": "Host of the vLLM server to connect to."},
|
||||
metadata={"help": "Host of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."},
|
||||
)
|
||||
vllm_server_port: int = field(
|
||||
default=8000,
|
||||
metadata={"help": "Port of the vLLM server to connect to."},
|
||||
metadata={"help": "Port of the vLLM server to connect to. Ignored if vllm_server_base_url is provided."},
|
||||
)
|
||||
vllm_server_timeout: float = field(
|
||||
default=240.0,
|
||||
@ -291,9 +360,23 @@ class GRPOConfig(TrainingArguments):
|
||||
"after the timeout, a `ConnectionError` is raised."
|
||||
},
|
||||
)
|
||||
vllm_guided_decoding_regex: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
|
||||
|
||||
# Parameters that control colocated vLLM execution (only used when `vllm_mode` is `"colocate"`)
|
||||
vllm_gpu_memory_utilization: float = field(
|
||||
default=0.3,
|
||||
metadata={
|
||||
"help": "Control the GPU memory utilization for vLLM. This setting only applies when `vllm_mode` is set "
|
||||
"to `'colocate'`. If you are using `vllm_mode='server'`, this parameter must be passed separately when "
|
||||
"launching the vLLM server via the `--vllm_gpu_memory_utilization` flag."
|
||||
},
|
||||
)
|
||||
vllm_tensor_parallel_size: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "Control the tensor parallel size for vLLM. This setting only applies when `vllm_mode` is set "
|
||||
"to `'colocate'`. If you are using `vllm_mode='server'`, this parameter must be passed separately when "
|
||||
"launching the vLLM server via the `--vllm_tensor_parallel_size` flag."
|
||||
},
|
||||
)
|
||||
|
||||
# Parameters that control the training
|
||||
@ -319,6 +402,14 @@ class GRPOConfig(TrainingArguments):
|
||||
default=0.2,
|
||||
metadata={"help": "Epsilon value for clipping."},
|
||||
)
|
||||
delta: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Enables the upper clipping bound in two-sided GRPO loss when set to a float. If `None` "
|
||||
"(default), standard GRPO clipping is used. Recommended to be greater than `1 + ε` when enabled. This "
|
||||
"method is introduced in the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291)."
|
||||
},
|
||||
)
|
||||
epsilon_high: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
@ -413,82 +504,58 @@ class GRPOConfig(TrainingArguments):
|
||||
},
|
||||
)
|
||||
|
||||
# Deprecated parameters
|
||||
vllm_device: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "This parameter is deprecated and will be removed in version 0.18.0. To use vLLM, start a vLLM "
|
||||
"server with the `trl vllm-serve` command."
|
||||
},
|
||||
)
|
||||
vllm_gpu_memory_utilization: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "This parameter is deprecated and will be removed in version 0.18.0. To control the GPU memory "
|
||||
"utilization for vLLM, you should now use the `gpu_memory_utilization` parameter in the vLLM server "
|
||||
"configuration."
|
||||
},
|
||||
)
|
||||
vllm_dtype: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "This parameter is deprecated and will be removed in version 0.18.0. To control the data type for "
|
||||
"vLLM generation, you should now use the `dtype` parameter in the vLLM server configuration."
|
||||
},
|
||||
)
|
||||
vllm_max_model_len: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "This parameter is deprecated and will be removed in version 0.18.0. To control the "
|
||||
"`max_model_len` for vLLM, you should now use the `max_model_len` parameter in the vLLM server "
|
||||
"configuration."
|
||||
},
|
||||
)
|
||||
vllm_enable_prefix_caching: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "This parameter is deprecated and will be removed in version 0.18.0. To control prefix caching in "
|
||||
"vLLM, you should now use the `enable_prefix_caching` parameter in the vLLM server configuration."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.vllm_device is not None:
|
||||
warnings.warn(
|
||||
"`vllm_device` is deprecated and will be removed in version 0.18.0. To use vLLM, start a vLLM server "
|
||||
"with the `trl vllm-serve` command.",
|
||||
DeprecationWarning,
|
||||
num_processes = self.world_size
|
||||
# The current default effective batch size
|
||||
if self.generation_batch_size is not None and self.steps_per_generation is not None:
|
||||
raise ValueError(
|
||||
"'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time"
|
||||
)
|
||||
|
||||
if self.vllm_gpu_memory_utilization is not None:
|
||||
warnings.warn(
|
||||
"`vllm_gpu_memory_utilization` is deprecated and will be removed in v0.18. To control the GPU memory "
|
||||
"utilization for vLLM, you should now use the `gpu_memory_utilization` parameter in the vLLM server "
|
||||
"configuration.",
|
||||
DeprecationWarning,
|
||||
if self.steps_per_generation is None:
|
||||
self.steps_per_generation = self.gradient_accumulation_steps
|
||||
|
||||
if self.generation_batch_size is None:
|
||||
self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation
|
||||
|
||||
if self.generation_batch_size % self.per_device_train_batch_size * num_processes != 0:
|
||||
raise ValueError(
|
||||
f"generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size "
|
||||
f"({self.per_device_train_batch_size * num_processes})."
|
||||
)
|
||||
|
||||
if self.vllm_dtype is not None:
|
||||
warnings.warn(
|
||||
"`vllm_dtype` is deprecated and will be removed in version 0.18.0. To control the data type for vLLM "
|
||||
"generation, you should now use the `dtype` parameter in the vLLM server configuration.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.steps_per_generation = self.generation_batch_size // (self.per_device_train_batch_size * num_processes)
|
||||
|
||||
if self.vllm_max_model_len is not None:
|
||||
warnings.warn(
|
||||
"`vllm_max_model_len` is deprecated and will be removed in version 0.18.0. To control the "
|
||||
"`max_model_len` for vLLM, you should now use the `max_model_len` parameter in the vLLM server "
|
||||
"configuration.",
|
||||
DeprecationWarning,
|
||||
# Check if the effective batch size can be divided by the number of generations
|
||||
if self.num_generations < 2:
|
||||
raise ValueError(
|
||||
"GRPO requires at least 2 generations per prompt to calculate the advantages. You provided "
|
||||
f"{self.num_generations}, which is less than the minimum required."
|
||||
)
|
||||
possible_values = [
|
||||
n_gen for n_gen in range(2, self.generation_batch_size + 1) if (self.generation_batch_size) % n_gen == 0
|
||||
]
|
||||
|
||||
if self.vllm_enable_prefix_caching is not None:
|
||||
warnings.warn(
|
||||
"`vllm_enable_prefix_caching` is deprecated and will be removed in version 0.18.0. To control prefix "
|
||||
"caching in vLLM, you should now use the `enable_prefix_caching` parameter in the vLLM server "
|
||||
"configuration.",
|
||||
DeprecationWarning,
|
||||
if self.num_generations not in possible_values:
|
||||
raise ValueError(
|
||||
f"The effective train batch size ({num_processes} x {self.per_device_train_batch_size} x "
|
||||
f"{self.steps_per_generation}) must be evenly divisible by the number of generations per "
|
||||
f"prompt ({self.num_generations}). Given the current effective train batch size, the valid values for "
|
||||
f"the number of generations are: {possible_values}."
|
||||
)
|
||||
if self.eval_strategy != "no":
|
||||
global_eval_batch_size = self.per_device_eval_batch_size * num_processes
|
||||
possible_values = [
|
||||
n_gen for n_gen in range(2, global_eval_batch_size + 1) if (global_eval_batch_size) % n_gen == 0
|
||||
]
|
||||
if self.num_generations not in possible_values:
|
||||
raise ValueError(
|
||||
f"The global eval batch size ({num_processes} x {self.per_device_eval_batch_size}) must be "
|
||||
f"evenly divisible by the number of generations per prompt ({self.num_generations}). Given the "
|
||||
"current global eval batch size, the valid values for the number of generations are: "
|
||||
f"{possible_values}."
|
||||
)
|
||||
if self.delta is not None and self.use_liger_loss:
|
||||
raise ValueError("Liger loss does not support two-sided GRPO loss yet.")
|
||||
|
File diff suppressed because it is too large
Load Diff
79
trl/trainer/iterative_sft_config.py
Normal file
79
trl/trainer/iterative_sft_config.py
Normal file
@ -0,0 +1,79 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
|
||||
@dataclass
|
||||
class IterativeSFTConfig(TrainingArguments):
|
||||
r"""
|
||||
Configuration class for the [`IterativeSFTTrainer`].
|
||||
|
||||
Only the parameters specific to iterative SFT training are listed here. For details on other parameters, refer to the
|
||||
[`~transformers.TrainingArguments`] documentation.
|
||||
|
||||
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
||||
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
||||
command line.
|
||||
|
||||
Parameters:
|
||||
> Parameters that control the model
|
||||
|
||||
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
||||
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
||||
argument of the [`IterativeSFTTrainer`] is provided as a string.
|
||||
|
||||
> Parameters that control the data preprocessing
|
||||
|
||||
max_length (`int` or `None`, *optional*, defaults to `None`):
|
||||
Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated.
|
||||
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
|
||||
The truncation mode to use, either `"keep_end"` or `"keep_start"`.
|
||||
optimize_device_cache (`bool`, *optional*, defaults to `False`):
|
||||
Whether to optimize CUDA cache for slightly more memory-efficient training.
|
||||
"""
|
||||
|
||||
# Parameters that control the model
|
||||
model_init_kwargs: Optional[dict[str, Any]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of "
|
||||
"the `IterativeSFTTrainer` is provided as a string."
|
||||
},
|
||||
)
|
||||
|
||||
# Parameters that control the data preprocessing
|
||||
max_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated."
|
||||
},
|
||||
)
|
||||
truncation_mode: str = field(
|
||||
default="keep_end",
|
||||
metadata={"help": "The truncation mode to use, either 'keep_end' or 'keep_start'."},
|
||||
)
|
||||
optimize_device_cache: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to optimize CUDA cache for slightly more memory-efficient training."},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.truncation_mode not in ["keep_end", "keep_start"]:
|
||||
raise ValueError(f"truncation_mode must be either 'keep_end' or 'keep_start', got {self.truncation_mode}")
|
@ -20,6 +20,8 @@ import torch
|
||||
from datasets import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BaseImageProcessor,
|
||||
DataCollator,
|
||||
DataCollatorForLanguageModeling,
|
||||
@ -36,6 +38,7 @@ from transformers.trainer_utils import EvalLoopOutput
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from ..core import PPODecorators
|
||||
from .iterative_sft_config import IterativeSFTConfig
|
||||
from .utils import generate_model_card, get_comet_experiment_url
|
||||
|
||||
|
||||
@ -52,39 +55,49 @@ class IterativeSFTTrainer(Trainer):
|
||||
The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization.
|
||||
|
||||
Args:
|
||||
model (`PreTrainedModel`):
|
||||
Model to be optimized, either an 'AutoModelForCausalLM' or an 'AutoModelForSeq2SeqLM'.
|
||||
Check the documentation of `PreTrainedModel` for more details.
|
||||
args (`transformers.TrainingArguments`):
|
||||
The arguments to use for training.
|
||||
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
||||
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
||||
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
||||
reuse the fine-tuned model.
|
||||
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
||||
The optimizer and scheduler to use for training.
|
||||
data_collator (Union[DataCollatorForLanguageModeling, DataCollatorForSeq2Seq], *optional*):
|
||||
Data collator to be used for training and passed along the dataloader.
|
||||
model (`Union[str, PreTrainedModel]`):
|
||||
Model to be trained. Can be either:
|
||||
|
||||
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
|
||||
a path to a *directory* containing model weights saved using
|
||||
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
|
||||
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
|
||||
in `args.model_init_kwargs`.
|
||||
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
||||
args ([`IterativeSFTConfig`], *optional*, defaults to `None`):
|
||||
Configuration for this trainer. If `None`, a default configuration is used.
|
||||
data_collator (`DataCollator`, *optional*):
|
||||
Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
|
||||
Will default to [`~transformers.default_data_collator`] if no `processing_class` is provided, an instance
|
||||
of [`~transformers.DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or
|
||||
tokenizer.
|
||||
eval_dataset (`datasets.Dataset`):
|
||||
The dataset to use for evaluation.
|
||||
max_length (`int`, defaults to `None`):
|
||||
The maximum length of the input.
|
||||
truncation_mode (`str`, defaults to `keep_end`):
|
||||
The truncation mode to use, either `keep_end` or `keep_start`.
|
||||
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
|
||||
Processing class used to process the data. If `None`, the processing class is loaded from the model's name
|
||||
with [`~transformers.AutoTokenizer.from_pretrained`].
|
||||
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
||||
The optimizer and scheduler to use for training.
|
||||
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
||||
The function to use to preprocess the logits before computing the metrics.
|
||||
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
||||
The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to metric values.
|
||||
optimize_device_cache (`bool`, *optional*, defaults to `False`):
|
||||
Optimize CUDA cache for slightly more memory-efficient training.
|
||||
max_length (`int`, *optional*, deprecated):
|
||||
Maximum length of the tokenized sequence. Use `args.max_length` instead.
|
||||
truncation_mode (`str`, *optional*, deprecated):
|
||||
The truncation mode to use. Use `args.truncation_mode` instead.
|
||||
optimize_device_cache (`bool`, *optional*, deprecated):
|
||||
Whether to optimize CUDA cache. Use `args.optimize_device_cache` instead.
|
||||
"""
|
||||
|
||||
_tag_names = ["trl", "iterative-sft"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[PreTrainedModel] = None,
|
||||
args: Optional[TrainingArguments] = None,
|
||||
model: Union[str, PreTrainedModel],
|
||||
args: Optional[Union[IterativeSFTConfig, TrainingArguments]] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
@ -92,35 +105,74 @@ class IterativeSFTTrainer(Trainer):
|
||||
None,
|
||||
None,
|
||||
),
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
max_length: Optional[int] = None,
|
||||
truncation_mode: Optional[str] = "keep_end",
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
||||
optimize_device_cache: Optional[bool] = False,
|
||||
# Deprecated parameters
|
||||
max_length: Optional[int] = None,
|
||||
truncation_mode: Optional[str] = None,
|
||||
optimize_device_cache: Optional[bool] = None,
|
||||
):
|
||||
# Step 0: check positional arguments validity
|
||||
if not isinstance(processing_class, (PreTrainedTokenizerBase)):
|
||||
raise ValueError(
|
||||
f"processing_class must be a PreTrainedTokenizerBase like a PreTrainedTokenizer or a PreTrainedTokenizerFast, got {type(processing_class)}"
|
||||
)
|
||||
if not isinstance(model, PreTrainedModel):
|
||||
raise ValueError(f"model must be a PreTrainedModel, got {type(model)}")
|
||||
if not model.can_generate():
|
||||
# Handle deprecated parameters
|
||||
deprecated_params = {}
|
||||
if max_length is not None:
|
||||
deprecated_params["max_length"] = max_length
|
||||
warnings.warn(
|
||||
f"The current model class {type(model)} is not compatible with `.generate()`"
|
||||
"Please make sure that this is intended."
|
||||
"The `max_length` parameter is deprecated and will be removed in version 0.20. "
|
||||
"Pass it through the `args` parameter using `IterativeSFTConfig(max_length=...)` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if optimizers[1] is None and args.max_steps == -1:
|
||||
raise ValueError(
|
||||
"When no scheduler is provided, you need to set the total number of training steps to perform `max_steps`"
|
||||
if truncation_mode is not None:
|
||||
deprecated_params["truncation_mode"] = truncation_mode
|
||||
warnings.warn(
|
||||
"The `truncation_mode` parameter is deprecated and will be removed in version 0.20. "
|
||||
"Pass it through the `args` parameter using `IterativeSFTConfig(truncation_mode=...)` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if optimize_device_cache is not None:
|
||||
deprecated_params["optimize_device_cache"] = optimize_device_cache
|
||||
warnings.warn(
|
||||
"The `optimize_device_cache` parameter is deprecated and will be removed in version 0.20 "
|
||||
"Pass it through the `args` parameter using `IterativeSFTConfig(optimize_device_cache=...)` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
self.is_encoder_decoder = getattr(model.config, "is_encoder_decoder", False)
|
||||
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
|
||||
# Args
|
||||
model_id = model if isinstance(model, str) else model.config._name_or_path
|
||||
if args is None:
|
||||
model_name = model_id.split("/")[-1]
|
||||
args = IterativeSFTConfig(f"{model_name}-IterativeSFT")
|
||||
elif isinstance(args, TrainingArguments) and not isinstance(args, IterativeSFTConfig):
|
||||
dict_args = args.to_dict()
|
||||
dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token
|
||||
dict_args.pop("push_to_hub_token")
|
||||
args = IterativeSFTConfig(**dict_args)
|
||||
|
||||
# Update args with deprecated parameters if provided
|
||||
if deprecated_params:
|
||||
for key, value in deprecated_params.items():
|
||||
setattr(args, key, value)
|
||||
|
||||
# Handle the tokenizer
|
||||
if processing_class is None:
|
||||
processing_class = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# Model
|
||||
if args.model_init_kwargs is not None and not isinstance(model, str):
|
||||
warnings.warn(
|
||||
"You passed model_init_kwargs to the `IterativeSFTConfig`, but your model is already instantiated. "
|
||||
"The `model_init_kwargs` will be ignored."
|
||||
)
|
||||
if isinstance(model, str):
|
||||
model = self._create_model_from_path(model, args)
|
||||
|
||||
# PEFT configuration and model wrapping
|
||||
if is_peft_available() and isinstance(model, PeftModel):
|
||||
self.is_peft_model = True
|
||||
else:
|
||||
self.is_peft_model = False
|
||||
|
||||
self.processing_class = processing_class
|
||||
self.is_encoder_decoder = getattr(model.config, "is_encoder_decoder", False)
|
||||
|
||||
if data_collator is None:
|
||||
if self.is_encoder_decoder:
|
||||
@ -132,9 +184,9 @@ class IterativeSFTTrainer(Trainer):
|
||||
else:
|
||||
self.data_collator = data_collator
|
||||
|
||||
self.max_length = max_length
|
||||
self.truncation_mode = truncation_mode
|
||||
self.optimize_device_cache = optimize_device_cache
|
||||
self.max_length = args.max_length
|
||||
self.truncation_mode = args.truncation_mode
|
||||
self.optimize_device_cache = args.optimize_device_cache
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
@ -167,6 +219,11 @@ class IterativeSFTTrainer(Trainer):
|
||||
|
||||
PPODecorators.optimize_device_cache = self.optimize_device_cache
|
||||
|
||||
def _create_model_from_path(self, model_path: str, args: IterativeSFTConfig) -> PreTrainedModel:
|
||||
"""Creates a model from a path or model identifier."""
|
||||
model_init_kwargs = args.model_init_kwargs or {}
|
||||
return AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
|
||||
|
||||
def prepare_model_inputs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor):
|
||||
if attention_mask is None:
|
||||
attention_mask = [torch.ones_like(ids) for ids in input_ids]
|
||||
|
@ -19,7 +19,6 @@ import textwrap
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from copy import deepcopy
|
||||
from operator import itemgetter
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
|
||||
|
||||
@ -31,7 +30,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from accelerate import PartialState
|
||||
from accelerate.utils import is_deepspeed_available, tqdm
|
||||
from accelerate.utils import tqdm
|
||||
from datasets import Dataset, concatenate_datasets
|
||||
from packaging import version
|
||||
from torch.utils.data import DataLoader, SequentialSampler
|
||||
@ -54,7 +53,7 @@ from transformers.utils import is_peft_available
|
||||
|
||||
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt, maybe_unpair_preference_dataset
|
||||
from ..import_utils import is_liger_kernel_available
|
||||
from ..models import PreTrainedModelWrapper, create_reference_model
|
||||
from ..models import create_reference_model, prepare_deepspeed
|
||||
from .kto_config import KTOConfig
|
||||
from .utils import (
|
||||
DPODataCollatorWithPadding,
|
||||
@ -68,9 +67,6 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
if is_liger_kernel_available():
|
||||
from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss
|
||||
|
||||
@ -779,7 +775,7 @@ class KTOTrainer(Trainer):
|
||||
)
|
||||
else:
|
||||
if self.is_deepspeed_enabled:
|
||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
@ -808,37 +804,6 @@ class KTOTrainer(Trainer):
|
||||
ignore_index=self.label_pad_token_id, beta=self.beta, use_ref_model=(self.ref_model is not None)
|
||||
)
|
||||
|
||||
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
||||
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
||||
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
||||
|
||||
if model is not None:
|
||||
if hasattr(model, "config"):
|
||||
hidden_size = (
|
||||
max(model.config.hidden_sizes)
|
||||
if getattr(model.config, "hidden_sizes", None)
|
||||
else getattr(model.config, "hidden_size", None)
|
||||
)
|
||||
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
||||
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
||||
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
||||
config_kwargs.update(
|
||||
{
|
||||
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
||||
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
||||
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
||||
}
|
||||
)
|
||||
|
||||
# If ZeRO-3 is used, we shard both the active and reference model.
|
||||
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
||||
if config_kwargs["zero_optimization"]["stage"] != 3:
|
||||
config_kwargs["zero_optimization"]["stage"] = 0
|
||||
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
@contextmanager
|
||||
def null_ref_context(self):
|
||||
"""Context manager for handling null reference model (that is, peft adapter manipulation)."""
|
||||
@ -1497,10 +1462,12 @@ class KTOTrainer(Trainer):
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.train_dataset is None or not has_length(self.train_dataset):
|
||||
def _get_train_sampler(self, dataset: Optional[Dataset] = None) -> Optional[torch.utils.data.Sampler]:
|
||||
if dataset is None:
|
||||
dataset = self.train_dataset
|
||||
if dataset is None or not has_length(dataset):
|
||||
return None
|
||||
return SequentialSampler(self.train_dataset)
|
||||
return SequentialSampler(dataset)
|
||||
|
||||
def generate_from_model_and_ref(self, model, batch: dict[str, torch.LongTensor]) -> tuple[str, str]:
|
||||
"""Generate samples from the model and reference model for the given batch of inputs."""
|
||||
|
@ -32,7 +32,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.trainer_utils import EvalPrediction
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils import is_apex_available
|
||||
from transformers.utils import is_apex_available, is_peft_available
|
||||
|
||||
from ..data_utils import is_conversational, maybe_apply_chat_template
|
||||
from ..models.modeling_base import GeometricMixtureWrapper
|
||||
@ -59,6 +59,10 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import PeftModel
|
||||
|
||||
|
||||
class NashMDTrainer(OnlineDPOTrainer):
|
||||
r"""
|
||||
Initialize NashMDTrainer as a subclass of [`OnlineDPOConfig`].
|
||||
@ -170,28 +174,50 @@ class NashMDTrainer(OnlineDPOTrainer):
|
||||
return self._mixture_coef
|
||||
|
||||
def _generate_completions(self, model, prompts):
|
||||
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
||||
model_output = unwrapped_model.generate(
|
||||
# Generate completions from the policy model.
|
||||
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_for_gen_ctx:
|
||||
model_output = unwrapped_policy_for_gen_ctx.generate(
|
||||
input_ids=prompts["input_ids"],
|
||||
attention_mask=prompts["attention_mask"],
|
||||
generation_config=self.generation_config,
|
||||
)
|
||||
|
||||
ref_model = model if self.ref_model is None else self.ref_model
|
||||
with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
|
||||
mixture_model = GeometricMixtureWrapper(
|
||||
model=unwrapped_model,
|
||||
ref_model=unwrapped_ref_model,
|
||||
generation_config=self.generation_config,
|
||||
mixture_coef=self.mixture_coef,
|
||||
device=self.accelerator.device,
|
||||
)
|
||||
# Get the DDP/FSDP unwrapped version of the main model.
|
||||
# This will be the policy model for GeometricMixtureWrapper (PEFT adapters active if PEFT is used).
|
||||
policy_model_for_gmw = self.accelerator.unwrap_model(model)
|
||||
|
||||
mixture_output = mixture_model.generate(
|
||||
input_ids=prompts["input_ids"],
|
||||
attention_mask=prompts["attention_mask"],
|
||||
generation_config=self.generation_config,
|
||||
)
|
||||
# Determine the correct reference model for GeometricMixtureWrapper.
|
||||
# This also needs to be DDP/FSDP unwrapped.
|
||||
ref_model_for_gmw: torch.nn.Module
|
||||
if self.ref_model is None:
|
||||
# No explicit ref_model is provided.
|
||||
# Use the base of the main `model` if it's a PEFT model.
|
||||
# policy_model_for_gmw is already DDP-unwrapped.
|
||||
if is_peft_available() and isinstance(policy_model_for_gmw, PeftModel):
|
||||
ref_model_for_gmw = policy_model_for_gmw.get_base_model()
|
||||
else:
|
||||
# Not a PEFT model (or PEFT not available), or already a base model.
|
||||
# Use the DDP-unwrapped policy model itself as the reference.
|
||||
ref_model_for_gmw = policy_model_for_gmw
|
||||
else:
|
||||
# An explicit ref_model is provided. Unwrap it for DDP/FSDP.
|
||||
ref_model_for_gmw = self.accelerator.unwrap_model(self.ref_model)
|
||||
|
||||
# Both models given to GeometricMixtureWrapper (policy_model_for_gmw and ref_model_for_gmw) are DDP-unwrapped.
|
||||
with torch.no_grad(): # Ensure no_grad context for mixture model generation
|
||||
mixture_model = GeometricMixtureWrapper(
|
||||
model=policy_model_for_gmw,
|
||||
ref_model=ref_model_for_gmw,
|
||||
generation_config=self.generation_config,
|
||||
mixture_coef=self.mixture_coef,
|
||||
device=self.accelerator.device,
|
||||
)
|
||||
|
||||
mixture_output = mixture_model.generate(
|
||||
input_ids=prompts["input_ids"],
|
||||
attention_mask=prompts["attention_mask"],
|
||||
generation_config=self.generation_config,
|
||||
)
|
||||
|
||||
return model_output, mixture_output
|
||||
|
||||
|
@ -19,7 +19,6 @@ import textwrap
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@ -30,7 +29,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from accelerate import PartialState
|
||||
from accelerate.utils import is_deepspeed_available
|
||||
from datasets import Dataset
|
||||
from packaging import version
|
||||
from torch.utils.data import DataLoader
|
||||
@ -52,7 +50,6 @@ from transformers.trainer_utils import EvalLoopOutput
|
||||
from transformers.utils import is_peft_available, is_torch_fx_proxy
|
||||
|
||||
from ..data_utils import maybe_apply_chat_template, maybe_extract_prompt
|
||||
from ..models import PreTrainedModelWrapper
|
||||
from .orpo_config import ORPOConfig
|
||||
from .utils import (
|
||||
DPODataCollatorWithPadding,
|
||||
@ -75,9 +72,6 @@ if is_peft_available():
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
@ -358,37 +352,6 @@ class ORPOTrainer(Trainer):
|
||||
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
||||
)
|
||||
|
||||
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
||||
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
||||
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
||||
|
||||
if model is not None:
|
||||
if hasattr(model, "config"):
|
||||
hidden_size = (
|
||||
max(model.config.hidden_sizes)
|
||||
if getattr(model.config, "hidden_sizes", None)
|
||||
else getattr(model.config, "hidden_size", None)
|
||||
)
|
||||
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
||||
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
||||
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
||||
config_kwargs.update(
|
||||
{
|
||||
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
||||
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
||||
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
||||
}
|
||||
)
|
||||
|
||||
# If ZeRO-3 is used, we shard both the active and reference model.
|
||||
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
||||
if config_kwargs["zero_optimization"]["stage"] != 3:
|
||||
config_kwargs["zero_optimization"]["stage"] = 0
|
||||
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def build_tokenized_answer(self, prompt, answer):
|
||||
"""
|
||||
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
|
||||
|
@ -44,7 +44,7 @@ from transformers import (
|
||||
from transformers.integrations import get_reporting_integration_callbacks
|
||||
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
|
||||
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
|
||||
from transformers.utils import is_peft_available
|
||||
from transformers.utils import is_peft_available, is_rich_available
|
||||
|
||||
from ..core import masked_mean, masked_whiten
|
||||
from ..models import create_reference_model
|
||||
@ -54,6 +54,7 @@ from .utils import (
|
||||
OnlineTrainerState,
|
||||
batch_generation,
|
||||
disable_dropout_in_model,
|
||||
empty_cache,
|
||||
exact_div,
|
||||
first_true_indices,
|
||||
forward,
|
||||
@ -107,7 +108,7 @@ class PPOTrainer(Trainer):
|
||||
ref_model: Optional[nn.Module],
|
||||
reward_model: nn.Module,
|
||||
train_dataset: Dataset,
|
||||
value_model: Optional[nn.Module] = None,
|
||||
value_model: nn.Module,
|
||||
data_collator: Optional[DataCollatorWithPadding] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
# less commonly used
|
||||
@ -437,7 +438,7 @@ class PPOTrainer(Trainer):
|
||||
logits = logitss[i : i + args.local_rollout_forward_batch_size]
|
||||
logprob = selective_log_softmax(logits, response)
|
||||
del logits
|
||||
torch.cuda.empty_cache()
|
||||
empty_cache()
|
||||
|
||||
if ref_policy is None:
|
||||
with self.null_ref_context():
|
||||
@ -448,7 +449,7 @@ class PPOTrainer(Trainer):
|
||||
ref_logits /= args.temperature + 1e-7
|
||||
ref_logprob = selective_log_softmax(ref_logits, response)
|
||||
del ref_output, ref_logits
|
||||
torch.cuda.empty_cache()
|
||||
empty_cache()
|
||||
|
||||
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
|
||||
postprocessed_response = response
|
||||
@ -484,7 +485,7 @@ class PPOTrainer(Trainer):
|
||||
scores = torch.cat(scores, 0)
|
||||
values = torch.cat(values, 0)
|
||||
del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
|
||||
torch.cuda.empty_cache()
|
||||
empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
|
||||
@ -531,7 +532,7 @@ class PPOTrainer(Trainer):
|
||||
returns = advantages + values
|
||||
advantages = masked_whiten(advantages, ~padding_mask)
|
||||
advantages = torch.masked_fill(advantages, padding_mask, 0)
|
||||
torch.cuda.empty_cache()
|
||||
empty_cache()
|
||||
|
||||
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
|
||||
for ppo_epoch_idx in range(args.num_ppo_epochs):
|
||||
@ -612,7 +613,7 @@ class PPOTrainer(Trainer):
|
||||
mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
|
||||
)
|
||||
# fmt: on
|
||||
torch.cuda.empty_cache()
|
||||
empty_cache()
|
||||
with torch.no_grad():
|
||||
mean_kl = kl.sum(1).mean()
|
||||
mean_entropy = (-logprobs).sum(1).mean()
|
||||
@ -649,12 +650,12 @@ class PPOTrainer(Trainer):
|
||||
self._save_checkpoint(model, trial=None)
|
||||
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
||||
del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
|
||||
torch.cuda.empty_cache()
|
||||
empty_cache()
|
||||
gc.collect()
|
||||
|
||||
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
|
||||
self.generate_completions(sampling=True)
|
||||
torch.cuda.empty_cache()
|
||||
empty_cache()
|
||||
del (
|
||||
query_responses,
|
||||
responses,
|
||||
@ -674,7 +675,7 @@ class PPOTrainer(Trainer):
|
||||
advantages,
|
||||
returns,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
empty_cache()
|
||||
|
||||
# HF trainer specifics
|
||||
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
|
||||
@ -732,7 +733,8 @@ class PPOTrainer(Trainer):
|
||||
df = pd.DataFrame(table)
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
print_rich_table(df.iloc[0 : 0 + 5])
|
||||
if is_rich_available():
|
||||
print_rich_table(df.iloc[0 : 0 + 5])
|
||||
if "wandb" in args.report_to:
|
||||
import wandb
|
||||
|
||||
|
@ -38,7 +38,7 @@ from transformers import (
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_pt_utils import nested_detach
|
||||
from transformers.trainer_utils import EvalPrediction
|
||||
from transformers.utils import is_peft_available
|
||||
from transformers.utils import is_peft_available, is_rich_available
|
||||
|
||||
from ..data_utils import maybe_apply_chat_template
|
||||
from .reward_config import RewardConfig
|
||||
@ -350,7 +350,8 @@ class RewardTrainer(Trainer):
|
||||
break
|
||||
df = pd.DataFrame(table)
|
||||
if self.accelerator.process_index == 0:
|
||||
print_rich_table(df[:num_print_samples])
|
||||
if is_rich_available():
|
||||
print_rich_table(df[:num_print_samples])
|
||||
if "wandb" in self.args.report_to:
|
||||
import wandb
|
||||
|
||||
|
@ -43,6 +43,7 @@ from transformers import (
|
||||
from transformers.integrations import get_reporting_integration_callbacks
|
||||
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
|
||||
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
|
||||
from transformers.utils import is_rich_available
|
||||
|
||||
from ..models.utils import unwrap_model_for_generation
|
||||
from ..trainer.utils import (
|
||||
@ -59,7 +60,7 @@ from ..trainer.utils import (
|
||||
truncate_response,
|
||||
)
|
||||
from .rloo_config import RLOOConfig
|
||||
from .utils import generate_model_card, get_comet_experiment_url, log_table_to_comet_experiment
|
||||
from .utils import empty_cache, generate_model_card, get_comet_experiment_url, log_table_to_comet_experiment
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
@ -332,14 +333,14 @@ class RLOOTrainer(Trainer):
|
||||
logits = logitss[i : i + args.local_rollout_forward_batch_size]
|
||||
logprob = selective_log_softmax(logits, response)
|
||||
del logits
|
||||
torch.cuda.empty_cache()
|
||||
empty_cache()
|
||||
|
||||
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
|
||||
ref_logits = ref_output.logits[:, context_length - 1 : -1]
|
||||
ref_logits /= args.temperature + 1e-7
|
||||
ref_logprob = selective_log_softmax(ref_logits, response)
|
||||
del ref_output, ref_logits
|
||||
torch.cuda.empty_cache()
|
||||
empty_cache()
|
||||
|
||||
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
|
||||
postprocessed_response = response
|
||||
@ -380,7 +381,7 @@ class RLOOTrainer(Trainer):
|
||||
sequence_lengths = torch.cat(sequence_lengths, 0)
|
||||
scores = torch.cat(scores, 0)
|
||||
del (logprob, ref_logprob, score)
|
||||
torch.cuda.empty_cache()
|
||||
empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# Response Processing 3. filter response. Ensure that the sample contains stop_token_id
|
||||
@ -438,7 +439,7 @@ class RLOOTrainer(Trainer):
|
||||
if args.normalize_advantage:
|
||||
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
empty_cache()
|
||||
|
||||
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
|
||||
for ppo_epoch_idx in range(args.num_ppo_epochs):
|
||||
@ -514,7 +515,7 @@ class RLOOTrainer(Trainer):
|
||||
mb_advantage, mb_responses, mb_query_responses, mb_logprobs,
|
||||
)
|
||||
# fmt: on
|
||||
torch.cuda.empty_cache()
|
||||
empty_cache()
|
||||
|
||||
# Compute metrics
|
||||
with torch.no_grad():
|
||||
@ -551,7 +552,7 @@ class RLOOTrainer(Trainer):
|
||||
if self.control.should_save:
|
||||
self._save_checkpoint(model, trial=None)
|
||||
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
||||
torch.cuda.empty_cache()
|
||||
empty_cache()
|
||||
gc.collect()
|
||||
|
||||
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
|
||||
@ -625,7 +626,8 @@ class RLOOTrainer(Trainer):
|
||||
df = pd.DataFrame(table)
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
print_rich_table(df.iloc[0 : 0 + 5])
|
||||
if is_rich_available():
|
||||
print_rich_table(df.iloc[0 : 0 + 5])
|
||||
if "wandb" in args.report_to:
|
||||
import wandb
|
||||
|
||||
|
@ -62,6 +62,8 @@ class SFTConfig(TrainingArguments):
|
||||
continuous sequence. This reduces memory usage by eliminating padding overhead. Currently, this is only
|
||||
supported with the `flash_attention_2` attention implementation, which can efficiently handle the flattened
|
||||
batch structure.
|
||||
pad_to_multiple_of (`int` or `None`, *optional*, defaults to `None`):
|
||||
If set, the sequences will be padded to a multiple of this value.
|
||||
eval_packing (`bool` or `None`, *optional*, defaults to `None`):
|
||||
Whether to pack the eval dataset. If `None`, uses the same value as `packing`.
|
||||
|
||||
@ -76,6 +78,9 @@ class SFTConfig(TrainingArguments):
|
||||
`False`, loss is computed on the entire sequence. If `None` (default), the behavior depends on the dataset:
|
||||
loss is computed on the completion for [prompt-completion](#prompt-completion) datasets, and on
|
||||
the full sequence for [language modeling](#language-modeling) datasets.
|
||||
activation_offloading (`bool`, *optional*, defaults to `False`):
|
||||
Whether to offload the activations to the CPU.
|
||||
|
||||
"""
|
||||
|
||||
# Parameters that control the model
|
||||
@ -140,6 +145,10 @@ class SFTConfig(TrainingArguments):
|
||||
"handle the flattened batch structure."
|
||||
},
|
||||
)
|
||||
pad_to_multiple_of: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "If set, the sequences will be padded to a multiple of this value."},
|
||||
)
|
||||
eval_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Whether to pack the eval dataset. If `None`, uses the same value as `packing`."},
|
||||
@ -165,79 +174,25 @@ class SFTConfig(TrainingArguments):
|
||||
)
|
||||
},
|
||||
)
|
||||
activation_offloading: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to offload the activations to the CPU."},
|
||||
)
|
||||
|
||||
# Deprecated parameters
|
||||
dataset_batch_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "This parameter is deprecated and will be removed in version 0.18.0. You can safely remove this "
|
||||
"parameter from your configuration."
|
||||
},
|
||||
)
|
||||
num_of_sequences: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "This parameter is deprecated and will be removed in version 0.18.0. Use `max_length` instead, "
|
||||
"which specifies the maximum length of the tokenized sequence, unlike `num_of_sequences`, which referred "
|
||||
"to string sequences."
|
||||
},
|
||||
)
|
||||
chars_per_token: Optional[float] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "This parameter is deprecated and will be removed in version 0.18.0. If you want to customize the "
|
||||
"packing length, use `max_length`."
|
||||
},
|
||||
)
|
||||
max_seq_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "This parameter is deprecated and will be removed in version 0.20.0. Use `max_length` instead."
|
||||
},
|
||||
)
|
||||
use_liger: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "This parameter is deprecated and will be removed in version 0.18.0. Use `use_liger_kernel` "
|
||||
"instead."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
if self.dataset_batch_size is not None:
|
||||
warnings.warn(
|
||||
"`dataset_batch_size` is deprecated and will be removed in version 0.18.0. You can safely remove this "
|
||||
"parameter from your configuration.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if self.num_of_sequences is not None:
|
||||
warnings.warn(
|
||||
"`num_of_sequences` is deprecated and will be removed in version 0.18.0. Use `max_length` instead, "
|
||||
"which specifies the maximum length of the tokenized sequence, unlike `num_of_sequences`, which "
|
||||
"referred to string sequences.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if self.chars_per_token is not None:
|
||||
warnings.warn(
|
||||
"`chars_per_token` is deprecated and will be removed in version 0.18.0. If you want to customize the "
|
||||
"packing length, use `max_length`.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
if self.max_seq_length is not None:
|
||||
warnings.warn(
|
||||
"`max_seq_length` is deprecated and will be removed in version 0.20.0. Use `max_length` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.max_length = self.max_seq_length
|
||||
|
||||
if self.use_liger is not None:
|
||||
warnings.warn(
|
||||
"`use_liger` is deprecated and will be removed in version 0.18.0. Use `use_liger_kernel` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
self.use_liger_kernel = self.use_liger
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import os
|
||||
import warnings
|
||||
@ -51,6 +52,7 @@ from ..data_utils import (
|
||||
pack_dataset,
|
||||
truncate_dataset,
|
||||
)
|
||||
from ..models import get_act_offloading_ctx_manager
|
||||
from .sft_config import SFTConfig
|
||||
from .utils import (
|
||||
ConstantLengthDataset,
|
||||
@ -80,7 +82,9 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
Token ID to use for padding.
|
||||
completion_only_loss (`bool`, *optional*, defaults to `True`):
|
||||
When the input contains a completion mask (`completion_mask`), the labels are set to -100 for the tokens
|
||||
that are not in the completion.
|
||||
that are no in the completion.
|
||||
pad_to_multiple_of (`int` or `None`, *optional*, defaults to `None`):
|
||||
If set, the sequences will be padded to a multiple of this value.
|
||||
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
||||
Type of Tensor to return. Only `"pt"` is currently supported.
|
||||
|
||||
@ -116,6 +120,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
|
||||
pad_token_id: int
|
||||
completion_only_loss: bool = True
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
return_tensors: str = "pt"
|
||||
|
||||
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
|
||||
@ -128,11 +133,22 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
|
||||
# Pad
|
||||
output = {}
|
||||
output["input_ids"] = pad(input_ids, padding_value=self.pad_token_id, padding_side="right")
|
||||
output["attention_mask"] = pad(attention_mask, padding_value=0, padding_side="right")
|
||||
output["labels"] = pad(labels, padding_value=-100, padding_side="right")
|
||||
output["input_ids"] = pad(
|
||||
input_ids,
|
||||
padding_value=self.pad_token_id,
|
||||
padding_side="right",
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
)
|
||||
output["attention_mask"] = pad(
|
||||
attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
output["labels"] = pad(
|
||||
labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
if self.completion_only_loss and "completion_mask" in examples[0]:
|
||||
completion_mask = pad(completion_mask, padding_value=0, padding_side="right")
|
||||
completion_mask = pad(
|
||||
completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion
|
||||
|
||||
return output
|
||||
@ -210,7 +226,8 @@ class SFTTrainer(Trainer):
|
||||
peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
|
||||
PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
|
||||
formatting_func (`Optional[Callable]`):
|
||||
Formatting function applied to the dataset before tokenization.
|
||||
Formatting function applied to the dataset before tokenization. Applying the formatting function explicitly
|
||||
converts the dataset into a [language modeling](#language-modeling) type.
|
||||
"""
|
||||
|
||||
_tag_names = ["trl", "sft"]
|
||||
@ -315,11 +332,21 @@ class SFTTrainer(Trainer):
|
||||
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
|
||||
"in the vocabulary before using it as a padding token."
|
||||
)
|
||||
data_collator = DataCollatorForLanguageModeling(pad_token_id, self.completion_only_loss)
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
pad_token_id, self.completion_only_loss, args.pad_to_multiple_of
|
||||
)
|
||||
|
||||
# Dataset
|
||||
preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
|
||||
if preprocess_dataset:
|
||||
if self.completion_only_loss and formatting_func:
|
||||
raise ValueError(
|
||||
"A formatting function was provided while `completion_only_loss=True`, which is incompatible. "
|
||||
"Using a formatter converts the dataset to a language modeling type, conflicting with "
|
||||
"completion-only loss. To resolve this, apply your formatting function before passing the "
|
||||
"dataset, or disable `completion_only_loss` in `SFTConfig`."
|
||||
)
|
||||
|
||||
train_dataset = self._prepare_dataset(
|
||||
train_dataset, processing_class, args, args.packing, formatting_func, "train"
|
||||
)
|
||||
@ -370,6 +397,12 @@ class SFTTrainer(Trainer):
|
||||
**super_init_kwargs,
|
||||
)
|
||||
|
||||
# Initialize activation offloading context
|
||||
if self.args.activation_offloading:
|
||||
self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model)
|
||||
else:
|
||||
self.maybe_activation_offload_context = contextlib.nullcontext()
|
||||
|
||||
# Add tags for models that have been loaded with the correct transformers version
|
||||
if hasattr(self.model, "add_model_tags"):
|
||||
self.model.add_model_tags(self._tag_names)
|
||||
@ -650,7 +683,7 @@ class SFTTrainer(Trainer):
|
||||
"""
|
||||
Compute training loss and additionally compute token accuracies
|
||||
"""
|
||||
mode = "eval" if self.control.should_evaluate else "train"
|
||||
mode = "train" if self.model.training else "eval"
|
||||
(loss, outputs) = super().compute_loss(
|
||||
model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch
|
||||
)
|
||||
@ -694,8 +727,13 @@ class SFTTrainer(Trainer):
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
# Override training step to add activation offloading context.
|
||||
def training_step(self, *args, **kwargs):
|
||||
with self.maybe_activation_offload_context:
|
||||
return super().training_step(*args, **kwargs)
|
||||
|
||||
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
mode = "eval" if self.control.should_evaluate else "train"
|
||||
mode = "train" if self.model.training else "eval"
|
||||
metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
|
||||
|
||||
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
||||
|
@ -45,12 +45,12 @@ from transformers import (
|
||||
)
|
||||
from transformers.utils import (
|
||||
is_peft_available,
|
||||
is_rich_available,
|
||||
is_torch_mlu_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_xpu_available,
|
||||
)
|
||||
|
||||
from ..import_utils import is_rich_available
|
||||
from ..trainer.model_config import ModelConfig
|
||||
|
||||
|
||||
@ -415,7 +415,12 @@ class RewardDataCollatorWithPadding:
|
||||
return batch
|
||||
|
||||
|
||||
def pad(tensors: list[torch.Tensor], padding_value: int = 0, padding_side: str = "right") -> torch.Tensor:
|
||||
def pad(
|
||||
tensors: list[torch.Tensor],
|
||||
padding_value: int = 0,
|
||||
padding_side: str = "right",
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Pads a list of tensors to the same shape along the first dimension.
|
||||
|
||||
@ -426,6 +431,8 @@ def pad(tensors: list[torch.Tensor], padding_value: int = 0, padding_side: str =
|
||||
Value to use for padding. Default is 0.
|
||||
padding_side (`str`):
|
||||
Side on which to add padding. Must be 'left' or 'right'. Default is 'right'.
|
||||
pad_to_multiple_of (`int`, *optional*, defaults to `None`):
|
||||
If set will pad the sequence to a multiple of the provided value.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
@ -446,18 +453,25 @@ def pad(tensors: list[torch.Tensor], padding_value: int = 0, padding_side: str =
|
||||
# Determine the maximum shape for each dimension
|
||||
output_shape = np.max([t.shape for t in tensors], 0).tolist()
|
||||
|
||||
# Apply pad_to_multiple_of to the first (sequence) dimension
|
||||
if pad_to_multiple_of is not None:
|
||||
remainder = output_shape[0] % pad_to_multiple_of
|
||||
if remainder != 0:
|
||||
output_shape[0] += pad_to_multiple_of - remainder
|
||||
|
||||
# Create an output tensor filled with the padding value
|
||||
output = torch.full((len(tensors), *output_shape), padding_value, dtype=tensors[0].dtype, device=tensors[0].device)
|
||||
|
||||
for i, t in enumerate(tensors):
|
||||
# Determine the slice for the sequence dimension
|
||||
if padding_side == "left":
|
||||
seq_slice = slice(output_shape[0] - t.shape[0], output_shape[0])
|
||||
seq_start = output_shape[0] - t.shape[0]
|
||||
elif padding_side == "right":
|
||||
seq_slice = slice(0, t.shape[0])
|
||||
seq_start = 0
|
||||
else:
|
||||
raise ValueError("padding_side must be 'left' or 'right'")
|
||||
|
||||
# Define the slices
|
||||
seq_slice = slice(seq_start, seq_start + t.shape[0])
|
||||
slices = (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:])
|
||||
output[i][slices] = t
|
||||
|
||||
@ -964,7 +978,11 @@ def cap_exp(value, cap=-1):
|
||||
return torch.exp(torch.clamp(value, max=cap))
|
||||
|
||||
|
||||
def print_rich_table(df: pd.DataFrame) -> Table:
|
||||
def print_rich_table(df: pd.DataFrame) -> None:
|
||||
if not is_rich_available():
|
||||
raise ImportError(
|
||||
"The function `print_rich_table` requires the `rich` library. Please install it with `pip install rich`."
|
||||
)
|
||||
console = Console()
|
||||
table = Table(show_lines=True)
|
||||
for column in df.columns:
|
||||
@ -1600,7 +1618,7 @@ def log_table_to_comet_experiment(name: str, table: pd.DataFrame) -> None:
|
||||
experiment.log_table(tabular_data=table, filename=name)
|
||||
|
||||
|
||||
def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
"""
|
||||
Shift non-zero elements in the mask and corresponding tensors to the left.
|
||||
|
||||
@ -1642,28 +1660,59 @@ def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor
|
||||
[5, 6, 0]])
|
||||
```
|
||||
"""
|
||||
_, M = mask.shape
|
||||
|
||||
# Create copy of mask and tensors
|
||||
mask = mask.clone()
|
||||
mask_copy = mask.clone()
|
||||
tensors = [t.clone() for t in tensors]
|
||||
|
||||
# Shift non-zero values to the left
|
||||
for i in range(mask.size(0)):
|
||||
first_one_idx = torch.nonzero(mask[i])[0].item()
|
||||
mask[i] = torch.roll(mask[i], shifts=-first_one_idx)
|
||||
for tensor in tensors:
|
||||
tensor[i] = torch.roll(tensor[i], shifts=-first_one_idx)
|
||||
first_non_zero = mask_copy.argmax(dim=1)
|
||||
pos = torch.arange(M, device=mask_copy.device).unsqueeze(0)
|
||||
idx_roll = (pos + first_non_zero.unsqueeze(1)) % M
|
||||
mask_roll = mask_copy.gather(1, idx_roll)
|
||||
rolled_tensors = [t.gather(1, idx_roll) for t in tensors]
|
||||
|
||||
# Get the first column idx that is all zeros and remove every column after that
|
||||
empty_cols = torch.sum(mask, dim=0) == 0
|
||||
first_empty_col = torch.nonzero(empty_cols)[0].item() if empty_cols.any() else mask.size(1)
|
||||
mask = mask[:, :first_empty_col]
|
||||
for i, tensor in enumerate(tensors):
|
||||
tensors[i] = tensor[:, :first_empty_col]
|
||||
# Truncate trailing columns that are all zeros in mask_roll
|
||||
col_sums = mask_roll.sum(dim=0)
|
||||
empty_cols = col_sums == 0
|
||||
first_empty_col = int(empty_cols.to(torch.int8).argmax()) if empty_cols.any() else M
|
||||
flushed_mask = mask_roll[:, :first_empty_col]
|
||||
flushed_tensors = [t[:, :first_empty_col] for t in rolled_tensors]
|
||||
|
||||
if not tensors:
|
||||
return mask
|
||||
else:
|
||||
return mask, *tensors
|
||||
if not flushed_tensors:
|
||||
return flushed_mask
|
||||
return flushed_mask, *flushed_tensors
|
||||
|
||||
|
||||
def flush_right(mask: torch.Tensor, *tensors: torch.Tensor) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
|
||||
"""
|
||||
Shift non-zero elements in the mask and corresponding tensors to the right. See `flush_left` for details.
|
||||
"""
|
||||
_, M = mask.shape
|
||||
|
||||
# Create copy of mask and tensors
|
||||
mask_copy = mask.clone()
|
||||
tensors = [t.clone() for t in tensors]
|
||||
|
||||
# Shift non-zero values to the right
|
||||
flipped_mask = torch.fliplr(mask_copy)
|
||||
first_non_zero = flipped_mask.argmax(dim=1)
|
||||
pos = torch.arange(M, device=mask_copy.device).unsqueeze(0)
|
||||
idx_roll = (pos - first_non_zero.unsqueeze(1)) % M
|
||||
mask_roll = mask_copy.gather(1, idx_roll)
|
||||
rolled_tensors = [t.gather(1, idx_roll) for t in tensors]
|
||||
|
||||
# Truncate leading columns that are all zeros in mask_roll
|
||||
col_sums = mask_roll.sum(dim=0)
|
||||
non_empty_cols = col_sums != 0
|
||||
first_non_empty_col = int(non_empty_cols.to(torch.int8).argmax()) if non_empty_cols.any() else M
|
||||
flushed_mask = mask_roll[:, first_non_empty_col:]
|
||||
flushed_tensors = [t[:, first_non_empty_col:] for t in rolled_tensors]
|
||||
|
||||
if not flushed_tensors:
|
||||
return flushed_mask
|
||||
return flushed_mask, *flushed_tensors
|
||||
|
||||
|
||||
def selective_log_softmax(logits, index):
|
||||
@ -1702,7 +1751,12 @@ def selective_log_softmax(logits, index):
|
||||
|
||||
|
||||
def print_prompt_completions_sample(
|
||||
prompts: list[str], completions: list[str], rewards: dict[str, list[float]], step: int, num_samples: int = None
|
||||
prompts: list[str],
|
||||
completions: list[str],
|
||||
rewards: dict[str, list[float]],
|
||||
advantages: list[float],
|
||||
step: int,
|
||||
num_samples: int = None,
|
||||
) -> None:
|
||||
"""
|
||||
Print out a sample of model completions to the console with multiple reward metrics.
|
||||
@ -1717,6 +1771,8 @@ def print_prompt_completions_sample(
|
||||
List of completions corresponding to the prompts.
|
||||
rewards (`dict[str, list[float]]`):
|
||||
Dictionary where keys are reward names and values are lists of rewards.
|
||||
advantages (`list[float]`):
|
||||
List of advantages corresponding to the prompts and completions.
|
||||
step (`int`):
|
||||
Current training step number, used in the output title.
|
||||
num_samples (`int` or `None`, *optional*, defaults to `None`):
|
||||
@ -1728,18 +1784,24 @@ def print_prompt_completions_sample(
|
||||
>>> prompts = ["The sky is", "The sun is"]
|
||||
>>> completions = [" blue.", " in the sky."]
|
||||
>>> rewards = {"Correctness": [0.123, 0.456], "Format": [0.789, 0.101]}
|
||||
>>> print_prompt_completions_sample(prompts, completions, rewards, 42)
|
||||
╭────────────────────── Step 42 ───────────────────────╮
|
||||
│ ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┓ │
|
||||
│ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ │
|
||||
│ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━┩ │
|
||||
│ │ The sky is │ blue. │ 0.12 │ 0.79 │ │
|
||||
│ ├────────────┼──────────────┼─────────────┼────────┤ │
|
||||
│ │ The sun is │ in the sky. │ 0.46 │ 0.10 │ │
|
||||
│ └────────────┴──────────────┴─────────────┴────────┘ │
|
||||
╰──────────────────────────────────────────────────────╯
|
||||
>>> advantages = [0.987, 0.654]
|
||||
>>> print_prompt_completions_sample(prompts, completions, rewards, advantages, 42)
|
||||
╭──────────────────────────── Step 42 ─────────────────────────────╮
|
||||
│ ┏━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━┓ │
|
||||
│ ┃ Prompt ┃ Completion ┃ Correctness ┃ Format ┃ Advantage ┃ │
|
||||
│ ┡━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━┩ │
|
||||
│ │ The sky is │ blue. │ 0.12 │ 0.79 │ 0.99 │ │
|
||||
│ ├────────────┼──────────────┼─────────────┼────────┼───────────┤ │
|
||||
│ │ The sun is │ in the sky. │ 0.46 │ 0.10 │ 0.65 │ │
|
||||
│ └────────────┴──────────────┴─────────────┴────────┴───────────┘ │
|
||||
╰──────────────────────────────────────────────────────────────────╯
|
||||
```
|
||||
"""
|
||||
if not is_rich_available():
|
||||
raise ImportError(
|
||||
"The function `print_prompt_completions_sample` requires the `rich` library. Please install it with "
|
||||
"`pip install rich`."
|
||||
)
|
||||
console = Console()
|
||||
table = Table(show_header=True, header_style="bold white", expand=True)
|
||||
|
||||
@ -1748,6 +1810,7 @@ def print_prompt_completions_sample(
|
||||
table.add_column("Completion", style="bright_green")
|
||||
for reward_name in rewards.keys():
|
||||
table.add_column(reward_name, style="bold cyan", justify="right")
|
||||
table.add_column("Advantage", style="bold magenta", justify="right")
|
||||
|
||||
# Some basic input validation
|
||||
if num_samples is not None:
|
||||
@ -1762,10 +1825,11 @@ def print_prompt_completions_sample(
|
||||
prompts = [prompts[i] for i in indices]
|
||||
completions = [completions[i] for i in indices]
|
||||
rewards = {key: [val[i] for i in indices] for key, val in rewards.items()}
|
||||
advantages = [advantages[i] for i in indices]
|
||||
|
||||
for i in range(len(prompts)):
|
||||
reward_values = [f"{rewards[key][i]:.2f}" for key in rewards.keys()] # 2 decimals
|
||||
table.add_row(Text(prompts[i]), Text(completions[i]), *reward_values)
|
||||
table.add_row(Text(prompts[i]), Text(completions[i]), *reward_values, f"{advantages[i]:.2f}")
|
||||
table.add_section() # Adds a separator between rows
|
||||
|
||||
panel = Panel(table, expand=False, title=f"Step {step}", border_style="bold white")
|
||||
|
@ -33,6 +33,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.trainer_utils import EvalPrediction
|
||||
from transformers.training_args import OptimizerNames
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from ..data_utils import is_conversational, maybe_apply_chat_template
|
||||
from ..models.utils import unwrap_model_for_generation
|
||||
@ -58,6 +59,10 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import PeftModel
|
||||
|
||||
|
||||
class XPOTrainer(OnlineDPOTrainer):
|
||||
r"""
|
||||
Initialize XPOTrainer as a subclass of [`OnlineDPOConfig`].
|
||||
@ -174,16 +179,26 @@ class XPOTrainer(OnlineDPOTrainer):
|
||||
return self._alpha
|
||||
|
||||
def _generate_completions(self, prompts, model):
|
||||
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
||||
model_output = unwrapped_model.generate(
|
||||
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_model_for_gen:
|
||||
model_output = unwrapped_policy_model_for_gen.generate(
|
||||
input_ids=prompts["input_ids"],
|
||||
attention_mask=prompts["attention_mask"],
|
||||
generation_config=self.generation_config,
|
||||
)
|
||||
|
||||
ref_model = model if self.ref_model is None else self.ref_model
|
||||
with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model:
|
||||
ref_output = unwrapped_ref_model.generate(
|
||||
actual_model_for_ref_generation: torch.nn.Module
|
||||
if self.ref_model is None:
|
||||
unwrapped_main_model_for_ref_logic = self.accelerator.unwrap_model(model)
|
||||
|
||||
if is_peft_available() and isinstance(unwrapped_main_model_for_ref_logic, PeftModel):
|
||||
actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic.get_base_model()
|
||||
else:
|
||||
actual_model_for_ref_generation = unwrapped_main_model_for_ref_logic
|
||||
else:
|
||||
actual_model_for_ref_generation = self.accelerator.unwrap_model(self.ref_model)
|
||||
|
||||
with unwrap_model_for_generation(actual_model_for_ref_generation, self.accelerator) as final_ref_model_for_gen:
|
||||
ref_output = final_ref_model_for_gen.generate(
|
||||
input_ids=prompts["input_ids"],
|
||||
attention_mask=prompts["attention_mask"],
|
||||
generation_config=self.generation_config,
|
||||
|
Reference in New Issue
Block a user