mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
155 Commits
v0.23-rele
...
a932e2796d
Author | SHA1 | Date | |
---|---|---|---|
a932e2796d | |||
04fd1203af | |||
19d2f97932 | |||
31caf64778 | |||
8e2d5516ca | |||
94aac4a101 | |||
26b7c2507e | |||
aa25c2697c | |||
93c7d88563 | |||
c7c041ecc8 | |||
ef40c047aa | |||
7e0adbc552 | |||
773afd9314 | |||
966b397201 | |||
927cf6ba46 | |||
56cb6ccf76 | |||
49c8f14b06 | |||
cefbacb30e | |||
fae245a062 | |||
2aa9506c69 | |||
d6eeb290d9 | |||
1684ef279a | |||
aab21eb5e7 | |||
b997a31981 | |||
86d1963cc1 | |||
039d526d24 | |||
bcd059a384 | |||
0e57b4a9df | |||
98488e0946 | |||
f45e86571b | |||
f5827928a0 | |||
f853e091ea | |||
803ec0d856 | |||
7a0a615d50 | |||
c38cb69ec7 | |||
68ef15c686 | |||
3dd7fc2850 | |||
51ced65153 | |||
4bb883a6e6 | |||
f7846321e7 | |||
a944890ff1 | |||
521db3520a | |||
e2c97a805a | |||
d1d0407d3c | |||
824ff8c73e | |||
f15399d3d3 | |||
cc578b6b14 | |||
30cf68a97b | |||
452284b8dc | |||
6be53e19bc | |||
3080fc1bd7 | |||
5d870955f8 | |||
8265800abf | |||
65eb45c32b | |||
ae6837f8d4 | |||
56a8f1128b | |||
529101537f | |||
0588b1f01d | |||
45ee98b05e | |||
3800a6ecc7 | |||
7ad9ce8acc | |||
0c2dc14014 | |||
ced8b337ba | |||
1eff7da9e0 | |||
1cbfb00b6a | |||
e086f073cf | |||
e5d437ed76 | |||
d1b4691900 | |||
39c603872f | |||
5a4021f23e | |||
ea66a9e650 | |||
da209f89fc | |||
ebb8899f5d | |||
70e2017dbc | |||
4368f54c97 | |||
22720d176b | |||
c8a5add88a | |||
a7b54f988b | |||
78bf77abbd | |||
3b9ac65a05 | |||
7a78320f58 | |||
67e83aee90 | |||
a0df357591 | |||
864e593e9f | |||
6428647063 | |||
8a5bfecc3a | |||
910aeebe06 | |||
e208823b3e | |||
f397a61e82 | |||
7fe9dd42ac | |||
79c774af54 | |||
9603b41d7e | |||
5ee56ed04f | |||
e85e634bff | |||
d633c4337f | |||
d1e24df031 | |||
094e0760d4 | |||
01c9b4c414 | |||
18faf03c4e | |||
d144e73e78 | |||
be1ffe59d2 | |||
fb6bdab33b | |||
526303edbd | |||
9e5e60c933 | |||
5c52f46f9a | |||
deac14a39f | |||
3d5a30bb77 | |||
251fdb228a | |||
37806e618b | |||
008c7ad9aa | |||
e8ba9eaf27 | |||
abe07c9e32 | |||
fe02ea2b52 | |||
68408d7219 | |||
94f8d00a62 | |||
b5ca3799ad | |||
a68b4af50f | |||
9f0ed8b130 | |||
27f22ba5a1 | |||
86f74b486f | |||
26b497ea63 | |||
d22bdb8031 | |||
0e204482e6 | |||
3c8d7209f1 | |||
0450f05ad9 | |||
7e2075347e | |||
20cc58d777 | |||
a6c0c57f6b | |||
10dc36d610 | |||
d2d1912d96 | |||
08ea00289a | |||
4ff8b4e007 | |||
6356343fd2 | |||
45e59f77ea | |||
4bd4acf172 | |||
8380869d33 | |||
5139af3712 | |||
2f46c18a66 | |||
e2b18ec4e7 | |||
78f1a928ce | |||
1d0b196f6b | |||
5a1c2f9b3b | |||
9955ee7eaa | |||
304eaf8053 | |||
69e288ebad | |||
d655ce48f8 | |||
91c4bba922 | |||
2845d024a4 | |||
f4ff248407 | |||
b8eb5c5d2d | |||
07f9ad982d | |||
417915a3e4 | |||
44ddc28bcd | |||
e8b8499f1f | |||
7eb7f42372 |
1
.github/workflows/build_documentation.yml
vendored
1
.github/workflows/build_documentation.yml
vendored
@ -14,6 +14,5 @@ jobs:
|
||||
commit_sha: ${{ github.sha }}
|
||||
package: trl
|
||||
version_tag_suffix: ""
|
||||
custom_container: huggingface/transformers-doc-builder
|
||||
secrets:
|
||||
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
||||
|
1
.github/workflows/build_pr_documentation.yml
vendored
1
.github/workflows/build_pr_documentation.yml
vendored
@ -16,4 +16,3 @@ jobs:
|
||||
pr_number: ${{ github.event.number }}
|
||||
package: trl
|
||||
version_tag_suffix: ""
|
||||
custom_container: huggingface/transformers-doc-builder
|
||||
|
89
.github/workflows/docker-build.yml
vendored
89
.github/workflows/docker-build.yml
vendored
@ -1,95 +1,84 @@
|
||||
name: Build Docker images (scheduled)
|
||||
name: Build TRL Docker image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
schedule:
|
||||
- cron: "0 1 * * *"
|
||||
|
||||
concurrency:
|
||||
group: docker-image-builds
|
||||
cancel-in-progress: false
|
||||
|
||||
env:
|
||||
CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
|
||||
|
||||
jobs:
|
||||
trl-latest:
|
||||
name: "Latest TRL GPU"
|
||||
trl:
|
||||
name: "Build and push TRL Docker image"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cleanup disk
|
||||
run: |
|
||||
sudo ls -l /usr/local/lib/
|
||||
sudo ls -l /usr/share/
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
- name: Check out code
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get TRL version from PyPI
|
||||
run: |
|
||||
VERSION=$(curl -s https://pypi.org/pypi/trl/json | jq -r .info.version)
|
||||
echo "VERSION=$VERSION" >> $GITHUB_ENV
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v1
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push GPU
|
||||
- name: Build and Push
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: ./docker/trl-latest-gpu
|
||||
context: docker/trl
|
||||
push: true
|
||||
tags: huggingface/trl-latest-gpu
|
||||
tags: |
|
||||
huggingface/trl:${{ env.VERSION }}
|
||||
huggingface/trl
|
||||
|
||||
- name: Post to Slack
|
||||
if: always()
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: 🤗 Results of the trl-latest-gpu Docker Image build
|
||||
slack_channel: ${{ secrets.CI_DOCKER_CHANNEL }}
|
||||
title: 🤗 Results of the TRL Dev Docker Image build
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
trl-source:
|
||||
name: "Latest TRL + HF ecosystem from source"
|
||||
trl-dev:
|
||||
name: "Build and push TRL Dev Docker image"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cleanup disk
|
||||
run: |
|
||||
sudo ls -l /usr/local/lib/
|
||||
sudo ls -l /usr/share/
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
- name: Check out code
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v1
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push GPU
|
||||
- name: Build and Push
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: ./docker/trl-source-gpu
|
||||
context: docker/trl-dev
|
||||
push: true
|
||||
tags: huggingface/trl-source-gpu
|
||||
tags: |
|
||||
huggingface/trl:dev
|
||||
|
||||
- name: Post to Slack
|
||||
if: always()
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: 🤗 Results of the trl-source-gpu Docker Image build
|
||||
slack_channel: ${{ secrets.CI_DOCKER_CHANNEL }}
|
||||
title: 🤗 Results of the TRL Dev Docker Image build
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
99
.github/workflows/slow-tests.yml
vendored
99
.github/workflows/slow-tests.yml
vendored
@ -14,91 +14,98 @@ env:
|
||||
|
||||
jobs:
|
||||
run_all_tests_single_gpu:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
docker-image-name:
|
||||
[
|
||||
"huggingface/trl-latest-gpu:latest",
|
||||
"huggingface/trl-source-gpu:latest",
|
||||
]
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
env:
|
||||
CUDA_VISIBLE_DEVICES: "0"
|
||||
TEST_TYPE: "single_gpu_${{ matrix.docker-image-name }}"
|
||||
TEST_TYPE: "single_gpu"
|
||||
container:
|
||||
image: ${{ matrix.docker-image-name }}
|
||||
options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true
|
||||
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
||||
options: --gpus all --shm-size "16gb"
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Pip install
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
source activate trl
|
||||
pip install -e ".[test,vlm]" --no-deps
|
||||
pip install pytest-reportlog parameterized
|
||||
apt-get update && apt-get install -y make git curl
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Create Python virtual environment
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install --upgrade setuptools wheel
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
uv pip install ".[dev]"
|
||||
uv pip install pytest-reportlog parameterized
|
||||
|
||||
- name: Run slow SFT tests on single GPU
|
||||
if: always()
|
||||
run: |
|
||||
source activate trl
|
||||
source .venv/bin/activate
|
||||
make slow_tests
|
||||
|
||||
- name: Generate Report
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
source .venv/bin/activate
|
||||
uv pip install slack_sdk tabulate
|
||||
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
run_all_tests_multi_gpu:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
docker-image-name:
|
||||
[
|
||||
"huggingface/trl-latest-gpu:latest",
|
||||
"huggingface/trl-source-gpu:latest",
|
||||
]
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
env:
|
||||
CUDA_VISIBLE_DEVICES: "0,1"
|
||||
TEST_TYPE: "multi_gpu_${{ matrix.docker-image-name }}"
|
||||
TEST_TYPE: "multi_gpu"
|
||||
container:
|
||||
image: ${{ matrix.docker-image-name }}
|
||||
options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true
|
||||
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
||||
options: --gpus all --shm-size "16gb"
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Pip install
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
source activate trl
|
||||
pip install -e ".[test,vlm]" --no-deps
|
||||
pip install pytest-reportlog parameterized
|
||||
apt-get update && apt-get install -y make git curl
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Create Python virtual environment
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install --upgrade setuptools wheel
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
uv pip install ".[dev]"
|
||||
uv pip install pytest-reportlog parameterized
|
||||
|
||||
- name: Run slow SFT tests on Multi GPU
|
||||
if: always()
|
||||
run: |
|
||||
source activate trl
|
||||
source .venv/bin/activate
|
||||
make slow_tests
|
||||
|
||||
- name: Run end-to-end examples tests on multi GPU
|
||||
if: always()
|
||||
run: |
|
||||
source activate trl
|
||||
pip install deepspeed
|
||||
make test_examples
|
||||
|
||||
- name: Generate Reports
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
source .venv/bin/activate
|
||||
uv pip install slack_sdk tabulate
|
||||
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
python scripts/log_example_reports.py --text_file_name temp_results_sft_tests.txt >> $GITHUB_STEP_SUMMARY
|
||||
python scripts/log_example_reports.py --text_file_name temp_results_dpo_tests.txt >> $GITHUB_STEP_SUMMARY
|
||||
rm *.txt
|
||||
rm *.txt
|
13
.github/workflows/tests.yml
vendored
13
.github/workflows/tests.yml
vendored
@ -11,11 +11,12 @@ on:
|
||||
- "scripts/**.py"
|
||||
- "tests/**.py"
|
||||
- "trl/**.py"
|
||||
- "setup.py"
|
||||
- "pyproject.toml"
|
||||
|
||||
env:
|
||||
TQDM_DISABLE: 1
|
||||
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
|
||||
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True"
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
@ -41,7 +42,7 @@ jobs:
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
||||
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
@ -93,7 +94,7 @@ jobs:
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
||||
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
@ -128,7 +129,7 @@ jobs:
|
||||
uv pip install -U git+https://github.com/huggingface/accelerate.git
|
||||
uv pip install -U git+https://github.com/huggingface/datasets.git
|
||||
uv pip install -U git+https://github.com/huggingface/transformers.git
|
||||
|
||||
uv pip install -U git+https://github.com/huggingface/peft.git
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
@ -149,7 +150,7 @@ jobs:
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
||||
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
@ -201,7 +202,7 @@ jobs:
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
||||
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
|
4
.github/workflows/tests_latest.yml
vendored
4
.github/workflows/tests_latest.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
||||
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
@ -24,7 +24,7 @@ jobs:
|
||||
steps:
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
with: { ref: v0.23-release }
|
||||
with: { ref: v0.24-release }
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
|
@ -31,4 +31,4 @@ keywords:
|
||||
- pytorch
|
||||
- transformers
|
||||
license: Apache-2.0
|
||||
version: "0.23"
|
||||
version: "0.24"
|
||||
|
184
CONTRIBUTING.md
184
CONTRIBUTING.md
@ -1,15 +1,10 @@
|
||||
# How to contribute to TRL?
|
||||
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code
|
||||
contributions are not the only way to help the community. Answering questions, helping
|
||||
others, and improving the documentation are also immensely valuable.
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable.
|
||||
|
||||
It also helps us if you spread the word! Reference the library in blog posts
|
||||
about the awesome projects it made possible, shout out on Twitter every time it has
|
||||
helped you, or simply ⭐️ the repository to say thank you.
|
||||
It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you.
|
||||
|
||||
However you choose to contribute, please be mindful and respect our
|
||||
[code of conduct](https://github.com/huggingface/trl/blob/main/CODE_OF_CONDUCT.md).
|
||||
However you choose to contribute, please be mindful and respect our [code of conduct](https://github.com/huggingface/trl/blob/main/CODE_OF_CONDUCT.md).
|
||||
|
||||
**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).**
|
||||
|
||||
@ -22,9 +17,7 @@ There are several ways you can contribute to TRL:
|
||||
* Implement trainers for new post-training algorithms.
|
||||
* Contribute to the examples or the documentation.
|
||||
|
||||
If you don't know where to start, there is a special [Good First
|
||||
Issue](https://github.com/huggingface/trl/labels/%F0%9F%91%B6%20good%20first%20issue) listing. It will give you a list of
|
||||
open issues that are beginner-friendly and help you start contributing to open-source. The best way to do that is to open a Pull Request and link it to the issue that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over.
|
||||
If you don't know where to start, there is a special [Good First Issue](https://github.com/huggingface/trl/labels/%F0%9F%91%B6%20good%20first%20issue) listing. It will give you a list of open issues that are beginner-friendly and help you start contributing to open-source. The best way to do that is to open a Pull Request and link it to the issue that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over.
|
||||
|
||||
For something slightly more challenging, you can also take a look at the [Good Second Issue](https://github.com/huggingface/trl/labels/Good%20Second%20Issue) list. In general though, if you feel like you know what you're doing, go for it and we'll help you get there! 🚀
|
||||
|
||||
@ -48,14 +41,12 @@ Do your best to follow these guidelines when submitting a bug-related issue or a
|
||||
|
||||
The TRL library is robust and reliable thanks to users who report the problems they encounter.
|
||||
|
||||
Before you report an issue, we would really appreciate it if you could **make sure the bug was not
|
||||
already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code.
|
||||
Before you report an issue, we would really appreciate it if you could **make sure the bug was not already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code.
|
||||
|
||||
Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so we can quickly resolve it:
|
||||
|
||||
* Your **OS type and version**, **Python**, **PyTorch**, **TRL** and **Transformers** versions.
|
||||
* A short, self-contained, code snippet that allows us to reproduce the bug in
|
||||
less than 30s.
|
||||
* A short, self-contained, code snippet that allows us to reproduce the bug in less than 30s.
|
||||
* The *full* traceback if an exception is raised.
|
||||
* Attach any other additional information, like screenshots, you think may help.
|
||||
|
||||
@ -106,29 +97,20 @@ We're always looking for improvements to the documentation that make it more cle
|
||||
|
||||
## Submitting a pull request (PR)
|
||||
|
||||
Before writing code, we strongly advise you to search through the existing PRs or
|
||||
issues to make sure that nobody is already working on the same thing. If you are
|
||||
unsure, it is always a good idea to open an issue to get some feedback.
|
||||
Before writing code, we strongly advise you to search through the existing PRs or issues to make sure that nobody is already working on the same thing. If you are unsure, it is always a good idea to open an issue to get some feedback.
|
||||
|
||||
You will need basic `git` proficiency to be able to contribute to
|
||||
TRL. `git` is not the easiest tool to use but it has the greatest
|
||||
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
|
||||
Git](https://git-scm.com/book/en/v2) is a very good reference.
|
||||
You will need basic `git` proficiency to be able to contribute to TRL. `git` is not the easiest tool to use but it has the greatest manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro Git](https://git-scm.com/book/en/v2) is a very good reference.
|
||||
|
||||
Follow these steps to start contributing:
|
||||
|
||||
1. Fork the [repository](https://github.com/huggingface/trl) by
|
||||
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
|
||||
under your GitHub user account.
|
||||
1. Fork the [repository](https://github.com/huggingface/trl) by clicking on the 'Fork' button on the repository's page. This creates a copy of the code under your GitHub user account.
|
||||
|
||||
2. Clone your fork to your local disk, and add the base repository as a remote. The following command
|
||||
assumes you have your public SSH key uploaded to GitHub. See the following guide for more
|
||||
[information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
|
||||
2. Clone your fork to your local disk, and add the base repository as a remote. The following command assumes you have your public SSH key uploaded to GitHub. See the following guide for more [information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
|
||||
|
||||
```bash
|
||||
$ git clone git@github.com:<your Github handle>/trl.git
|
||||
$ cd trl
|
||||
$ git remote add upstream https://github.com/huggingface/trl.git
|
||||
git clone git@github.com:<your Github handle>/trl.git
|
||||
cd trl
|
||||
git remote add upstream https://github.com/huggingface/trl.git
|
||||
```
|
||||
|
||||
3. Create a new branch to hold your development changes, and do this for every new PR you work on.
|
||||
@ -136,15 +118,15 @@ Follow these steps to start contributing:
|
||||
Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):
|
||||
|
||||
```bash
|
||||
$ git checkout main
|
||||
$ git fetch upstream
|
||||
$ git merge upstream/main
|
||||
git checkout main
|
||||
git fetch upstream
|
||||
git merge upstream/main
|
||||
```
|
||||
|
||||
Once your `main` branch is synchronized, create a new branch from it:
|
||||
|
||||
```bash
|
||||
$ git checkout -b a-descriptive-name-for-my-changes
|
||||
git checkout -b a-descriptive-name-for-my-changes
|
||||
```
|
||||
|
||||
**Do not** work on the `main` branch.
|
||||
@ -152,32 +134,28 @@ Follow these steps to start contributing:
|
||||
4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library:
|
||||
|
||||
```bash
|
||||
$ pip install -e .[dev]
|
||||
pip install -e .[dev]
|
||||
```
|
||||
|
||||
(If TRL was already installed in the virtual environment, remove
|
||||
it with `pip uninstall trl` before reinstalling it.)
|
||||
(If TRL was already installed in the virtual environment, remove it with `pip uninstall trl` before reinstalling it.)
|
||||
|
||||
Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using
|
||||
the provided Dev Container. Documentation on how to get started with dev containers is available [here](https://code.visualstudio.com/docs/remote/containers).
|
||||
Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using the provided Dev Container. Check [the documentation on how to get started with dev containers](https://code.visualstudio.com/docs/remote/containers).
|
||||
|
||||
5. Develop the features on your branch.
|
||||
|
||||
As you work on the features, you should make sure that the test suite
|
||||
passes. You should run the tests impacted by your changes like this (see
|
||||
below an explanation regarding the environment variable):
|
||||
As you work on the features, you should make sure that the test suite passes. You should run the tests impacted by your changes like this (see below an explanation regarding the environment variable):
|
||||
|
||||
```bash
|
||||
$ pytest tests/<TEST_TO_RUN>.py
|
||||
```
|
||||
|
||||
> For the following commands leveraging the `make` utility.
|
||||
```bash
|
||||
pytest tests/<TEST_TO_RUN>.py
|
||||
```
|
||||
|
||||
You can also run the full suite with the following command.
|
||||
> For the following commands leveraging the `make` utility.
|
||||
|
||||
```bash
|
||||
$ make test
|
||||
```
|
||||
You can also run the full suite with the following command.
|
||||
|
||||
```bash
|
||||
make test
|
||||
```
|
||||
|
||||
TRL relies on `ruff` for maintaining consistent code formatting across its source files. Before submitting any PR, you should apply automatic style corrections and run code verification checks.
|
||||
|
||||
@ -186,59 +164,51 @@ Follow these steps to start contributing:
|
||||
To apply these checks and corrections in one step, use:
|
||||
|
||||
```bash
|
||||
$ make precommit
|
||||
make precommit
|
||||
```
|
||||
|
||||
This command runs the following:
|
||||
- Executes `pre-commit` hooks to automatically fix style issues with `ruff` and other tools.
|
||||
- Runs additional scripts such as adding copyright information.
|
||||
|
||||
* Executes `pre-commit` hooks to automatically fix style issues with `ruff` and other tools.
|
||||
* Runs additional scripts such as adding copyright information.
|
||||
|
||||
If you prefer to apply the style corrections separately or review them individually, the `pre-commit` hook will handle the formatting for the files in question.
|
||||
|
||||
Once you're happy with your changes, add changed files using `git add` and
|
||||
make a commit with `git commit` to record your changes locally:
|
||||
Once you're happy with your changes, add changed files using `git add` and make a commit with `git commit` to record your changes locally:
|
||||
|
||||
```bash
|
||||
$ git add modified_file.py
|
||||
$ git commit
|
||||
```
|
||||
```bash
|
||||
git add modified_file.py
|
||||
git commit
|
||||
```
|
||||
|
||||
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
|
||||
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
|
||||
|
||||
It is a good idea to sync your copy of the code with the original
|
||||
repository regularly. This way you can quickly account for changes:
|
||||
It is a good idea to sync your copy of the code with the original
|
||||
repository regularly. This way you can quickly account for changes:
|
||||
|
||||
```bash
|
||||
$ git fetch upstream
|
||||
$ git rebase upstream/main
|
||||
```
|
||||
```bash
|
||||
git fetch upstream
|
||||
git rebase upstream/main
|
||||
```
|
||||
|
||||
Push the changes to your account using:
|
||||
Push the changes to your account using:
|
||||
|
||||
```bash
|
||||
$ git push -u origin a-descriptive-name-for-my-changes
|
||||
```
|
||||
```bash
|
||||
git push -u origin a-descriptive-name-for-my-changes
|
||||
```
|
||||
|
||||
6. Once you are satisfied (**and the checklist below is happy too**), go to the
|
||||
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
|
||||
to the project maintainers for review.
|
||||
6. Once you are satisfied (**and the checklist below is happy too**), go to the webpage of your fork on GitHub. Click on 'Pull request' to send your changes to the project maintainers for review.
|
||||
|
||||
7. It's ok if maintainers ask you for changes. It happens to core contributors too! To ensure everyone can review your changes in the pull request, work on your local branch and push the updates to your fork. They will automatically appear in the pull request.
|
||||
|
||||
|
||||
### Checklist
|
||||
|
||||
1. The title of your pull request should be a summary of its contribution;
|
||||
2. If your pull request addresses an issue, please mention the issue number in
|
||||
the pull request description to make sure they are linked (and people
|
||||
consulting the issue know you are working on it);
|
||||
3. To indicate a work in progress please prefix the title with `[WIP]`, or mark
|
||||
the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate
|
||||
it from PRs ready to be merged;
|
||||
2. If your pull request addresses an issue, please mention the issue number in the pull request description to make sure they are linked (and people consulting the issue know you are working on it);
|
||||
3. To indicate a work in progress please prefix the title with `[WIP]`, or mark the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate it from PRs ready to be merged;
|
||||
4. Make sure existing tests pass;
|
||||
5. Add high-coverage tests. No quality testing = no merge.
|
||||
|
||||
|
||||
### Tests
|
||||
|
||||
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in
|
||||
@ -248,7 +218,7 @@ We use `pytest` to run the tests. From the root of the
|
||||
repository here's how to run tests with `pytest` for the library:
|
||||
|
||||
```bash
|
||||
$ python -m pytest -sv ./tests
|
||||
python -m pytest -sv ./tests
|
||||
```
|
||||
|
||||
That's how `make test` is implemented (without the `pip install` line)!
|
||||
@ -260,23 +230,23 @@ you're working on.
|
||||
|
||||
1. **Use defaults when appropriate**:
|
||||
|
||||
Provide default values unless the parameter's value varies significantly by use case. For example, datasets or models should not have defaults, but parameters like `learning_rate` should.
|
||||
Provide default values unless the parameter's value varies significantly by use case. For example, datasets or models should not have defaults, but parameters like `learning_rate` should.
|
||||
|
||||
2. **Prioritize proven defaults**:
|
||||
|
||||
Default values should align with those recommended in the original paper or method. Alternatives require strong evidence of superior performance in most cases.
|
||||
Default values should align with those recommended in the original paper or method. Alternatives require strong evidence of superior performance in most cases.
|
||||
|
||||
3. **Ensure safety and predictability**:
|
||||
|
||||
Defaults must be safe, expected and reliable. Avoid settings that could lead to surprising outcomes, such as excessive memory usage or poor performance in edge cases.
|
||||
Defaults must be safe, expected and reliable. Avoid settings that could lead to surprising outcomes, such as excessive memory usage or poor performance in edge cases.
|
||||
|
||||
4. **Balance consistency and flexibility**:
|
||||
|
||||
Aim for consistent defaults across similar functions or methods. However, consistency should not be preferred to point 2 or 3.
|
||||
Aim for consistent defaults across similar functions or methods. However, consistency should not be preferred to point 2 or 3.
|
||||
|
||||
5. **Opt-in for new features**:
|
||||
|
||||
Do not enable new features or improvements (e.g., novel loss functions) by default. Users should explicitly opt-in to use these.
|
||||
Do not enable new features or improvements (e.g., novel loss functions) by default. Users should explicitly opt-in to use these.
|
||||
|
||||
### Writing documentation
|
||||
|
||||
@ -318,26 +288,26 @@ def replicate_str(string: str, n: int, sep: str = " ") -> str:
|
||||
* Note that `Optional` means that the value can be `None`, and `*optional*` means that it is not required for the user to pass a value.
|
||||
E.g., for arguments that can't be `None` and aren't required:
|
||||
|
||||
```python
|
||||
```txt
|
||||
foo (`int`, *optional*, defaults to `4`):
|
||||
```
|
||||
|
||||
For arguments that can be `None` and are required:
|
||||
|
||||
```python
|
||||
```txt
|
||||
foo (`Optional[int]`):
|
||||
```
|
||||
|
||||
for arguments that can be `None` and aren't required:
|
||||
for arguments that can be `None` and aren't required (in this case, if the default value is `None`, you can omit it):
|
||||
|
||||
```python
|
||||
foo (`Optional[int]`, *optional*, defaults to `None`):
|
||||
```txt
|
||||
foo (`Optional[int]`, *optional*):
|
||||
```
|
||||
|
||||
* **String Defaults:**
|
||||
* Ensured that default string values are wrapped in double quotes:
|
||||
|
||||
```python
|
||||
```txt
|
||||
defaults to `"foo"`
|
||||
```
|
||||
|
||||
@ -346,7 +316,7 @@ def replicate_str(string: str, n: int, sep: str = " ") -> str:
|
||||
* **Default Value Formatting:**
|
||||
* Consistently surrounded default values with backticks for improved formatting:
|
||||
|
||||
```python
|
||||
```txt
|
||||
defaults to `4`
|
||||
```
|
||||
|
||||
@ -383,8 +353,8 @@ Our approach to deprecation and backward compatibility is flexible and based on
|
||||
|
||||
When a feature or component is marked for deprecation, its use will emit a warning message. This warning will include:
|
||||
|
||||
- **Transition Guidance**: Instructions on how to migrate to the alternative solution or replacement.
|
||||
- **Removal Version**: The target version when the feature will be removed, providing users with a clear timeframe to transition.
|
||||
* **Transition Guidance**: Instructions on how to migrate to the alternative solution or replacement.
|
||||
* **Removal Version**: The target version when the feature will be removed, providing users with a clear timeframe to transition.
|
||||
|
||||
Example:
|
||||
|
||||
@ -398,9 +368,9 @@ Example:
|
||||
|
||||
The deprecation and removal schedule is based on each feature's usage and impact, with examples at two extremes:
|
||||
|
||||
- **Experimental or Low-Use Features**: For a feature that is experimental or has limited usage, backward compatibility may not be maintained between releases. Users should therefore anticipate potential breaking changes from one version to the next.
|
||||
* **Experimental or Low-Use Features**: For a feature that is experimental or has limited usage, backward compatibility may not be maintained between releases. Users should therefore anticipate potential breaking changes from one version to the next.
|
||||
|
||||
- **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning.
|
||||
* **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning.
|
||||
|
||||
These examples represent the two ends of a continuum. The specific timeline for each feature will be determined individually, balancing innovation with user stability needs.
|
||||
|
||||
@ -410,22 +380,22 @@ Warnings play a critical role in guiding users toward resolving potential issues
|
||||
|
||||
#### Definitions
|
||||
|
||||
- **Correct**: An operation is correct if it is valid, follows the intended approach, and aligns with the current best practices or guidelines within the codebase. This is the recommended or intended way to perform the operation.
|
||||
- **Supported**: An operation is supported if it is technically valid and works within the current codebase, but it may not be the most efficient, optimal, or recommended way to perform the task. This includes deprecated features or legacy approaches that still work but may be phased out in the future.
|
||||
* **Correct**: An operation is correct if it is valid, follows the intended approach, and aligns with the current best practices or guidelines within the codebase. This is the recommended or intended way to perform the operation.
|
||||
* **Supported**: An operation is supported if it is technically valid and works within the current codebase, but it may not be the most efficient, optimal, or recommended way to perform the task. This includes deprecated features or legacy approaches that still work but may be phased out in the future.
|
||||
|
||||
#### Choosing the right message
|
||||
|
||||
- **Correct → No warning**:
|
||||
* **Correct → No warning**:
|
||||
If the operation is fully valid and expected, no message should be issued. The system is working as intended, so no warning is necessary.
|
||||
|
||||
- **Correct but deserves attention → No warning, possibly a log message**:
|
||||
* **Correct but deserves attention → No warning, possibly a log message**:
|
||||
When an operation is correct but uncommon or requires special attention, providing an informational message can be helpful. This keeps users informed without implying any issue. If available, use the logger to output this message. Example:
|
||||
|
||||
```python
|
||||
logger.info("This is an informational message about a rare but correct operation.")
|
||||
```
|
||||
|
||||
- **Correct but very likely a mistake → Warning with option to disable**:
|
||||
* **Correct but very likely a mistake → Warning with option to disable**:
|
||||
In rare cases, you may want to issue a warning for a correct operation that’s very likely a mistake. In such cases, you must provide an option to suppress the warning. This can be done with a flag in the function. Example:
|
||||
|
||||
```python
|
||||
@ -436,7 +406,7 @@ Warnings play a critical role in guiding users toward resolving potential issues
|
||||
# Do something
|
||||
```
|
||||
|
||||
- **Supported but not correct → Warning**:
|
||||
* **Supported but not correct → Warning**:
|
||||
If the operation is technically supported but is deprecated, suboptimal, or could cause future issues (e.g., conflicting arguments), a warning should be raised. This message should be actionable, meaning it must explain how to resolve the issue. Example:
|
||||
|
||||
```python
|
||||
@ -446,7 +416,7 @@ Warnings play a critical role in guiding users toward resolving potential issues
|
||||
# Do something
|
||||
```
|
||||
|
||||
- **Not supported → Exception**:
|
||||
* **Not supported → Exception**:
|
||||
If the operation is invalid or unsupported, raise an exception. This indicates that the operation cannot be performed and requires immediate attention. Example:
|
||||
|
||||
```python
|
||||
|
@ -1,6 +1,7 @@
|
||||
include LICENSE
|
||||
include CONTRIBUTING.md
|
||||
include README.md
|
||||
recursive-exclude * __pycache__
|
||||
include trl/accelerate_configs/*.yaml
|
||||
include trl/templates/*.md
|
||||
include trl/accelerate_configs/*.yaml
|
||||
recursive-exclude * __pycache__
|
||||
prune tests
|
||||
|
19
Makefile
19
Makefile
@ -1,12 +1,11 @@
|
||||
.PHONY: test precommit common_tests slow_tests test_examples tests_gpu
|
||||
.PHONY: test precommit common_tests slow_tests tests_gpu test_experimental
|
||||
|
||||
check_dirs := examples tests trl
|
||||
|
||||
ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
|
||||
COMMAND_FILES_PATH = `pwd`/commands
|
||||
|
||||
test:
|
||||
pytest -n auto -m "not slow and not low-priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' tests/
|
||||
pytest -n auto -m "not slow and not low_priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' tests/
|
||||
|
||||
precommit:
|
||||
python scripts/add_copyrights.py
|
||||
@ -16,15 +15,5 @@ precommit:
|
||||
slow_tests:
|
||||
pytest -m "slow" tests/ $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
|
||||
|
||||
test_examples:
|
||||
touch temp_results_sft_tests.txt
|
||||
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
|
||||
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_sft.sh; \
|
||||
echo $$?','$${file} >> temp_results_sft_tests.txt; \
|
||||
done
|
||||
|
||||
touch temp_results_dpo_tests.txt
|
||||
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
|
||||
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_dpo.sh; \
|
||||
echo $$?','$${file} >> temp_results_dpo_tests.txt; \
|
||||
done
|
||||
test_experimental:
|
||||
pytest -k "experimental"
|
||||
|
26
README.md
26
README.md
@ -136,23 +136,13 @@ trainer.train()
|
||||
Here is a basic example of how to use the [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer):
|
||||
|
||||
```python
|
||||
from trl import RewardConfig, RewardTrainer
|
||||
from trl import RewardTrainer
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
|
||||
)
|
||||
model.config.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
|
||||
trainer = RewardTrainer(
|
||||
args=training_args,
|
||||
model=model,
|
||||
processing_class=tokenizer,
|
||||
model="Qwen/Qwen2.5-0.5B-Instruct",
|
||||
train_dataset=dataset,
|
||||
)
|
||||
trainer.train()
|
||||
@ -190,6 +180,18 @@ cd trl/
|
||||
pip install -e .[dev]
|
||||
```
|
||||
|
||||
## Experimental
|
||||
|
||||
A minimal incubation area is available under `trl.experimental` for unstable / fast-evolving features. Anything there may change or be removed in any release without notice.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from trl.experimental.new_trainer import NewTrainer
|
||||
```
|
||||
|
||||
Read more in the [Experimental docs](https://huggingface.co/docs/trl/main/en/experimental).
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
|
@ -1,58 +0,0 @@
|
||||
#!/bin/bash
|
||||
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
|
||||
# but defaults to QLoRA + PEFT
|
||||
OUTPUT_DIR="test_dpo/"
|
||||
MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
DATASET_NAME="trl-internal-testing/hh-rlhf-helpful-base-trl-style"
|
||||
MAX_STEPS=5
|
||||
BATCH_SIZE=2
|
||||
SEQ_LEN=128
|
||||
|
||||
# Handle extra arguments in case one passes accelerate configs.
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
EXTRA_TRAINING_ARGS="""--use_peft \
|
||||
--load_in_4bit
|
||||
"""
|
||||
|
||||
# This is a hack to get the number of available GPUs
|
||||
NUM_GPUS=2
|
||||
|
||||
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
else
|
||||
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
|
||||
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
|
||||
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
|
||||
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
|
||||
EXTRA_TRAINING_ARGS="--fp16"
|
||||
else
|
||||
echo "Keeping QLoRA + PEFT"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
CMD="""
|
||||
accelerate launch $EXTRA_ACCELERATE_ARGS \
|
||||
--num_processes $NUM_GPUS \
|
||||
--mixed_precision 'fp16' \
|
||||
`pwd`/trl/scripts/dpo.py \
|
||||
--model_name_or_path $MODEL_NAME \
|
||||
--dataset_name $DATASET_NAME \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_steps $MAX_STEPS \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--max_length $SEQ_LEN \
|
||||
$EXTRA_TRAINING_ARGS
|
||||
"""
|
||||
|
||||
echo "Starting program..."
|
||||
|
||||
{ # try
|
||||
echo $CMD
|
||||
eval "$CMD"
|
||||
} || { # catch
|
||||
# save log for exception
|
||||
echo "Operation Failed!"
|
||||
exit 1
|
||||
}
|
||||
exit 0
|
@ -1,59 +0,0 @@
|
||||
#!/bin/bash
|
||||
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
|
||||
# but defaults to QLoRA + PEFT
|
||||
OUTPUT_DIR="test_sft/"
|
||||
MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
DATASET_NAME="stanfordnlp/imdb"
|
||||
MAX_STEPS=5
|
||||
BATCH_SIZE=2
|
||||
SEQ_LEN=128
|
||||
|
||||
|
||||
# Handle extra arguments in case one passes accelerate configs.
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
EXTRA_TRAINING_ARGS="""--use_peft \
|
||||
--load_in_4bit
|
||||
"""
|
||||
|
||||
# Set your number of GPUs here
|
||||
NUM_GPUS=2
|
||||
|
||||
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
else
|
||||
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
|
||||
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
|
||||
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
|
||||
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
|
||||
EXTRA_TRAINING_ARGS="--fp16"
|
||||
else
|
||||
echo "Keeping QLoRA + PEFT"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
CMD="""
|
||||
accelerate launch $EXTRA_ACCELERATE_ARGS \
|
||||
--num_processes $NUM_GPUS \
|
||||
--mixed_precision 'fp16' \
|
||||
`pwd`/trl/scripts/sft.py \
|
||||
--model_name $MODEL_NAME \
|
||||
--dataset_name $DATASET_NAME \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_steps $MAX_STEPS \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--max_length $SEQ_LEN \
|
||||
$EXTRA_TRAINING_ARGS
|
||||
"""
|
||||
|
||||
echo "Starting program..."
|
||||
|
||||
{ # try
|
||||
echo $CMD
|
||||
eval "$CMD"
|
||||
} || { # catch
|
||||
# save log for exception
|
||||
echo "Operation Failed!"
|
||||
exit 1
|
||||
}
|
||||
exit 0
|
6
docker/trl-dev/Dockerfile
Normal file
6
docker/trl-dev/Dockerfile
Normal file
@ -0,0 +1,6 @@
|
||||
FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-runtime
|
||||
RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
|
||||
RUN pip install --upgrade pip uv
|
||||
RUN uv pip install --system --no-cache "git+https://github.com/huggingface/trl.git#egg=trl[liger,peft,vlm]"
|
||||
RUN uv pip install --system hf_transfer liger_kernel trackio peft
|
||||
RUN uv pip install --system https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp311-cp311-linux_x86_64.whl
|
@ -1,66 +0,0 @@
|
||||
# Builds GPU docker image of PyTorch
|
||||
# Uses multi-staged approach to reduce size
|
||||
# Stage 1
|
||||
# Use base conda image to reduce time
|
||||
FROM continuumio/miniconda3:latest AS compile-image
|
||||
# Specify py version
|
||||
ENV PYTHON_VERSION=3.10
|
||||
# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl git wget software-properties-common git-lfs && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists*
|
||||
|
||||
# Install audio-related libraries
|
||||
RUN apt-get update && \
|
||||
apt install -y ffmpeg
|
||||
|
||||
RUN apt install -y libsndfile1-dev
|
||||
RUN git lfs install
|
||||
|
||||
# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip
|
||||
|
||||
# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
# We don't install pytorch here yet since CUDA isn't available
|
||||
# instead we use the direct torch wheel
|
||||
ENV PATH /opt/conda/envs/trl/bin:$PATH
|
||||
# Activate our bash shell
|
||||
RUN chsh -s /bin/bash
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
# Stage 2
|
||||
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
|
||||
COPY --from=compile-image /opt/conda /opt/conda
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
||||
RUN chsh -s /bin/bash
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
RUN source activate trl && \
|
||||
python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
|
||||
|
||||
# Install apt libs
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl git wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists*
|
||||
|
||||
# Activate the conda env and install transformers + accelerate from source
|
||||
RUN source activate trl && \
|
||||
python3 -m pip install -U --no-cache-dir \
|
||||
librosa \
|
||||
"soundfile>=0.12.1" \
|
||||
scipy \
|
||||
transformers \
|
||||
accelerate \
|
||||
peft \
|
||||
trl[test]@git+https://github.com/huggingface/trl
|
||||
|
||||
RUN source activate trl && \
|
||||
pip freeze | grep trl
|
||||
|
||||
RUN echo "source activate trl" >> ~/.profile
|
||||
|
||||
# Activate the virtualenv
|
||||
CMD ["/bin/bash"]
|
@ -1,66 +0,0 @@
|
||||
# Builds GPU docker image of PyTorch
|
||||
# Uses multi-staged approach to reduce size
|
||||
# Stage 1
|
||||
# Use base conda image to reduce time
|
||||
FROM continuumio/miniconda3:latest AS compile-image
|
||||
# Specify py version
|
||||
ENV PYTHON_VERSION=3.10
|
||||
# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl git wget software-properties-common git-lfs && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists*
|
||||
|
||||
# Install audio-related libraries
|
||||
RUN apt-get update && \
|
||||
apt install -y ffmpeg
|
||||
|
||||
RUN apt install -y libsndfile1-dev
|
||||
RUN git lfs install
|
||||
|
||||
# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip
|
||||
|
||||
# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
# We don't install pytorch here yet since CUDA isn't available
|
||||
# instead we use the direct torch wheel
|
||||
ENV PATH /opt/conda/envs/trl/bin:$PATH
|
||||
# Activate our bash shell
|
||||
RUN chsh -s /bin/bash
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
# Stage 2
|
||||
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
|
||||
COPY --from=compile-image /opt/conda /opt/conda
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
||||
RUN chsh -s /bin/bash
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
RUN source activate trl && \
|
||||
python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
|
||||
|
||||
# Install apt libs
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl git wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists*
|
||||
|
||||
# Activate the conda env and install transformers + accelerate from source
|
||||
RUN source activate trl && \
|
||||
python3 -m pip install -U --no-cache-dir \
|
||||
librosa \
|
||||
"soundfile>=0.12.1" \
|
||||
scipy \
|
||||
git+https://github.com/huggingface/transformers \
|
||||
git+https://github.com/huggingface/accelerate \
|
||||
git+https://github.com/huggingface/peft \
|
||||
trl[test]@git+https://github.com/huggingface/trl
|
||||
|
||||
RUN source activate trl && \
|
||||
pip freeze | grep transformers
|
||||
|
||||
RUN echo "source activate trl" >> ~/.profile
|
||||
|
||||
# Activate the virtualenv
|
||||
CMD ["/bin/bash"]
|
4
docker/trl/Dockerfile
Normal file
4
docker/trl/Dockerfile
Normal file
@ -0,0 +1,4 @@
|
||||
FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-runtime
|
||||
RUN pip install --upgrade pip uv
|
||||
RUN uv pip install --system trl[liger,peft,vlm] hf_transfer trackio
|
||||
RUN uv pip install --system https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp311-cp311-linux_x86_64.whl
|
@ -11,10 +11,8 @@
|
||||
title: Dataset Formats
|
||||
- local: paper_index
|
||||
title: Paper Index
|
||||
- local: how_to_train
|
||||
title: Training FAQ
|
||||
- local: logging
|
||||
title: Understanding Logs
|
||||
- local: experimental
|
||||
title: Experimental
|
||||
title: Conceptual Guides
|
||||
- sections:
|
||||
- local: clis
|
||||
@ -53,25 +51,19 @@
|
||||
title: Example Overview
|
||||
- local: community_tutorials
|
||||
title: Community Tutorials
|
||||
- local: lora_without_regret
|
||||
title: LoRA Without Regret
|
||||
- local: sentiment_tuning
|
||||
title: Sentiment Tuning
|
||||
- local: using_llama_models
|
||||
title: Training StackLlama
|
||||
- local: detoxifying_a_lm
|
||||
title: Detoxifying a Language Model
|
||||
- local: multi_adapter_rl
|
||||
title: Multi Adapter RLHF
|
||||
title: Examples
|
||||
- sections:
|
||||
- sections: # Sorted alphabetically
|
||||
- local: alignprop_trainer
|
||||
title: AlignProp
|
||||
- local: bco_trainer
|
||||
title: BCO
|
||||
- local: cpo_trainer
|
||||
title: CPO
|
||||
- local: ddpo_trainer
|
||||
title: DDPO
|
||||
- local: dpo_trainer
|
||||
title: DPO
|
||||
- local: online_dpo_trainer
|
||||
@ -96,8 +88,6 @@
|
||||
title: RLOO
|
||||
- local: sft_trainer
|
||||
title: SFT
|
||||
- local: iterative_sft_trainer
|
||||
title: Iterative SFT
|
||||
- local: xpo_trainer
|
||||
title: XPO
|
||||
title: Trainers
|
||||
|
@ -1,102 +0,0 @@
|
||||
# Aligning Text-to-Image Diffusion Models with Reward Backpropagation
|
||||
|
||||
[](https://huggingface.co/models?other=alignprop,trl)
|
||||
|
||||
## The why
|
||||
|
||||
If your reward function is differentiable, directly backpropagating gradients from the reward models to the diffusion model is significantly more sample and compute efficient (25x) than doing policy gradient algorithm like DDPO.
|
||||
AlignProp does full backpropagation through time, which allows updating the earlier steps of denoising via reward backpropagation.
|
||||
|
||||
<div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/reward_tuning.png"/></div>
|
||||
|
||||
|
||||
## Getting started with `examples/scripts/alignprop.py`
|
||||
|
||||
The `alignprop.py` script is a working example of using the `AlignProp` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`AlignPropConfig`).
|
||||
|
||||
**Note:** one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1.
|
||||
|
||||
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post-finetuning to HuggingFace hub. The following bash command is to be entered to get things running
|
||||
|
||||
```batch
|
||||
python alignprop.py --hf_user_access_token <token>
|
||||
```
|
||||
|
||||
To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`
|
||||
|
||||
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
|
||||
|
||||
- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater than 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps)
|
||||
- The configurable truncation backprop absolute step (`--alignprop_config.truncated_backprop_timestep=49`) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False
|
||||
|
||||
## Setting up the image logging hook function
|
||||
|
||||
Expect the function to be given a dictionary with keys
|
||||
```python
|
||||
['image', 'prompt', 'prompt_metadata', 'rewards']
|
||||
|
||||
```
|
||||
and `image`, `prompt`, `prompt_metadata`, `rewards`are batched.
|
||||
You are free to log however you want the use of `wandb` or `tensorboard` is recommended.
|
||||
|
||||
### Key terms
|
||||
|
||||
- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
|
||||
- `prompt` : The prompt is the text that is used to generate the image
|
||||
- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
|
||||
- `image` : The image generated by the Stable Diffusion model
|
||||
|
||||
Example code for logging sampled images with `wandb` is given below.
|
||||
|
||||
```python
|
||||
# for logging these images to wandb
|
||||
|
||||
def image_outputs_hook(image_data, global_step, accelerate_logger):
|
||||
# For the sake of this example, we only care about the last batch
|
||||
# hence we extract the last element of the list
|
||||
result = {}
|
||||
images, prompts, rewards = [image_data['images'],image_data['prompts'],image_data['rewards']]
|
||||
for i, image in enumerate(images):
|
||||
pil = Image.fromarray(
|
||||
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
|
||||
)
|
||||
pil = pil.resize((256, 256))
|
||||
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
|
||||
accelerate_logger.log_images(
|
||||
result,
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
### Using the finetuned model
|
||||
|
||||
Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
pipeline.to("cuda")
|
||||
|
||||
pipeline.load_lora_weights('mihirpd/alignprop-trl-aesthetics')
|
||||
|
||||
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
|
||||
results = pipeline(prompts)
|
||||
|
||||
for prompt, image in zip(prompts,results.images):
|
||||
image.save(f"dump/{prompt}.png")
|
||||
```
|
||||
|
||||
## Credits
|
||||
|
||||
This work is heavily influenced by the repo [here](https://github.com/mihirp1998/AlignProp/) and the associated paper [Aligning Text-to-Image Diffusion Models with Reward Backpropagation
|
||||
by Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki](https://huggingface.co/papers/2310.03739).
|
||||
|
||||
## AlignPropTrainer
|
||||
|
||||
[[autodoc]] AlignPropTrainer
|
||||
- train
|
||||
|
||||
## AlignPropConfig
|
||||
|
||||
[[autodoc]] AlignPropConfig
|
@ -1,6 +1,6 @@
|
||||
# BCO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=bco,trl)
|
||||
[](https://huggingface.co/models?other=bco,trl)
|
||||
|
||||
TRL supports the Binary Classifier Optimization (BCO).
|
||||
The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0.
|
||||
@ -12,17 +12,16 @@ The [`BCOTrainer`] requires an [unpaired preference dataset](dataset_formats#unp
|
||||
The [`BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
|
||||
## Expected model format
|
||||
|
||||
The BCO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
|
||||
|
||||
## Using the `BCOTrainer`
|
||||
|
||||
For a detailed example have a look at the `examples/scripts/bco.py` script. At a high level we need to initialize the `BCOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
|
||||
For a detailed example have a look at the `examples/scripts/bco.py` script. At a high level we need to initialize the `BCOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
|
||||
|
||||
The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
|
||||
|
||||
|
||||
|
||||
```py
|
||||
```python
|
||||
training_args = BCOConfig(
|
||||
beta=0.1,
|
||||
)
|
||||
@ -35,9 +34,10 @@ bco_trainer = BCOTrainer(
|
||||
processing_class=tokenizer,
|
||||
)
|
||||
```
|
||||
|
||||
After this one can then call:
|
||||
|
||||
```py
|
||||
```python
|
||||
bco_trainer.train()
|
||||
```
|
||||
|
||||
@ -49,7 +49,7 @@ If the prompts in your desired and undesired datasets differ a lot, it is useful
|
||||
|
||||
Choose an embedding model and tokenizer:
|
||||
|
||||
```py
|
||||
```python
|
||||
embedding_model = AutoModel.from_pretrained(your_model_id)
|
||||
embedding_tokenizer = AutoTokenizer.from_pretrained(your_model_id)
|
||||
|
||||
@ -64,7 +64,7 @@ embedding_func = partial(embed_prompt, model=embedding_model)
|
||||
|
||||
Set `prompt_sample_size` to define how many prompts are selected to train the UDM classifier and start the training with the provided embedding function:
|
||||
|
||||
```py
|
||||
```python
|
||||
training_args = BCOConfig(
|
||||
beta=0.1,
|
||||
prompt_sample_size=512,
|
||||
|
@ -1,4 +1,7 @@
|
||||
# Best of N sampling: Alternative ways to get better model output without RL based fine-tuning
|
||||
# Best of N sampling: Alternative ways to get better model output without RL based fine-tuning
|
||||
|
||||
> [!WARNING]
|
||||
> Best-of-N sampling is deprecated and will be removed in TRL 0.25.0.
|
||||
|
||||
Within the extras module is the `best-of-n` sampler class that serves as an alternative method of generating better model output.
|
||||
As to how it fares against the RL based fine-tuning, please look in the `examples` directory for a comparison example
|
||||
@ -8,7 +11,6 @@ As to how it fares against the RL based fine-tuning, please look in the `example
|
||||
To get started quickly, instantiate an instance of the class with a model, a length sampler, a tokenizer and a callable that serves as a proxy reward pipeline that outputs reward scores for input queries
|
||||
|
||||
```python
|
||||
|
||||
from transformers import pipeline, AutoTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from trl.core import LengthSampler
|
||||
@ -19,41 +21,33 @@ reward_pipe = pipeline("sentiment-analysis", model=reward_model, device=device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(ref_model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
||||
# callable that takes a list of raw text and returns a list of corresponding reward scores
|
||||
def queries_to_scores(list_of_strings):
|
||||
return [output["score"] for output in reward_pipe(list_of_strings)]
|
||||
|
||||
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler)
|
||||
|
||||
|
||||
```
|
||||
|
||||
And assuming you have a list/tensor of tokenized queries, you can generate better output by calling the `generate` method
|
||||
|
||||
```python
|
||||
|
||||
best_of_n.generate(query_tensors, device=device, **gen_kwargs)
|
||||
|
||||
```
|
||||
|
||||
The default sample size is 4, but you can change it at the time of instance initialization like so
|
||||
|
||||
```python
|
||||
|
||||
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, sample_size=8)
|
||||
|
||||
```
|
||||
|
||||
The default output is the result of taking the top scored output for each query, but you can change it to top 2 and so on by passing the `n_candidates` argument at the time of instance initialization
|
||||
|
||||
```python
|
||||
|
||||
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, n_candidates=2)
|
||||
|
||||
```
|
||||
|
||||
There is the option of setting the generation settings (like `temperature`, `pad_token_id`) at the time of instance creation as opposed to when calling the `generate` method.
|
||||
This is done by passing a `GenerationConfig` from the `transformers` library at the time of initialization
|
||||
This is done by passing a [`~transformers.GenerationConfig`] from the `transformers` library at the time of initialization
|
||||
|
||||
```python
|
||||
|
||||
|
@ -23,3 +23,7 @@
|
||||
## BEMACallback
|
||||
|
||||
[[autodoc]] BEMACallback
|
||||
|
||||
## WeaveCallback
|
||||
|
||||
[[autodoc]] WeaveCallback
|
||||
|
@ -2,17 +2,20 @@
|
||||
|
||||
TRL provides a powerful command-line interface (CLI) to fine-tune large language models (LLMs) using methods like Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and more. The CLI abstracts away much of the boilerplate, letting you launch training jobs quickly and reproducibly.
|
||||
|
||||
## Commands
|
||||
|
||||
Currently supported commands are:
|
||||
|
||||
#### Training Commands
|
||||
### Training Commands
|
||||
|
||||
- `trl dpo`: fine-tune a LLM with DPO
|
||||
- `trl grpo`: fine-tune a LLM with GRPO
|
||||
- `trl kto`: fine-tune a LLM with KTO
|
||||
- `trl reward`: train a Reward Model
|
||||
- `trl rloo`: fine-tune a LLM with RLOO
|
||||
- `trl sft`: fine-tune a LLM with SFT
|
||||
|
||||
#### Other Commands
|
||||
### Other Commands
|
||||
|
||||
- `trl env`: get the system information
|
||||
- `trl vllm-serve`: serve a model with vLLM
|
||||
@ -41,6 +44,15 @@ trl dpo \
|
||||
--dataset_name anthropic/hh-rlhf
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Reward">
|
||||
|
||||
```bash
|
||||
trl reward \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name trl-lib/ultrafeedback_binarized
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
@ -78,6 +90,21 @@ Launch with:
|
||||
trl dpo --config dpo_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Reward">
|
||||
|
||||
```yaml
|
||||
# reward_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: trl-lib/ultrafeedback_binarized
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl reward --config reward_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
@ -138,6 +165,33 @@ Launch with:
|
||||
```bash
|
||||
trl dpo --config dpo_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Reward inline">
|
||||
|
||||
```bash
|
||||
trl reward \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name trl-lib/ultrafeedback_binarized \
|
||||
--num_processes 4
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Reward w/ config file">
|
||||
|
||||
```yaml
|
||||
# reward_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: trl-lib/ultrafeedback_binarized
|
||||
num_processes: 4
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl reward --config reward_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
@ -145,22 +199,22 @@ trl dpo --config dpo_config.yaml
|
||||
|
||||
The `--accelerate_config` flag lets you easily configure distributed training with [🤗 Accelerate](https://github.com/huggingface/accelerate). This flag accepts either:
|
||||
|
||||
* the name of a predefined config profile (built into TRL), or
|
||||
* a path to a custom Accelerate YAML config file.
|
||||
- the name of a predefined config profile (built into TRL), or
|
||||
- a path to a custom Accelerate YAML config file.
|
||||
|
||||
#### Predefined Config Profiles
|
||||
|
||||
TRL provides several ready-to-use Accelerate configs to simplify common training setups:
|
||||
|
||||
| Name | Description |
|
||||
| ------------ | ----------------------------------- |
|
||||
| `fsdp1` | Fully Sharded Data Parallel Stage 1 |
|
||||
| `fsdp2` | Fully Sharded Data Parallel Stage 2 |
|
||||
| `zero1` | DeepSpeed ZeRO Stage 1 |
|
||||
| `zero2` | DeepSpeed ZeRO Stage 2 |
|
||||
| `zero3` | DeepSpeed ZeRO Stage 3 |
|
||||
| `multi_gpu` | Multi-GPU training |
|
||||
| `single_gpu` | Single-GPU training |
|
||||
| Name | Description |
|
||||
| --- | --- |
|
||||
| `fsdp1` | Fully Sharded Data Parallel Stage 1 |
|
||||
| `fsdp2` | Fully Sharded Data Parallel Stage 2 |
|
||||
| `zero1` | DeepSpeed ZeRO Stage 1 |
|
||||
| `zero2` | DeepSpeed ZeRO Stage 2 |
|
||||
| `zero3` | DeepSpeed ZeRO Stage 3 |
|
||||
| `multi_gpu` | Multi-GPU training |
|
||||
| `single_gpu` | Single-GPU training |
|
||||
|
||||
To use one of these, just pass the name to `--accelerate_config`. TRL will automatically load the corresponding config file from `trl/accelerate_config/`.
|
||||
|
||||
@ -217,6 +271,33 @@ Launch with:
|
||||
```bash
|
||||
trl dpo --config dpo_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Reward inline">
|
||||
|
||||
```bash
|
||||
trl reward \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name trl-lib/ultrafeedback_binarized \
|
||||
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Reward w/ config file">
|
||||
|
||||
```yaml
|
||||
# reward_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: trl-lib/ultrafeedback_binarized
|
||||
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl reward --config reward_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
@ -224,7 +305,7 @@ trl dpo --config dpo_config.yaml
|
||||
|
||||
You can use dataset mixtures to combine multiple datasets into a single training dataset. This is useful for training on diverse data sources or when you want to mix different types of data.
|
||||
|
||||
<hfoptions id="accelerate_config">
|
||||
<hfoptions id="dataset_mixtures">
|
||||
<hfoption id="SFT">
|
||||
|
||||
```yaml
|
||||
@ -258,6 +339,23 @@ Launch with:
|
||||
trl dpo --config dpo_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Reward">
|
||||
|
||||
```yaml
|
||||
# reward_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
datasets:
|
||||
- path: trl-lib/tldr-preference
|
||||
- path: trl-lib/lm-human-preferences-sentiment
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl reward --config reward_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
@ -2,10 +2,13 @@
|
||||
|
||||
Community tutorials are made by active members of the Hugging Face community who want to share their knowledge and expertise with others. They are a great way to learn about the library and its features, and to get started with core classes and modalities.
|
||||
|
||||
# Language Models
|
||||
## Language Models
|
||||
|
||||
### Tutorials
|
||||
|
||||
| Task | Class | Description | Author | Tutorial | Colab |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| Reinforcement Learning | [`GRPOTrainer`] | Efficient Online Training with GRPO and vLLM in TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/grpo_vllm_online_training) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/grpo_vllm_online_training.ipynb) |
|
||||
| Reinforcement Learning | [`GRPOTrainer`] | Post training an LLM for reasoning with GRPO in TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb) |
|
||||
| Reinforcement Learning | [`GRPOTrainer`] | Mini-R1: Reproduce Deepseek R1 „aha moment“ a RL tutorial | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/mini-deepseek-r1) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/mini-deepseek-r1-aha-grpo.ipynb) |
|
||||
| Reinforcement Learning | [`GRPOTrainer`] | RL on LLaMA 3.1-8B with GRPO and Unsloth optimizations | [Andrea Manzoni](https://huggingface.co/AManzoni) | [Link](https://colab.research.google.com/github/amanzoni1/fine_tuning/blob/main/RL_LLama3_1_8B_GRPO.ipynb) | [](https://colab.research.google.com/github/amanzoni1/fine_tuning/blob/main/RL_LLama3_1_8B_GRPO.ipynb) |
|
||||
@ -15,9 +18,28 @@ Community tutorials are made by active members of the Hugging Face community who
|
||||
| Preference Optimization | [`ORPOTrainer`] | Fine-tuning Llama 3 with ORPO combining instruction tuning and preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html) | [](https://colab.research.google.com/drive/1eHNWg9gnaXErdAa8_mcvjMupbSS6rDvi) |
|
||||
| Instruction tuning | [`SFTTrainer`] | How to fine-tune open LLMs in 2025 with Hugging Face | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-llms-in-2025) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-llms-in-2025.ipynb) |
|
||||
|
||||
<Youtube id="cnGyyM0vOes" />
|
||||
### Videos
|
||||
|
||||
# Vision Language Models
|
||||
| Task | Title | Author | Video |
|
||||
| --- | --- | --- | --- |
|
||||
| Instruction tuning | Fine-tuning open AI models using Hugging Face TRL | [Wietse Venema](https://huggingface.co/wietsevenema) | [<img src="https://img.youtube.com/vi/cnGyyM0vOes/0.jpg">](https://youtu.be/cnGyyM0vOes) |
|
||||
| Instruction tuning | How to fine-tune a smol-LM with Hugging Face, TRL, and the smoltalk Dataset | [Mayurji](https://huggingface.co/iammayur) | [<img src="https://img.youtube.com/vi/jKdXv3BiLu0/0.jpg">](https://youtu.be/jKdXv3BiLu0) |
|
||||
|
||||
|
||||
<details>
|
||||
<summary>⚠️ Deprecated features notice for "How to fine-tune a smol-LM with Hugging Face, TRL, and the smoltalk Dataset" (click to expand)</summary>
|
||||
|
||||
> [!WARNING]
|
||||
> The tutorial uses two deprecated features:
|
||||
>
|
||||
> - `SFTTrainer(..., tokenizer=tokenizer)`: Use `SFTTrainer(..., processing_class=tokenizer)` instead, or simply omit it (it will be inferred from the model).
|
||||
> - `setup_chat_format(model, tokenizer)`: Use `SFTConfig(..., chat_template_path="Qwen/Qwen3-0.6B")`, where `chat_template_path` specifies the model whose chat template you want to copy.
|
||||
|
||||
</details>
|
||||
|
||||
## Vision Language Models
|
||||
|
||||
### Tutorials
|
||||
|
||||
| Task | Class | Description | Author | Tutorial | Colab |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
|
@ -1,6 +1,6 @@
|
||||
# CPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=cpo,trl)
|
||||
[](https://huggingface.co/models?other=cpo,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
@ -98,15 +98,13 @@ To use this loss as described in the paper, we can set the `loss_type="alphapo"`
|
||||
|
||||
The CPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`CPOConfig`]. The following loss functions are supported:
|
||||
|
||||
| `loss_type=` | Description |
|
||||
| -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model, and in fact, the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
|
||||
| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
|
||||
| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair, and thus the smaller the `beta`, the larger this gap is. As per the paper, the loss is averaged over log-likelihoods of the completion (unlike DPO, which is summed only). |
|
||||
| `"simpo"` | The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, simply set `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`] and `simpo_gamma` to a recommended value. |
|
||||
| `"alphapo"` | The [AlphaPO](https://huggingface.co/papers/2501.03884) method is also implemented in the [`CPOTrainer`]. This is syntactic sugar that automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`. AlphaPO applies a transformation to the reward function shape in the context of SimPO loss when the `alpha` parameter is non-zero. |
|
||||
|
||||
|
||||
| `loss_type=` | Description |
|
||||
| --- | --- |
|
||||
| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model, and in fact, the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
|
||||
| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
|
||||
| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair, and thus the smaller the `beta`, the larger this gap is. As per the paper, the loss is averaged over log-likelihoods of the completion (unlike DPO, which is summed only). |
|
||||
| `"simpo"` | The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, simply set `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`] and `simpo_gamma` to a recommended value. |
|
||||
| `"alphapo"` | The [AlphaPO](https://huggingface.co/papers/2501.03884) method is also implemented in the [`CPOTrainer`]. This is syntactic sugar that automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`. AlphaPO applies a transformation to the reward function shape in the context of SimPO loss when the `alpha` parameter is non-zero. |
|
||||
|
||||
### For Mixture of Experts Models: Enabling the auxiliary loss
|
||||
|
||||
|
@ -2,8 +2,6 @@
|
||||
|
||||
TRL is designed with modularity in mind so that users are able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. Note: Although these examples use the DPOTrainer, the customization applies to most (if not all) trainers.
|
||||
|
||||
|
||||
|
||||
## Use different optimizers and schedulers
|
||||
|
||||
By default, the `DPOTrainer` creates a `torch.optim.AdamW` optimizer. You can create and define a different optimizer and pass it to `DPOTrainer` as follows:
|
||||
@ -84,11 +82,11 @@ trainer = DPOTrainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Pass 8-bit reference models
|
||||
|
||||
## Pass 8-bit reference models
|
||||
|
||||
Since `trl` supports all keyword arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning.
|
||||
|
||||
Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/en/peft#load-in-8bit-or-4bit).
|
||||
Read more about 8-bit model loading in `transformers` [Load in 8bit or 4bit](https://huggingface.co/docs/transformers/en/peft#load-in-8bit-or-4bit).
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
@ -114,7 +112,7 @@ trainer.train()
|
||||
|
||||
## Use the accelerator cache optimizer
|
||||
|
||||
When training large models, you should better handle the accelerator cache by iteratively clearing it. To do so, simply pass `optimize_device_cache=True` to `DPOConfig`:
|
||||
When training large models, you should better handle the accelerator cache by iteratively clearing it. To do so, simply pass `optimize_device_cache=True` to [`DPOConfig`]:
|
||||
|
||||
```python
|
||||
training_args = DPOConfig(..., optimize_device_cache=True)
|
||||
|
@ -81,7 +81,7 @@ This guide provides an overview of the dataset formats and types supported by ea
|
||||
<td>Stepwise supervision</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": "Which number is larger, 9.8 or 9.11?",
|
||||
"completions": ["The fractional part of 9.8 is 0.8.",
|
||||
"completions": ["The fractional part of 9.8 is 0.8.",
|
||||
"The fractional part of 9.11 is 0.11.",
|
||||
"0.11 is greater than 0.8.",
|
||||
"Hence, 9.11 > 9.8."],
|
||||
@ -289,31 +289,28 @@ prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the
|
||||
|
||||
For examples of prompt-only datasets, refer to the [Prompt-only datasets collection](https://huggingface.co/collections/trl-lib/prompt-only-datasets-677ea25245d20252cea00368).
|
||||
|
||||
<Tip>
|
||||
|
||||
While both the prompt-only and language modeling types are similar, they differ in how the input is handled. In the prompt-only type, the prompt represents a partial input that expects the model to complete or continue, while in the language modeling type, the input is treated as a complete sentence or sequence. These two types are processed differently by TRL. Below is an example showing the difference in the output of the `apply_chat_template` function for each type:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
from trl import apply_chat_template
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
|
||||
|
||||
# Example for prompt-only type
|
||||
prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
|
||||
apply_chat_template(prompt_only_example, tokenizer)
|
||||
# Output: {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n'}
|
||||
|
||||
# Example for language modeling type
|
||||
lm_example = {"messages": [{"role": "user", "content": "What color is the sky?"}]}
|
||||
apply_chat_template(lm_example, tokenizer)
|
||||
# Output: {'text': '<|user|>\nWhat color is the sky?<|end|>\n<|endoftext|>'}
|
||||
```
|
||||
|
||||
- The prompt-only output includes a `'<|assistant|>\n'`, indicating the beginning of the assistant’s turn and expecting the model to generate a completion.
|
||||
- In contrast, the language modeling output treats the input as a complete sequence and terminates it with `'<|endoftext|>'`, signaling the end of the text and not expecting any additional content.
|
||||
|
||||
</Tip>
|
||||
> [!TIP]
|
||||
> While both the prompt-only and language modeling types are similar, they differ in how the input is handled. In the prompt-only type, the prompt represents a partial input that expects the model to complete or continue, while in the language modeling type, the input is treated as a complete sentence or sequence. These two types are processed differently by TRL. Below is an example showing the difference in the output of the `apply_chat_template` function for each type:
|
||||
>
|
||||
> ```python
|
||||
> from transformers import AutoTokenizer
|
||||
> from trl import apply_chat_template
|
||||
>
|
||||
> tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct")
|
||||
>
|
||||
> # Example for prompt-only type
|
||||
> prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
|
||||
> apply_chat_template(prompt_only_example, tokenizer)
|
||||
> # Output: {'prompt': '<|user|>\nWhat color is the sky?<|end|>\n<|assistant|>\n'}
|
||||
>
|
||||
> # Example for language modeling type
|
||||
> lm_example = {"messages": [{"role": "user", "content": "What color is the sky?"}]}
|
||||
> apply_chat_template(lm_example, tokenizer)
|
||||
> # Output: {'text': '<|user|>\nWhat color is the sky?<|end|>\n<|endoftext|>'}
|
||||
> ```
|
||||
>
|
||||
> - The prompt-only output includes a `'<|assistant|>\n'`, indicating the beginning of the assistant’s turn and expecting the model to generate a completion.
|
||||
> - In contrast, the language modeling output treats the input as a complete sequence and terminates it with `'<|endoftext|>'`, signaling the end of the text and not expecting any additional content.
|
||||
|
||||
#### Prompt-completion
|
||||
|
||||
@ -390,31 +387,27 @@ For examples of stepwise supervision datasets, refer to the [Stepwise supervisio
|
||||
|
||||
Choosing the right dataset type depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset types supported by each TRL trainer.
|
||||
|
||||
| Trainer | Expected dataset type |
|
||||
| ----------------------- | ------------------------------------------------------------------------------------------------------ |
|
||||
| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) |
|
||||
| [`GRPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) |
|
||||
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`PPOTrainer`] | Tokenized language modeling |
|
||||
| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
|
||||
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
|
||||
| [`RLOOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
|
||||
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| Trainer | Expected dataset type |
|
||||
| --- | --- |
|
||||
| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) |
|
||||
| [`GRPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`PPOTrainer`] | Tokenized language modeling |
|
||||
| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
|
||||
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
|
||||
| [`RLOOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
|
||||
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
|
||||
<Tip>
|
||||
|
||||
TRL trainers only support standard dataset formats, [for now](https://github.com/huggingface/trl/issues/2071). If you have a conversational dataset, you must first convert it into a standard format.
|
||||
For more information on how to work with conversational datasets, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
|
||||
|
||||
</Tip>
|
||||
> [!TIP]
|
||||
> TRL trainers only support standard dataset formats, [for now](https://github.com/huggingface/trl/issues/2071). If you have a conversational dataset, you must first convert it into a standard format.
|
||||
> For more information on how to work with conversational datasets, refer to the [Working with conversational datasets in TRL](#working-with-conversational-datasets-in-trl) section.
|
||||
|
||||
## Working with conversational datasets in TRL
|
||||
|
||||
@ -423,7 +416,7 @@ Fortunately, TRL offers tools to easily handle this conversion, which are detail
|
||||
|
||||
### Converting a conversational dataset into a standard dataset
|
||||
|
||||
To convert a conversational dataset into a standard dataset, you need to _apply a chat template_ to the dataset. A chat template is a predefined structure that typically includes placeholders for user and assistant messages. This template is provided by the tokenizer of the model you use.
|
||||
To convert a conversational dataset into a standard dataset, you need to *apply a chat template* to the dataset. A chat template is a predefined structure that typically includes placeholders for user and assistant messages. This template is provided by the tokenizer of the model you use.
|
||||
|
||||
For detailed instructions on using chat templating, refer to the [Chat templating section in the `transformers` documentation](https://huggingface.co/docs/transformers/en/chat_templating).
|
||||
|
||||
@ -466,27 +459,21 @@ dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
|
||||
# 'completion': ['It is blue.<|end|>\n<|endoftext|>', 'In the sky.<|end|>\n<|endoftext|>']}
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
> [!WARNING]
|
||||
> We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle of a conversation.
|
||||
> For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks.
|
||||
|
||||
We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle of a conversation.
|
||||
For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks.
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
It's important to note that chat templates are model-specific. For example, if you use the chat template from [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) with the above example, you get a different output:
|
||||
|
||||
```python
|
||||
apply_chat_template(example, AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct"))
|
||||
# Output:
|
||||
# {'prompt': '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n',
|
||||
# 'completion': 'It is blue.<|im_end|>\n'}
|
||||
```
|
||||
|
||||
Always use the chat template associated with the model you're working with. Using the wrong template can lead to inaccurate or unexpected results.
|
||||
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> It's important to note that chat templates are model-specific. For example, if you use the chat template from [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) with the above example, you get a different output:
|
||||
>
|
||||
> ```python
|
||||
> apply_chat_template(example, AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct"))
|
||||
> # Output:
|
||||
> # {'prompt': '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n',
|
||||
> # 'completion': 'It is blue.<|im_end|>\n'}
|
||||
> ```
|
||||
>
|
||||
> Always use the chat template associated with the model you're working with. Using the wrong template can lead to inaccurate or unexpected results.
|
||||
|
||||
## Using any dataset with TRL: preprocessing and conversion
|
||||
|
||||
@ -532,15 +519,15 @@ This section provides example code to help you convert between different dataset
|
||||
|
||||
For simplicity, some of the examples below do not follow this recommendation and use the standard format. However, the conversions can be applied directly to the conversational format without modification.
|
||||
|
||||
| From \ To | Language modeling | Prompt-completion | Prompt-only | Preference with implicit prompt | Preference | Unpaired preference | Stepwise supervision |
|
||||
| ------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------- | --------------------------------------------------------- | --------------------------------------------------------- | ------------------------------------------------------------------------- | -------------------- |
|
||||
| Language modeling | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
|
||||
| Prompt-completion | [🔗](#from-prompt-completion-to-language-modeling-dataset) | N/A | [🔗](#from-prompt-completion-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
|
||||
| Prompt-only | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
|
||||
| Preference with implicit prompt | [🔗](#from-preference-with-implicit-prompt-to-language-modeling-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-completion-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-only-dataset) | N/A | [🔗](#from-implicit-to-explicit-prompt-preference-dataset) | [🔗](#from-preference-with-implicit-prompt-to-unpaired-preference-dataset) | N/A |
|
||||
| Preference | [🔗](#from-preference-to-language-modeling-dataset) | [🔗](#from-preference-to-prompt-completion-dataset) | [🔗](#from-preference-to-prompt-only-dataset) | [🔗](#from-explicit-to-implicit-prompt-preference-dataset) | N/A | [🔗](#from-preference-to-unpaired-preference-dataset) | N/A |
|
||||
| Unpaired preference | [🔗](#from-unpaired-preference-to-language-modeling-dataset) | [🔗](#from-unpaired-preference-to-prompt-completion-dataset) | [🔗](#from-unpaired-preference-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
|
||||
| Stepwise supervision | [🔗](#from-stepwise-supervision-to-language-modeling-dataset) | [🔗](#from-stepwise-supervision-to-prompt-completion-dataset) | [🔗](#from-stepwise-supervision-to-prompt-only-dataset) | N/A | N/A | [🔗](#from-stepwise-supervision-to-unpaired-preference-dataset) | N/A |
|
||||
| From \ To | Language modeling | Prompt-completion | Prompt-only | Preference with implicit prompt | Preference | Unpaired preference | Stepwise supervision |
|
||||
| --- | --- | --- | --- | --- | --- | --- | --- |
|
||||
| Language modeling | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
|
||||
| Prompt-completion | [🔗](#from-prompt-completion-to-language-modeling-dataset) | N/A | [🔗](#from-prompt-completion-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
|
||||
| Prompt-only | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
|
||||
| Preference with implicit prompt | [🔗](#from-preference-with-implicit-prompt-to-language-modeling-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-completion-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-only-dataset) | N/A | [🔗](#from-implicit-to-explicit-prompt-preference-dataset) | [🔗](#from-preference-with-implicit-prompt-to-unpaired-preference-dataset) | N/A |
|
||||
| Preference | [🔗](#from-preference-to-language-modeling-dataset) | [🔗](#from-preference-to-prompt-completion-dataset) | [🔗](#from-preference-to-prompt-only-dataset) | [🔗](#from-explicit-to-implicit-prompt-preference-dataset) | N/A | [🔗](#from-preference-to-unpaired-preference-dataset) | N/A |
|
||||
| Unpaired preference | [🔗](#from-unpaired-preference-to-language-modeling-dataset) | [🔗](#from-unpaired-preference-to-prompt-completion-dataset) | [🔗](#from-unpaired-preference-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
|
||||
| Stepwise supervision | [🔗](#from-stepwise-supervision-to-language-modeling-dataset) | [🔗](#from-stepwise-supervision-to-prompt-completion-dataset) | [🔗](#from-stepwise-supervision-to-prompt-only-dataset) | N/A | N/A | [🔗](#from-stepwise-supervision-to-unpaired-preference-dataset) | N/A |
|
||||
|
||||
### From prompt-completion to language modeling dataset
|
||||
|
||||
@ -716,13 +703,10 @@ dataset = unpair_preference_dataset(dataset)
|
||||
'label': True}
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
|
||||
Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
|
||||
This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
|
||||
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
|
||||
> Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
|
||||
> This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
|
||||
|
||||
### From preference to language modeling dataset
|
||||
|
||||
@ -857,13 +841,10 @@ dataset = unpair_preference_dataset(dataset)
|
||||
'label': True}
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
|
||||
Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
|
||||
This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
|
||||
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
|
||||
> Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
|
||||
> This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
|
||||
|
||||
### From unpaired preference to language modeling dataset
|
||||
|
||||
@ -1038,7 +1019,7 @@ Some trainers also support fine-tuning vision-language models (VLMs) using image
|
||||
|
||||
A conversational vision dataset differs from a standard conversational dataset in two key ways:
|
||||
|
||||
1. The dataset must contain the key `images` with the image data.
|
||||
1. The dataset must contain the key `images` with the image data (as lists of PIL images) or `image` with a single PIL image.
|
||||
2. The `"content"` field in messages must be a list of dictionaries, where each dictionary specifies the type of data: `"image"` or `"text"`.
|
||||
|
||||
Example:
|
||||
@ -1062,3 +1043,23 @@ An example of a conversational vision dataset is the [openbmb/RLAIF-V-Dataset](h
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
> [!NOTE]
|
||||
> Mixing text-only and vision-language data in the dataset is possible, but it requires `transformers` version 4.57.0 or later. Example:
|
||||
>
|
||||
> ```python
|
||||
> dataset = Dataset.from_dict({
|
||||
> "prompt": [
|
||||
> [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky in the image?"}]}],
|
||||
> [{"role": "user", "content": [{"type": "text", "text": "What is the capital of France?"}]}],
|
||||
> ],
|
||||
> "completion": [
|
||||
> [{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}],
|
||||
> [{"role": "assistant", "content": [{"type": "text", "text": "Paris."}]}],
|
||||
> ],
|
||||
> "images": [
|
||||
> [PIL.Image.open("path/to/sky_image1.png")],
|
||||
> [],
|
||||
> ],
|
||||
> })
|
||||
> ```
|
||||
|
@ -1,131 +0,0 @@
|
||||
# Denoising Diffusion Policy Optimization
|
||||
|
||||
[](https://huggingface.co/models?other=ddpo,trl)
|
||||
|
||||
## The why
|
||||
|
||||
| Before | After DDPO finetuning |
|
||||
| --- | --- |
|
||||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/pre_squirrel.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/post_squirrel.png"/></div> |
|
||||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/pre_crab.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/post_crab.png"/></div> |
|
||||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/pre_starfish.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/post_starfish.png"/></div> |
|
||||
|
||||
|
||||
## Getting started with Stable Diffusion finetuning with reinforcement learning
|
||||
|
||||
The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers`
|
||||
library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers.
|
||||
Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to be made.
|
||||
|
||||
There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.**
|
||||
There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide.
|
||||
|
||||
The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is the only way to constrain the scheduler step output to an output type befitting of the algorithm at hand (DDPO).
|
||||
|
||||
For a more detailed look into the interface and the associated default implementation, go [here](https://github.com/lvwerra/trl/tree/main/trl/models/modeling_sd_base.py)
|
||||
|
||||
Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training.
|
||||
|
||||
Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images.
|
||||
|
||||
## Getting started with `examples/scripts/ddpo.py`
|
||||
|
||||
The `ddpo.py` script is a working example of using the `DDPO` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`DDPOConfig`).
|
||||
|
||||
**Note:** one A100 GPU is recommended to get this running. Anything below a A100 will not be able to run this example script and even if it does via relatively smaller sized parameters, the results will most likely be poor.
|
||||
|
||||
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running
|
||||
|
||||
```batch
|
||||
python ddpo.py --hf_user_access_token <token>
|
||||
```
|
||||
|
||||
To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`
|
||||
|
||||
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
|
||||
|
||||
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) should be greater than or equal to the configurable training batch size (`--ddpo_config.train_batch_size=3`)
|
||||
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by the configurable train batch size (`--ddpo_config.train_batch_size=3`)
|
||||
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by both the configurable gradient accumulation steps (`--ddpo_config.train_gradient_accumulation_steps=1`) and the configurable accelerator processes count
|
||||
|
||||
## Setting up the image logging hook function
|
||||
|
||||
Expect the function to be given a list of lists of the form
|
||||
```python
|
||||
[[image, prompt, prompt_metadata, rewards, reward_metadata], ...]
|
||||
|
||||
```
|
||||
and `image`, `prompt`, `prompt_metadata`, `rewards`, `reward_metadata` are batched.
|
||||
The last list in the lists of lists represents the last sample batch. You are likely to want to log this one
|
||||
While you are free to log however you want the use of `wandb` or `tensorboard` is recommended.
|
||||
|
||||
### Key terms
|
||||
|
||||
- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
|
||||
- `reward_metadata` : The reward metadata is the metadata associated with the reward. Think of this as extra information payload delivered alongside the reward
|
||||
- `prompt` : The prompt is the text that is used to generate the image
|
||||
- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
|
||||
- `image` : The image generated by the Stable Diffusion model
|
||||
|
||||
Example code for logging sampled images with `wandb` is given below.
|
||||
|
||||
```python
|
||||
# for logging these images to wandb
|
||||
|
||||
def image_outputs_hook(image_data, global_step, accelerate_logger):
|
||||
# For the sake of this example, we only care about the last batch
|
||||
# hence we extract the last element of the list
|
||||
result = {}
|
||||
images, prompts, _, rewards, _ = image_data[-1]
|
||||
for i, image in enumerate(images):
|
||||
pil = Image.fromarray(
|
||||
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
|
||||
)
|
||||
pil = pil.resize((256, 256))
|
||||
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
|
||||
accelerate_logger.log_images(
|
||||
result,
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
### Using the finetuned model
|
||||
|
||||
Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
|
||||
|
||||
```python
|
||||
|
||||
import torch
|
||||
from trl import DefaultDDPOStableDiffusionPipeline
|
||||
|
||||
pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/ddpo-finetuned-sd-model")
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
# memory optimization
|
||||
pipeline.vae.to(device, torch.float16)
|
||||
pipeline.text_encoder.to(device, torch.float16)
|
||||
pipeline.unet.to(device, torch.float16)
|
||||
|
||||
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
|
||||
results = pipeline(prompts)
|
||||
|
||||
for prompt, image in zip(prompts,results.images):
|
||||
image.save(f"{prompt}.png")
|
||||
|
||||
```
|
||||
|
||||
## Credits
|
||||
|
||||
This work is heavily influenced by the repo [here](https://github.com/kvablack/ddpo-pytorch) and the associated paper [Training Diffusion Models
|
||||
with Reinforcement Learning by Kevin Black, Michael Janner, Yilan Du, Ilya Kostrikov, Sergey Levine](https://huggingface.co/papers/2305.13301).
|
||||
|
||||
## DDPOTrainer
|
||||
|
||||
[[autodoc]] DDPOTrainer
|
||||
|
||||
## DDPOConfig
|
||||
|
||||
[[autodoc]] DDPOConfig
|
||||
|
@ -1,10 +1,7 @@
|
||||
# DeepSpeed Integration
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Section under construction. Feel free to contribute!
|
||||
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> Section under construction. Feel free to contribute!
|
||||
|
||||
TRL supports training with DeepSpeed, a library that implements advanced training optimization techniques. These include optimizer state partitioning, offloading, gradient partitioning, and more.
|
||||
|
||||
|
@ -1,187 +0,0 @@
|
||||
# Detoxifying a Language Model using PPO
|
||||
|
||||
Language models (LMs) are known to sometimes generate toxic outputs. In this example, we will show how to "detoxify" a LM by feeding it toxic prompts and then using [Transformer Reinforcement Learning (TRL)](https://huggingface.co/docs/trl/index) and Proximal Policy Optimization (PPO) to "detoxify" it.
|
||||
|
||||
Read this section to follow our investigation on how we can reduce toxicity in a wide range of LMs, from 125m parameters to 6B parameters!
|
||||
|
||||
Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/huggingface/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo:
|
||||
|
||||
| File | Description | Colab link |
|
||||
|---|---| --- |
|
||||
| [`gpt-j-6b-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x |
|
||||
| [`evaluate-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x |
|
||||
| [Interactive Space](https://huggingface.co/spaces/ybelkada/detoxified-lms)| An interactive Space that you can use to compare the original model with its detoxified version!| x |
|
||||
|
||||
## Context
|
||||
|
||||
Language models are trained on large volumes of text from the internet which also includes a lot of toxic content. Naturally, language models pick up the toxic patterns during training. Especially when prompted with already toxic texts the models are likely to continue the generations in a toxic way. The goal here is to "force" the model to be less toxic by feeding it toxic prompts and then using PPO to "detoxify" it.
|
||||
|
||||
### Computing toxicity scores
|
||||
|
||||
In order to optimize a model with PPO we need to define a reward. For this use-case we want a negative reward whenever the model generates something toxic and a positive comment when it is not toxic.
|
||||
Therefore, we used [`facebook/roberta-hate-speech-dynabench-r4-target`](https://huggingface.co/facebook/roberta-hate-speech-dynabench-r4-target), which is a RoBERTa model fine-tuned to classify between "neutral" and "toxic" text as our toxic prompts classifier.
|
||||
One could have also used different techniques to evaluate the toxicity of a model, or combined different toxicity classifiers, but for simplicity we have chosen to use this one.
|
||||
|
||||
### Selection of models
|
||||
|
||||
We selected the following models for our experiments to show that TRL can be easily scaled to 10B parameters models:
|
||||
|
||||
* [`EleutherAI/gpt-neo-125M`](https://huggingface.co/EleutherAI/gpt-neo-125M) (125 million parameters)
|
||||
* [`EleutherAI/gpt-neo-2.7B`](https://huggingface.co/EleutherAI/gpt-neo-2.7B) (2.7 billion parameters)
|
||||
* [`EleutherAI/gpt-j-6B`](https://huggingface.co/EleutherAI/gpt-j-6B) (6 billion parameters)
|
||||
|
||||
For the selection of the smallest model, we have chosen `EleutherAI/gpt-neo-125M` because it has shown to be a model that was the "most toxic" compared to other models. We have run toxicity evaluation using `facebook/roberta-hate-speech-dynabench-r4-target` model on 4 different architectures on a subset of `allenai/real-toxicity-prompts` dataset. Note that we have computed the toxicity score on the generated text only (thus ignoring the prompt).
|
||||
|
||||
| Model | Mean toxicity score |
|
||||
|---|---|
|
||||
| `gpt2` | 0.01602 |
|
||||
| `facebook/opt-350m` | 0.01628 |
|
||||
| `bigscience/bloom-560m` | 0.00767 |
|
||||
| `EleutherAI/gpt-neo-125M` | **0.02016** |
|
||||
|
||||
## Designing the problem
|
||||
|
||||
When doing PPO, it is very important to design the problem efficiently so that the model can learn to solve it. Let's cover the topics that were important for the model to converge.
|
||||
|
||||
### Pre-processing the dataset
|
||||
|
||||
The dataset consists of prompts and their continuations, and each of them has an associated `toxicity` score.
|
||||
|
||||
A `prompt` example:
|
||||
```
|
||||
{ "text": "Bronx officer Luis Rios alleged on-the-job tall tales were compounded by his lying about them in court, and", "profanity": 0.08360514, "sexually_explicit": 0.118823394, "identity_attack": 0.09608547, "flirtation": 0.23102224, "threat": 0.13215046, "insult": 0.10130461, "severe_toxicity": 0.04068885, "toxicity": 0.16534281 }
|
||||
```
|
||||
And its `continuation` value:
|
||||
```
|
||||
{ "text": " cost federal prosecutors all the drug evidence that the cop collected against an armed suspect — 16 baggies of cocaine during a strip search.", "severe_toxicity": 0.067997746, "toxicity": 0.1694093, "profanity": 0.11931301, "sexually_explicit": 0.12521537, "identity_attack": 0.09268324, "flirtation": 0.13452998, "threat": 0.31312028, "insult": 0.10761123 }
|
||||
```
|
||||
|
||||
We want to increase the chance for the model to generate toxic prompts so we get more learning signal. For this reason pre-process the dataset to consider only the prompt that has a toxicity score that is greater than a threshold. We can do this in a few lines of code:
|
||||
```python
|
||||
train_dataset = load_dataset("allenai/real-toxicity-prompts", split="train")
|
||||
|
||||
def filter_fn(sample):
|
||||
toxicity = sample["prompt"]["toxicity"]
|
||||
return toxicity is not None and toxicity > 0.3
|
||||
|
||||
train_dataset = train_dataset.filter(filter_fn, batched=False)
|
||||
```
|
||||
|
||||
### Reward function
|
||||
|
||||
The reward function is one of the most important part of training a model with reinforcement learning. It is the function that will tell the model if it is doing well or not.
|
||||
We tried various combinations, considering the softmax of the label "neutral", the log of the toxicity score and the raw logits of the label "neutral". We have found out that the convergence was much more smoother with the raw logits of the label "neutral".
|
||||
```python
|
||||
logits = toxicity_model(**toxicity_inputs).logits.float()
|
||||
rewards = (logits[:, 0]).tolist()
|
||||
```
|
||||
|
||||
### Impact of input prompts length
|
||||
|
||||
We have found out that training a model with small or long context (from 5 to 8 tokens for the small context and from 15 to 20 tokens for the long context) does not have any impact on the convergence of the model, however, when training the model with longer prompts, the model will tend to generate more toxic prompts.
|
||||
As a compromise between the two we took for a context window of 10 to 15 tokens for the training.
|
||||
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-long-vs-short-context.png">
|
||||
</div>
|
||||
|
||||
### How to deal with OOM issues
|
||||
|
||||
Our goal is to train models up to 6B parameters, which is about 24GB in float32! Here are two tricks we use to be able to train a 6B model on a single 40GB-RAM GPU:
|
||||
|
||||
- Use `bfloat16` precision: Simply load your model in `bfloat16` when calling `from_pretrained` and you can reduce the size of the model by 2:
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
and the optimizer will take care of computing the gradients in `bfloat16` precision. Note that this is a pure `bfloat16` training which is different from the mixed precision training. If one wants to train a model in mixed-precision, they should not load the model with `dtype` and specify the mixed precision argument when calling `accelerate config`.
|
||||
|
||||
- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by specifying `num_shared_layers` argument when calling the `create_reference_model()` function. For example, if you want to share the first 6 layers of the model, you can do it like this:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-shared-layers.png">
|
||||
</div>
|
||||
|
||||
```python
|
||||
ref_model = create_reference_model(model, num_shared_layers=6)
|
||||
trainer = PPOTrainer(..., ref_model=ref_model)
|
||||
```
|
||||
|
||||
In the example above this means that the model has the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model).
|
||||
|
||||
- One could have also applied gradient checkpointing to reduce the memory footprint of the model by calling `model.pretrained_model.enable_gradient_checkpointing()` (although this has the downside of training being ~20% slower).
|
||||
|
||||
## Training the model!
|
||||
|
||||
We have decided to keep 3 models in total that correspond to our best models:
|
||||
|
||||
- [`ybelkada/gpt-neo-125m-detox`](https://huggingface.co/ybelkada/gpt-neo-125m-detox)
|
||||
- [`ybelkada/gpt-neo-2.7B-detox`](https://huggingface.co/ybelkada/gpt-neo-2.7B-detox)
|
||||
- [`ybelkada/gpt-j-6b-detox`](https://huggingface.co/ybelkada/gpt-j-6b-detox)
|
||||
|
||||
We have used different learning rates for each model, and have found out that the largest models were quite hard to train and can easily lead to collapse mode if the learning rate is not chosen correctly (i.e. if the learning rate is too high):
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-collapse-mode.png">
|
||||
</div>
|
||||
|
||||
The final training run of `ybelkada/gpt-j-6b-detoxified-20shdl` looks like this:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-gpt-j-final-run-2.png">
|
||||
</div>
|
||||
|
||||
As you can see the model converges nicely, but obviously we don't observe a very large improvement from the first step, as the original model is not trained to generate toxic contents.
|
||||
|
||||
Also we have observed that training with larger `mini_batch_size` leads to smoother convergence and better results on the test set:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-gpt-j-mbs-run.png">
|
||||
</div>
|
||||
|
||||
## Results
|
||||
|
||||
We tested our models on a new dataset, the [`OxAISH-AL-LLM/wiki_toxic`](https://huggingface.co/datasets/OxAISH-AL-LLM/wiki_toxic) dataset. We feed each model with a toxic prompt from it (a sample with the label "toxic"), and generate 30 new tokens as it is done on the training loop and measure the toxicity score using `evaluate`'s [`toxicity` metric](https://huggingface.co/spaces/ybelkada/toxicity).
|
||||
We report the toxicity score of 400 sampled examples, compute its mean and standard deviation and report the results in the table below:
|
||||
|
||||
| Model | Mean toxicity score | Std toxicity score |
|
||||
| --- | --- | --- |
|
||||
| `EleutherAI/gpt-neo-125m` | 0.1627 | 0.2997 |
|
||||
| `ybelkada/gpt-neo-125m-detox` | **0.1148** | **0.2506** |
|
||||
| --- | --- | --- |
|
||||
| `EleutherAI/gpt-neo-2.7B` | 0.1884 | 0.3178 |
|
||||
| `ybelkada/gpt-neo-2.7B-detox` | **0.0916** | **0.2104** |
|
||||
| --- | --- | --- |
|
||||
| `EleutherAI/gpt-j-6B` | 0.1699 | 0.3033 |
|
||||
| `ybelkada/gpt-j-6b-detox` | **0.1510** | **0.2798** |
|
||||
|
||||
<div class="column" style="text-align:center">
|
||||
<figure>
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-final-barplot.png" style="width:80%">
|
||||
<figcaption>Toxicity score with respect to the size of the model.</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
Below are few generation examples of `gpt-j-6b-detox` model:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-toxicity-examples.png">
|
||||
</div>
|
||||
|
||||
The evaluation script can be found [here](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py).
|
||||
|
||||
### Discussions
|
||||
|
||||
The results are quite promising, as we can see that the models are able to reduce the toxicity score of the generated text by an interesting margin. The gap is clear for `gpt-neo-2B` model but we see less so for the `gpt-j-6B` model. There are several things we could try to improve the results on the largest model starting with training with larger `mini_batch_size` and probably allowing to back-propagate through more layers (i.e. use less shared layers).
|
||||
|
||||
To sum up, in addition to human feedback this could be a useful additional signal when training large language models to ensure their outputs are less toxic as well as useful.
|
||||
|
||||
### Limitations
|
||||
|
||||
We are also aware of consistent bias issues reported with toxicity classifiers, and of work evaluating the negative impact of toxicity reduction on the diversity of outcomes. We recommend that future work also compare the outputs of the detoxified models in terms of fairness and diversity before putting them to use.
|
||||
|
||||
## What is next?
|
||||
|
||||
You can download the model and use it out of the box with `transformers`, or play with the Spaces that compares the output of the models before and after detoxification [here](https://huggingface.co/spaces/ybelkada/detoxified-lms).
|
@ -1,8 +1,7 @@
|
||||
# Distributing Training
|
||||
|
||||
<Tip warning={true}>
|
||||
Section under construction. Feel free to contribute!
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> Section under construction. Feel free to contribute!
|
||||
|
||||
## Multi-GPU Training with TRL
|
||||
|
||||
@ -27,11 +26,12 @@ accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml train
|
||||
This automatically distributes the workload across all available GPUs.
|
||||
|
||||
Under the hood, [🤗 Accelerate](https://github.com/huggingface/accelerate) creates one model per GPU. Each process:
|
||||
|
||||
- Processes its own batch of data
|
||||
- Computes the loss and gradients for that batch
|
||||
- Shares gradient updates across all GPUs
|
||||
|
||||

|
||||

|
||||
|
||||
The effective batch size is calculated as:
|
||||
|
||||
@ -49,12 +49,142 @@ Example, these configurations are equivalent, and should yield the same results:
|
||||
| 1 | 4 | 8 | Lower memory usage, slower training |
|
||||
| 8 | 4 | 1 | Multi-GPU to get the best of both worlds |
|
||||
|
||||
<Tip>
|
||||
> [!TIP]
|
||||
> Having one model per GPU can lead to high memory usage, which may not be feasible for large models or low-memory GPUs. In such cases, you can leverage [DeepSpeed](https://github.com/deepspeedai/DeepSpeed), which provides optimizations like model sharding, Zero Redundancy Optimizer, mixed precision training, and offloading to CPU or NVMe. Check out our [DeepSpeed Integration](deepspeed_integration) guide for more details.
|
||||
|
||||
Having one model per GPU can lead to high memory usage, which may not be feasible for large models or low-memory GPUs. In such cases, you can leverage [DeepSpeed](https://github.com/deepspeedai/DeepSpeed), which provides optimizations like model sharding, Zero Redundancy Optimizer, mixed precision training, and offloading to CPU or NVMe. Check out our [DeepSpeed Integration](deepspeed_integration) guide for more details.
|
||||
## Context Parallelism
|
||||
|
||||
</Tip>
|
||||
Context Parallelism (CP) is a parallelization technique that enables training with longer sequences by splitting the sequence dimension across multiple GPUs. Each GPU processes a portion of the sequence, allowing you to train with sequences longer than what would fit on a single GPU's memory.
|
||||
|
||||
For more details on CP, see the [Ultrascale Playbook - Context Parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=context_parallelism).
|
||||
|
||||
CP is particularly useful when:
|
||||
|
||||
- You want to train with very long sequences (>32k tokens)
|
||||
- Single GPU memory is insufficient for your desired sequence length
|
||||
- You need to maintain sequence coherence across the full context
|
||||
|
||||
### Requirements and Limitations
|
||||
|
||||
CP has specific requirements:
|
||||
|
||||
1. **Accelerate 1.10 or higher** is required
|
||||
2. **FSDP2 (PyTorch FSDP v2)** is required as the distributed training backend
|
||||
3. **SDPA attention** - Flash Attention is currently not supported with CP
|
||||
4. **Sequence length divisibility** - sequences must be divisible by `cp_size * 2`. This is now automatically handled using the `pad_to_multiple_of` parameter in the data collator, which works seamlessly with both standard and padding-free modes.
|
||||
|
||||
### Configuration
|
||||
|
||||
To enable CP, you need to configure both Accelerate and your training arguments:
|
||||
|
||||
#### Accelerate Configuration
|
||||
|
||||
Use one of the provided accelerate config files (e.g. [`context_parallel_2gpu.yaml`](https://github.com/huggingface/trl/blob/main/examples/accelerate_configs/context_parallel_2gpu.yaml) for 2 GPUs):
|
||||
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
fsdp_config:
|
||||
fsdp_activation_checkpointing: true # Enable activation checkpointing for memory efficiency
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_version: 2
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 2 # Number of GPUs
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
parallelism_config:
|
||||
parallelism_config_dp_replicate_size: 1
|
||||
parallelism_config_dp_shard_size: 1
|
||||
parallelism_config_tp_size: 1
|
||||
parallelism_config_cp_size: 2 # Context parallel size
|
||||
```
|
||||
|
||||
#### Training Configuration
|
||||
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(
|
||||
# required
|
||||
pad_to_multiple_of=4, # ensures divisibility by cp_size * 2
|
||||
# to get the most out of CP
|
||||
max_length=16384, # long sequence length
|
||||
packing=True, # use packing to reduce padding
|
||||
use_liger_kernel=True, # compatible with CP
|
||||
gradient_checkpointing=False, # The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg can't be set to True simultaneously
|
||||
per_device_train_batch_size=1,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
Then, launch your training script with the appropriate accelerate config file:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file context_parallel_2gpu.yaml train.py
|
||||
```
|
||||
|
||||
### Best Practices
|
||||
|
||||
1. **Use the `pad_to_multiple_of` parameter** - This is now the recommended way to ensure sequence length divisibility:
|
||||
- For `cp_size=2`: use `pad_to_multiple_of=4` (since `cp_size * 2 = 4`)
|
||||
- For `cp_size=4`: use `pad_to_multiple_of=8` (since `cp_size * 2 = 8`)
|
||||
- The data collator automatically pads sequences to the required multiple, ensuring compatibility with CP
|
||||
|
||||
2. **Use packing with padding** - The default BFD (Best Fit Decreasing) strategy works perfectly:
|
||||
- Preserves sequence boundaries and maintains training quality
|
||||
- Works seamlessly with both `padding_free=True` and standard padding modes
|
||||
|
||||
3. **Combine with other memory optimizations** like Liger kernels, bfloat16, and gradient checkpointing
|
||||
|
||||
4. **Start with smaller context parallel sizes** (2-4 GPUs) before scaling up
|
||||
|
||||
5. **Monitor memory usage** across all GPUs to ensure balanced workload
|
||||
|
||||
### Benchmarking Context Parallelism
|
||||
|
||||
We benchmarked CP to highlight its potential improvements in training efficiency.
|
||||
Our experiments were conducted using **1, 2, 4, and 8 H100 GPUs**, though the results can be extended to larger clusters with more nodes and GPUs.
|
||||
|
||||
For the setup, we fine-tuned an **8B model** ([Qwen/Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B)) using the provided accelerate configuration
|
||||
([`context_parallel_2gpu.yaml`](https://github.com/huggingface/trl/blob/main/examples/accelerate_configs/context_parallel_2gpu.yaml)).
|
||||
We adjusted `num_processes` and `parallelism_config_cp_size` based on the number of GPUs for each run.
|
||||
Training was performed with the [sft.py](https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py) example script, combined with the parameters described above.
|
||||
|
||||
The results below summarize the **maximum trainable sequence length** and **iterations per second** for different numbers of GPUs. A value marked as `OOM` indicates that the configuration ran out of memory and could not be trained.
|
||||
|
||||
These results show that **Context Parallelism (CP) scales effectively with more GPUs**, enabling training on much longer sequences. With **8 GPUs**, context lengths of over **300k tokens** become feasible, unlocking training with extremely long contexts while maintaining reasonable throughput.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/context_parallelism_max_length_plot.png" alt="CP Max content length" width="45%"/>
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/context_parallelism_s_it_plot.png" alt="CP seconds/iteration" width="45%"/>
|
||||
</div>
|
||||
|
||||
> [!TIP]
|
||||
> Accelerate also supports **N-Dimensional Parallelism (ND-parallelism)**, which enables you to combine different parallelization strategies to efficiently distribute model training across multiple GPUs.
|
||||
>
|
||||
> You can learn more and explore configuration examples in the [Accelerate ND-parallelism guide](https://github.com/huggingface/accelerate/blob/main/examples/torch_native_parallelism/README.md#nd-parallelism).
|
||||
|
||||
### Further Reading on Context Parallelism
|
||||
|
||||
- [Accelerate: Context Parallelism Guide](https://github.com/huggingface/accelerate/blob/main/docs/source/concept_guides/context_parallelism.md)
|
||||
- [Accelerate Example: 128k Sequence Length](https://github.com/huggingface/accelerate/blob/main/examples/torch_native_parallelism/README.md#context-parallelism-128k-sequence-length)
|
||||
- [Hugging Face Blog: Enabling Long-Context Training with Sequence Parallelism in Axolotl](https://huggingface.co/blog/axolotl-ai-co/long-context-with-sequence-parallelism-in-axolotl)
|
||||
- [Snowflake Engineering Blog: Arctic Long Sequence Training (ALST) — Scalable and Efficient Training for Multi-Million Token Sequences (Note that they use a different strategy)](https://www.snowflake.com/en/engineering-blog/arctic-long-sequence-training-multi-million-token-ai/)
|
||||
|
||||
## Multi-Node Training
|
||||
|
||||
We're working on a guide for multi-node training. Stay tuned! 🚀
|
||||
We're working on a guide for multi-node training. Stay tuned! 🚀
|
||||
|
@ -1,6 +1,6 @@
|
||||
# DPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=dpo,trl) [](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment)
|
||||
[](https://huggingface.co/models?other=dpo,trl) [](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment)
|
||||
|
||||
## Overview
|
||||
|
||||
@ -19,7 +19,7 @@ Then, fine-tuning a language model via DPO consists of two steps and is easier t
|
||||
|
||||
This process is illustrated in the sketch below (from [Figure 1 of the DPO paper](https://huggingface.co/papers/2305.18290)):
|
||||
|
||||

|
||||

|
||||
|
||||
Read more about DPO algorithm in the [original paper](https://huggingface.co/papers/2305.18290).
|
||||
|
||||
@ -101,7 +101,6 @@ Additionally, unlike standard text-based models where a `tokenizer` is used, for
|
||||
|
||||
For a complete example of fine-tuning a vision-language model, refer to the script in [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py).
|
||||
|
||||
|
||||
## Example script
|
||||
|
||||
We provide an example script to train a model using the DPO method. The script is available in [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py)
|
||||
@ -192,10 +191,10 @@ To scale how much the auxiliary loss contributes to the total loss, use the hype
|
||||
|
||||
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks for DPO listed below:
|
||||
|
||||
| GPU | Model | Dataset | 🤗 | 🤗 + FlashAttention 2 | 🦥 Unsloth | 🦥 VRAM saved |
|
||||
| -------- | --------- | ---------- | --- | --------------------- | --------- | ------------ |
|
||||
| A100 40G | Zephyr 7b | Ultra Chat | 1x | 1.24x | **1.88x** | -11.6% |
|
||||
| Tesla T4 | Zephyr 7b | Ultra Chat | 1x | 1.09x | **1.55x** | -18.6% |
|
||||
| GPU | Model | Dataset | 🤗 | 🤗 + FlashAttention 2 | 🦥 Unsloth | 🦥 VRAM saved |
|
||||
| --- | --- | --- | --- | --- | --- | --- |
|
||||
| A100 40G | Zephyr 7b | Ultra Chat | 1x | 1.24x | **1.88x** | -11.6% |
|
||||
| Tesla T4 | Zephyr 7b | Ultra Chat | 1x | 1.09x | **1.55x** | -18.6% |
|
||||
|
||||
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
|
||||
|
||||
@ -295,3 +294,7 @@ dpo_trainer = DPOTrainer(
|
||||
## DataCollatorForPreference
|
||||
|
||||
[[autodoc]] trainer.dpo_trainer.DataCollatorForPreference
|
||||
|
||||
## FDivergenceType
|
||||
|
||||
[[autodoc]] trainer.dpo_trainer.FDivergenceType
|
||||
|
@ -1,16 +1,15 @@
|
||||
# Examples
|
||||
|
||||
|
||||
## Introduction
|
||||
|
||||
The examples should work in any of the following settings (with the same script):
|
||||
- single GPU
|
||||
- multi GPUs (using PyTorch distributed mode)
|
||||
- multi GPUs (using DeepSpeed ZeRO-Offload stages 1, 2, & 3)
|
||||
- fp16 (mixed-precision), fp32 (normal precision), or bf16 (bfloat16 precision)
|
||||
|
||||
To run it in each of these various modes, first initialize the accelerate
|
||||
configuration with `accelerate config`
|
||||
- single GPU
|
||||
- multi GPUs (using PyTorch distributed mode)
|
||||
- multi GPUs (using DeepSpeed ZeRO-Offload stages 1, 2, & 3)
|
||||
- fp16 (mixed-precision), fp32 (normal precision), or bf16 (bfloat16 precision)
|
||||
|
||||
To run it in each of these various modes, first initialize the accelerate configuration with `accelerate config`.
|
||||
|
||||
To train with a 4-bit or 8-bit model, please run:
|
||||
|
||||
@ -28,25 +27,22 @@ accelerate config # will prompt you to define the training configuration
|
||||
|
||||
Then, it is encouraged to launch jobs with `accelerate launch`!
|
||||
|
||||
|
||||
## Maintained Examples
|
||||
|
||||
Scripts can be used as examples of how to use TRL trainers. They are located in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) directory. Additionally, we provide examples in the [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directory. These examples are maintained and tested regularly.
|
||||
|
||||
| File | Description |
|
||||
| --- | --- |
|
||||
| [`examples/scripts/alignprop.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/alignprop.py) | This script shows how to use the [`AlignPropTrainer`] to fine-tune a diffusion model. |
|
||||
| [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. |
|
||||
| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`examples/scripts/ddpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ddpo.py) | This script shows how to use the [`DDPOTrainer`] to fine-tune a stable diffusion model using reinforcement learning. |
|
||||
| [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a model. |
|
||||
| [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. |
|
||||
| [`examples/scripts/evals/judge_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/evals/judge_tldr.py) | This script shows how to use [`HfPairwiseJudge`] or [`OpenAIPairwiseJudge`] to judge model generations. |
|
||||
| [`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py) | This script shows how to use the [`GKDTrainer`] to fine-tune a model. |
|
||||
| [`trl/scripts/grpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/grpo.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a model. |
|
||||
| [`examples/scripts/grpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/grpo_vlm.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. |
|
||||
| [`examples/scripts/gspo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune model for reasoning using the [AI-MO/NuminaMath-TIR](https://huggingface.co/datasets/AI-MO/NuminaMath-TIR) dataset. |
|
||||
| [`examples/scripts/gspo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo_vlm.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. |
|
||||
| [`examples/scripts/grpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/grpo_vlm.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. |
|
||||
| [`examples/scripts/gspo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune model for reasoning using the [AI-MO/NuminaMath-TIR](https://huggingface.co/datasets/AI-MO/NuminaMath-TIR) dataset. |
|
||||
| [`examples/scripts/gspo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo_vlm.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. |
|
||||
| [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`KTOTrainer`] to fine-tune a model. |
|
||||
| [`examples/scripts/mpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/mpo_vlm.py) | This script shows how to use MPO via the [`DPOTrainer`] to align a model based on preferences using the [HuggingFaceH4/rlaif-v_formatted](https://huggingface.co/datasets/HuggingFaceH4/rlaif-v_formatted) dataset and a set of loss weights with weights. |
|
||||
| [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`NashMDTrainer`] to fine-tune a model. |
|
||||
@ -74,11 +70,6 @@ Here are also some easier-to-run colab notebooks that you can use to get started
|
||||
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
|
||||
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
|
||||
|
||||
|
||||
We also have some other examples that are less maintained but can be used as a reference:
|
||||
1. **[research_projects](https://github.com/huggingface/trl/tree/main/examples/research_projects)**: Check out this folder to find the scripts used for some research projects that used TRL (LM de-toxification, Stack-Llama, etc.)
|
||||
|
||||
|
||||
## Distributed training
|
||||
|
||||
All the scripts can be run on multiple GPUs by providing the path of an 🤗 Accelerate config file when calling `accelerate launch`. To launch one of them on one or multiple GPUs, run the following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine and `--all_arguments_of_the_script` with your arguments).
|
||||
|
163
docs/source/experimental.md
Normal file
163
docs/source/experimental.md
Normal file
@ -0,0 +1,163 @@
|
||||
# Experimental Features
|
||||
|
||||
The `trl.experimental` namespace provides a minimal, clearly separated space for fast iteration on new ideas.
|
||||
|
||||
> [!WARNING]
|
||||
> **Stability contract:** Anything under `trl.experimental` may change or be removed in *any* release (including patch versions) without prior deprecation. Do not rely on these APIs for production workloads.
|
||||
|
||||
## Current Experimental Features
|
||||
|
||||
The following modules are currently available under [`trl.experimental`](https://github.com/huggingface/trl/tree/main/trl/experimental).
|
||||
This list is not exhaustive and may change at any time.
|
||||
|
||||
### BEMA for Reference Model
|
||||
|
||||
This feature implements the BEMA algorithm to update the reference model during DPO training.
|
||||
|
||||
```python
|
||||
from trl.experimental.bema_for_ref_model import BEMACallback, DPOTrainer
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
pref_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
|
||||
ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
||||
|
||||
bema_callback = BEMACallback(update_ref_model=True)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
train_dataset=pref_dataset,
|
||||
processing_class=tokenizer,
|
||||
callbacks=[bema_callback],
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### GFPO
|
||||
|
||||
This feature implements the GFPO algorithm to enforce concise reasoning in the model's output generation, as proposed in the paper [Sample More to Think Less: Group Filtered Policy Optimization for Concise Reasoning](https://huggingface.co/papers/2508.09726).
|
||||
|
||||
To activate GFPO in [`GFPOTrainer`]:
|
||||
|
||||
- set `num_remains_in_group` in [`GFPOConfig`]
|
||||
- define a group filter function and set it to `group_filter_func` in [`GFPOTrainer`]. `group_filter_func` will score the `num_generations` completions and The GFPOTrainer filters groups according to their scores to get top `num_remains_in_group` completions as a new group. Model will be trained on the filtered group.
|
||||
|
||||
```python
|
||||
# train_gfpo.py
|
||||
from trl.experimental.gfpo import GFPOConfig, GFPOTrainer
|
||||
|
||||
# dummy group filter to scores the completions based on its indice in group
|
||||
class GroupFilter:
|
||||
def __call__(self, group_completions, group_rewards, **kwargs):
|
||||
group_scores = []
|
||||
for completions, rewards in zip(group_completions, group_rewards):
|
||||
scores = [float(i) for i in range(len(completions))]
|
||||
group_scores.append(scores)
|
||||
return group_scores
|
||||
|
||||
training_args = GFPOConfig(
|
||||
output_dir="Qwen3-0.6B-GFPO",
|
||||
per_device_train_batch_size=4,
|
||||
num_remains_in_group=2,
|
||||
bf16=True,
|
||||
)
|
||||
trainer = GFPOTrainer(
|
||||
model="Qwen/Qwen3-0.6B",
|
||||
reward_funcs=...,
|
||||
train_dataset=...,
|
||||
args=training_args,
|
||||
group_filter_func=GroupFilter(),
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### GSPO-token
|
||||
|
||||
In the paper [Group Sequence Policy Optimization](https://huggingface.co/papers/2507.18071), the authors propose a token-level objective variant to GSPO, called GSPO-token. To use GSPO-token, you can use the `GRPOTrainer` class in `trl.experimental.gspo_token`.
|
||||
|
||||
```python
|
||||
from trl.experimental.gspo_token import GRPOTrainer
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
importance_sampling_level="sequence_token",
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> To leverage GSPO-token, the user will need to provide the per-token advantage \\( \hat{A_{i,t}} \\) for each token \\( t \\) in the sequence \\( i \\) (i.e., make \\( \hat{A_{i,t}} \\) varies with \\( t \\)—which isn't the case here, \\( \hat{A_{i,t}}=\hat{A_{i}} \\)). Otherwise, GSPO-Token gradient is just equivalent to the original GSPO implementation.
|
||||
|
||||
### GRPO With Replay Buffer
|
||||
|
||||
This experimental trainer, trains a model with GRPO but replaces groups (and corresponding completions) that have 0 standard deviation with groups with high rewards and standard deviation that've been used to train a model in prior batches.
|
||||
|
||||
#### Usage
|
||||
|
||||
```python
|
||||
from trl.experimental.grpo_with_replay_buffer import GRPOWithReplayBufferTrainer
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
||||
|
||||
# Guarantee that some rewards have 0 std
|
||||
def custom_reward_func(completions, **kwargs):
|
||||
if torch.rand(1).item() < 0.25:
|
||||
return [0] * len(completions) # simulate some None rewards
|
||||
else:
|
||||
return torch.rand(len(completions)).tolist()
|
||||
|
||||
training_args = GRPOWithReplayBufferConfig(
|
||||
output_dir=self.tmp_dir,
|
||||
learning_rate=1e-4,
|
||||
per_device_train_batch_size=4,
|
||||
num_generations=4,
|
||||
max_completion_length=8,
|
||||
replay_buffer_size=8,
|
||||
report_to="none",
|
||||
)
|
||||
trainer = GRPOTrainer(
|
||||
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
||||
reward_funcs=[custom_reward_func],
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
To silence the runtime notice:
|
||||
|
||||
```bash
|
||||
export TRL_EXPERIMENTAL_SILENCE=1
|
||||
```
|
||||
|
||||
## Promotion Path (Simple)
|
||||
|
||||
1. **Prototype outside the main repo:** Start development in your own fork or a separate repository to iterate quickly.
|
||||
2. **Experimental inclusion:** Once it’s ready for early users, move the idea into `trl.experimental.<feature>`.
|
||||
3. **Improve:** Add tests, a short doc/example, and demonstrate the usage.
|
||||
4. **Promote:** Once the API proves stable and there is clear interest or adoption from the community, move it into `trl.<feature>` (stable module).
|
||||
|
||||
## FAQ
|
||||
|
||||
**Why not just use branches?**
|
||||
Because branches are not shipped to users; experimental code inside the package lets early adopters try things and give feedback.
|
||||
|
||||
**Can these APIs change or vanish without warning?**
|
||||
Yes. Anything inside `trl.experimental` can change or disappear in *any* release.
|
||||
|
||||
**Should I use this in production?**
|
||||
Only if you are fine with updating your code quickly when things change.
|
||||
|
||||
**Will maintainers promptly fix issues in `trl.experimental`?**
|
||||
Not necessarily. The experimental module is a playground for new ideas, and maintainers may not prioritize bug fixes or feature requests there. Issues may remain unresolved until (or unless) the feature graduates to the stable API.
|
@ -1,17 +1,17 @@
|
||||
# Generalized Knowledge Distillation Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=gkd,trl)
|
||||
[](https://huggingface.co/models?other=gkd,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
Generalized Knowledge Distillation (GKD) was proposed in [On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes](https://huggingface.co/papers/2306.13649) by Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem.
|
||||
Generalized Knowledge Distillation (GKD) was proposed in [On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes](https://huggingface.co/papers/2306.13649) by Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Knowledge distillation (KD) is widely used for compressing a teacher model to reduce its inference cost and memory footprint, by training a smaller student model. However, current KD methods for auto-regressive sequence models suffer from distribution mismatch between output sequences seen during training and those generated by the student during inference. To address this issue, we introduce Generalized Knowledge Distillation (GKD). Instead of solely relying on a fixed set of output sequences, GKD trains the student on its self-generated output sequences by leveraging feedback from the teacher on such sequences. Unlike supervised KD approaches, GKD also offers the flexibility to employ alternative loss functions between the student and teacher, which can be useful when the student lacks the expressivity to mimic the teacher's distribution. Furthermore, GKD facilitates the seamless integration of distillation with RL fine-tuning (RLHF). We demonstrate the efficacy of GKD for distilling auto-regressive language models on summarization, translation, and arithmetic reasoning tasks, and task-agnostic distillation for instruction-tuning.
|
||||
|
||||
|
||||
The key aspects of GKD are:
|
||||
|
||||
1. It addresses the train-inference distribution mismatch in auto-regressive sequence models by training the student model on its self-generated output sequences.
|
||||
2. GKD allows flexibility in choosing different divergence measures between student and teacher models via the generalized Jensen-Shannon Divergence (JSD), which can be useful when the student lacks the capacity to fully mimic the teacher.
|
||||
|
||||
@ -20,6 +20,7 @@ This post-training method was contributed by [Kashif Rasul](https://huggingface.
|
||||
## Usage tips
|
||||
|
||||
The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`GKDConfig`] namely:
|
||||
|
||||
* `lmbda`: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, the loss reduces to supervised JSD where the student is trained with the token-level probabilities of the teacher. When `lmbda=1.0`, the loss reduces to on-policy JSD, where the student generates output sequences and token-specific feedback on these sequences from the teacher. For values in between [0, 1] it is random between the two based on the `lmbda` value for each batch.
|
||||
* `seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher.
|
||||
* `beta`: controls the interpolation in the generalized Jensen-Shannon Divergence. When `beta=0.0` the loss approximates forward KL divergence, while for `beta=1.0` the loss approximates reverse KL divergence. For values in between [0, 1] it interpolates between the two.
|
||||
@ -85,6 +86,7 @@ trainer.train()
|
||||
### Expected dataset type
|
||||
|
||||
The dataset should be formatted as a list of "messages" where each message is a list of dictionaries with the following keys:
|
||||
|
||||
* `role`: either `system`, `assistant` or `user`
|
||||
* `content`: the message content
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# GRPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=grpo,trl)
|
||||
[](https://huggingface.co/models?other=grpo,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
@ -56,13 +56,13 @@ accelerate launch train_grpo.py
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 1 day.
|
||||
|
||||

|
||||

|
||||
|
||||
## Looking deeper into the GRPO method
|
||||
|
||||
GRPO is an online learning algorithm, meaning it improves iteratively by using the data generated by the trained model itself during training. The intuition behind GRPO objective is to maximize the advantage of the generated completions, while ensuring that the model remains close to the reference policy. To understand how GRPO works, it can be broken down into four main steps: **Generating completions**, **computing the advantage**, **estimating the KL divergence**, and **computing the loss**.
|
||||
|
||||

|
||||

|
||||
|
||||
### Generating completions
|
||||
|
||||
@ -76,17 +76,11 @@ $$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$
|
||||
|
||||
This approach gives the method its name: **Group Relative Policy Optimization (GRPO)**.
|
||||
|
||||
<Tip>
|
||||
> [!TIP]
|
||||
> It was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that scaling by \\( \text{std}(\mathbf{r}) \\) may cause a question-level difficulty bias. You can disable this scaling by setting `scale_rewards=False` in [`GRPOConfig`].
|
||||
|
||||
It was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that scaling by \\( \text{std}(\mathbf{r}) \\) may cause a question-level difficulty bias. You can disable this scaling by setting `scale_rewards=False` in [`GRPOConfig`].
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
[Part I: Tricks or Traps? A Deep Dive into RL for LLM Reasoning (Lite PPO)](https://huggingface.co/papers/2508.08221) showed that calculating the mean at the local (group) level and the standard deviation at the global (batch) level enables more robust reward shaping. You can use this scaling strategy by setting `scale_rewards="batch"` in [`GRPOConfig`].
|
||||
|
||||
</Tip>
|
||||
> [!TIP]
|
||||
> [Part I: Tricks or Traps? A Deep Dive into RL for LLM Reasoning (Lite PPO)](https://huggingface.co/papers/2508.08221) showed that calculating the mean at the local (group) level and the standard deviation at the global (batch) level enables more robust reward shaping. You can use this scaling strategy by setting `scale_rewards="batch"` in [`GRPOConfig`].
|
||||
|
||||
### Estimating the KL divergence
|
||||
|
||||
@ -105,17 +99,11 @@ $$
|
||||
|
||||
where the first term represents the scaled advantage and the second term penalizes deviations from the reference policy through KL divergence.
|
||||
|
||||
<Tip>
|
||||
> [!TIP]
|
||||
> Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we don't scale by \\( \frac{1}{|o_i|} \\) because it was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that this introduces a response-level length bias. More details in [loss types](#loss-types).
|
||||
|
||||
Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we don't scale by \\( \frac{1}{|o_i|} \\) because it was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that this introduces a response-level length bias. More details in [loss types](#loss-types).
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we use \\( \beta = 0.0 \\) by default, meaning that the KL divergence term is not used. This choice is motivated by several recent studies (e.g., [Open-Reasoner-Zero: An Open Source Approach to Scaling Up Reinforcement Learning on the Base Model](https://huggingface.co/papers/2503.24290)) which have shown that the KL divergence term is not essential for training with GRPO. As a result, it has become common practice to exclude it (e.g. [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783), [DAPO: An Open-Source LLM Reinforcement Learning System at Scale](https://huggingface.co/papers/2503.14476)). If you wish to include the KL divergence term, you can set `beta` in [`GRPOConfig`] to a non-zero value.
|
||||
|
||||
</Tip>
|
||||
> [!TIP]
|
||||
> Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we use \\( \beta = 0.0 \\) by default, meaning that the KL divergence term is not used. This choice is motivated by several recent studies (e.g., [Open-Reasoner-Zero: An Open Source Approach to Scaling Up Reinforcement Learning on the Base Model](https://huggingface.co/papers/2503.24290)) which have shown that the KL divergence term is not essential for training with GRPO. As a result, it has become common practice to exclude it (e.g. [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783), [DAPO: An Open-Source LLM Reinforcement Learning System at Scale](https://huggingface.co/papers/2503.14476)). If you wish to include the KL divergence term, you can set `beta` in [`GRPOConfig`] to a non-zero value.
|
||||
|
||||
In the original paper, this formulation is generalized to account for multiple updates after each generation (denoted \\( \mu \\), can be set with `num_iterations` in [`GRPOConfig`]) by leveraging the **clipped surrogate objective**:
|
||||
|
||||
@ -167,7 +155,7 @@ While training and evaluating, we record the following reward metrics:
|
||||
- `completions/mean_terminated_length`: The average length of generated completions that terminate with EOS.
|
||||
- `completions/min_terminated_length`: The minimum length of generated completions that terminate with EOS.
|
||||
- `completions/max_terminated_length`: The maximum length of generated completions that terminate with EOS.
|
||||
- `completions/clipped_ratio` : The ratio of truncated (clipped) completions.
|
||||
- `completions/clipped_ratio`: The ratio of truncated (clipped) completions.
|
||||
- `reward/{reward_func_name}/mean`: The average reward from a specific reward function.
|
||||
- `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function.
|
||||
- `reward`: The overall average reward after applying reward weights.
|
||||
@ -178,10 +166,10 @@ While training and evaluating, we record the following reward metrics:
|
||||
- `entropy`: Average entropy of token predictions across generated completions. (If `mask_truncated_completions=True`, masked sequences tokens are excluded.)
|
||||
- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero.
|
||||
- `clip_ratio/region_mean`: The ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities where the GRPO objective is clipped to stay within the trust region:
|
||||
$$
|
||||
\text{clip}\left( r_{i,t}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \qquad r_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}\,.
|
||||
$$
|
||||
A higher value means more tokens are clipped, which constrains how much the policy $\pi_\theta$ can change.
|
||||
$$
|
||||
\text{clip}\left( r_{i,t}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \qquad r_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}\,.
|
||||
$$
|
||||
A higher value means more tokens are clipped, which constrains how much the policy $\pi_\theta$ can change.
|
||||
- `clip_ratio/low_mean`: The average ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
|
||||
- `clip_ratio/low_min`: The minimum ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
|
||||
- `clip_ratio/high_mean`: The average ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\)
|
||||
@ -192,26 +180,28 @@ A higher value means more tokens are clipped, which constrains how much the poli
|
||||
### Speed up training with vLLM-powered generation
|
||||
|
||||
Generation is often the main bottleneck when training with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a high-throughput, low-latency inference engine for LLMs. To enable it, first install the package with
|
||||
|
||||
```shell
|
||||
pip install trl[vllm]
|
||||
```
|
||||
|
||||
We support two ways of using vLLM during training: **server mode** and **colocate mode**.
|
||||
|
||||
<Tip>
|
||||
By default, Truncated Importance Sampling is activated for vLLM generation to address the generation-training mismatch that occurs when using different frameworks. This can be turned off by setting `vllm_importance_sampling_correction=False`. For more information, see [Truncated Importance Sampling](paper_index#truncated-importance-sampling)
|
||||
</Tip>
|
||||
> [!TIP]
|
||||
> By default, Truncated Importance Sampling is activated for vLLM generation to address the generation-training mismatch that occurs when using different frameworks. This can be turned off by setting `vllm_importance_sampling_correction=False`. For more information, see [Truncated Importance Sampling](paper_index#truncated-importance-sampling)
|
||||
|
||||
#### 🔌 Option 1: Server mode
|
||||
|
||||
In this mode, vLLM runs in a separate process (and using separate GPUs) and communicates with the trainer via HTTP. This is ideal if you have dedicated GPUs for inference.
|
||||
|
||||
1. **Start the vLLM server**:
|
||||
|
||||
```bash
|
||||
trl vllm-serve --model <model_name>
|
||||
```
|
||||
|
||||
2. **Enable server mode in your training script**:
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
@ -222,11 +212,8 @@ In this mode, vLLM runs in a separate process (and using separate GPUs) and comm
|
||||
)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure that the server is using different GPUs than the trainer, otherwise you may run into NCCL errors. You can specify the GPUs to use with the `CUDA_VISIBLE_DEVICES` environment variable.
|
||||
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> Make sure that the server is using different GPUs than the trainer, otherwise you may run into NCCL errors. You can specify the GPUs to use with the `CUDA_VISIBLE_DEVICES` environment variable.
|
||||
|
||||
#### 🧩 Option 2: Colocate mode
|
||||
|
||||
@ -242,30 +229,19 @@ training_args = GRPOConfig(
|
||||
)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
> [!TIP]
|
||||
> Depending on the model size and the overall GPU memory requirements for training, you may need to adjust the `vllm_gpu_memory_utilization` parameter in [`GRPOConfig`] to avoid underutilization or out-of-memory errors.
|
||||
>
|
||||
> We provide a [HF Space](https://huggingface.co/spaces/trl-lib/recommend-vllm-memory) to help estimate the recommended GPU memory utilization based on your model configuration and experiment settings. Simply use it as follows to get `vllm_gpu_memory_utilization` recommendation:
|
||||
>
|
||||
> <iframe src="https://trl-lib-recommend-vllm-memory.hf.space" frameborder="0" width="850" height="450"></iframe>
|
||||
>
|
||||
> If the recommended value does not work in your environment, we suggest adding a small buffer (e.g., +0.05 or +0.1) to the recommended value to ensure stability.
|
||||
>
|
||||
> If you still find you are getting out-of-memory errors set `vllm_enable_sleep_mode` to True and the vllm parameters and cache will be offloaded during the optimization step. For more information, see [Reducing Memory Usage with vLLM Sleep Mode](reducing_memory_usage#vllm-sleep-mode).
|
||||
|
||||
Depending on the model size and the overall GPU memory requirements for training, you may need to adjust the `vllm_gpu_memory_utilization` parameter in [`GRPOConfig`] to avoid underutilization or out-of-memory errors.
|
||||
|
||||
We provide a [HF Space](https://huggingface.co/spaces/trl-lib/recommend-vllm-memory) to help estimate the recommended GPU memory utilization based on your model configuration and experiment settings. Simply use it as follows to get `vllm_gpu_memory_utilization` recommendation:
|
||||
|
||||
<iframe
|
||||
src="https://trl-lib-recommend-vllm-memory.hf.space"
|
||||
frameborder="0"
|
||||
width="850"
|
||||
height="450"
|
||||
></iframe>
|
||||
|
||||
If the recommended value does not work in your environment, we suggest adding a small buffer (e.g., +0.05 or +0.1) to the recommended value to ensure stability.
|
||||
|
||||
If you still find you are getting out-of-memory errors set `vllm_sleep_enabled` to True and the vllm parameters and cache will be offloaded during the optimization step. For more information, see [Reducing Memory Usage with vLLM Sleep Mode](reducing_memory_usage#vllm-sleep-mode).
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
By default, GRPO uses `MASTER_ADDR=localhost` and `MASTER_PORT=12345` for vLLM, but you can override these values by setting the environment variables accordingly.
|
||||
|
||||
</Tip>
|
||||
> [!TIP]
|
||||
> By default, GRPO uses `MASTER_ADDR=localhost` and `MASTER_PORT=12345` for vLLM, but you can override these values by setting the environment variables accordingly.
|
||||
|
||||
For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).
|
||||
|
||||
@ -273,7 +249,7 @@ For more information, see [Speeding up training with vLLM](speeding_up_training#
|
||||
|
||||
When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include:
|
||||
|
||||
- **DeepSpeed ZeRO Stage 3**: ZeRO leverages data parallelism to distribute model states (weights, gradients, optimizer states) across multiple GPUs and CPUs, reducing memory and compute requirements on each device. Since large models cannot fit on a single GPU, using ZeRO Stage 3 is required for training such model. For more details, see [DeepSpeed Integration](deepspeed_integration).
|
||||
- **DeepSpeed ZeRO Stage 3**: ZeRO leverages data parallelism to distribute model states (weights, gradients, optimizer states) across multiple GPUs and CPUs, reducing memory and compute requirements on each device. Since large models cannot fit on a single GPU, using ZeRO Stage 3 is required for training such models. For more details, see [DeepSpeed Integration](deepspeed_integration).
|
||||
- **Accelerate**: Accelerate is a library that simplifies distributed training across multiple GPUs and nodes. It provides a simple API to launch distributed training and handles the complexities of distributed training, such as data parallelism, gradient accumulation, and distributed data loading. For more details, see [Distributing Training](distributing_training).
|
||||
- **vLLM**: See the previous section on how to use vLLM to speed up generation.
|
||||
|
||||
@ -352,7 +328,7 @@ The [`GRPOTrainer`] supports using custom reward functions instead of dense rewa
|
||||
- `completions` (contains the generated completions),
|
||||
- `completions_ids` (contains the tokenized completions),
|
||||
- `trainer_state` ([`~transformers.TrainerState`]): The current state of the trainer. This can be used to implement dynamic reward functions, such as curriculum learning, where the reward is adjusted based on the training progress.
|
||||
- All columns names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument.
|
||||
- All column names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument.
|
||||
|
||||
The easiest way to comply with this requirement is to use `**kwargs` in the function signature.
|
||||
- Depending on the dataset format, the input will vary:
|
||||
@ -381,7 +357,7 @@ You can test it as follows:
|
||||
[2.0, 4.0]
|
||||
```
|
||||
|
||||
#### Example 1.1: Reward longer completions (based in the number of characters)
|
||||
#### Example 1.1: Reward longer completions (based on the number of characters)
|
||||
|
||||
Same as the previous example, but this time the reward function is based on the number of characters instead of tokens.
|
||||
|
||||
@ -401,10 +377,10 @@ You can test it as follows:
|
||||
[6.0, 12.0]
|
||||
```
|
||||
|
||||
#### Example 2: Reward completions with specific format
|
||||
#### Example 2: Reward completions with a specific format
|
||||
|
||||
Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the _format reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948).
|
||||
It is designed for conversational format, where prompts and completions consist of structured messages.
|
||||
It is designed for a conversational format, where prompts and completions consist of structured messages.
|
||||
|
||||
```python
|
||||
import re
|
||||
@ -457,6 +433,7 @@ You can test this function as follows:
|
||||
>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth)
|
||||
[1.0, 0.0]
|
||||
```
|
||||
|
||||
#### Example 4: Multi-task reward functions
|
||||
|
||||
Below is an example of using multiple reward functions in the [`GRPOTrainer`]. In this example, we define two task-specific reward functions: `math_reward_func` and `coding_reward_func`. The `math_reward_func` rewards math problems based on their correctness, while the `coding_reward_func` rewards coding problems based on whether the solution works.
|
||||
@ -513,12 +490,10 @@ trainer = GRPOTrainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
In this example, the `math_reward_func` and `coding_reward_func` are designed to work with a mixed dataset that contains both math and coding problems. The `task` column in the dataset is used to determine which reward function to apply to each problem. If there is no relevant reward function for a sample in the dataset, the reward function will return `None` and the [`GRPOTrainer`] will continue with the valid functions and tasks. This allows the [`GRPOTrainer`] to handle multiple reward functions with different applicability.
|
||||
In this example, the `math_reward_func` and `coding_reward_func` are designed to work with a mixed dataset that contains both math and coding problems. The `task` column in the dataset is used to determine which reward function to apply to each problem. If there is no relevant reward function for a sample in the dataset, the reward function will return `None`, and the [`GRPOTrainer`] will continue with the valid functions and tasks. This allows the [`GRPOTrainer`] to handle multiple reward functions with different applicability.
|
||||
|
||||
Note that the [`GRPOTrainer`] will ignore the `None` rewards returned by the reward functions and only consider the rewards returned by the relevant functions. This ensures that the model is trained on the relevant tasks and ignores the tasks for which there is no relevant reward function.
|
||||
|
||||
|
||||
|
||||
#### Passing the reward function to the trainer
|
||||
|
||||
To use your custom reward function, pass it to the [`GRPOTrainer`] as follows:
|
||||
@ -561,9 +536,8 @@ Tested with:
|
||||
- **Qwen2.5-VL** — e.g., `Qwen/Qwen2.5-VL-3B-Instruct`
|
||||
- **SmolVLM2** — e.g., `HuggingFaceTB/SmolVLM2-2.2B-Instruct`
|
||||
|
||||
<Tip>
|
||||
Compatibility with all VLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes.
|
||||
</Tip>
|
||||
> [!TIP]
|
||||
> Compatibility with all VLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes.
|
||||
|
||||
### Quick Start
|
||||
|
||||
@ -589,9 +563,8 @@ accelerate launch \
|
||||
|
||||
### Configuration Tips
|
||||
|
||||
<Tip warning={true}>
|
||||
VLM training may fail if image tokens are truncated. We highly recommend to disable truncation by setting `max_prompt_length` to `None`.
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> VLM training may fail if image tokens are truncated. We highly recommend disabling truncation by setting `max_prompt_length` to `None`.
|
||||
|
||||
- Use LoRA on vision-language projection layers
|
||||
- Enable 4-bit quantization to reduce memory usage
|
||||
@ -603,7 +576,7 @@ VLM training may fail if image tokens are truncated. We highly recommend to disa
|
||||
Each training sample should include:
|
||||
|
||||
- `prompt`: Text formatted via the processor's chat template
|
||||
- `image`: A single image (PIL or NumPy array)
|
||||
- `image`/`images`: PIL Image or list of PIL Images
|
||||
|
||||
The trainer automatically handles image-to-tensor conversion via the model’s image processor.
|
||||
|
||||
|
@ -1,65 +0,0 @@
|
||||
# Training FAQ
|
||||
|
||||
## What Metrics Should I Look at?
|
||||
|
||||
When performing classical supervised fine-tuning of language models, the loss (especially the validation loss) serves as a good indicator of the training progress. However, in Reinforcement Learning (RL), the loss becomes less informative about the model's performance, and its value may fluctuate while the actual performance improves.
|
||||
|
||||
To address this, we recommend focusing on two key metrics first:
|
||||
|
||||
**Mean Reward**: The primary goal is to maximize the reward achieved by the model during RL training.
|
||||
**Objective KL Divergence**: KL divergence (Kullback-Leibler divergence) measures the dissimilarity between two probability distributions. In the context of RL training, we use it to quantify the difference between the current model and a reference model. Ideally, we want to keep the KL divergence between 0 and 10 to ensure the model's generated text remains close to what the reference model produces.
|
||||
|
||||
However, there are more metrics that can be useful for debugging, check out the [logging section](logging).
|
||||
|
||||
## Why Do We Use a Reference Model, and What's the Purpose of KL Divergence?
|
||||
|
||||
When training RL models, optimizing solely for reward may lead to unexpected behaviors, where the model exploits the environment in ways that don't align with good language generation. In the case of RLHF, we use a reward model trained to predict whether a generated text is highly ranked by humans.
|
||||
|
||||
However, the RL model being optimized against the reward model may learn patterns that yield high reward but do not represent good language. This can result in extreme cases where the model generates texts with excessive exclamation marks or emojis to maximize the reward. In some worst-case scenarios, the model may generate patterns completely unrelated to natural language yet receive high rewards, similar to adversarial attacks.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/kl-example.png">
|
||||
<p style="text-align: center;"> <b>Figure:</b> Samples without a KL penalty from <a href="https://huggingface.co/papers/1909.08593">https://huggingface.co/papers/1909.08593</a>. </p>
|
||||
</div>
|
||||
|
||||
To address this issue, we add a penalty to the reward function based on the KL divergence between the current model and the reference model. By doing this, we encourage the model to stay close to what the reference model generates.
|
||||
|
||||
## What Is the Concern with Negative KL Divergence?
|
||||
|
||||
If you generate text by purely sampling from the model distribution things work fine in general. But when you use the `generate` method there are a few caveats because it does not always purely sample depending on the settings which can cause KL-divergence to go negative. Essentially when the active model achieves `log_p_token_active < log_p_token_ref` we get negative KL-div. This can happen in several cases:
|
||||
|
||||
- **top-k sampling**: the model can smooth out the probability distribution causing the top-k tokens having a smaller probability than those of the reference model but they still are selected
|
||||
- **min_length**: this ignores the EOS token until `min_length` is reached. thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached
|
||||
|
||||
These are just a few examples. Why is negative KL an issue? The total reward `R` is computed `R = r - beta * KL` so if the model can learn how to drive KL-divergence negative it effectively gets a positive reward. In many cases it can be much easier to exploit such a bug in the generation than actually learning the reward function. In addition the KL can become arbitrarily small thus the actual reward can be very small compared to it.
|
||||
|
||||
So how should you generate text for PPO training? Let's have a look!
|
||||
|
||||
## How to generate text for training?
|
||||
|
||||
In order to avoid the KL issues described above we recommend to use the following settings:
|
||||
|
||||
```python
|
||||
generation_kwargs = {
|
||||
"min_length": -1, # don't ignore the EOS token (see above)
|
||||
"top_k": 0.0, # no top-k sampling
|
||||
"top_p": 1.0, # no nucleus sampling
|
||||
"do_sample": True, # yes, we want to sample
|
||||
"pad_token_id": tokenizer.eos_token_id, # most decoder models don't have a padding token - use EOS token instead
|
||||
"max_new_tokens": 32, # specify how many tokens you want to generate at most
|
||||
}
|
||||
```
|
||||
|
||||
With these settings we usually don't encounter any issues. You can also experiment with other settings but if you encounter issues with negative KL-divergence try to go back to these and see if they persist.
|
||||
|
||||
## How can debug your own use-case?
|
||||
|
||||
Debugging the RL pipeline can be challenging due to its complexity. Here are some tips and suggestions to make the process easier:
|
||||
|
||||
- **Start from a working example**: Begin with a working example from the trl repository and gradually modify it to fit your specific use-case. Changing everything at once can make it difficult to identify the source of potential issues. For example, you can start by replacing the model in the example and once you figure out the best hyperparameters try to switch to your dataset and reward model. If you change everything at once you won't know where a potential problem comes from.
|
||||
- **Start small, scale later**: Training large models can be very slow and take several hours or days until you see any improvement. For debugging this is not a convenient timescale so try to use small model variants during the development phase and scale up once that works. That being said you sometimes have to be careful as small models might not have the capacity to solve a complicated task either.
|
||||
- **Start simple**: Try to start with a minimal example and build complexity from there. Your use-case might require for example a complicated reward function consisting of many different rewards - try to use one signal first and see if you can optimize that and then add more complexity after that.
|
||||
- **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a bug in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations.
|
||||
- **Inspect the reward model**: If your reward is not improving over time maybe there's an issue with the reward model. You can look at extreme cases to see if it does what it should: e.g. in the sentiment case you can check if simple positive and negative examples really get different rewards. And you can look at the distribution of your dataset. Finally, maybe the reward is dominated by the query which the model can't affect so you might need to normalize this (e.g. reward of query+response minus reward of the query).
|
||||
|
||||
These are just a few tips that we find helpful - if you have more useful tricks feel free to open a PR to add them as well!
|
@ -7,6 +7,46 @@
|
||||
TRL is a full stack library where we provide a set of tools to train transformer language models with methods like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO), Reward Modeling, and more.
|
||||
The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers).
|
||||
|
||||
Below is the current list of TRL trainers, organized by method type (⚡️ = vLLM support).
|
||||
|
||||
## Taxonomy
|
||||
|
||||
<div style="display: flex; justify-content: space-between; width: 100%; gap: 2rem;">
|
||||
<div style="flex: 1; min-width: 0;">
|
||||
|
||||
### Online methods
|
||||
|
||||
- [`GRPOTrainer`] ⚡️
|
||||
- [`RLOOTrainer`] ⚡️
|
||||
- [`OnlineDPOTrainer`] ⚡️
|
||||
- [`NashMDTrainer`] ⚡️
|
||||
- [`XPOTrainer`] ⚡️
|
||||
- [`PPOTrainer`]
|
||||
|
||||
### Reward modeling
|
||||
|
||||
- [`PRMTrainer`]
|
||||
- [`RewardTrainer`]
|
||||
|
||||
</div>
|
||||
<div style="flex: 1; min-width: 0;">
|
||||
|
||||
### Offline methods
|
||||
|
||||
- [`SFTTrainer`]
|
||||
- [`DPOTrainer`]
|
||||
- [`ORPOTrainer`]
|
||||
- [`BCOTrainer`]
|
||||
- [`CPOTrainer`]
|
||||
- [`KTOTrainer`]
|
||||
|
||||
### Knowledge distillation
|
||||
|
||||
- [`GKDTrainer`]
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
## 🎉 What's New
|
||||
|
||||
**✨ OpenAI GPT OSS Support**: TRL now fully supports fine-tuning the latest [OpenAI GPT OSS models](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4)! Check out the:
|
||||
|
@ -1,13 +1,15 @@
|
||||
# Installation
|
||||
|
||||
You can install TRL either from PyPI or from source:
|
||||
|
||||
## PyPI
|
||||
|
||||
Install the library with pip or [uv](https://docs.astral.sh/uv/):
|
||||
|
||||
<hfoptions id="install">
|
||||
<hfoption id="uv">
|
||||
|
||||
uv is a fast Rust-based Python package and project manager. Refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions).
|
||||
uv is a fast Rust-based Python package and project manager. Refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions.
|
||||
|
||||
```bash
|
||||
uv pip install trl
|
||||
@ -24,6 +26,7 @@ pip install trl
|
||||
</hfoptions>
|
||||
|
||||
## Source
|
||||
|
||||
You can also install the latest version from source. First clone the repo and then run the installation with `pip`:
|
||||
|
||||
```bash
|
||||
|
@ -1,147 +0,0 @@
|
||||
# Iterative Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=iterative-sft,trl)
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
The IterativeSFTTrainer is deprecated and will be removed in version 0.24.0. Please use the [`SFTTrainer`].
|
||||
|
||||
</Tip>
|
||||
|
||||
Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.
|
||||
|
||||
## Quickstart
|
||||
|
||||
To get started quickly, you can either pass a model identifier or a pre-instantiated model to the trainer:
|
||||
|
||||
```python
|
||||
from trl import IterativeSFTConfig, IterativeSFTTrainer
|
||||
|
||||
# Using a model identifier
|
||||
trainer = IterativeSFTTrainer(
|
||||
"facebook/opt-350m",
|
||||
args=IterativeSFTConfig(
|
||||
max_length=512,
|
||||
output_dir="./output",
|
||||
),
|
||||
)
|
||||
|
||||
# Or using a pre-instantiated model
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
|
||||
trainer = IterativeSFTTrainer(
|
||||
model,
|
||||
args=IterativeSFTConfig(
|
||||
max_length=512,
|
||||
output_dir="./output",
|
||||
),
|
||||
processing_class=tokenizer,
|
||||
)
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
The [`IterativeSFTTrainer`] supports two ways of providing input data to the `step` function:
|
||||
|
||||
### Using a list of tensors as input:
|
||||
|
||||
```python
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
trainer.step(**inputs)
|
||||
```
|
||||
|
||||
### Using a list of strings as input:
|
||||
|
||||
```python
|
||||
inputs = {
|
||||
"texts": texts,
|
||||
"texts_labels": texts_labels, # Optional, defaults to texts
|
||||
}
|
||||
|
||||
trainer.step(**inputs)
|
||||
```
|
||||
|
||||
For causal language models, labels will automatically be created from `input_ids` or from `texts`. When using sequence to sequence models you will have to provide your own labels or `text_labels`.
|
||||
|
||||
## Configuration
|
||||
|
||||
The [`IterativeSFTConfig`] class provides several parameters to customize the training:
|
||||
|
||||
```python
|
||||
from trl import IterativeSFTConfig
|
||||
|
||||
config = IterativeSFTConfig(
|
||||
# Model initialization parameters
|
||||
model_init_kwargs={"dtype": "bfloat16"},
|
||||
|
||||
# Data preprocessing parameters
|
||||
max_length=512,
|
||||
truncation_mode="keep_end",
|
||||
|
||||
# Training parameters
|
||||
output_dir="./output",
|
||||
learning_rate=2e-5,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
max_steps=1000,
|
||||
save_steps=100,
|
||||
optim="adamw_torch",
|
||||
report_to="wandb",
|
||||
)
|
||||
```
|
||||
|
||||
### Model Initialization
|
||||
|
||||
You can control how the model is initialized by passing keyword arguments to `model_init_kwargs`:
|
||||
|
||||
```python
|
||||
config = IterativeSFTConfig(
|
||||
model_init_kwargs={
|
||||
"dtype": "bfloat16",
|
||||
"device_map": "auto",
|
||||
"trust_remote_code": True,
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### Data Preprocessing
|
||||
|
||||
The trainer supports two truncation modes:
|
||||
|
||||
- `keep_end`: Truncates from the start of the sequence
|
||||
- `keep_start`: Truncates from the end of the sequence
|
||||
|
||||
```python
|
||||
config = IterativeSFTConfig(
|
||||
max_length=512,
|
||||
truncation_mode="keep_end", # or "keep_start"
|
||||
)
|
||||
```
|
||||
|
||||
### Training Optimization
|
||||
|
||||
You can optimize CUDA cache usage for more memory-efficient training:
|
||||
|
||||
```python
|
||||
config = IterativeSFTConfig(
|
||||
optimize_device_cache=True,
|
||||
)
|
||||
```
|
||||
|
||||
## IterativeSFTTrainer
|
||||
|
||||
[[autodoc]] IterativeSFTTrainer
|
||||
- train
|
||||
- save_model
|
||||
- push_to_hub
|
||||
|
||||
## IterativeSFTConfig
|
||||
|
||||
[[autodoc]] IterativeSFTConfig
|
@ -1,46 +1,64 @@
|
||||
# Training using Jobs
|
||||
# Training with Jobs
|
||||
|
||||
[Jobs](https://huggingface.co/docs/huggingface_hub/guides/jobs) lets you run training scripts on fully managed infrastructure (no need to handle GPUs, dependencies, or environment setup locally). This makes it easy to scale and monitor your experiments directly from the Hub.
|
||||
[](https://huggingface.co/models?other=hf_jobs,trl)
|
||||
|
||||
In this guide, you’ll learn how to:
|
||||
[Hugging Face Jobs](https://huggingface.co/docs/huggingface_hub/guides/jobs) lets you run training scripts on fully managed infrastructure—no need to manage GPUs or local environment setup.
|
||||
|
||||
- Run TRL training scripts using Jobs.
|
||||
- Configure hardware, timeouts, environment variables, and secrets.
|
||||
- Monitor and manage jobs from the CLI or Python.
|
||||
In this guide, you'll learn how to:
|
||||
|
||||
<Tip>
|
||||
* Use [TRL Jobs](https://github.com/huggingface/trl-jobs) to easily run pre-optimized TRL training
|
||||
* Run any TRL training script with uv scripts
|
||||
|
||||
When a model is trained using **TRL + Jobs**, a tag is automatically added to the model card.
|
||||
You can explore models trained with this method [Hugging Face model hub](https://huggingface.co/models?other=hf_jobs).
|
||||
|
||||
</Tip>
|
||||
For general details about Hugging Face Jobs (hardware selection, job monitoring, etc.), see the [Jobs documentation](https://huggingface.co/docs/huggingface_hub/guides/jobs).
|
||||
|
||||
## Requirements
|
||||
|
||||
- [Pro](https://hf.co/pro), [Team](https://hf.co/enterprise), or [Enterprise](https://hf.co/enterprise) plan.
|
||||
- Logged into the Hugging Face Hub (`hf auth login`).
|
||||
* A [Pro](https://hf.co/pro), [Team](https://hf.co/enterprise), or [Enterprise](https://hf.co/enterprise) plan
|
||||
* Logged in to the Hugging Face Hub (`hf auth login`)
|
||||
|
||||
## Preparing your Script
|
||||
## Using TRL Jobs
|
||||
|
||||
You can launch Jobs using either the [`hf jobs` CLI](https://huggingface.co/docs/huggingface_hub/guides/cli#hf-jobs) or the Python API. A convenient option is to use [UV scripts](https://docs.astral.sh/uv/guides/scripts/), which packages all dependencies directly into a single Python file. You can run them like this:
|
||||
[TRL Jobs](https://github.com/huggingface/trl-jobs) is a high-level wrapper around Hugging Face Jobs and TRL that streamlines training. It provides optimized default configurations so you can start quickly without manually tuning parameters.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
pip install trl-jobs
|
||||
trl-jobs sft --model_name Qwen/Qwen3-0.6B --dataset_name trl-lib/Capybara
|
||||
```
|
||||
|
||||
TRL Jobs supports everything covered in this guide, with additional optimizations to simplify workflows.
|
||||
|
||||
## Using uv Scripts
|
||||
|
||||
For more control, you can run Hugging Face Jobs directly with your own scripts, using [uv scripts](https://docs.astral.sh/uv/guides/scripts/).
|
||||
|
||||
Create a Python script (e.g., `train.py`) containing your training code:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
trainer = SFTTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset,
|
||||
)
|
||||
trainer.train()
|
||||
trainer.push_to_hub("Qwen2.5-0.5B-SFT")
|
||||
```
|
||||
|
||||
Launch the job using either the [`hf jobs` CLI](https://huggingface.co/docs/huggingface_hub/guides/cli#hf-jobs) or the Python API:
|
||||
|
||||
<hfoptions id="script_type">
|
||||
<hfoption id="bash">
|
||||
|
||||
```bash
|
||||
hf jobs uv run --flavor a100-large --secrets HF_TOKEN "https://raw.githubusercontent.com/huggingface/trl/main/trl/scripts/sft.py" --model_name_or_path Qwen/Qwen2-0.5B --dataset_name trl-lib/Capybara
|
||||
```
|
||||
|
||||
The script can also be a local file:
|
||||
|
||||
```bash
|
||||
hf jobs uv run --flavor a100-large --secrets HF_TOKEN trl/scripts/sft.py --model_name_or_path Qwen/Qwen2-0.5B --dataset_name trl-lib/Capybara
|
||||
```
|
||||
|
||||
Since it runs using a Docker Image from Hugging Face Spaces or Docker Hub, you can also specify it:
|
||||
|
||||
```bash
|
||||
hf jobs uv run --flavor a100-large --secrets HF_TOKEN --image <docker-image> trl/scripts/sft.py --model_name_or_path Qwen/Qwen2-0.5B --dataset_name trl-lib/Capybara
|
||||
hf jobs uv run \
|
||||
--flavor a100-large \
|
||||
--with trl \
|
||||
--secrets HF_TOKEN \
|
||||
train.py
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
@ -48,236 +66,113 @@ hf jobs uv run --flavor a100-large --secrets HF_TOKEN --image <docker-image> trl
|
||||
|
||||
```python
|
||||
from huggingface_hub import run_uv_job
|
||||
|
||||
run_uv_job(
|
||||
"https://raw.githubusercontent.com/huggingface/trl/main/trl/scripts/sft.py",
|
||||
token="hf...",
|
||||
"train.py",
|
||||
dependencies=["trl"],
|
||||
flavor="a100-large",
|
||||
script_args=[
|
||||
"--model_name_or_path", "Qwen/Qwen2-0.5B",
|
||||
"--dataset_name", "trl-lib/Capybara",
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
The script can also be a local file:
|
||||
|
||||
```python
|
||||
from huggingface_hub import run_uv_job
|
||||
run_uv_job(
|
||||
"trl/scripts/sft.py",
|
||||
token="hf...",
|
||||
flavor="a100-large",
|
||||
script_args=[
|
||||
"--model_name_or_path", "Qwen/Qwen2-0.5B",
|
||||
"--dataset_name", "trl-lib/Capybara",
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
Since it runs using a Docker Image from Hugging Face Spaces or Docker Hub, you can also specify it:
|
||||
|
||||
```python
|
||||
from huggingface_hub import run_uv_job
|
||||
run_uv_job(
|
||||
"sft.py",
|
||||
token="hf...",
|
||||
flavor="a100-large",
|
||||
image="<docker-image>",
|
||||
script_args=[
|
||||
"--model_name_or_path", "Qwen/Qwen2-0.5B",
|
||||
"--dataset_name", "trl-lib/Capybara",
|
||||
]
|
||||
secrets={"HF_TOKEN": "hf_..."},
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
You can also run jobs without UV:
|
||||
To run successfully, the script needs:
|
||||
|
||||
* **TRL installed**: Use the `--with trl` flag or the `dependencies` argument. uv installs these dependencies automatically before running the script.
|
||||
* **An authentication token**: Required to push the trained model (or perform other authenticated operations). Provide it with the `--secrets HF_TOKEN` flag or the `secrets` argument.
|
||||
|
||||
> [!WARNING]
|
||||
> When training with Jobs, be sure to:
|
||||
>
|
||||
> * **Set a sufficient timeout**. Jobs time out after 30 minutes by default. If your job exceeds the timeout, it will fail and all progress will be lost. See [Setting a custom timeout](https://huggingface.co/docs/huggingface_hub/guides/jobs#setting-a-custom-timeout).
|
||||
> * **Push the model to the Hub**. The Jobs environment is ephemeral—files are deleted when the job ends. If you don’t push the model, it will be lost.
|
||||
|
||||
You can also run a script directly from a URL:
|
||||
|
||||
<hfoptions id="script_type">
|
||||
<hfoption id="bash">
|
||||
|
||||
In this case, we give the cli the Docker image and run it as:
|
||||
|
||||
```bash
|
||||
hf jobs run --flavor a100-large pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel python -c "import torch; print(torch.cuda.get_device_name())"
|
||||
hf jobs uv run \
|
||||
--flavor a100-large \
|
||||
--with trl \
|
||||
--secrets HF_TOKEN \
|
||||
"https://gist.githubusercontent.com/qgallouedec/eb6a7d20bd7d56f9c440c3c8c56d2307/raw/69fd78a179e19af115e4a54a1cdedd2a6c237f2f/train.py"
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="python">
|
||||
|
||||
```python
|
||||
from huggingface_hub import run_job
|
||||
run_job(
|
||||
image="pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel",
|
||||
command=["python", "-c", "import torch; print(torch.cuda.get_device_name())"],
|
||||
from huggingface_hub import run_uv_job
|
||||
|
||||
run_uv_job(
|
||||
"https://gist.githubusercontent.com/qgallouedec/eb6a7d20bd7d56f9c440c3c8c56d2307/raw/69fd78a179e19af115e4a54a1cdedd2a6c237f2f/train.py",
|
||||
flavor="a100-large",
|
||||
dependencies=["trl"],
|
||||
secrets={"HF_TOKEN": "hf_..."},
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Adding Dependencies with UV
|
||||
|
||||
All example scripts in TRL are compatible with `uv`, allowing seamless execution with Jobs. You can check the full list of examples in [Maintained examples](example_overview#maintained-examples).
|
||||
|
||||
Dependencies are specified at the top of the script using this structure:
|
||||
To make a script self-contained, declare dependencies at the top:
|
||||
|
||||
```python
|
||||
# /// script
|
||||
# dependencies = [
|
||||
# "trl @ git+https://github.com/huggingface/trl.git",
|
||||
# "trl",
|
||||
# "peft",
|
||||
# ]
|
||||
# ///
|
||||
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from trl import SFTTrainer
|
||||
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset,
|
||||
peft_config=LoraConfig(),
|
||||
)
|
||||
trainer.train()
|
||||
trainer.push_to_hub("Qwen2.5-0.5B-SFT")
|
||||
```
|
||||
|
||||
When you run the UV script, these dependencies are automatically installed. In the example above, `trl` and `peft` would be installed before the script runs.
|
||||
|
||||
You can also provide dependencies directly in the `uv run` command:
|
||||
You can then run the script without specifying dependencies:
|
||||
|
||||
<hfoptions id="script_type">
|
||||
<hfoption id="bash">
|
||||
|
||||
Using the `--with` flag.
|
||||
|
||||
```bash
|
||||
hf jobs uv run \
|
||||
--flavor a100-large \
|
||||
--secrets HF_TOKEN \
|
||||
--with transformers \
|
||||
--with torch \
|
||||
"https://raw.githubusercontent.com/huggingface/trl/main/trl/scripts/sft.py" \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B \
|
||||
--dataset_name trl-lib/Capybara
|
||||
train.py
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="python">
|
||||
|
||||
Using the `dependencies` argument.
|
||||
|
||||
```python
|
||||
from huggingface_hub import run_uv_job
|
||||
|
||||
run_uv_job(
|
||||
"https://raw.githubusercontent.com/huggingface/trl/main/trl/scripts/sft.py",
|
||||
dependencies=["transformers", "torch"]
|
||||
token="hf...",
|
||||
"train.py",
|
||||
flavor="a100-large",
|
||||
script_args=[
|
||||
"--model_name_or_path", "Qwen/Qwen2-0.5B",
|
||||
"--dataset_name", "trl-lib/Capybara",
|
||||
]
|
||||
secrets={"HF_TOKEN": "hf_..."},
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Hardware and Timeout Settings
|
||||
|
||||
Jobs allow you to select a specific hardware configuration using the `--flavor` flag. As of 08/25, the available options are:
|
||||
|
||||
**CPU:** `cpu-basic`, `cpu-upgrade`
|
||||
**GPU:** `t4-small`, `t4-medium`, `l4x1`, `l4x4`, `a10g-small`, `a10g-large`, `a10g-largex2`, `a10g-largex4`, `a100-large`
|
||||
**TPU:** `v5e-1x1`, `v5e-2x2`, `v5e-2x4`
|
||||
|
||||
You can always check the latest list of supported hardware flavors in [Spaces config reference](https://huggingface.co/docs/hub/en/spaces-config-reference).
|
||||
|
||||
By default, jobs have a **30-minute timeout**, after which they will automatically stop. For long-running tasks like training, you can increase the timeout as needed. Supported time units are:
|
||||
|
||||
- `s`: seconds
|
||||
- `m`: minutes
|
||||
- `h`: hours
|
||||
- `d`: days
|
||||
|
||||
Example with a 2-hour timeout:
|
||||
|
||||
<hfoptions id="script_type">
|
||||
<hfoption id="bash">
|
||||
|
||||
Using the `--timeout` flag:
|
||||
|
||||
```bash
|
||||
hf jobs uv run \
|
||||
--timeout 2h \
|
||||
--flavor a100-large \
|
||||
--secrets HF_TOKEN \
|
||||
--with transformers \
|
||||
--with torch \
|
||||
"https://raw.githubusercontent.com/huggingface/trl/main/trl/scripts/sft.py" \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B \
|
||||
--dataset_name trl-lib/Capybara
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="python">
|
||||
|
||||
Using the `timeout` argument:
|
||||
|
||||
```python
|
||||
from huggingface_hub import run_uv_job
|
||||
run_uv_job(
|
||||
"https://raw.githubusercontent.com/huggingface/trl/main/trl/scripts/sft.py",
|
||||
timeout="2h",
|
||||
token="hf...",
|
||||
flavor="a100-large",
|
||||
script_args=[
|
||||
"--model_name_or_path", "Qwen/Qwen2-0.5B",
|
||||
"--dataset_name", "trl-lib/Capybara",
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Environment Variables, Secrets, and Token
|
||||
|
||||
You can pass environment variables, secrets, and your auth token to your jobs.
|
||||
|
||||
<hfoptions id="script_type">
|
||||
<hfoption id="bash">
|
||||
|
||||
Using the `--env`, `--secrets`, and/or `--token` options.
|
||||
|
||||
```bash
|
||||
hf jobs uv run \
|
||||
trl/scripts/sft.py \
|
||||
--flavor a100-large \
|
||||
--env FOO=foo \
|
||||
--env BAR=bar \
|
||||
--secrets HF_TOKEN=HF_TOKEN \
|
||||
--secrets MY_SECRET=password \
|
||||
--token hf...
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="python">
|
||||
|
||||
|
||||
Using the `env`, `secrets`, and/or `token` arguments.
|
||||
|
||||
```python
|
||||
from huggingface_hub import run_uv_job
|
||||
run_uv_job(
|
||||
"trl/scripts/sft.py",
|
||||
env={"FOO": "foo", "BAR": "bar"},
|
||||
secrets={"MY_SECRET": "psswrd"},
|
||||
token="hf..."
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Training and Evaluating a Model with Jobs
|
||||
|
||||
TRL example scripts are fully UV-compatible, allowing you to run a complete training workflow directly on Jobs. You can customize the training by providing the usual script arguments, along with hardware specifications and secrets.
|
||||
|
||||
To evaluate your training runs, in addition to reviewing the job logs, you can use [**Trackio**](https://huggingface.co/blog/trackio), a lightweight experiment tracking library. Trackio enables end-to-end experiment management on the Hugging Face Hub. All TRL example scripts already support reporting to Trackio via the `report_to` argument. Using this feature saves your experiments in an interactive HF Space, making it easy to monitor metrics, compare runs, and track progress over time.
|
||||
TRL example scripts are fully uv-compatible, so you can run a complete training workflow directly on Jobs. You can customize training with standard script arguments plus hardware and secrets:
|
||||
|
||||
<hfoptions id="script_type">
|
||||
<hfoption id="bash">
|
||||
@ -286,19 +181,10 @@ To evaluate your training runs, in addition to reviewing the job logs, you can u
|
||||
hf jobs uv run \
|
||||
--flavor a100-large \
|
||||
--secrets HF_TOKEN \
|
||||
"trl/scripts/sft.py" \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B \
|
||||
--dataset_name trl-lib/Capybara \
|
||||
--learning_rate 2.0e-5 \
|
||||
--num_train_epochs 1 \
|
||||
--packing \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--eos_token '<|im_end|>' \
|
||||
--eval_strategy steps \
|
||||
--eval_steps 100 \
|
||||
--output_dir Qwen2-0.5B-SFT \
|
||||
--report_to trackio \
|
||||
https://raw.githubusercontent.com/huggingface/trl/refs/heads/main/examples/scripts/prm.py \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
||||
--dataset_name trl-lib/prm800k \
|
||||
--output_dir Qwen2-0.5B-Reward \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
@ -307,24 +193,14 @@ hf jobs uv run \
|
||||
|
||||
```python
|
||||
from huggingface_hub import run_uv_job
|
||||
|
||||
run_uv_job(
|
||||
"trl/scripts/sft.py",
|
||||
"https://raw.githubusercontent.com/huggingface/trl/refs/heads/main/examples/scripts/prm.py",
|
||||
flavor="a100-large",
|
||||
secrets={"HF_TOKEN": "your_hf_token"},
|
||||
secrets={"HF_TOKEN": "hf_..."},
|
||||
script_args=[
|
||||
"--model_name_or_path", "Qwen/Qwen2-0.5B",
|
||||
"--dataset_name", "trl-lib/Capybara",
|
||||
"--learning_rate", "2.0e-5",
|
||||
"--num_train_epochs", "1",
|
||||
"--packing",
|
||||
"--per_device_train_batch_size", "2",
|
||||
"--gradient_accumulation_steps", "8",
|
||||
"--eos_token", "<|im_end|>",
|
||||
"--eval_strategy", "steps",
|
||||
"--eval_steps", "100",
|
||||
"--output_dir", "Qwen2-0.5B-SFT",
|
||||
"--report_to", "trackio",
|
||||
"--model_name_or_path", "Qwen/Qwen2-0.5B-Instruct",
|
||||
"--dataset_name", "trl-lib/prm800k",
|
||||
"--output_dir", "Qwen2-0.5B-Reward",
|
||||
"--push_to_hub"
|
||||
]
|
||||
)
|
||||
@ -332,61 +208,67 @@ run_uv_job(
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
See the full list of examples in [Maintained examples](example_overview#maintained-examples).
|
||||
|
||||
## Monitoring and Managing Jobs
|
||||
### Docker Images
|
||||
|
||||
After launching a job, you can track its progress on the [Jobs page](https://huggingface.co/settings/jobs). Additionally, Jobs provides CLI and Python commands to check status, view logs, or cancel a job.
|
||||
An up-to-date Docker image with all TRL dependencies is available at [huggingface/trl](https://hub.docker.com/r/huggingface/trl) and can be used directly with Hugging Face Jobs:
|
||||
|
||||
<hfoptions id="script_type">
|
||||
<hfoption id="bash">
|
||||
|
||||
```bash
|
||||
# List your jobs
|
||||
hf jobs ps -a
|
||||
|
||||
# List your running jobs
|
||||
hf jobs ps
|
||||
|
||||
# Inspect the status of a job
|
||||
hf jobs inspect
|
||||
|
||||
# View logs from a job
|
||||
hf jobs logs job_id
|
||||
|
||||
# Cancel a job
|
||||
hf jobs cancel job_id
|
||||
hf jobs uv run \
|
||||
--flavor a100-large \
|
||||
--secrets HF_TOKEN \
|
||||
--image huggingface/trl \
|
||||
train.py
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="python">
|
||||
|
||||
|
||||
```python
|
||||
from huggingface_hub import list_jobs, inspect_job, fetch_job_logs, cancel_job
|
||||
from huggingface_hub import run_uv_job
|
||||
|
||||
# List your jobs
|
||||
jobs = list_jobs()
|
||||
jobs[0]
|
||||
|
||||
# List your running jobs
|
||||
running_jobs = [job for job in list_jobs() if job.status.stage == "RUNNING"]
|
||||
|
||||
# Inspect the status of a job
|
||||
inspect_job(job_id=job_id)
|
||||
|
||||
# View logs from a job
|
||||
for log in fetch_job_logs(job_id=job_id):
|
||||
print(log)
|
||||
|
||||
# Cancel a job
|
||||
cancel_job(job_id=job_id)
|
||||
run_uv_job(
|
||||
"train.py",
|
||||
flavor="a100-large",
|
||||
secrets={"HF_TOKEN": "hf_..."},
|
||||
image="huggingface/trl",
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Best Practices and Tips
|
||||
Jobs runs on a Docker image from Hugging Face Spaces or Docker Hub, so you can also specify any custom image:
|
||||
|
||||
- Choose hardware that fits the size of your model and dataset for optimal performance.
|
||||
- Training jobs can be long-running. Consider increasing the default timeout.
|
||||
- Reuse training and evaluation scripts whenever possible to streamline workflows.
|
||||
<hfoptions id="script_type">
|
||||
<hfoption id="bash">
|
||||
|
||||
```bash
|
||||
hf jobs uv run \
|
||||
--flavor a100-large \
|
||||
--secrets HF_TOKEN \
|
||||
--image <docker-image> \
|
||||
--secrets HF_TOKEN \
|
||||
train.py
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="python">
|
||||
|
||||
```python
|
||||
from huggingface_hub import run_uv_job
|
||||
|
||||
run_uv_job(
|
||||
"train.py",
|
||||
flavor="a100-large",
|
||||
secrets={"HF_TOKEN": "hf_..."},
|
||||
image="<docker-image>",
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
@ -1,10 +1,7 @@
|
||||
# Judges
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
TRL Judges is an experimental API which is subject to change at any time.
|
||||
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> TRL Judges is an experimental API which is subject to change at any time.
|
||||
|
||||
TRL provides judges to easily compare two completions.
|
||||
|
||||
@ -16,7 +13,7 @@ pip install trl[judges]
|
||||
|
||||
## Using the provided judges
|
||||
|
||||
TRL provides several judges out of the box. For example, you can use the `HfPairwiseJudge` to compare two completions using a pre-trained model from the Hugging Face model hub:
|
||||
TRL provides several judges out of the box. For example, you can use the [`HfPairwiseJudge`] to compare two completions using a pre-trained model from the Hugging Face model hub:
|
||||
|
||||
```python
|
||||
from trl import HfPairwiseJudge
|
||||
|
@ -43,12 +43,8 @@ Or using the TRL CLI:
|
||||
trl sft ... --attn_implementation kernels-community/flash-attn
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
Now you can leverage faster attention backends with a pre-optimized kernel for your hardware configuration from the Hub, speeding up both development and training.
|
||||
|
||||
</Tip>
|
||||
|
||||
> [!TIP]
|
||||
> Now you can leverage faster attention backends with a pre-optimized kernel for your hardware configuration from the Hub, speeding up both development and training.
|
||||
|
||||
## Comparing Attention Implementations
|
||||
|
||||
@ -57,15 +53,14 @@ The experiments were run on a single **H100 GPU** with **CUDA 12.9**, leveraging
|
||||
Keep in mind that the results shown here are specific to this setup and may vary with different training configurations.
|
||||
|
||||
The following figure illustrates both **latency** (time per training step) and **peak allocated memory** for the different attention implementations and kernel backends.
|
||||
Kernel-based implementations perform on par with custom-installed attention, and increasing the model’s `max_length` further enhances performance. Memory consumption is similar across all implementations, showing no significant differences. We get the same performance but with less friction, as described in [the following section](#benchmarking-flash-attention-build-from-source-vs-hub-kernels).
|
||||
|
||||
Kernel-based implementations perform on par with custom-installed attention, and increasing the model’s `max_length` further enhances performance. Memory consumption is similar across all implementations, showing no significant differences. We get the same performance but with less friction, as described in [the following section](#flash-attention-vs-hub-kernels).
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/kernels_guide_latency.png" alt="Latency and Memory Usage" width="600"/>
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/kernels_guide_peak_allocated_memory.png" alt="Latency and Memory Usage" width="600"/>
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/kernels_guide_latency.png" alt="Latency and Memory Usage" width="45%"/>
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/kernels_guide_peak_allocated_memory.png" alt="Latency and Memory Usage" width="45%"/>
|
||||
</div>
|
||||
|
||||
## Flash Attention (Build-from-Source) vs. Hub Kernels
|
||||
## Flash Attention vs. Hub Kernels
|
||||
|
||||
Building Flash Attention from source can be time-consuming, often taking anywhere from several minutes to hours, depending on your hardware, CUDA/PyTorch configuration, and whether precompiled wheels are available.
|
||||
|
||||
@ -77,7 +72,6 @@ You can combine **FlashAttention kernels** with **Liger kernels** for additional
|
||||
|
||||
First, install the Liger kernel dependency:
|
||||
|
||||
|
||||
```bash
|
||||
pip install liger-kernel
|
||||
```
|
||||
@ -99,6 +93,4 @@ training_args = SFTConfig(
|
||||
)
|
||||
```
|
||||
|
||||
Learn more about this integration [here](./liger_kernel_integration).
|
||||
|
||||
|
||||
Learn more about the [Liger Kernel Integration](./liger_kernel_integration).
|
||||
|
@ -1,12 +1,11 @@
|
||||
# KTO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=kto,trl)
|
||||
[](https://huggingface.co/models?other=kto,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
Kahneman-Tversky Optimization (KTO) was introduced in [KTO: Model Alignment as Prospect Theoretic Optimization](https://huggingface.co/papers/2402.01306) by [Kawin Ethayarajh](https://huggingface.co/kawine), [Winnie Xu](https://huggingface.co/xwinxu), [Niklas Muennighoff](https://huggingface.co/Muennighoff), Dan Jurafsky, [Douwe Kiela](https://huggingface.co/douwekiela).
|
||||
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Kahneman & Tversky's prospect theory tells us that humans perceive random variables in a biased but well-defined manner; for example, humans are famously loss-averse. We show that objectives for aligning LLMs with human feedback implicitly incorporate many of these biases -- the success of these objectives (e.g., DPO) over cross-entropy minimization can partly be ascribed to them being human-aware loss functions (HALOs). However, the utility functions these methods attribute to humans still differ from those in the prospect theory literature. Using a Kahneman-Tversky model of human utility, we propose a HALO that directly maximizes the utility of generations instead of maximizing the log-likelihood of preferences, as current methods do. We call this approach Kahneman-Tversky Optimization (KTO), and it matches or exceeds the performance of preference-based methods at scales from 1B to 30B. Crucially, KTO does not need preferences -- only a binary signal of whether an output is desirable or undesirable for a given input. This makes it far easier to use in the real world, where preference data is scarce and expensive.
|
||||
@ -51,7 +50,7 @@ accelerate launch train_kto.py
|
||||
|
||||
Distributed across 8 x H100 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
|
||||
|
||||

|
||||

|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
@ -60,14 +59,14 @@ To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) pe
|
||||
What is the best programming language?
|
||||
|
||||
<strong><span style="color: blue;"><trl-lib/Qwen2-0.5B-KTO>:</span></strong>
|
||||
The best programming language can vary depending on individual preferences, industry-specific requirements, technical skills, and familiarity with the specific use case or task. Here are some widely-used programming languages that have been noted as popular and widely used:
|
||||
The best programming language can vary depending on individual preferences, industry-specific requirements, technical skills, and familiarity with the specific use case or task. Here are some widely-used programming languages that have been noted as popular and widely used:
|
||||
|
||||
Here are some other factors to consider when choosing a programming language for a project:
|
||||
|
||||
<strong><span style="color: green;">1</span> JavaScript</strong>: JavaScript is at the heart of the web and can be used for building web applications, APIs, and interactive front-end applications like frameworks like React and Angular. It's similar to C, C++, and F# in syntax structure and is accessible and easy to learn, making it a popular choice for beginners and professionals alike.
|
||||
<strong><span style="color: green;">2</span> Java</strong>: Known for its object-oriented programming (OOP) and support for Java 8 and .NET, Java is used for developing enterprise-level software applications, high-performance games, as well as mobile apps, game development, and desktop applications.
|
||||
<strong><span style="color: green;">3</span> C++</strong>: Known for its flexibility and scalability, C++ offers comprehensive object-oriented programming and is a popular choice for high-performance computing and other technical fields. It's a powerful platform for building real-world applications and games at scale.
|
||||
<strong><span style="color: green;">4</span> Python</strong>: Developed by Guido van Rossum in 1991, Python is a high-level, interpreted, and dynamically typed language known for its simplicity, readability, and versatility.
|
||||
<strong><span style="color: green;">1</span> JavaScript</strong>: JavaScript is at the heart of the web and can be used for building web applications, APIs, and interactive front-end applications like frameworks like React and Angular. It's similar to C, C++, and F# in syntax structure and is accessible and easy to learn, making it a popular choice for beginners and professionals alike.
|
||||
<strong><span style="color: green;">2</span> Java</strong>: Known for its object-oriented programming (OOP) and support for Java 8 and .NET, Java is used for developing enterprise-level software applications, high-performance games, as well as mobile apps, game development, and desktop applications.
|
||||
<strong><span style="color: green;">3</span> C++</strong>: Known for its flexibility and scalability, C++ offers comprehensive object-oriented programming and is a popular choice for high-performance computing and other technical fields. It's a powerful platform for building real-world applications and games at scale.
|
||||
<strong><span style="color: green;">4</span> Python</strong>: Developed by Guido van Rossum in 1991, Python is a high-level, interpreted, and dynamically typed language known for its simplicity, readability, and versatility.
|
||||
</code></pre>
|
||||
|
||||
## Expected dataset format
|
||||
@ -102,7 +101,6 @@ To ensure that we train MOEs similarly during preference-tuning, it is beneficia
|
||||
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
|
||||
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
|
||||
|
||||
|
||||
### Batch size recommendations
|
||||
|
||||
Use a per-step batch size that is at least 4, and an effective batch size between 16 and 128. Even if your effective batch size is large, if your per-step batch size is poor, then the KL estimate in KTO will be poor.
|
||||
|
@ -1,24 +1,21 @@
|
||||
# Liger Kernel Integration
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Section under construction. Feel free to contribute!
|
||||
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> Section under construction. Feel free to contribute!
|
||||
|
||||
[Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduce memory usage by 60%. That way, we can **4x** our context length, as described in the benchmark below. They have implemented Hugging Face compatible `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, with more to come. The kernel works out of the box with [FlashAttention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed).
|
||||
|
||||
With this memory reduction, you can potentially turn off `cpu_offloading` or gradient checkpointing to further boost the performance.
|
||||
|
||||
| Speed Up | Memory Reduction |
|
||||
|--------------------------|-------------------------|
|
||||
| Speed Up | Memory Reduction |
|
||||
| --- | --- |
|
||||
|  |  |
|
||||
|
||||
1. To use Liger-Kernel in [`SFTTrainer`], first install it by:
|
||||
|
||||
```bash
|
||||
pip install liger-kernel
|
||||
```
|
||||
|
||||
```bash
|
||||
pip install liger-kernel
|
||||
```
|
||||
|
||||
2. Once installed, set `use_liger_kernel` in [`SFTConfig`]. No other changes are needed!
|
||||
|
||||
|
@ -1,106 +0,0 @@
|
||||
# Logging
|
||||
|
||||
As reinforcement learning algorithms are historically challenging to debug, it's important to pay careful attention to logging.
|
||||
By default, TRL trainers like [`PPOTrainer`] and [`GRPOTrainer`] save a lot of relevant information to supported experiment trackers like Trackio, Weights & Biases (wandb) or TensorBoard.
|
||||
|
||||
Upon initialization, pass the `report_to` argument to the respective configuration object (e.g., [`PPOConfig`] for `PPOTrainer`, or [`GRPOConfig`] for `GRPOTrainer`):
|
||||
|
||||
```python
|
||||
# For PPOTrainer
|
||||
ppo_config = PPOConfig(
|
||||
# ...,
|
||||
report_to="trackio" # or "wandb" or "tensorboard"
|
||||
)
|
||||
|
||||
# For GRPOTrainer
|
||||
grpo_config = GRPOConfig(
|
||||
# ...,
|
||||
report_to="trackio" # or "wandb" or "tensorboard"
|
||||
)
|
||||
```
|
||||
|
||||
If you want to log with TensorBoard, you might also need to specify logging directories, for example, by adding `logging_dir=PATH_TO_LOGS` to the configuration object (e.g., `PPOConfig` or `GRPOConfig`).
|
||||
|
||||
## PPO Logging
|
||||
|
||||
Here's a brief explanation for the logged metrics provided in the data:
|
||||
|
||||
* `eps`: Tracks the number of episodes per second.
|
||||
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
|
||||
* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
|
||||
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
|
||||
* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
|
||||
* `objective/scores`: The mean scores returned by the reward model / environment.
|
||||
* `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
|
||||
* `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
|
||||
* `loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
|
||||
* `loss/value_avg`: The average value loss, indicating the difference between the predicted value and the actual reward.
|
||||
* `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to `policy/clipfrac_avg` but for the value function.
|
||||
* `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
|
||||
* `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
|
||||
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
|
||||
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
|
||||
* `lr`: The current learning rate used by the optimizer.
|
||||
* `episode`: The current episode count in the training process.
|
||||
|
||||
### Crucial values
|
||||
During training, many values are logged, here are the most important ones:
|
||||
|
||||
1. `objective/scores`: The mean scores returned by the reward model / environment.
|
||||
1. `objective/rlhf_reward`: The mean RLHF reward. This is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
|
||||
1. `objective/non_score_reward`: The mean reward from non-score-related sources (e.g., KL penalty).
|
||||
|
||||
Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables):
|
||||
|
||||
1. `loss/value_avg`: The average value loss. It will spike / NaN when not going well.
|
||||
1. `val/ratio`: The mean ratio of the current policy probability to the old policy probability. This number should float around 1.0. If this `ratio` is too high (e.g., 2.0 or 1000.0) or too small (e.g., 0.1), it means the updates between consecutive policies are too drastic.
|
||||
1. `policy/clipfrac_avg` and `policy/approxkl_avg`: If `val/ratio` is too high, the `ratio` is going to get clipped, resulting in high `policy/clipfrac_avg` and high `policy/approxkl_avg` as well.
|
||||
1. `objective/kl`: The mean KL divergence. It should stay positive and ideally not too large, so that the policy is not too far away from the reference policy.
|
||||
|
||||
## GRPO Logging
|
||||
|
||||
Here's a brief explanation for the logged metrics provided in the data for the GRPO trainer:
|
||||
|
||||
* `num_tokens`: Total number of input tokens processed during training so far.
|
||||
|
||||
#### Completions
|
||||
|
||||
* `completions/mean_length`: Mean length of all generated completions (including those not ending with an EOS token).
|
||||
* `completions/min_length`: Minimum length among all generated completions.
|
||||
* `completions/max_length`: Maximum length among all generated completions.
|
||||
* `completions/clipped_ratio`: The ratio of completions that did not end with an EOS token before reaching the maximum generation length (i.e., they were truncated).
|
||||
* `completions/mean_terminated_length`: Mean length of only those completions that successfully ended with an EOS token.
|
||||
* `completions/min_terminated_length`: Minimum length among completions that ended with an EOS token.
|
||||
* `completions/max_terminated_length`: Maximum length among completions that ended with an EOS token.
|
||||
|
||||
#### Rewards
|
||||
|
||||
* `rewards/{reward_func_name}/mean`: The mean reward obtained from a specific, named reward function (e.g., `rewards/my_custom_reward/mean`). This is logged for each reward function used.
|
||||
* `rewards/{reward_func_name}/std`: The standard deviation of rewards from a specific, named reward function.
|
||||
* `reward`: The overall mean of the (potentially weighted and, if `args.scale_rewards` is true, normalized) rewards, after group-wise normalization (advantages).
|
||||
* `reward_std`: The standard deviation of the (potentially weighted) rewards *before* group-wise normalization for advantages.
|
||||
|
||||
#### Policy and Loss Metrics
|
||||
|
||||
* `kl`: The mean Kullback-Leibler (KL) divergence between the current policy and the reference policy. This is logged only if `beta` (the KL coefficient in `GRPOConfig`) is non-zero.
|
||||
* `entropy`: Average entropy of token predictions across generated completions.
|
||||
* If Liger GRPOLoss is used (`use_liger_loss: True` in `GRPOConfig`):
|
||||
* `clip_ratio`: The fraction of policy updates where the probability ratio was clipped according to the GRPO loss's epsilon bounds.
|
||||
* If standard GRPOLoss is used (`use_liger_loss: False`):
|
||||
* `clip_ratio/low_mean`: The mean fraction of instances where the probability ratio `r_t(θ)` was clipped at the lower bound `1 - epsilon_low` (occurs when advantage is negative and ratio is below the bound).
|
||||
* `clip_ratio/low_min`: The minimum observed fraction for `clip_ratio/low_mean` across batches/processes.
|
||||
* `clip_ratio/high_mean`: The mean fraction of instances where the probability ratio `r_t(θ)` was clipped at the upper bound `1 + epsilon_high` (occurs when advantage is positive and ratio is above the bound).
|
||||
* `clip_ratio/high_max`: The maximum observed fraction for `clip_ratio/high_mean` across batches/processes.
|
||||
* `clip_ratio/region_mean`: The mean fraction of instances where the probability ratio was clipped at either the lower or upper bound.
|
||||
|
||||
### Crucial GRPO values
|
||||
|
||||
During GRPO training, monitor these values for insights into performance and stability:
|
||||
|
||||
1. `reward`: This is the primary objective. It reflects the (group-wise normalized) rewards the policy is achieving. It should generally increase during successful training.
|
||||
1. `kl`: If `beta > 0`, this tracks the divergence from the reference model. Keep an eye on it to ensure the policy doesn't stray too far, which can lead to instability.
|
||||
1. `clip_ratio/*` (either `clip_ratio` for Liger loss or the more detailed `clip_ratio/...` metrics for standard loss): These indicate how often the policy updates are being constrained by the GRPO clipping mechanism. Very high values might suggest that the policy is trying to change too drastically (potentially due to large advantages or a learning rate that's too high) or that the epsilon clipping range is too restrictive.
|
||||
1. `completions/clipped_ratio`: A high ratio here indicates that the model is frequently generating completions that are cut off by `max_completion_length` rather than naturally ending with an EOS token. This might suggest issues with learning sequence termination or that `max_completion_length` is too short.
|
||||
1. `rewards/{reward_func_name}/mean`: Monitoring the mean of individual reward functions can help diagnose which aspects of the desired behavior the model is learning or struggling with, especially when using multiple reward sources.
|
||||
1. `entropy`: Measures how uncertain the policy is in its action choices, higher entropy suggests more exploration. A collapse in entropy means the policy is becoming overconfident and deterministic, often too early. This can stall learning by reducing exploration and making updates overly biased. Stable but non-zero entropy is usually a sign that the policy retains flexibility and continues to explore.
|
||||
|
442
docs/source/lora_without_regret.md
Normal file
442
docs/source/lora_without_regret.md
Normal file
@ -0,0 +1,442 @@
|
||||
# LoRA Without Regret
|
||||
|
||||
Recent research from the team at [Thinking Machines Lab](https://thinkingmachines.ai/blog/lora/) (Schulman et al., 2025) shows that **LoRA can match full fine-tuning performance** when configured correctly, while using only ~67% of the compute. These findings are exciting to TRL users because they're straightforward to implement and can improve model performance on smaller budgets.
|
||||
|
||||
This guide provides simple instructions to reproduce the results of the blog post in TRL.
|
||||
|
||||
> [!TIP]
|
||||
> It is recommended to read the blog post before following this guide, or to consult both resources in parallel for best results.
|
||||
|
||||
## Benefits of LoRA over full fine-tuning
|
||||
|
||||
First of all, let's remind ourselves of the benefits of [LoRA over full fine-tuning](https://huggingface.co/docs/trl/en/peft_integration).
|
||||
|
||||
LoRA adds adapter layers on top of the base model, which contains significantly fewer parameters than the base model itself. This design reduces GPU memory requirements and enables more efficient training. As described in the [blog](https://thinkingmachines.ai/blog/lora/), this approach was originally thought to involve a performance trade-off, although careful configuration can overcome this trade-off and match full fine-tuning performance.
|
||||
|
||||
## Examples with TRL
|
||||
|
||||
Let's implement and train LoRA adapters in TRL scripts based on the core findings of the blog post. Afterwards, we'll revisit each finding in light of the TRL results.
|
||||
|
||||
### Supervised Fine-Tuning (SFT)
|
||||
|
||||
The blog post performs SFT on a range of models and datasets from the Hub, which we can reproduce in TRL.
|
||||
|
||||
| Model | Dataset |
|
||||
| --- | --- |
|
||||
| [Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B) | [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture) |
|
||||
| [Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B) | [open-thoughts/OpenThoughts-114k](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k) |
|
||||
| [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B) | [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture) |
|
||||
| [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B) | [open-thoughts/OpenThoughts-114k](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k) |
|
||||
|
||||
<hfoptions id="sft">
|
||||
<hfoption id="python">
|
||||
|
||||
We can integrate these findings with the TRL Python API like so:
|
||||
|
||||
```python
|
||||
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from trl import SFTTrainer, SFTConfig
|
||||
|
||||
dataset = load_dataset("open-thoughts/OpenThoughts-114k", split="train")
|
||||
|
||||
peft_config = LoraConfig(r=256, lora_alpha=16, target_modules="all-linear")
|
||||
|
||||
training_args = SFTConfig(
|
||||
learning_rate=2e-4,
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
num_train_epochs=1,
|
||||
report_to=["trackio"],
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model="Qwen/Qwen2.5-3B-Instruct",
|
||||
train_dataset=dataset,
|
||||
peft_config=peft_config,
|
||||
args=training_args,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="jobs">
|
||||
|
||||
```bash
|
||||
|
||||
hf jobs uv run \
|
||||
--flavor a100-large \
|
||||
--timeout 8h \
|
||||
--secrets HF_TOKEN \
|
||||
"https://raw.githubusercontent.com/huggingface/trl/main/trl/scripts/sft.py" \
|
||||
--model_name_or_path Qwen/Qwen2.5-3B-Instruct \
|
||||
--dataset_name open-thoughts/OpenThoughts-114k \
|
||||
--learning_rate 2.0e-5 \
|
||||
--num_train_epochs 1 \
|
||||
--packing \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 16 \
|
||||
--use_peft \
|
||||
--lora_r 256 \
|
||||
--lora_alpha 16 \
|
||||
--lora_target_modules all-linear \
|
||||
--output_dir Qwen2.5-3B-OpenThoughts-LoRA \
|
||||
--report_to trackio \
|
||||
--push_to_hub
|
||||
|
||||
```
|
||||
|
||||
To use Hugging Face Jobs, you will need to be logged in to the Hugging Face Hub (`hf auth login`) and have a [Pro](https://hf.co/pro), [Team](https://hf.co/enterprise), or [Enterprise](https://hf.co/enterprise) plan. Check out the [Jobs documentation](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) for more details.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="local">
|
||||
|
||||
```bash
|
||||
|
||||
uv run "https://raw.githubusercontent.com/huggingface/trl/main/trl/scripts/sft.py" \
|
||||
--model_name_or_path Qwen/Qwen2.5-3B-Instruct \
|
||||
--dataset_name open-thoughts/OpenThoughts-114k \
|
||||
--learning_rate 2.0e-5 \
|
||||
--num_train_epochs 1 \
|
||||
--packing \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 16 \
|
||||
--gradient_checkpointing \
|
||||
--eval_strategy no \
|
||||
--use_peft \
|
||||
--lora_r 256 \
|
||||
--lora_alpha 16 \
|
||||
--lora_target_modules all-linear \
|
||||
--output_dir Qwen2.5-3B-OpenThoughts-LoRA \
|
||||
--report_to trackio \
|
||||
--push_to_hub
|
||||
|
||||
```
|
||||
|
||||
To run the script locally, you will need to have `uv` installed. Check out the [uv documentation](https://docs.astral.sh/uv/) for more details.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Once training starts, you can monitor the progress in [Trackio](https://huggingface.co/trackio), which will log the URL.
|
||||
|
||||
### Reinforcement Learning (GRPO)
|
||||
|
||||
The blog post performs GRPO on a range of models and datasets from the Hub, and once again we can reproduce the results in TRL.
|
||||
|
||||
| Model | Dataset |
|
||||
| --- | --- |
|
||||
| [Llama-3.1-8B-Base](https://huggingface.co/meta-llama/Llama-3.2-1B) | [GSM8k](https://huggingface.co/datasets/openai/gsm8k) |
|
||||
| [Llama-3.1-8B-Base](https://huggingface.co/meta-llama/Llama-3.2-1B) | [DeepMath-103K](https://huggingface.co/datasets/zwhe99/DeepMath-103K) |
|
||||
| [Qwen3-8b-base](https://huggingface.co/Qwen/Qwen3-8b-base) | [DeepMath-103K](https://huggingface.co/datasets/zwhe99/DeepMath-103K) |
|
||||
|
||||
For reinforcement learning, the blog uses a math reasoning task that we can reproduce as a Python function.
|
||||
|
||||
<details>
|
||||
<summary>Reward function</summary>
|
||||
|
||||
```python
|
||||
def strip_reasoning_accuracy_reward(
|
||||
completions: list[list[dict[str, str]]], solution: list[str], **kwargs
|
||||
) -> list[Optional[float]]:
|
||||
"""Reward function that strips reasoning tags and checks mathematical accuracy.
|
||||
|
||||
This function:
|
||||
1. Extracts the content from completions
|
||||
2. Removes <think></think> tags (for reasoning that shouldn't be evaluated)
|
||||
3. Parses both the gold solution and the predicted answer
|
||||
4. Uses math_verify to check if they are mathematically equivalent
|
||||
|
||||
Args:
|
||||
completions: List of model completions, each containing a list of messages
|
||||
solution: List of ground truth solutions
|
||||
**kwargs: Additional arguments (ignored but required for trainer compatibility)
|
||||
|
||||
Returns:
|
||||
List of rewards where:
|
||||
- 1.0 if the answer is correct
|
||||
- 0.0 if the answer is incorrect
|
||||
- None if the solution is not parseable (skips this example)
|
||||
"""
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
rewards = []
|
||||
|
||||
for content, sol in zip(contents, solution):
|
||||
# Strip reasoning tags from completion
|
||||
while "<think>" in content and "</think>" in content:
|
||||
start = content.find("<think>")
|
||||
end = content.find("</think>", start)
|
||||
if start != -1 and end != -1:
|
||||
content = content[:start] + content[end + len("</think>") :]
|
||||
else:
|
||||
break
|
||||
|
||||
# Parse gold solution
|
||||
gold_parsed = parse(
|
||||
f"${sol}$",
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
boxed_match_priority=0, try_extract_without_anchor=True
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
if len(gold_parsed) != 0:
|
||||
# We require the answer to be provided in correct latex (no malformed operators)
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
boxed_match_priority=0,
|
||||
normalization_config=NormalizationConfig(
|
||||
basic_latex=True,
|
||||
units=True,
|
||||
malformed_operators=False,
|
||||
nits=False,
|
||||
boxed=True,
|
||||
),
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
|
||||
# Compute binary rewards if verifiable, `None` otherwise to skip this example
|
||||
try:
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception as e:
|
||||
print(
|
||||
f"verify failed: {e}, answer: {answer_parsed}, gold: {gold_parsed}"
|
||||
)
|
||||
reward = None
|
||||
else:
|
||||
# If the gold solution is not parseable, we assign `None` to skip this example
|
||||
reward = None
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<hfoptions id="grpo">
|
||||
<hfoption id="python">
|
||||
|
||||
We can implement these recommendations with the TRL Python API like so:
|
||||
|
||||
```python
|
||||
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
|
||||
dataset = load_dataset("HuggingFaceH4/OpenR1-Math-220k-default-verified", split="train")
|
||||
|
||||
def strip_reasoning_accuracy_reward(completions, **kwargs):
|
||||
"""Reward function that strips reasoning and accuracy scores from the model outputs."""
|
||||
|
||||
...
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=1,
|
||||
lora_alpha=32,
|
||||
target_modules="all-linear"
|
||||
)
|
||||
|
||||
training_args = GRPOConfig(
|
||||
learning_rate=5e-5,
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
num_train_epochs=1,
|
||||
num_generations=8,
|
||||
generation_batch_size=8,
|
||||
report_to=["trackio"],
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen3-0.6B",
|
||||
reward_funcs=strip_reasoning_accuracy_reward,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> This snippet skips the reward function which is defined above to keep the example concise.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="jobs">
|
||||
|
||||
```bash
|
||||
|
||||
hf jobs uv run \
|
||||
--flavor a100-large \
|
||||
--timeout 4h \
|
||||
--secrets HF_TOKEN \
|
||||
--env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
|
||||
"https://huggingface.co/datasets/burtenshaw/lora-without-regrets/resolve/main/grpo.py" \
|
||||
--model_name_or_path Qwen/Qwen3-0.6B \
|
||||
--dataset_name HuggingFaceH4/OpenR1-Math-220k-default-verified \
|
||||
--output_dir grpo-full-qwen3-0.6b \
|
||||
--learning_rate 1.0e-6 \
|
||||
--lr_scheduler_type cosine \
|
||||
--warmup_ratio 0.0 \
|
||||
--max_grad_norm 1.0 \
|
||||
--beta 0.0 \
|
||||
--max_prompt_length 1024 \
|
||||
--max_completion_length 4096 \
|
||||
--num_generations 16 \
|
||||
--generation_batch_size 16 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--num_train_epochs 1 \
|
||||
--lora_r 1 \
|
||||
--lora_alpha 32 \
|
||||
--lora_dropout 0.0 \
|
||||
--lora_target_modules all-linear \
|
||||
--vllm_mode colocate \
|
||||
--save_strategy steps \
|
||||
--save_steps 50 \
|
||||
--save_total_limit 1 \
|
||||
--logging_steps 1 \
|
||||
--max_steps 200 \
|
||||
--report_to trackio
|
||||
```
|
||||
|
||||
To use Hugging Face Jobs, you will need to be logged in to the Hugging Face Hub (`hf auth login`) and have a [Pro](https://hf.co/pro), [Team](https://hf.co/enterprise), or [Enterprise](https://hf.co/enterprise) plan. Check out the [Jobs documentation](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) for more details.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="local">
|
||||
|
||||
```bash
|
||||
uv run "https://huggingface.co/datasets/burtenshaw/lora-without-regrets/resolve/main/grpo.py" \
|
||||
--model_name_or_path Qwen/Qwen3-0.6B \
|
||||
--dataset_name HuggingFaceH4/OpenR1-Math-220k-default-verified \
|
||||
--output_dir grpo-full-qwen3-0.6b \
|
||||
--learning_rate 1.0e-6 \
|
||||
--lr_scheduler_type cosine \
|
||||
--warmup_ratio 0.0 \
|
||||
--max_grad_norm 1.0 \
|
||||
--beta 0.0 \
|
||||
--max_prompt_length 1024 \
|
||||
--max_completion_length 4096 \
|
||||
--num_generations 16 \
|
||||
--generation_batch_size 16 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--per_device_train_batch_size 1 \
|
||||
--num_train_epochs 1 \
|
||||
--lora_r 1 \
|
||||
--lora_alpha 32 \
|
||||
--lora_dropout 0.0 \
|
||||
--lora_target_modules all-linear \
|
||||
--vllm_mode colocate \
|
||||
--save_strategy steps \
|
||||
--save_steps 50 \
|
||||
--save_total_limit 1 \
|
||||
--logging_steps 1 \
|
||||
--max_steps 200 \
|
||||
--report_to trackio
|
||||
```
|
||||
|
||||
To run the script locally, you will need to have `uv` installed. Check out the [uv documentation](https://docs.astral.sh/uv/) for more details.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
The reinforcement learning script with GRPO is implemented as a custom script in TRL, which uses the reward function shown above. You can review it at [`grpo.py`](https://huggingface.co/datasets/burtenshaw/lora-without-regrets/blob/main/grpo.py) - Reinforcement learning with LoRA best practices
|
||||
|
||||
## Key findings in optimizing LoRA
|
||||
|
||||
The authors recommend applying LoRA to all weight matrices rather than limiting it to attention layers, as increasing the rank does not compensate for this restriction. In TRL, this can be configured using `--lora_target_modules all-linear` to apply LoRA to all weight matrices.
|
||||
|
||||
We were able to reproduce the results of the blog post using TRL and the SmolLM3 model. We trained the model for 500 steps on the [Math 220k dataset](https://huggingface.co/datasets/HuggingFaceH4/OpenR1-Math-220k-default-verified) with the reward function and configuration above. As you can see in the figure below, the LoRA model's average train reward curve matches the full fine-tuning curve.
|
||||
|
||||

|
||||
|
||||
And most importantly, the LoRA model uses significantly less memory than the full fine-tuning model, as we can see in the figure below.
|
||||
|
||||

|
||||
|
||||
Here are the parameters we used to train the above models
|
||||
|
||||
| Parameter | LoRA | Full FT |
|
||||
| --- | --- | --- |
|
||||
| `--model_name_or_path` | HuggingFaceTB/SmolLM3-3B | HuggingFaceTB/SmolLM3-3B |
|
||||
| `--dataset_name` | HuggingFaceH4/OpenR1-Math-220k-default-verified | HuggingFaceH4/OpenR1-Math-220k-default-verified |
|
||||
| `--learning_rate` | 1.0e-5 | 1.0e-6 |
|
||||
| `--max_prompt_length` | 1024 | 1024 |
|
||||
| `--max_completion_length` | 4096 | 4096 |
|
||||
| `--lora_r` | 1 | - |
|
||||
| `--lora_alpha` | 32 | - |
|
||||
| `--lora_dropout` | 0.0 | - |
|
||||
| `--lora_target_modules` | all-linear | - |
|
||||
|
||||
Let's break down the key findings of the blog post and how we were able to reproduce them.
|
||||
|
||||
### 1. *LoRA performs better when applied to all weight matrices*
|
||||
|
||||
The authors recommend applying LoRA to all weight matrices rather than limiting it to attention layers, as increasing the rank does not compensate for this restriction.
|
||||
|
||||

|
||||
|
||||
Attention-only LoRA underperforms even when using a higher rank to match parameter count. In TRL, this can be configured using `--lora_target_modules all-linear` to apply LoRA to all weight matrices. In Python, we can do this like so:
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
|
||||
peft_config = LoraConfig(target_modules="all-linear")
|
||||
```
|
||||
|
||||
### 2. *The adapter needs sufficient capacity to learn from the dataset*
|
||||
|
||||
The blog post recommends using a sufficient LoRA rank to learn from the dataset. The rank determines the number of trainable parameters in the LoRA adapter. Therefore, "For datasets that exceed LoRA capacity, LoRA underperforms FullFT".
|
||||
|
||||

|
||||
|
||||
In the TRL script, we could use `--lora_r` to set the rank and adapt it based on the task and dataset we're training on. The blog post recommends the following ranks based on the task and dataset size:
|
||||
|
||||
Reinforcement learning tasks typically require lower capacity, so smaller LoRA ranks can be used. This is because policy gradient algorithms extract roughly ~1 bit of information per episode, demanding minimal parameter capacity.
|
||||
|
||||
The blog post defines the ideal dataset size for LoRA to match full fine-tuning as "Post-training scale". Which we can use to determine the recommended rank for SFT and RL LoRAs as:
|
||||
|
||||
| Task Type | Dataset Size | Recommended Rank |
|
||||
| --- | --- | --- |
|
||||
| **SFT** | Post-training scale | 256 |
|
||||
| **RL** | Any size | 1-32 |
|
||||
|
||||
### 3. *"FullFT and high-rank LoRAs have similar learning curves"*
|
||||
|
||||
Counterintuitively, the blog post recommends using a higher learning rate than for full fine-tuning. In the table above, we used 1.0e-5 for LoRA and 1.0e-6 for full fine-tuning. In the TRL script, we could use `--learning_rate` to set the learning rate. The \\( \frac{1}{r} \\) scaling in LoRA makes the optimal learning rate approximately rank-independent.
|
||||
|
||||

|
||||
|
||||
### 4. *"In some scenarios, LoRA is less tolerant of large batch sizes than full fine-tuning."*
|
||||
|
||||
The blog post recommends using an effective batch size < 32 because the authors found LoRA to be less tolerant of large batch sizes. This could not be mitigated by increasing the LoRA rank. In the TRL script, we could use `--per_device_train_batch_size` and `--gradient_accumulation_steps` to set the batch size.
|
||||
|
||||

|
||||
|
||||
## Takeaways
|
||||
|
||||
Using TRL, you can efficiently implement LoRA adapters to match full fine-tuning performance, applying the core insights (targeting all weight matrices, choosing the right rank, and managing batch size and learning rate) without the heavy compute cost of FullFT.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{schulman2025lora,
|
||||
title = {{LoRA Without Regret}},
|
||||
author = {John Schulman and Thinking Machines Lab},
|
||||
year = 2025,
|
||||
journal = {Thinking Machines Lab: Connectionism},
|
||||
doi = {10.64434/tml.20250929},
|
||||
note = {https://thinkingmachines.ai/blog/lora/}
|
||||
}
|
||||
```
|
@ -8,7 +8,6 @@ With the `AutoModelForCausalLMWithValueHead` class TRL supports all decoder mode
|
||||
|
||||
## AutoModelForCausalLMWithValueHead
|
||||
|
||||
|
||||
[[autodoc]] AutoModelForCausalLMWithValueHead
|
||||
- __init__
|
||||
- forward
|
||||
@ -25,4 +24,4 @@ With the `AutoModelForCausalLMWithValueHead` class TRL supports all decoder mode
|
||||
|
||||
## create_reference_model
|
||||
|
||||
[[autodoc]] create_reference_model
|
||||
[[autodoc]] create_reference_model
|
||||
|
@ -14,11 +14,11 @@ You need to address this approach in three stages that we summarize as follows:
|
||||
2- Train a reward model using `peft`. This is required in order to re-use the adapter during the RL optimisation process (step 3 below). We show an example of leveraging the `RewardTrainer` from TRL in [this example](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py)
|
||||
3- Fine tune new adapters on the base model using PPO and the reward adapter. ("0 abstraction RL")
|
||||
|
||||
Make sure to use the same model (i.e. same architecture and same weights) for the stages 2 & 3.
|
||||
Make sure to use the same model (i.e. same architecture and same weights) for the stages 2 & 3.
|
||||
|
||||
## Quickstart
|
||||
|
||||
Let us assume you have trained your reward adapter on `llama-7b` model using `RewardTrainer` and pushed the weights on the hub under `trl-lib/llama-7b-hh-rm-adapter`.
|
||||
Let us assume you have trained your reward adapter on `llama-7b` model using `RewardTrainer` and pushed the weights on the hub under `trl-lib/llama-7b-hh-rm-adapter`.
|
||||
When doing PPO, before passing the model to `PPOTrainer` create your model as follows:
|
||||
|
||||
```python
|
||||
@ -48,6 +48,7 @@ trainer = PPOTrainer(
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
Then inside your PPO training loop, call the `compute_reward_score` method by accessing the `model` attribute from `PPOTrainer`.
|
||||
|
||||
```python
|
||||
@ -56,9 +57,9 @@ rewards = trainer.model.compute_reward_score(**inputs)
|
||||
|
||||
## Advanced usage
|
||||
|
||||
### Control on the adapter name
|
||||
### Control on the adapter name
|
||||
|
||||
If you are familiar with the `peft` library, you know that you can use multiple adapters inside the same model. What you can do is train multiple adapters on the same base model to fine-tune on different policies.
|
||||
If you are familiar with the `peft` library, you know that you can use multiple adapters inside the same model. What you can do is train multiple adapters on the same base model to fine-tune on different policies.
|
||||
In this case, you want to be able to control the adapter name you want to activate back, after retrieving the reward. For that, simply pass the appropriate `adapter_name` to `ppo_adapter_name` argument when calling `compute_reward_score`.
|
||||
|
||||
```python
|
||||
@ -71,6 +72,7 @@ rewards = trainer.model.compute_reward_score(**inputs, ppo_adapter_name=adapter_
|
||||
|
||||
For more memory efficient fine-tuning, you can load your base model in 8-bit or 4-bit while keeping the adapters in the default precision (float32).
|
||||
Just pass the appropriate arguments (i.e. `load_in_8bit=True` or `load_in_4bit=True`) to `AutoModelForCausalLMWithValueHead.from_pretrained` as follows (assuming you have installed `bitsandbytes`):
|
||||
|
||||
```python
|
||||
model_name = "llama-7b"
|
||||
rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"
|
||||
@ -88,7 +90,7 @@ model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model_name,
|
||||
peft_config=lora_config,
|
||||
reward_adapter=rm_adapter_id,
|
||||
load_in_8bit=True,
|
||||
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
|
||||
)
|
||||
|
||||
...
|
||||
|
@ -1,16 +1,16 @@
|
||||
# Nash-MD Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=nash-md,trl)
|
||||
[](https://huggingface.co/models?other=nash-md,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
Nash-MD was proposed in the paper [Nash Learning from Human Feedback](https://huggingface.co/papers/2312.00886) by Rémi Munos, [Michal Valko](https://huggingface.co/misovalko), Daniele Calandriello, Mohammad Gheshlaghi Azar, Mark Rowland, Daniel Guo, Yunhao Tang, Matthieu Geist, Thomas Mésnard, and Andrea Michi.
|
||||
Nash-MD was proposed in the paper [Nash Learning from Human Feedback](https://huggingface.co/papers/2312.00886) by Rémi Munos, [Michal Valko](https://huggingface.co/misovalko), Daniele Calandriello, Mohammad Gheshlaghi Azar, Mark Rowland, Daniel Guo, Yunhao Tang, Matthieu Geist, Thomas Mésnard, and Andrea Michi.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Reinforcement learning from human feedback (RLHF) has emerged as the main paradigm for aligning large language models (LLMs) with human preferences. Typically, RLHF involves the initial step of learning a reward model from human feedback, often expressed as preferences between pairs of text generations produced by a pre-trained LLM. Subsequently, the LLM's policy is fine-tuned by optimizing it to maximize the reward model through a reinforcement learning algorithm. However, an inherent limitation of current reward models is their inability to fully represent the richness of human preferences and their dependency on the sampling distribution. In this study, we introduce an alternative pipeline for the fine-tuning of LLMs using pairwise human feedback. Our approach entails the initial learning of a preference model, which is conditioned on two inputs given a prompt, followed by the pursuit of a policy that consistently generates responses preferred over those generated by any competing policy, thus defining the Nash equilibrium of this preference model. We term this approach Nash learning from human feedback (NLHF). In the context of a tabular policy representation, we present a novel algorithmic solution, Nash-MD, founded on the principles of mirror descent. This algorithm produces a sequence of policies, with the last iteration converging to the regularized Nash equilibrium. Additionally, we explore parametric representations of policies and introduce gradient descent algorithms for deep-learning architectures. To demonstrate the effectiveness of our approach, we present experimental results involving the fine-tuning of a LLM for a text summarization task. We believe NLHF offers a compelling avenue for preference learning and policy optimization with the potential of advancing the field of aligning LLMs with human preferences.
|
||||
|
||||
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif) and [Daniil Tiapkin](https://huggingface.co/dtiapkin), [Pierre Ménard](https://huggingface.co/menardprr), Daniele Calandriello and [Quentin Gallouédec](https://huggingface.co/qgallouedec).
|
||||
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif) and [Daniil Tiapkin](https://huggingface.co/dtiapkin), [Pierre Ménard](https://huggingface.co/menardprr), Daniele Calandriello and [Quentin Gallouédec](https://huggingface.co/qgallouedec).
|
||||
|
||||
## Quick start
|
||||
|
||||
@ -85,11 +85,8 @@ Instead of a judge, you can chose to use a reward model -- see [Reward Bench](ht
|
||||
)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training.
|
||||
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training.
|
||||
|
||||
### Encourage EOS token generation
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
# Online DPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=online-dpo,trl)
|
||||
[](https://huggingface.co/models?other=online-dpo,trl)
|
||||
|
||||
## Overview
|
||||
## Overview
|
||||
|
||||
Online DPO was proposed in [Direct Language Model Alignment from Online AI Feedback](https://huggingface.co/papers/2402.04792) by Shangmin Guo, Biao Zhang, Tianlin Liu, Tianqi Liu, Misha Khalman, Felipe Llinares, Alexandre Rame, Thomas Mesnard, Yao Zhao, Bilal Piot, Johan Ferret, and Mathieu Blondel.
|
||||
Online DPO was proposed in [Direct Language Model Alignment from Online AI Feedback](https://huggingface.co/papers/2402.04792) by Shangmin Guo, Biao Zhang, Tianlin Liu, Tianqi Liu, Misha Khalman, Felipe Llinares, Alexandre Rame, Thomas Mesnard, Yao Zhao, Bilal Piot, Johan Ferret, and Mathieu Blondel.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
@ -112,7 +112,6 @@ This callback logs the model's generated completions directly to Weights & Biase
|
||||
|
||||

|
||||
|
||||
|
||||
## Example script
|
||||
|
||||
We provide an example script to train a model using the online DPO method. The script is available in [`examples/scripts/dpo_online.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_online.py)
|
||||
@ -153,8 +152,7 @@ While training and evaluating, we record the following reward metrics. Here is a
|
||||
|
||||
To validate the online DPO implementation works, we ran experiments with the Pythia 1B, 2.8B, and 6.9B models on a single node of 8 x H100s. Here are the commands we used to run the experiments. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
|
||||
|
||||
|
||||
```
|
||||
```shell
|
||||
# 1B Online DPO experiment
|
||||
accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml \
|
||||
examples/scripts/dpo_online.py \
|
||||
@ -213,9 +211,8 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml
|
||||
|
||||
Checkpoints and experiment tracking are available at:
|
||||
|
||||
- [🤗 Model checkpoints](https://huggingface.co/collections/trl-lib/online-dpo-66acd3fa38a331a9cd457b07)
|
||||
- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/reports/Online-DPO-experiments-for-TL-DR-summarisation--Vmlldzo5MTczMDU0)
|
||||
|
||||
* [🤗 Model checkpoints](https://huggingface.co/collections/trl-lib/online-dpo-66acd3fa38a331a9cd457b07)
|
||||
* [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/reports/Online-DPO-experiments-for-TL-DR-summarisation--Vmlldzo5MTczMDU0)
|
||||
|
||||
To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR.
|
||||
For more information on how to use judges, see [Judges](judges).
|
||||
@ -259,8 +256,6 @@ plt.tight_layout()
|
||||
plt.show()
|
||||
```
|
||||
|
||||

|
||||
|
||||
The online DPO checkpoint gets increasingly more win rate as we scale up the model sizes. This is a good sign that the online DPO implementation is working as intended.
|
||||
|
||||
## OnlineDPOTrainer
|
||||
|
@ -1,6 +1,6 @@
|
||||
# ORPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=orpo,trl) [](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment)
|
||||
[](https://huggingface.co/models?other=orpo,trl) [](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment)
|
||||
|
||||
## Overview
|
||||
|
||||
@ -54,7 +54,7 @@ accelerate launch train_orpo.py
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
|
||||
|
||||

|
||||

|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-ORPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
@ -64,11 +64,11 @@ What is the best programming language?
|
||||
|
||||
<strong><span style="color: blue;"><trl-lib/Qwen2-0.5B-ORPO>:</span></strong>
|
||||
It's challenging to determine the best programming language as no one language is perfect, as the complexity of a task and the type of project are significant factors. Some popular languages include Java, Python, JavaScript, and
|
||||
C++. If you have specific needs or requirements for a specific project, it's important to choose the language that best suits those needs.
|
||||
C++. If you have specific needs or requirements for a specific project, it's important to choose the language that best suits those needs.
|
||||
|
||||
Here are some other factors to consider when choosing a programming language for a project:
|
||||
|
||||
<strong><span style="color: green;">• Language proficiency:</span></strong> A good programming language is more likely to be easy to understand and use, and will allow developers to collaborate on projects more efficiently.
|
||||
<strong><span style="color: green;">• Language proficiency:</span></strong> A good programming language is more likely to be easy to understand and use, and will allow developers to collaborate on projects more efficiently.
|
||||
<strong><span style="color: green;">• Ease of use:</span></strong> There are tools and libraries available to make programming more accessible, so developers should choose a language that can help them get started easier.
|
||||
<strong><span style="color: green;">• Code readability:</span></strong> A clear and concise codebase should be easy to read and understand, especially when working with large projects.
|
||||
<strong><span style="color: green;">• Tool and framework support:</span></strong> There are numerous libraries available for Python, Java, and JavaScript, along with tools like IDEs and static code analysis tools.
|
||||
@ -118,7 +118,7 @@ While training and evaluating, we record the following reward metrics:
|
||||
- `log_odds_chosen`: the mean log odds ratio of the chosen responses over the rejected responses
|
||||
- `log_odds_ratio`: the mean of the `log(sigmoid(log_odds_chosen))`
|
||||
- `nll_loss`: the mean negative log likelihood loss from the SFT part of the loss over chosen responses
|
||||
|
||||
|
||||
## ORPOTrainer
|
||||
|
||||
[[autodoc]] ORPOTrainer
|
||||
|
@ -1,10 +1,7 @@
|
||||
# Paper Index
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Section under construction. Feel free to contribute!
|
||||
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> Section under construction. Feel free to contribute!
|
||||
|
||||
## Group Relative Policy Optimization
|
||||
|
||||
@ -32,6 +29,8 @@ training_args = GRPOConfig(
|
||||
|
||||
Note that this method only has an effect when training goes slightly off-policy—for example, when `steps_per_generation > gradient_accumulation_steps` or `num_iterations > 1`. Otherwise, it is effectively equivalent to no modification.
|
||||
|
||||
TRL also provide an experimental implementation of GSPO-token, see [Experimental - GSPO-Token](experimental#gspo-token).
|
||||
|
||||
#### Policy ratio: GRPO vs. GSPO
|
||||
|
||||
In GSPO, the policy ratio is defined at the sequence-level. In other words, it is the ratio between the probability of the current policy generating a sequence over the old policy generating that same sequence.
|
||||
@ -171,7 +170,7 @@ $$
|
||||
}
|
||||
$$
|
||||
|
||||
Despite \\( \textcolor{red}{\pi_{\text{inference}}} \\) and \\( \textcolor{blue}{\pi_{\text{training}}} \\) sharing the same model parameters \\( \theta \\), they can produce significantly different token probabilities. This unexpected behavior implicitly breaks the on-policy assumption, and silently turns training off-policy.
|
||||
Despite \\( \textcolor{red}{\pi_{\text{inference}}} \\) and \\( \textcolor{blue}{\pi_{\text{training}}} \\) sharing the same model parameters \\( \theta \\), they can produce significantly different token probabilities. This unexpected behavior implicitly breaks the on-policy assumption, and silently turns training off-policy.
|
||||
|
||||
Truncated Importance Sampling (TIS) addresses this issue by adapting the model update via importance-sampling correction. The gradient computation of the aforementioned PPO objective becomes
|
||||
|
||||
@ -202,6 +201,12 @@ training_args = GRPOConfig(
|
||||
)
|
||||
```
|
||||
|
||||
### Sample More to Think Less: Group Filtered Policy Optimization for Concise Reasoning
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2508.09726
|
||||
|
||||
See [Experimental - GFPO](experimental#gfpo).
|
||||
|
||||
## Direct Policy Optimization
|
||||
|
||||
Papers relating to the [`DPOTrainer`]
|
||||
@ -333,7 +338,7 @@ training_args = DPOConfig(
|
||||
)
|
||||
```
|
||||
|
||||
For the unpaired version, the user should utilize `BCOConfig` and `BCOTrainer`.
|
||||
For the unpaired version, the user should utilize [`BCOConfig`] and [`BCOTrainer`].
|
||||
|
||||
### Self-Play Preference Optimization for Language Model Alignment
|
||||
|
||||
@ -453,10 +458,7 @@ trainer = SFTTrainer(
|
||||
Dynamic Fine-Tuning (DFT) improves the generalization of Large Language Models (LLMs) by dynamically rescaling gradients, outperforming standard Supervised Fine-Tuning (SFT) and showing competitive results in offline reinforcement learning.
|
||||
|
||||
$$
|
||||
\mathcal{L}_{\text{DFT}}(\theta)
|
||||
= \mathbb{E}_{(x,y) \sim \mathcal{D}} \left[ - \sum_{t=1}^{|y|}
|
||||
\textcolor{red}{\text{sg}\big(\pi_\theta(y_t \mid y_{<t}, x)\big)}
|
||||
\; \log \pi_\theta(y_t \mid y_{<t}, x) \right]
|
||||
\mathcal{L}_{\text{DFT}}(\theta) = \mathbb{E}_{(x,y) \sim \mathcal{D}} \left[ - \sum_{t=1}^{|y|} \textcolor{red}{\text{sg}\big(\pi_\theta(y_t \mid y_{<t}, x)\big)} \; \log \pi_\theta(y_t \mid y_{<t}, x) \right]
|
||||
$$
|
||||
|
||||
where \\( \text{sg}(\cdot) \\) is the stop-gradient operator. To use DFT with SFT as described in the paper, you can use the `loss_type="dft"` argument:
|
||||
@ -528,3 +530,53 @@ training_args = CPOConfig(
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
## Reward Modeling
|
||||
|
||||
Papers relating to the [`RewardTrainer`]
|
||||
|
||||
### Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2312.09244
|
||||
|
||||
This paper proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs and thereby resolving the issue of underdetermination.
|
||||
|
||||
$$
|
||||
\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-)) \textcolor{red}{- \eta \cdot (r_\theta(x, y^+) + r_\theta(x, y^-))^2} \right].
|
||||
$$
|
||||
|
||||
To use this auxiliary loss with [`RewardTrainer`], you can use the `center_rewards_coefficient` argument in [`RewardConfig`] as follows:
|
||||
|
||||
```python
|
||||
from trl import RewardConfig
|
||||
|
||||
training_args = RewardConfig(
|
||||
center_rewards_coefficient=0.01, # η in the paper
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
### Llama 2: Open Foundation and Fine-Tuned Chat Models
|
||||
|
||||
**📜 Paper**: https://huggingface.co/papers/2307.09288
|
||||
|
||||
In this paper, the authors propose to leverage their preference ratings being decomposed as a scale of four points (e.g., _significantly better_) to provide more informative feedback to the reward model. This is done by adding a margin to the loss function, which encourages the reward model to assign larger gaps in scores for pairs with higher preference ratings.
|
||||
|
||||
$$
|
||||
\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-,\textcolor{red}{m}) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-) \textcolor{red}{- m}) \right].
|
||||
$$
|
||||
|
||||
You can add a margin to the loss by adding a `margin` column to the dataset. The following example shows how to set up a the "Margin Small" setting of the paper.
|
||||
|
||||
```python
|
||||
def add_margin(example):
|
||||
preference_to_margin = {
|
||||
"significantly better": 1.0,
|
||||
"better": 2.0/3.0,
|
||||
"slightly better": 1.0/3.0,
|
||||
"negligibly better / unsure": 0.0,
|
||||
}
|
||||
return {"margin": preference_to_margin[example["preference_label"]]}
|
||||
|
||||
dataset = dataset.map(add_margin)
|
||||
```
|
||||
|
@ -3,17 +3,10 @@
|
||||
The notebooks and scripts in these examples show how to use Low Rank Adaptation (LoRA) to fine-tune models in a memory efficient manner. Most of PEFT methods supported in peft library but note that some PEFT methods such as Prompt tuning are not supported.
|
||||
For more information on LoRA, see the [original paper](https://huggingface.co/papers/2106.09685).
|
||||
|
||||
Here's an overview of the `peft`-enabled notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples):
|
||||
|
||||
| File | Task | Description | Colab link |
|
||||
|---|---| --- |
|
||||
| [`stack_llama/rl_training.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py) | RLHF | Distributed fine-tuning of the 7b parameter LLaMA models with a learned reward model and `peft`. | |
|
||||
| [`stack_llama/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py) | Reward Modeling | Distributed training of the 7b parameter LLaMA reward model with `peft`. | |
|
||||
| [`stack_llama/supervised_finetuning.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py) | SFT | Distributed instruction/supervised fine-tuning of the 7b parameter LLaMA model with `peft`. | |
|
||||
|
||||
## Installation
|
||||
|
||||
Note: peft is in active development, so we install directly from their Github page.
|
||||
Peft also relies on the latest version of transformers.
|
||||
Peft also relies on the latest version of transformers.
|
||||
|
||||
```bash
|
||||
pip install trl[peft]
|
||||
@ -27,7 +20,7 @@ Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scr
|
||||
|
||||
## How to use it?
|
||||
|
||||
Simply declare a `PeftConfig` object in your script and pass it through `.from_pretrained` to load the TRL+PEFT model.
|
||||
Simply declare a [`~peft.PeftConfig`] object in your script and pass it through `.from_pretrained` to load the TRL+PEFT model.
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
@ -47,7 +40,9 @@ model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
peft_config=lora_config,
|
||||
)
|
||||
```
|
||||
|
||||
And if you want to load your model in 8bit precision:
|
||||
|
||||
```python
|
||||
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
config.model_name,
|
||||
@ -55,7 +50,9 @@ pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
peft_config=lora_config,
|
||||
)
|
||||
```
|
||||
|
||||
... or in 4bit precision:
|
||||
|
||||
```python
|
||||
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
config.model_name,
|
||||
@ -64,7 +61,6 @@ pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
## Launch scripts
|
||||
|
||||
The `trl` library is powered by `accelerate`. As such it is best to configure and launch trainings with the following commands:
|
||||
@ -77,6 +73,7 @@ accelerate launch examples/scripts/ppo.py --use_peft # launch`es training
|
||||
## Using `trl` + `peft` and Data Parallelism
|
||||
|
||||
You can scale up to as many GPUs as you want, as long as you are able to fit the training process in a single device. The only tweak you need to apply is to load the model as follows:
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
...
|
||||
@ -94,7 +91,9 @@ pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
peft_config=lora_config,
|
||||
)
|
||||
```
|
||||
|
||||
And if you want to load your model in 8bit precision:
|
||||
|
||||
```python
|
||||
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
config.model_name,
|
||||
@ -102,7 +101,9 @@ pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
load_in_8bit=True,
|
||||
)
|
||||
```
|
||||
|
||||
... or in 4bit precision:
|
||||
|
||||
```python
|
||||
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
config.model_name,
|
||||
@ -110,21 +111,20 @@ pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
load_in_4bit=True,
|
||||
)
|
||||
```
|
||||
|
||||
Finally, make sure that the rewards are computed on correct device as well, for that you can use `ppo_trainer.model.current_device`.
|
||||
|
||||
## Naive pipeline parallelism (NPP) for large models (>60B models)
|
||||
|
||||
The `trl` library also supports naive pipeline parallelism (NPP) for large models (>60B models). This is a simple way to parallelize the model across multiple GPUs.
|
||||
The `trl` library also supports naive pipeline parallelism (NPP) for large models (>60B models). This is a simple way to parallelize the model across multiple GPUs.
|
||||
This paradigm, termed as "Naive Pipeline Parallelism" (NPP) is a simple way to parallelize the model across multiple GPUs. We load the model and the adapters across multiple GPUs and the activations and gradients will be naively communicated across the GPUs. This supports `int8` models as well as other `dtype` models.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-npp.png">
|
||||
</div>
|
||||

|
||||
|
||||
### How to use NPP?
|
||||
|
||||
Simply load your model with a custom `device_map` argument on the `from_pretrained` to split your model across multiple devices. Check out this [nice tutorial](https://github.com/huggingface/blog/blob/main/accelerate-large-models.md) on how to properly create a `device_map` for your model.
|
||||
|
||||
Simply load your model with a custom `device_map` argument on the `from_pretrained` to split your model across multiple devices. Check out this [nice tutorial](https://github.com/huggingface/blog/blob/main/accelerate-large-models.md) on how to properly create a `device_map` for your model.
|
||||
|
||||
Also make sure to have the `lm_head` module on the first GPU device as it may throw an error if it is not on the first device. As this time of writing, you need to install the `main` branch of `accelerate`: `pip install git+https://github.com/huggingface/accelerate.git@main` and `peft`: `pip install git+https://github.com/huggingface/peft.git@main`.
|
||||
|
||||
### Launch scripts
|
||||
|
@ -1,10 +1,11 @@
|
||||
# PPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=ppo,trl)
|
||||
[](https://huggingface.co/models?other=ppo,trl)
|
||||
|
||||
TRL supports training LLMs with [Proximal Policy Optimization (PPO)](https://huggingface.co/papers/1707.06347).
|
||||
|
||||
References:
|
||||
|
||||
- [Fine-Tuning Language Models from Human Preferences](https://github.com/openai/lm-human-preferences)
|
||||
- [Learning to Summarize from Human Feedback](https://github.com/openai/summarize-from-feedback)
|
||||
- [The N Implementation Details of RLHF with PPO](https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo)
|
||||
@ -31,49 +32,45 @@ python examples/scripts/ppo/ppo.py \
|
||||
--missing_eos_penalty 1.0
|
||||
```
|
||||
|
||||
|
||||
## Explanation of the logged metrics
|
||||
|
||||
The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35)
|
||||
|
||||
* `eps`: Tracks the number of episodes per second.
|
||||
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
|
||||
* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
|
||||
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
|
||||
* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
|
||||
* `objective/scores`: The mean scores returned by the reward model / environment.
|
||||
* `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
|
||||
* `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
|
||||
* `loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
|
||||
* `loss/value_avg`: The average value loss, indicating the difference between the predicted value and the actual reward.
|
||||
* `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function.
|
||||
* `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
|
||||
* `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
|
||||
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
|
||||
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
|
||||
* `lr`: lr: The current learning rate used by the optimizer.
|
||||
* `episode`: episode: The current episode count in the training process.
|
||||
|
||||
- `eps`: Tracks the number of episodes per second.
|
||||
- `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
|
||||
- `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
|
||||
- `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
|
||||
- `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
|
||||
- `objective/scores`: The mean scores returned by the reward model / environment.
|
||||
- `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
|
||||
- `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
|
||||
- `loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
|
||||
- `loss/value_avg`: The average value loss, indicating the difference between the predicted value and the actual reward.
|
||||
- `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function.
|
||||
- `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
|
||||
- `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
|
||||
- `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
|
||||
- `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
|
||||
- `lr`: lr: The current learning rate used by the optimizer.
|
||||
- `episode`: episode: The current episode count in the training process.
|
||||
|
||||
## Cookbook
|
||||
|
||||
* Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
|
||||
* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try understand why this is happening and try to fix it.
|
||||
* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
|
||||
* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
|
||||
* Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions.
|
||||
|
||||
- Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
|
||||
- Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try understand why this is happening and try to fix it.
|
||||
- Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
|
||||
- Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
|
||||
- Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions.
|
||||
|
||||
## What is my model doing exactly?
|
||||
|
||||
To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations.
|
||||
|
||||

|
||||

|
||||
|
||||
In the logs the sampled generations look like
|
||||
|
||||
In the logs the sampled generations look like
|
||||
|
||||
```
|
||||
```txt
|
||||
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
|
||||
┃ query ┃ model response ┃ score ┃
|
||||
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
|
||||
@ -177,7 +174,7 @@ This PPO implementation is based on the [The N+ Implementation Details of RLHF w
|
||||
|
||||
To validate the PPO implementation works, we ran experiment on the 1B model. Here are the command we used to run the experiment. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
|
||||
|
||||
```
|
||||
```shell
|
||||
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
|
||||
examples/scripts/ppo/ppo_tldr.py \
|
||||
--output_dir models/minimal/ppo_tldr \
|
||||
@ -212,8 +209,7 @@ The PPO checkpoint gets a 64.7% preferred rate vs the 33.0% preference rate of t
|
||||
|
||||
Metrics:
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
```bash
|
||||
# pip install openrlbenchmark==0.2.1a5
|
||||
|
@ -1,12 +1,9 @@
|
||||
# PRM Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=prm,trl)
|
||||
[](https://huggingface.co/models?other=prm,trl)
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
PRM Trainer is an experimental API which is subject to change at any time.
|
||||
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> PRM Trainer is an experimental API which is subject to change at any time.
|
||||
|
||||
## Overview
|
||||
|
||||
@ -18,7 +15,6 @@ The abstract from the paper is the following:
|
||||
|
||||
This post-training method was contributed by [Gaetan Lopez](https://github.com/gaetanlop), [Lewis Tunstall](https://huggingface.co/lewtun), [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Agustín Piqueres](https://huggingface.co/plaguss).
|
||||
|
||||
|
||||
## Quick start
|
||||
|
||||
This example demonstrates how to train a model using the PRM method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) as the base model. We use the stepwise supervision data from the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd). You can view the data in the dataset here:
|
||||
@ -57,7 +53,6 @@ Distributed across 8 GPUs, the training takes approximately 1 hour.
|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward-Math-Sheperd) performs, you can use the following script.
|
||||
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import pipeline
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Quickstart
|
||||
|
||||
TRL is a comprehensive library for post-training foundation models using techniques like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO).
|
||||
TRL is a comprehensive library for post-training foundation models using techniques like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO).
|
||||
|
||||
## Quick Examples
|
||||
|
||||
@ -51,6 +51,21 @@ trainer = DPOTrainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### Reward Modeling
|
||||
|
||||
```python
|
||||
from trl import RewardTrainer
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
trainer = RewardTrainer(
|
||||
model="Qwen/Qwen2.5-0.5B-Instruct",
|
||||
train_dataset=dataset,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Command Line Interface
|
||||
|
||||
Skip the code entirely - train directly from your terminal:
|
||||
@ -63,6 +78,10 @@ trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
# DPO: Align with preferences
|
||||
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--dataset_name trl-lib/ultrafeedback_binarized
|
||||
|
||||
# Reward: Train a reward model
|
||||
trl reward --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--dataset_name trl-lib/ultrafeedback_binarized
|
||||
```
|
||||
|
||||
## What's Next?
|
||||
@ -72,7 +91,6 @@ trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
- [SFT Trainer](sft_trainer) - Complete SFT guide
|
||||
- [DPO Trainer](dpo_trainer) - Preference alignment
|
||||
- [GRPO Trainer](grpo_trainer) - Group relative policy optimization
|
||||
- [Training FAQ](how_to_train) - Common questions
|
||||
|
||||
### 🚀 Scale Up
|
||||
|
||||
@ -122,4 +140,4 @@ Try adjusting the learning rate:
|
||||
training_args = SFTConfig(learning_rate=2e-5) # Good starting point
|
||||
```
|
||||
|
||||
For more help, see our [Training FAQ](how_to_train) or open an [issue on GitHub](https://github.com/huggingface/trl/issues).
|
||||
For more help, open an [issue on GitHub](https://github.com/huggingface/trl/issues).
|
||||
|
@ -1,18 +1,13 @@
|
||||
# Reducing Memory Usage
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Section under construction. Feel free to contribute!
|
||||
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> Section under construction. Feel free to contribute!
|
||||
|
||||
## Truncation
|
||||
|
||||
Sequence lengths in the dataset can vary widely. When data is batched, sequences are padded to match the longest one in the batch, which can cause high memory usage, even if most sequences are relatively short.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/why_you_should_truncate.png" alt="Truncation prompt-completion" width="600"/>
|
||||
</div>
|
||||

|
||||
|
||||
To reduce memory usage, it's important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case.
|
||||
|
||||
@ -21,9 +16,7 @@ To reduce memory usage, it's important to truncate sequences to a reasonable len
|
||||
|
||||
DPO truncation is applied first to the prompt and to the completion via the `max_prompt_length` and `max_completion_length` parameters. The `max_length` parameter is then used to truncate the resulting sequence.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/truncation_prompt_completion.png" alt="Truncation prompt-completion" width="600"/>
|
||||
</div>
|
||||

|
||||
|
||||
To set the truncation parameters, use the following code snippet:
|
||||
|
||||
@ -46,9 +39,7 @@ training_args = DPOConfig(..., max_completion_length=...)
|
||||
|
||||
SFT truncation is applied to the input sequence via the `max_length` parameter.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/truncation_input_ids.png" alt="Truncation input ids" width="600"/>
|
||||
</div>
|
||||

|
||||
|
||||
To set the truncation parameter, use the following code snippet:
|
||||
|
||||
@ -71,30 +62,22 @@ To help you choose an appropriate value, we provide a utility to visualize the s
|
||||
|
||||
## Packing
|
||||
|
||||
<Tip>
|
||||
|
||||
This technique applies only to SFT.
|
||||
|
||||
</Tip>
|
||||
|
||||
> [!TIP]
|
||||
> This technique applies only to SFT.
|
||||
|
||||
[Truncation](#truncation) has several drawbacks:
|
||||
|
||||
1. **Loss of information**: Key data at the end of a sequence may be discarded.
|
||||
2. **Choosing truncation length**: Too short loses data; too long undermines efficiency.
|
||||
|
||||
Packing, introduced in [Raffel et al., 2020](https://huggingface.co/papers/1910.10683), addresses these issues by grouping sequences instead of truncating. It concatenates and splits dataset sequences into the desired lengths.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/packing_2.png" alt="Packing" width="600"/>
|
||||
</div>
|
||||

|
||||
|
||||
Packing reduces padding by merging several sequences in one row when possible. We use an advanced method to be near-optimal in the way we pack the dataset. To enable packing, use `packing=True` in the [`SFTConfig`].
|
||||
|
||||
<Tip>
|
||||
|
||||
In TRL 0.18 and earlier, packing used a more aggressive method that reduced padding to almost nothing, but had the downside of breaking sequence continuity for a large fraction of the dataset. To revert to this strategy, use `packing_strategy="wrapped"` in `SFTConfig`.
|
||||
|
||||
</Tip>
|
||||
> [!TIP]
|
||||
> In TRL 0.18 and earlier, packing used a more aggressive method that reduced padding to almost nothing, but had the downside of breaking sequence continuity for a large fraction of the dataset. To revert to this strategy, use `packing_strategy="wrapped"` in [`SFTConfig`].
|
||||
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
@ -102,11 +85,8 @@ from trl import SFTConfig
|
||||
training_args = SFTConfig(..., packing=True, max_length=512)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Packing may cause batch contamination, where adjacent sequences influence one another. This can be problematic for some applications. For more details, see [#1230](https://github.com/huggingface/trl/issues/1230).
|
||||
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> Packing may cause batch contamination, where adjacent sequences influence one another. This can be problematic for some applications. For more details, see [#1230](https://github.com/huggingface/trl/issues/1230).
|
||||
|
||||
## Liger for reducing peak memory usage
|
||||
|
||||
@ -154,15 +134,10 @@ training_args = KTOConfig(..., use_liger_loss=True)
|
||||
|
||||
Padding-free batching is an alternative approach for reducing memory usage. In this method, a batch is first sampled and then flattened into a single sequence, avoiding padding. Unlike packing, which can result in incomplete sequences by combining parts of different samples, padding-free batching ensures that all sequences remain complete and intact.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/padding-free.png" alt="Padding-free batching" width="600"/>
|
||||
</div>
|
||||

|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
It's highly recommended to use padding-free batching with **FlashAttention 2** or **FlashAttention 3**. Otherwise, you may encounter batch contamination issues.
|
||||
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> It's highly recommended to use padding-free batching with **FlashAttention 2** or **FlashAttention 3**. Otherwise, you may encounter batch contamination issues.
|
||||
|
||||
<hfoptions id="padding-free">
|
||||
<hfoption id="DPO">
|
||||
@ -197,21 +172,19 @@ from trl import SFTConfig
|
||||
training_args = SFTConfig(..., activation_offloading=True)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
When using activation offloading with models that use Liger kernels, you must disable Liger cross entropy due to compatibility issues. The issue occurs specifically with `use_liger_kernel=True` because Liger cross entropy performs in-place operations which conflict with activation offloading. The default setting (`use_liger_kernel=False`) works:
|
||||
|
||||
```python
|
||||
# When using activation offloading with a model that uses Liger kernels:
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(
|
||||
activation_offloading=True,
|
||||
use_liger_kernel=False, # Disable Liger cross entropy
|
||||
# Other parameters...
|
||||
)
|
||||
```
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> When using activation offloading with models that use Liger kernels, you must disable Liger cross entropy due to compatibility issues. The issue occurs specifically with `use_liger_kernel=True` because Liger cross entropy performs in-place operations which conflict with activation offloading. The default setting (`use_liger_kernel=False`) works:
|
||||
>
|
||||
> ```python
|
||||
> # When using activation offloading with a model that uses Liger kernels:
|
||||
> from trl import SFTConfig
|
||||
>
|
||||
> training_args = SFTConfig(
|
||||
> activation_offloading=True,
|
||||
> use_liger_kernel=False, # Disable Liger cross entropy
|
||||
> # Other parameters...
|
||||
> )
|
||||
> ```
|
||||
|
||||
Under the hood, activation offloading implements PyTorch's [`saved_tensors_hooks`](https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html#hooks-for-autograd-saved-tensors) to intercept activations during the forward pass. It intelligently manages which tensors to offload based on size and context, avoiding offloading output tensors which would be inefficient. For performance optimization, it can optionally use CUDA streams to overlap computation with CPU-GPU transfers.
|
||||
|
||||
@ -262,107 +235,6 @@ training_args = RLOOConfig(..., ds3_gather_for_generation=False)
|
||||
|
||||
This adjustment prevents model weights from being gathered, avoiding OOM errors, but it may result in slower generation speeds.
|
||||
|
||||
## Context Parallelism
|
||||
|
||||
Context Parallelism (CP) is a parallelization technique that enables training with longer sequences by splitting the sequence dimension across multiple GPUs. Each GPU processes a portion of the sequence, allowing you to train with sequences longer than what would fit on a single GPU's memory.
|
||||
|
||||
For more details on CP, see the [Ultrascale Playbook - Context Parallelism](https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=context_parallelism).
|
||||
|
||||
CP is particularly useful when:
|
||||
|
||||
- You want to train with very long sequences (>32k tokens)
|
||||
- Single GPU memory is insufficient for your desired sequence length
|
||||
- You need to maintain sequence coherence across the full context
|
||||
|
||||
### Requirements and Limitations
|
||||
|
||||
CP has specific requirements:
|
||||
|
||||
1. **Accelerate 1.10 or higher** is required
|
||||
2. **FSDP2 (PyTorch FSDP v2)** is required as the distributed training backend
|
||||
3. **SDPA attention** - Flash Attention is currently not supported with CP
|
||||
4. **Sequence length divisibility** - sequences must be divisible by `cp_size * 2`. This is now automatically handled using the `pad_to_multiple_of` parameter in the data collator, which works seamlessly with both standard and padding-free modes.
|
||||
|
||||
### Configuration
|
||||
|
||||
To enable CP, you need to configure both Accelerate and your training arguments:
|
||||
|
||||
#### Accelerate Configuration
|
||||
|
||||
Use one of the provided accelerate config files (e.g. `fsdp_context_parallel_2gpu.yaml` for 2 GPUs):
|
||||
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
fsdp_config:
|
||||
fsdp_activation_checkpointing: false
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_version: 2
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 2 # Number of GPUs
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
parallelism_config:
|
||||
parallelism_config_dp_replicate_size: 1
|
||||
parallelism_config_dp_shard_size: 1
|
||||
parallelism_config_tp_size: 1
|
||||
parallelism_config_cp_size: 2 # Context parallel size
|
||||
```
|
||||
|
||||
#### Training Configuration
|
||||
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(
|
||||
# required
|
||||
pad_to_multiple_of=4, # ensures divisibility by cp_size * 2
|
||||
# to get the most out of CP
|
||||
max_length=16384, # long sequence length
|
||||
packing=True, # use packing to reduce padding
|
||||
use_liger_kernel=True, # compatible with CP
|
||||
per_device_train_batch_size=1,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
Then, launch your training script with the appropriate accelerate config file:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file fsdp_context_parallel_2gpu.yaml train.py
|
||||
```
|
||||
|
||||
### Best Practices
|
||||
|
||||
1. **Use the `pad_to_multiple_of` parameter** - This is now the recommended way to ensure sequence length divisibility:
|
||||
- For `cp_size=2`: use `pad_to_multiple_of=4` (since `cp_size * 2 = 4`)
|
||||
- For `cp_size=4`: use `pad_to_multiple_of=8` (since `cp_size * 2 = 8`)
|
||||
- The data collator automatically pads sequences to the required multiple, ensuring compatibility with CP
|
||||
|
||||
2. **Use packing with padding** - The default BFD (Best Fit Decreasing) strategy works perfectly:
|
||||
- Preserves sequence boundaries and maintains training quality
|
||||
- Works seamlessly with both `padding_free=True` and standard padding modes
|
||||
|
||||
3. **Combine with other memory optimizations** like Liger kernels, bfloat16, and gradient checkpointing
|
||||
|
||||
4. **Start with smaller context parallel sizes** (2-4 GPUs) before scaling up
|
||||
|
||||
5. **Monitor memory usage** across all GPUs to ensure balanced workload
|
||||
|
||||
## vLLM sleep mode
|
||||
|
||||
When using vLLM as the generation backend, you can enable _sleep mode_ to offload vLLM parameters and cache to CPU RAM during the optimization step and reload them back to GPU VRAM when needed for weight synchronization and generation.
|
||||
@ -373,7 +245,7 @@ When using vLLM as the generation backend, you can enable _sleep mode_ to offloa
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(..., vllm_sleep_enabled=True)
|
||||
training_args = GRPOConfig(..., vllm_enable_sleep_mode=True)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
@ -382,7 +254,7 @@ training_args = GRPOConfig(..., vllm_sleep_enabled=True)
|
||||
```python
|
||||
from trl import RLOOConfig
|
||||
|
||||
training_args = RLOOConfig(..., vllm_sleep_enabled=True)
|
||||
training_args = RLOOConfig(..., vllm_enable_sleep_mode=True)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
|
@ -1,85 +1,226 @@
|
||||
# Reward Modeling
|
||||
|
||||
[](https://huggingface.co/models?other=reward-trainer,trl)
|
||||
[](https://huggingface.co/models?other=reward-trainer,trl)
|
||||
|
||||
TRL supports custom reward modeling for anyone to perform reward modeling on their dataset and model.
|
||||
## Overview
|
||||
|
||||
Check out a complete flexible example at [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py).
|
||||
TRL supports the Outcome-supervised Reward Modeling (ORM) Trainer for training reward models.
|
||||
|
||||
## Expected dataset type
|
||||
This post-training method was contributed by [Younes Belkada](https://huggingface.co/ybelkada).
|
||||
|
||||
The [`RewardTrainer`] requires a [*implicit prompt* preference dataset](dataset_formats#preference). It means that the dataset should only contain the columns `"chosen"` and `"rejected"` (and not `"prompt"`).
|
||||
The [`RewardTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
## Quick start
|
||||
|
||||
You can also use a pretokenized dataset, in which case the dataset should contain the following columns: `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`.
|
||||
|
||||
## Using the `RewardTrainer`
|
||||
|
||||
After preparing your dataset, you can use the [`RewardTrainer`] in the same way as the `Trainer` class from 🤗 Transformers.
|
||||
You should pass an `AutoModelForSequenceClassification` model to the [`RewardTrainer`], along with a [`RewardConfig`] which configures the hyperparameters of the training.
|
||||
|
||||
### Leveraging 🤗 PEFT to train a reward model
|
||||
|
||||
Just pass a `peft_config` in the keyword arguments of [`RewardTrainer`], and the trainer should automatically take care of converting the model into a PEFT model!
|
||||
This example demonstrates how to train a reward model using the [`RewardTrainer`] from TRL. We train a [Qwen 3 0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) model on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), large-scale, fine-grained, diverse preference dataset.
|
||||
|
||||
```python
|
||||
from peft import LoraConfig, TaskType
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from trl import RewardTrainer, RewardConfig
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained("gpt2")
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.SEQ_CLS,
|
||||
inference_mode=False,
|
||||
r=8,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.1,
|
||||
)
|
||||
|
||||
...
|
||||
from trl import RewardTrainer
|
||||
from datasets import load_dataset
|
||||
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
processing_class=tokenizer,
|
||||
model="Qwen/Qwen3-0.6B",
|
||||
train_dataset=load_dataset("trl-lib/ultrafeedback_binarized", split="train"),
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
<iframe src="https://trl-lib-trackio.hf.space/?project=trl-documentation&metrics=train*&sidebar=hidden&runs=reward_qwen3-0.6B_ultrafeedback2" style="width: 100%; min-width: 300px; max-width: 800px;" height="830" frameBorder="0"></iframe>
|
||||
|
||||
## Expected dataset type and format
|
||||
|
||||
[`RewardTrainer`] supports [preference](dataset_formats#preference) datasets type (both implicit and explicit prompt). The [`RewardTrainer`] is compatible with both [standard](dataset_formats#standard) and [conversational](dataset_formats#conversational) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
|
||||
```python
|
||||
# Standard preference (implicit prompt)
|
||||
{"chosen": "The sky is blue.",
|
||||
"rejected": "The sky is green."}
|
||||
|
||||
# Conversational preference (implicit prompt)
|
||||
{"chosen": [{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."}],
|
||||
"rejected": [{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is green."}]}
|
||||
|
||||
# Standard preference (explicit prompt)
|
||||
{"prompt": "The sky is",
|
||||
"chosen": " blue.",
|
||||
"rejected": " green."}
|
||||
|
||||
# Conversational preference (explicit prompt)
|
||||
{"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"chosen": [{"role": "assistant", "content": "It is blue."}],
|
||||
"rejected": [{"role": "assistant", "content": "It is green."}]}
|
||||
```
|
||||
|
||||
If your dataset is not in one of these formats, you can preprocess it to convert it into the expected format. Here is an example with the [lmarena-ai/arena-human-preference-55k](https://huggingface.co/datasets/lmarena-ai/arena-human-preference-55k) dataset:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
import json
|
||||
|
||||
dataset = load_dataset("lmarena-ai/arena-human-preference-55k")
|
||||
|
||||
# Filter out ties
|
||||
dataset = dataset.filter(lambda example: example["winner_tie"] == 0)
|
||||
|
||||
# Create 'chosen' and 'rejected' fields based on the winner column
|
||||
def response_a_b_to_chosen_rejected(example):
|
||||
if example["winner_model_a"] == 1:
|
||||
example["chosen"] = example["response_a"]
|
||||
example["rejected"] = example["response_b"]
|
||||
else:
|
||||
example["chosen"] = example["response_b"]
|
||||
example["rejected"] = example["response_a"]
|
||||
return example
|
||||
|
||||
dataset = dataset.map(response_a_b_to_chosen_rejected)
|
||||
|
||||
# Convert to conversational format
|
||||
def make_conversation(example):
|
||||
prompt = json.loads(example["prompt"])[0] # '["What color is the sky?"]' -> "What color is the sky?"
|
||||
chosen = json.loads(example["chosen"])[0]
|
||||
rejected = json.loads(example["rejected"])[0]
|
||||
return {
|
||||
"chosen": [{"role": "user", "content": prompt}, {"role": "assistant", "content": chosen}],
|
||||
"rejected": [{"role": "user", "content": prompt}, {"role": "assistant", "content": rejected}],
|
||||
}
|
||||
|
||||
|
||||
dataset = dataset.map(make_conversation)
|
||||
|
||||
# Keep only necessary columns
|
||||
dataset = dataset.select_columns(["chosen", "rejected"])
|
||||
|
||||
print(next(iter(dataset["train"])))
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"chosen": [
|
||||
{"role": "user", "content": "Is it morally right to try to have a certain percentage of females on managerial positions?"},
|
||||
{"role": "assistant", "content": "The question of whether it is morally right to aim for a certain percentage of females..."},
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "Is it morally right to try to have a certain percentage of females on managerial positions?"},
|
||||
{"role": "assistant", "content": "As an AI, I don't have personal beliefs or opinions. However, ..."},
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
## Looking deeper into the training method
|
||||
|
||||
Reward Models (RMs) are typically trained using supervised learning on datasets containing pairs of preferred and non-preferred responses. The goal is to learn a function that assigns higher scores to preferred responses, enabling the model to rank outputs based on preferences.
|
||||
|
||||
This section breaks down how reward modeling works in practice, covering the key steps: **preprocessing** and **loss computation**.
|
||||
|
||||
### Preprocessing and tokenization
|
||||
|
||||
During training, each example is expected to contain a **chosen** and **rejected** field. For more details on the expected formats, see [Dataset formats - Preference](dataset_formats#preference).
|
||||
The [`RewardTrainer`] tokenizes each input using the model's tokenizer. If prompts and completions (chosen and rejected) are provided separately (explicit prompt case), they are concatenated before tokenization.
|
||||
|
||||
### Computing the loss
|
||||
|
||||
Let \\( x \\) be the input sequence (prompt) and \\( y^+ \\) and \\( y^- \\) be the chosen and rejected sequences respectively. Under the Bradley-Terry model ([Bradley & Terry, 1952](https://www.jstor.org/stable/2334029)), the probability that \\( y^+ \\) is preferred over \\( y^- \\) given a reward function \\( r \\) is \\( p(y^+ ≻ y^- |x) = \sigma(r(x, y^+)−r(x, y^-)) \\), where \\( σ \\) is the sigmoid function.
|
||||
|
||||
The reward model \\( r_\theta(x, y) \\) is trained to assign higher scores to preferred responses \\( y^+ \\) over non-preferred ones \\( y^- \\). The loss is then defined as the negative log-likelihood of the observed preferences:
|
||||
|
||||
$$
|
||||
\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-)) \right].
|
||||
$$
|
||||
|
||||
> [!TIP]
|
||||
> The Bradley-Terry model is underdetermined, meaning that adding a constant to all rewards does not change the preference probabilities. To address this, [Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking](https://huggingface.co/papers/2312.09244) proposes adding an auxiliary loss term that encourages the rewards to be centered around zero. This is controlled by the `center_rewards_coefficient` parameter in the [`RewardConfig`]. The recommended value is `1e-2`.
|
||||
|
||||
## Logged metrics
|
||||
|
||||
While training and evaluating we record the following reward metrics:
|
||||
|
||||
* `global_step`: The total number of optimizer steps taken so far.
|
||||
* `epoch`: The current epoch number, based on dataset iteration.
|
||||
* `num_tokens`: The total number of tokens processed so far.
|
||||
* `loss`: The average loss over the last logging interval.
|
||||
* `accuracy`: The proportion of correct predictions (i.e., the model assigned a higher score to the chosen response than to the rejected one) averaged over the last logging interval.
|
||||
* `min_reward`: The minimum reward score assigned by the model. This value is averaged over the logging interval.
|
||||
* `mean_reward`: The average reward score assigned by the model over the last logging interval.
|
||||
* `max_reward`: The maximum reward score assigned by the model. This value is averaged over the logging interval.
|
||||
* `margin`: The average margin (difference between chosen and rejected rewards) over the last logging interval.
|
||||
* `learning_rate`: The current learning rate, which may change dynamically if a scheduler is used.
|
||||
* `grad_norm`: The L2 norm of the gradients, computed before gradient clipping.
|
||||
|
||||
## Customization
|
||||
|
||||
### Model initialization
|
||||
|
||||
You can directly pass the kwargs of the [`~transformers.AutoModelForSequenceClassification.from_pretrained()`] method to the [`RewardConfig`]. For example, if you want to load a model in a different precision, analogous to
|
||||
|
||||
```python
|
||||
model = AutoModelForSequenceClassification.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
you can do so by passing the `model_init_kwargs={"dtype": torch.bfloat16}` argument to the [`RewardConfig`].
|
||||
|
||||
```python
|
||||
from trl import RewardConfig
|
||||
|
||||
training_args = RewardConfig(
|
||||
model_init_kwargs={"dtype": torch.bfloat16},
|
||||
)
|
||||
```
|
||||
|
||||
Note that all keyword arguments of [`~transformers.AutoModelForSequenceClassification.from_pretrained()`] are supported, except for `num_labels`, which is automatically set to 1.
|
||||
|
||||
### Train adapters with PEFT
|
||||
|
||||
We support tight integration with 🤗 PEFT library, allowing any user to conveniently train adapters and share them on the Hub, rather than training the entire model.
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import RewardTrainer
|
||||
from peft import LoraConfig
|
||||
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
|
||||
trainer = RewardTrainer(
|
||||
"Qwen/Qwen3-4B",
|
||||
train_dataset=dataset,
|
||||
peft_config=peft_config,
|
||||
peft_config=LoraConfig(modules_to_save=["score"]) # important to include the score head when base model is not a sequence classification model
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
```
|
||||
|
||||
### Adding a margin to the loss
|
||||
|
||||
As in the [Llama 2 paper](https://huggingface.co/papers/2307.09288), you can add a margin to the loss by adding a `margin` column to the dataset. The reward collator will automatically pass it through and the loss will be computed accordingly.
|
||||
You can also continue training your [`~peft.PeftModel`]. For that, first load a `PeftModel` outside [`RewardTrainer`] and pass it directly to the trainer without the `peft_config` argument being passed.
|
||||
|
||||
```python
|
||||
def add_margin(row):
|
||||
# Assume you have a score_chosen and score_rejected columns that you want to use to compute the margin
|
||||
return {'margin': row['score_chosen'] - row['score_rejected']}
|
||||
from datasets import load_dataset
|
||||
from trl import RewardTrainer
|
||||
from peft import AutoPeftModelForCausalLM
|
||||
|
||||
dataset = dataset.map(add_margin)
|
||||
```
|
||||
model = AutoPeftModelForCausalLM.from_pretrained("trl-lib/Qwen3-4B-Reward-LoRA", is_trainable=True)
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
|
||||
### Centering rewards
|
||||
|
||||
In many scenarios, it's preferable to ensure that a reward model's output is mean zero. This is often done by first calculating the model's average score and then subtracting it.
|
||||
|
||||
[[Eisenstein et al., 2023]](https://huggingface.co/papers/2312.09244) proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs:
|
||||
|
||||
$$\Big( R(p, r_1) + R(p, r_2) \Big)^2 $$
|
||||
|
||||
This auxiliary loss is combined with the main loss function, weighted by the parameter `center_rewards_coefficient` in the `[RewardConfig]`. By default, this feature is deactivated (`center_rewards_coefficient = None`).
|
||||
|
||||
```python
|
||||
training_args = RewardConfig(
|
||||
center_rewards_coefficient=0.01,
|
||||
...
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
For reference results, please refer PR [#1932](https://github.com/huggingface/trl/pull/1932).
|
||||
> [!TIP]
|
||||
> When training adapters, you typically use a higher learning rate (≈1e‑3) since only new parameters are being learned.
|
||||
>
|
||||
> ```python
|
||||
> RewardConfig(learning_rate=1e-3, ...)
|
||||
> ```
|
||||
|
||||
## Tool Calling with Reward Modeling
|
||||
|
||||
The [`RewardTrainer`] fully supports fine-tuning models with _tool calling_ capabilities. In this case, each dataset example should include:
|
||||
|
||||
* The conversation messages, including any tool calls (`tool_calls`) and tool responses (`tool` role messages)
|
||||
* The list of available tools in the `tools` column, typically provided as JSON schemas
|
||||
|
||||
For details on the expected dataset structure, see the [Dataset Format — Tool Calling](dataset_formats#tool-calling) section.
|
||||
|
||||
## RewardTrainer
|
||||
|
||||
@ -91,3 +232,7 @@ For reference results, please refer PR [#1932](https://github.com/huggingface/tr
|
||||
## RewardConfig
|
||||
|
||||
[[autodoc]] RewardConfig
|
||||
|
||||
## DataCollatoForPreference
|
||||
|
||||
[[autodoc]] trainer.reward_trainer.DataCollatorForPreference
|
||||
|
@ -2,14 +2,14 @@
|
||||
|
||||
This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`] and [`RLOOTrainer`].
|
||||
|
||||
## Format rewards
|
||||
## accuracy_reward
|
||||
|
||||
### think_format_reward
|
||||
[[autodoc]] rewards.accuracy_reward
|
||||
|
||||
## think_format_reward
|
||||
|
||||
[[autodoc]] rewards.think_format_reward
|
||||
|
||||
## Other rewards
|
||||
|
||||
### get_soft_overlong_punishment
|
||||
## get_soft_overlong_punishment
|
||||
|
||||
[[autodoc]] rewards.get_soft_overlong_punishment
|
||||
|
@ -1,6 +1,6 @@
|
||||
# RLOO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=rloo,trl)
|
||||
[](https://huggingface.co/models?other=rloo,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
@ -70,7 +70,7 @@ At each training step, we sample a batch of prompts and generate a set of \\( G
|
||||
In RLOO, the reward consists of two components: the reward provided by the reward model (or reward function) and a KL penalty that discourages the policy from deviating too far from a fixed reference policy
|
||||
|
||||
1. For each of the \\( G \\) generated sequences \\( o_i = (o_{i,1}, \dots, o_{i,T}) \\) conditioned on a query \\( q \\), we compute a scalar reward using a reward model \\( R(o_i, q) \\).
|
||||
2. Concurenlty, we estimate the KL divergence between the current policy \\( \pi_\theta \\) and the fixed reference policy \\( \pi_{\text{ref}} \\) over the sequence. The KL estimate for sequence \\( o_i \\) is:
|
||||
2. Concurrently, we estimate the KL divergence between the current policy \\( \pi_\theta \\) and the fixed reference policy \\( \pi_{\text{ref}} \\) over the sequence. The KL estimate for sequence \\( o_i \\) is:
|
||||
|
||||
$$
|
||||
\mathbb{D}_{\mathrm{KL}}\!\left[\pi_\theta\|\pi_{\mathrm{ref}}\right] = \sum_{t=1}^T \log \frac{\pi_\theta(o_{i,t} \mid q, o_{i,<t})}{\pi_{\mathrm{ref}}(o_{i,t} \mid q, o_{i,<t})}.
|
||||
@ -84,34 +84,30 @@ $$
|
||||
|
||||
where \\( \beta > 0 \\) controls the strength of the KL penalty.
|
||||
|
||||
<Tip>
|
||||
|
||||
In a purely online setting (`num_iterations = 1`, default), the data are generated by the current policy. In this case, the KL penalty is computed directly using the current policy.
|
||||
|
||||
In the more general setting (e.g., multiple gradient steps per batch), the data are instead generated by an earlier snapshot \\( \pi_{\text{old}} \\). To keep the penalty consistent with the sampling distribution, the KL is defined with respect to this policy:
|
||||
|
||||
$$
|
||||
\mathbb{D}_{\mathrm{KL}}\!\left[\pi_{\text{old}} \,\|\, \pi_{\text{ref}}\right].
|
||||
$$
|
||||
|
||||
Equivalently, for a sampled sequence $o$, the Monte Carlo estimate is
|
||||
|
||||
$$
|
||||
\mathbb{D}_{\mathrm{KL}}\!\left[\pi_{\text{old}} \|\pi_{\mathrm{ref}}\right] = \sum_{t=1}^T \log \frac{\pi_{\text{old}}(o_{i,t} \mid q, o_{i,<t})}{\pi_{\mathrm{ref}}(o_{i,t} \mid q, o_{i,<t})}.
|
||||
$$
|
||||
|
||||
</Tip>
|
||||
> [!TIP]
|
||||
> In a purely online setting (`num_iterations = 1`, default), the data are generated by the current policy. In this case, the KL penalty is computed directly using the current policy.
|
||||
>
|
||||
> In the more general setting (e.g., multiple gradient steps per batch), the data are instead generated by an earlier snapshot \\( \pi_{\text{old}} \\). To keep the penalty consistent with the sampling distribution, the KL is defined with respect to this policy:
|
||||
>
|
||||
> $$
|
||||
> \mathbb{D}_{\mathrm{KL}}\!\left[\pi_{\text{old}} \,\|\, \pi_{\text{ref}}\right].
|
||||
> $$
|
||||
>
|
||||
> Equivalently, for a sampled sequence $o$, the Monte Carlo estimate is
|
||||
>
|
||||
> $$
|
||||
> \mathbb{D}_{\mathrm{KL}}\!\left[\pi_{\text{old}} \|\pi_{\mathrm{ref}}\right] = \sum_{t=1}^T \log \frac{\pi_{\text{old}}(o_{i,t} \mid q, o_{i,<t})}{\pi_{\mathrm{ref}}(o_{i,t} \mid q, o_{i,<t})}.
|
||||
> $$
|
||||
|
||||
### Computing the advantage
|
||||
|
||||
Once the rewards for each completion have been computed, we calculate a baseline as the average reward of all other samples in the same batch, excluding the current sample. This baseline is used to reduce the variance of the policy gradient estimate. The advantage for each completion is then obtained as the difference between its own reward and this leave-one-out baseline.
|
||||
Once the rewards for each completion have been computed, we calculate a baseline as the average reward of all other samples in the same batch, excluding the current sample. This baseline is used to reduce the variance of the policy gradient estimate. The advantage for each completion is then obtained as the difference between its own reward and this leave-one-out baseline.
|
||||
|
||||
Formally, for a batch of G completions, the baseline for completion is:
|
||||
$$
|
||||
b_i = \frac{1}{G-1} \sum_{j \neq i} r_j
|
||||
$$
|
||||
|
||||
|
||||
and then the advantage for each completion is computed as the difference between its reward and the baseline:
|
||||
|
||||
$$
|
||||
@ -145,7 +141,7 @@ While training and evaluating, we record the following reward metrics:
|
||||
- `completions/mean_terminated_length`: The average length of generated completions that terminate with EOS.
|
||||
- `completions/min_terminated_length`: The minimum length of generated completions that terminate with EOS.
|
||||
- `completions/max_terminated_length`: The maximum length of generated completions that terminate with EOS.
|
||||
- `completions/clipped_ratio` : The ratio of truncated (clipped) completions.
|
||||
- `completions/clipped_ratio`: The ratio of truncated (clipped) completions.
|
||||
- `reward/{reward_func_name}/mean`: The average reward from a specific reward function.
|
||||
- `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function.
|
||||
- `reward`: The overall average reward after applying reward weights.
|
||||
@ -154,9 +150,9 @@ While training and evaluating, we record the following reward metrics:
|
||||
- `entropy`: Average entropy of token predictions across generated completions. (If `mask_truncated_completions=True`, masked sequences tokens are excluded.)
|
||||
- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero.
|
||||
- `clip_ratio/region_mean`: The ratio of sequence probabilities where the RLOO objective is clipped to stay within the trust region:
|
||||
$$
|
||||
\text{clip}\left( r_{i}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \qquad r_{i}(\theta) = \frac{\pi_\theta(o_{i} \mid q)}{\pi_{\theta_{\text{old}}}(o_{i} \mid q)}\,.
|
||||
$$
|
||||
$$
|
||||
\text{clip}\left( r_{i}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \qquad r_{i}(\theta) = \frac{\pi_\theta(o_{i} \mid q)}{\pi_{\theta_{\text{old}}}(o_{i} \mid q)}\,.
|
||||
$$
|
||||
|
||||
A higher value means more samples are clipped, which constrains how much the policy $\pi_\theta$ can change.
|
||||
- `clip_ratio/low_mean`: The average ratio of sequence probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
|
||||
@ -169,6 +165,7 @@ $$
|
||||
### Speed up training with vLLM-powered generation
|
||||
|
||||
Generation is often the main bottleneck when training with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a high-throughput, low-latency inference engine for LLMs. To enable it, first install the package with
|
||||
|
||||
```shell
|
||||
pip install trl[vllm]
|
||||
```
|
||||
@ -180,11 +177,13 @@ We support two ways of using vLLM during training: **server mode** and **colocat
|
||||
In this mode, vLLM runs in a separate process (and using separate GPUs) and communicates with the trainer via HTTP. This is ideal if you have dedicated GPUs for inference.
|
||||
|
||||
1. **Start the vLLM server**:
|
||||
|
||||
```bash
|
||||
trl vllm-serve --model <model_name>
|
||||
```
|
||||
|
||||
2. **Enable server mode in your training script**:
|
||||
|
||||
```python
|
||||
from trl import RLOOConfig
|
||||
|
||||
@ -195,11 +194,8 @@ In this mode, vLLM runs in a separate process (and using separate GPUs) and comm
|
||||
)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure that the server is using different GPUs than the trainer, otherwise you may run into NCCL errors. You can specify the GPUs to use with the `CUDA_VISIBLE_DEVICES` environment variable.
|
||||
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> Make sure that the server is using different GPUs than the trainer, otherwise you may run into NCCL errors. You can specify the GPUs to use with the `CUDA_VISIBLE_DEVICES` environment variable.
|
||||
|
||||
#### 🧩 Option 2: Colocate mode
|
||||
|
||||
@ -215,28 +211,19 @@ training_args = RLOOConfig(
|
||||
)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
> [!TIP]
|
||||
> Depending on the model size and the overall GPU memory requirements for training, you may need to adjust the `vllm_gpu_memory_utilization` parameter in [`RLOOConfig`] to avoid underutilization or out-of-memory errors.
|
||||
>
|
||||
> We provide a [HF Space](https://huggingface.co/spaces/trl-lib/recommend-vllm-memory) to help estimate the recommended GPU memory utilization based on your model configuration and experiment settings. Simply use it as follows to get `vllm_gpu_memory_utilization` recommendation:
|
||||
>
|
||||
> <iframe src="https://trl-lib-recommend-vllm-memory.hf.space" frameborder="0" width="850" height="450"></iframe>
|
||||
>
|
||||
> If the recommended value does not work in your environment, we suggest adding a small buffer (e.g., +0.05 or +0.1) to the recommended value to ensure stability.
|
||||
>
|
||||
> If you still find you are getting out-of-memory errors set `vllm_enable_sleep_mode` to True and the vllm parameters and cache will be offloaded during the optimization step. For more information, see [Reducing Memory Usage with vLLM Sleep Mode](reducing_memory_usage#vllm-sleep-mode).
|
||||
|
||||
Depending on the model size and the overall GPU memory requirements for training, you may need to adjust the `vllm_gpu_memory_utilization` parameter in [`RLOOConfig`] to avoid underutilization or out-of-memory errors.
|
||||
|
||||
We provide a [HF Space](https://huggingface.co/spaces/trl-lib/recommend-vllm-memory) to help estimate the recommended GPU memory utilization based on your model configuration and experiment settings. Simply use it as follows to get `vllm_gpu_memory_utilization` recommendation:
|
||||
|
||||
<iframe
|
||||
src="https://trl-lib-recommend-vllm-memory.hf.space"
|
||||
frameborder="0"
|
||||
width="850"
|
||||
height="450"
|
||||
></iframe>
|
||||
|
||||
If the recommended value does not work in your environment, we suggest adding a small buffer (e.g., +0.05 or +0.1) to the recommended value to ensure stability.
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
By default, RLOO uses `MASTER_ADDR=localhost` and `MASTER_PORT=12345` for vLLM, but you can override these values by setting the environment variables accordingly.
|
||||
|
||||
</Tip>
|
||||
> [!TIP]
|
||||
> By default, RLOO uses `MASTER_ADDR=localhost` and `MASTER_PORT=12345` for vLLM, but you can override these values by setting the environment variables accordingly.
|
||||
|
||||
For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).
|
||||
|
||||
@ -244,7 +231,7 @@ For more information, see [Speeding up training with vLLM](speeding_up_training#
|
||||
|
||||
When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include:
|
||||
|
||||
- **DeepSpeed ZeRO Stage 3**: ZeRO leverages data parallelism to distribute model states (weights, gradients, optimizer states) across multiple GPUs and CPUs, reducing memory and compute requirements on each device. Since large models cannot fit on a single GPU, using ZeRO Stage 3 is required for training such model. For more details, see [DeepSpeed Integration](deepspeed_integration).
|
||||
- **DeepSpeed ZeRO Stage 3**: ZeRO leverages data parallelism to distribute model states (weights, gradients, optimizer states) across multiple GPUs and CPUs, reducing memory and compute requirements on each device. Since large models cannot fit on a single GPU, using ZeRO Stage 3 is required for training such models. For more details, see [DeepSpeed Integration](deepspeed_integration).
|
||||
- **Accelerate**: Accelerate is a library that simplifies distributed training across multiple GPUs and nodes. It provides a simple API to launch distributed training and handles the complexities of distributed training, such as data parallelism, gradient accumulation, and distributed data loading. For more details, see [Distributing Training](distributing_training).
|
||||
- **vLLM**: See the previous section on how to use vLLM to speed up generation.
|
||||
|
||||
@ -323,7 +310,7 @@ The [`RLOOTrainer`] supports using custom reward functions instead of dense rewa
|
||||
- `completions` (contains the generated completions),
|
||||
- `completions_ids` (contains the tokenized completions),
|
||||
- `trainer_state` ([`~transformers.TrainerState`]): The current state of the trainer. This can be used to implement dynamic reward functions, such as curriculum learning, where the reward is adjusted based on the training progress.
|
||||
- All columns names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument.
|
||||
- All column names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument.
|
||||
|
||||
The easiest way to comply with this requirement is to use `**kwargs` in the function signature.
|
||||
- Depending on the dataset format, the input will vary:
|
||||
@ -352,7 +339,7 @@ You can test it as follows:
|
||||
[2.0, 4.0]
|
||||
```
|
||||
|
||||
#### Example 1.1: Reward longer completions (based in the number of characters)
|
||||
#### Example 1.1: Reward longer completions (based on the number of characters)
|
||||
|
||||
Same as the previous example, but this time the reward function is based on the number of characters instead of tokens.
|
||||
|
||||
@ -372,10 +359,10 @@ You can test it as follows:
|
||||
[6.0, 12.0]
|
||||
```
|
||||
|
||||
#### Example 2: Reward completions with specific format
|
||||
#### Example 2: Reward completions with a specific format
|
||||
|
||||
Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the _format reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948).
|
||||
It is designed for conversational format, where prompts and completions consist of structured messages.
|
||||
It is designed for a conversational format, where prompts and completions consist of structured messages.
|
||||
|
||||
```python
|
||||
import re
|
||||
@ -428,6 +415,7 @@ You can test this function as follows:
|
||||
>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth)
|
||||
[1.0, 0.0]
|
||||
```
|
||||
|
||||
#### Example 4: Multi-task reward functions
|
||||
|
||||
Below is an example of using multiple reward functions in the [`RLOOTrainer`]. In this example, we define two task-specific reward functions: `math_reward_func` and `coding_reward_func`. The `math_reward_func` rewards math problems based on their correctness, while the `coding_reward_func` rewards coding problems based on whether the solution works.
|
||||
@ -484,12 +472,10 @@ trainer = RLOOTrainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
In this example, the `math_reward_func` and `coding_reward_func` are designed to work with a mixed dataset that contains both math and coding problems. The `task` column in the dataset is used to determine which reward function to apply to each problem. If there is no relevant reward function for a sample in the dataset, the reward function will return `None` and the [`RLOOTrainer`] will continue with the valid functions and tasks. This allows the [`RLOOTrainer`] to handle multiple reward functions with different applicability.
|
||||
In this example, the `math_reward_func` and `coding_reward_func` are designed to work with a mixed dataset that contains both math and coding problems. The `task` column in the dataset is used to determine which reward function to apply to each problem. If there is no relevant reward function for a sample in the dataset, the reward function will return `None`, and the [`RLOOTrainer`] will continue with the valid functions and tasks. This allows the [`RLOOTrainer`] to handle multiple reward functions with different applicability.
|
||||
|
||||
Note that the [`RLOOTrainer`] will ignore the `None` rewards returned by the reward functions and only consider the rewards returned by the relevant functions. This ensures that the model is trained on the relevant tasks and ignores the tasks for which there is no relevant reward function.
|
||||
|
||||
|
||||
|
||||
#### Passing the reward function to the trainer
|
||||
|
||||
To use your custom reward function, pass it to the [`RLOOTrainer`] as follows:
|
||||
@ -518,6 +504,64 @@ and the reward will be computed as the sum of the rewards from each function, or
|
||||
|
||||
Note that [`RLOOTrainer`] supports multiple reward functions of different types. See the parameters documentation for more details.
|
||||
|
||||
## Vision-Language Model (VLM) Training
|
||||
|
||||
RLOO supports training Vision-Language Models (VLMs) on multimodal datasets containing both text and images.
|
||||
|
||||
### Supported Models
|
||||
|
||||
Tested with:
|
||||
|
||||
- **Gemma3** — e.g., `google/gemma-3-4b-it`
|
||||
- **LLaVA-NeXT** — e.g., `llava-hf/llava-v1.6-mistral-7b-hf`
|
||||
- **Qwen2-VL** — e.g., `Qwen/Qwen2-VL-2B-Instruct`
|
||||
- **Qwen2.5-VL** — e.g., `Qwen/Qwen2.5-VL-3B-Instruct`
|
||||
- **SmolVLM2** — e.g., `HuggingFaceTB/SmolVLM2-2.2B-Instruct`
|
||||
|
||||
> [!TIP]
|
||||
> Compatibility with all VLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes.
|
||||
|
||||
### Quick Start
|
||||
|
||||
Use [rloo\_vlm.py](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo_vlm.py) to fine-tune a VLM. Example command for training on [`lmms-lab/multimodal-open-r1-8k-verified`](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified):
|
||||
|
||||
```bash
|
||||
accelerate launch \
|
||||
--config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
examples/scripts/rloo_vlm.py \
|
||||
--model_name_or_path Qwen/Qwen2.5-VL-3B-Instruct \
|
||||
--output_dir rloo-Qwen2.5-VL-3B-Instruct \
|
||||
--learning_rate 1e-5 \
|
||||
--gradient_checkpointing \
|
||||
--dtype bfloat16 \
|
||||
--max_prompt_length 2048 \
|
||||
--max_completion_length 1024 \
|
||||
--use_vllm \
|
||||
--vllm_mode colocate \
|
||||
--use_peft \
|
||||
--lora_target_modules "q_proj", "v_proj" \
|
||||
--log_completions
|
||||
```
|
||||
|
||||
### Configuration Tips
|
||||
|
||||
> [!WARNING]
|
||||
> VLM training may fail if image tokens are truncated. We highly recommend disabling truncation by setting `max_prompt_length` to `None`.
|
||||
|
||||
- Use LoRA on vision-language projection layers
|
||||
- Enable 4-bit quantization to reduce memory usage
|
||||
- VLMs are memory-intensive — start with smaller batch sizes
|
||||
- Most models are compatible with vLLM (`server` and `colocate` modes)
|
||||
|
||||
### Dataset Format
|
||||
|
||||
Each training sample should include:
|
||||
|
||||
- `prompt`: Text formatted via the processor's chat template
|
||||
- `image`/`images`: PIL Image or list of PIL Images
|
||||
|
||||
The trainer automatically handles image-to-tensor conversion via the model’s image processor.
|
||||
|
||||
## RLOOTrainer
|
||||
|
||||
[[autodoc]] RLOOTrainer
|
||||
@ -540,7 +584,7 @@ Note that [`RLOOTrainer`] supports multiple reward functions of different types.
|
||||
|
||||
## Migration Guide from the old implementation (0.21 and below)
|
||||
|
||||
With the release of version 0.22.0, we have revamped the [`RLOOTrainer`] to be more alinged with other online trainers in the library like [`GRPOTrainer`]. This new implementation introduces several changes to the configuration parameters and overall structure of the trainer.
|
||||
With the release of version 0.22.0, we have revamped the [`RLOOTrainer`] to be more aligned with other online trainers in the library, like [`GRPOTrainer`]. This new implementation introduces several changes to the configuration parameters and overall structure of the trainer.
|
||||
Below is a summary of the key changes for [`RLOOConfig`]:
|
||||
|
||||
| TRL ≤ 0.21.x | TRL ≥ 0.22.0 |
|
||||
|
@ -4,15 +4,11 @@ The notebooks and scripts in these examples show how to fine-tune a model with a
|
||||
|
||||
Here's an overview of the notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples):
|
||||
|
||||
|
||||
|
||||
| File | Description |
|
||||
|------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------|
|
||||
| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb) | This script shows how to use the `PPOTrainer` to fine-tune a sentiment analysis model using IMDB dataset |
|
||||
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
|
||||
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook.
|
||||
|
||||
|
||||
| File | Description |
|
||||
| --- |--- |
|
||||
| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb) | This script shows how to use the `PPOTrainer` to fine-tune a sentiment analysis model using IMDB dataset |
|
||||
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
|
||||
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
|
||||
|
||||
## Usage
|
||||
|
||||
@ -30,7 +26,6 @@ python examples/scripts/ppo.py --log_with wandb --mini_batch_size 1 --gradient_a
|
||||
|
||||
Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
|
||||
|
||||
## Few notes on multi-GPU
|
||||
|
||||
## Few notes on multi-GPU
|
||||
|
||||
To run in multi-GPU setup with DDP (distributed Data Parallel) change the `device_map` value to `device_map={"": Accelerator().process_index}` and make sure to run your script with `accelerate launch yourscript.py`. If you want to apply naive pipeline parallelism you can use `device_map="auto"`.
|
||||
To run in multi-GPU setup with DDP (distributed Data Parallel) change the `device_map` value to `device_map={"": Accelerator().process_index}` and make sure to run your script with `accelerate launch yourscript.py`. If you want to apply naive pipeline parallelism you can use `device_map="auto"`.
|
||||
|
@ -23,7 +23,7 @@ trainer = SFTTrainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
<iframe src="https://trl-lib-trackio.hf.space/?project=trl-documentation&metrics=train/loss,train/mean_token_accuracy,train/num_tokens&sidebar=hidden" style="width: 100%; min-width: 300px; max-width: 800px;" height="830" frameBorder="0"></iframe>
|
||||
<iframe src="https://trl-lib-trackio.hf.space/?project=trl-documentation&metrics=train*&runs=sft_qwen3-0.6B_capybara" style="width: 100%; min-width: 300px; max-width: 800px;" height="830" frameBorder="0"></iframe>
|
||||
|
||||
## Expected dataset type and format
|
||||
|
||||
@ -105,11 +105,8 @@ $$
|
||||
|
||||
where \\( y_t \\) is the target token at timestep \\( t \\), and the model is trained to predict the next token given the previous ones. In practice, padding tokens are masked out during loss computation.
|
||||
|
||||
<Tip>
|
||||
|
||||
[On the Generalization of SFT: A Reinforcement Learning Perspective with Reward Rectification](https://huggingface.co/papers/2508.05629) proposes an alternative loss function, called **Dynamic Fine-Tuning (DFT)**, which aims to improve generalization by rectifying the reward signal. This method can be enabled by setting `loss_type="dft"` in the [`SFTConfig`]. For more details, see [Paper Index - Dynamic Fine-Tuning](paper_index#on-the-generalization-of-sft-a-reinforcement-learning-perspective-with-reward-rectification).
|
||||
|
||||
</Tip>
|
||||
> [!TIP]
|
||||
> [On the Generalization of SFT: A Reinforcement Learning Perspective with Reward Rectification](https://huggingface.co/papers/2508.05629) proposes an alternative loss function, called **Dynamic Fine-Tuning (DFT)**, which aims to improve generalization by rectifying the reward signal. This method can be enabled by setting `loss_type="dft"` in the [`SFTConfig`]. For more details, see [Paper Index - Dynamic Fine-Tuning](paper_index#on-the-generalization-of-sft-a-reinforcement-learning-perspective-with-reward-rectification).
|
||||
|
||||
### Label shifting and masking
|
||||
|
||||
@ -180,9 +177,8 @@ To train on completion only, use a [prompt-completion](dataset_formats#prompt-co
|
||||
|
||||

|
||||
|
||||
<Tip>
|
||||
Training on completion only is compatible with training on assistant messages only. In this case, use a [conversational](dataset_formats#conversational) [prompt-completion](dataset_formats#prompt-completion) dataset and set `assistant_only_loss=True` in the [`SFTConfig`].
|
||||
</Tip>
|
||||
> [!TIP]
|
||||
> Training on completion only is compatible with training on assistant messages only. In this case, use a [conversational](dataset_formats#conversational) [prompt-completion](dataset_formats#prompt-completion) dataset and set `assistant_only_loss=True` in the [`SFTConfig`].
|
||||
|
||||
### Train adapters with PEFT
|
||||
|
||||
@ -204,7 +200,7 @@ trainer = SFTTrainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
You can also continue training your [`peft.PeftModel`]. For that, first load a `PeftModel` outside [`SFTTrainer`] and pass it directly to the trainer without the `peft_config` argument being passed.
|
||||
You can also continue training your [`~peft.PeftModel`]. For that, first load a `PeftModel` outside [`SFTTrainer`] and pass it directly to the trainer without the `peft_config` argument being passed.
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
@ -222,15 +218,12 @@ trainer = SFTTrainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
When training adapters, you typically use a higher learning rate (≈1e‑4) since only new parameters are being learned.
|
||||
|
||||
```python
|
||||
SFTConfig(learning_rate=1e-4, ...)
|
||||
```
|
||||
|
||||
</Tip>
|
||||
> [!TIP]
|
||||
> When training adapters, you typically use a higher learning rate (≈1e‑4) since only new parameters are being learned.
|
||||
>
|
||||
> ```python
|
||||
> SFTConfig(learning_rate=1e-4, ...)
|
||||
> ```
|
||||
|
||||
### Train with Liger Kernel
|
||||
|
||||
@ -313,17 +306,14 @@ trainer = SFTTrainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
For VLMs, truncating may remove image tokens, leading to errors during training. To avoid this, set `max_length=None` in the [`SFTConfig`]. This allows the model to process the full sequence length without truncating image tokens.
|
||||
|
||||
```python
|
||||
SFTConfig(max_length=None, ...)
|
||||
```
|
||||
|
||||
Only use `max_length` when you've verified that truncation won't remove image tokens for the entire dataset.
|
||||
|
||||
</Tip>
|
||||
> [!TIP]
|
||||
> For VLMs, truncating may remove image tokens, leading to errors during training. To avoid this, set `max_length=None` in the [`SFTConfig`]. This allows the model to process the full sequence length without truncating image tokens.
|
||||
>
|
||||
> ```python
|
||||
> SFTConfig(max_length=None, ...)
|
||||
> ```
|
||||
>
|
||||
> Only use `max_length` when you've verified that truncation won't remove image tokens for the entire dataset.
|
||||
|
||||
## SFTTrainer
|
||||
|
||||
|
@ -1,10 +1,7 @@
|
||||
# Speeding Up Training
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Section under construction. Feel free to contribute!
|
||||
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> Section under construction. Feel free to contribute!
|
||||
|
||||
## vLLM for fast generation in online methods
|
||||
|
||||
@ -14,13 +11,7 @@ To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm)
|
||||
To use [vLLM](https://github.com/vllm-project/vllm), first install it using:
|
||||
|
||||
```bash
|
||||
pip install vllm
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```bash
|
||||
pip install "trl[vllm]"
|
||||
pip install trl[vllm]
|
||||
```
|
||||
|
||||
<hfoptions id="vllm examples">
|
||||
@ -53,21 +44,20 @@ training_args = GRPOConfig(..., use_vllm=True)
|
||||
|
||||
You can customize the server configuration by passing additional arguments. For more information, see [vLLM integration](vllm_integration).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
When using vLLM, ensure that the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation using `CUDA_VISIBLE_DEVICES`.
|
||||
|
||||
Set GPUs **0-3** for vLLM generation:
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model <model_name>
|
||||
```
|
||||
|
||||
And GPUs **4-7** for training:
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
|
||||
```
|
||||
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> When using vLLM, ensure that the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation using `CUDA_VISIBLE_DEVICES`.
|
||||
>
|
||||
> Set GPUs **0-3** for vLLM generation:
|
||||
>
|
||||
> ```sh
|
||||
> CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model <model_name>
|
||||
> ```
|
||||
>
|
||||
> And GPUs **4-7** for training:
|
||||
>
|
||||
> ```sh
|
||||
> CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
|
||||
> ```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="RLOO">
|
||||
@ -88,21 +78,20 @@ training_args = RLOOConfig(..., use_vllm=True)
|
||||
|
||||
You can customize the server configuration by passing additional arguments. For more information, see [vLLM integration](vllm_integration).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
When using vLLM, ensure that the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation using `CUDA_VISIBLE_DEVICES`.
|
||||
|
||||
Set GPUs **0-3** for vLLM generation:
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model <model_name>
|
||||
```
|
||||
|
||||
And GPUs **4-7** for training:
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
|
||||
```
|
||||
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> When using vLLM, ensure that the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation using `CUDA_VISIBLE_DEVICES`.
|
||||
>
|
||||
> Set GPUs **0-3** for vLLM generation:
|
||||
>
|
||||
> ```sh
|
||||
> CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model <model_name>
|
||||
> ```
|
||||
>
|
||||
> And GPUs **4-7** for training:
|
||||
>
|
||||
> ```sh
|
||||
> CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
|
||||
> ```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
@ -64,4 +64,4 @@ trainer.train()
|
||||
|
||||
will give you a hosted dashboard at https://huggingface.co/spaces/trl-lib/trackio.
|
||||
|
||||
<iframe src="https://trl-lib-trackio.hf.space/?project=trl-documentation&sidebar=hidden" style="width: 100%; min-width: 300px; max-width: 800px;" height="830" frameBorder="0"></iframe>
|
||||
<iframe src="https://trl-lib-trackio.hf.space/?project=trl-documentation&sidebar=hidden&runs=sft_qwen3-0.6B_capybara" style="width: 100%; min-width: 300px; max-width: 800px;" height="830" frameBorder="0"></iframe>
|
||||
|
@ -1,159 +0,0 @@
|
||||
# Using LLaMA models with TRL
|
||||
|
||||
We've begun rolling out examples to use Meta's LLaMA models in `trl` (see [Meta's LLaMA release](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) for the original LLaMA model).
|
||||
|
||||
## Efficient training strategies
|
||||
|
||||
Even training the smallest LLaMA model requires an enormous amount of memory. Some quick math: in bf16, every parameter uses 2 bytes (in fp32 4 bytes) in addition to 8 bytes used, e.g., in the Adam optimizer (see the [performance docs](https://huggingface.co/docs/transformers/perf_train_gpu_one#optimizer) in Transformers for more info). So a 7B parameter model would use `(2+8)*7B=70GB` just to fit in memory and would likely need more when you compute intermediate values such as attention scores. So you couldn’t train the model even on a single 80GB A100 like that. You can use some tricks, like more efficient optimizers of half-precision training, to squeeze a bit more into memory, but you’ll run out sooner or later.
|
||||
|
||||
Another option is to use Parameter-Efficient Fine-Tuning (PEFT) techniques, such as the [`peft`](https://github.com/huggingface/peft) library, which can perform low-rank adaptation (LoRA) on a model loaded in 8-bit.
|
||||
For more on `peft` + `trl`, see the [Peft integration](peft_integration) docs.
|
||||
|
||||
Loading the model in 8bit reduces the memory footprint drastically since you only need one byte per parameter for the weights (e.g. 7B LlaMa is 7GB in memory).
|
||||
Instead of training the original weights directly, LoRA adds small adapter layers on top of some specific layers (usually the attention layers); thus, the number of trainable parameters is drastically reduced.
|
||||
|
||||
In this scenario, a rule of thumb is to allocate ~1.2-1.4GB per billion parameters (depending on the batch size and sequence length) to fit the entire fine-tuning setup.
|
||||
This enables fine-tuning larger models (up to 50-60B scale models on a NVIDIA A100 80GB) at low cost.
|
||||
|
||||
Now we can fit very large models into a single GPU, but the training might still be very slow.
|
||||
The simplest strategy in this scenario is data parallelism: we replicate the same training setup into separate GPUs and pass different batches to each GPU.
|
||||
With this, you can parallelize the forward/backward passes of the model and scale with the number of GPUs.
|
||||
|
||||

|
||||
|
||||
We use either the `transformers.Trainer` or `accelerate`, which both support data parallelism without any code changes, by simply passing arguments when calling the scripts with `torchrun` or `accelerate launch`. The following runs a training script with 8 GPUs on a single machine with `accelerate` and `torchrun`, respectively.
|
||||
|
||||
```bash
|
||||
accelerate launch --multi_gpu --num_machines 1 --num_processes 8 my_accelerate_script.py
|
||||
torchrun --nnodes 1 --nproc_per_node 8 my_torch_script.py
|
||||
```
|
||||
|
||||
## Supervised fine-tuning
|
||||
|
||||
Before we start training reward models and tuning our model with RL, it helps if the model is already good in the domain we are interested in.
|
||||
In our case, we want it to answer questions, while for other use cases, we might want it to follow instructions, in which case instruction tuning is a great idea.
|
||||
The easiest way to achieve this is by continuing to train the language model with the language modeling objective on texts from the domain or task.
|
||||
The [StackExchange dataset](https://huggingface.co/datasets/HuggingFaceH4/stack-exchange-preferences) is enormous (over 10 million instructions), so we can easily train the language model on a subset of it.
|
||||
|
||||
There is nothing special about fine-tuning the model before doing RLHF - it’s just the causal language modeling objective from pretraining that we apply here.
|
||||
To use the data efficiently, we use a technique called packing: instead of having one text per sample in the batch and then padding to either the longest text or the maximal context of the model, we concatenate a lot of texts with an EOS token in between and cut chunks of the context size to fill the batch without any padding.
|
||||
|
||||

|
||||
|
||||
With this approach the training is much more efficient as each token that is passed through the model is also trained in contrast to padding tokens which are usually masked from the loss.
|
||||
If you don't have much data and are more concerned about occasionally cutting off some tokens that are overflowing the context you can also use a classical data loader.
|
||||
|
||||
|
||||
```python
|
||||
# load model in 8bit
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_path,
|
||||
load_in_8bit=True,
|
||||
device_map={"": Accelerator().local_process_index}
|
||||
)
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
# add LoRA to model
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
model = get_peft_model(model, config)
|
||||
```
|
||||
|
||||
We train the model for a few thousand steps with the causal language modeling objective and save the model.
|
||||
Since we will tune the model again with different objectives, we merge the adapter weights with the original model weights.
|
||||
|
||||
**Disclaimer:** due to LLaMA's license, we release only the adapter weights for this and the model checkpoints in the following sections.
|
||||
You can apply for access to the base model's weights by filling out Meta AI's [form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) and then converting them to the 🤗 Transformers format by running this [script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py).
|
||||
Note that you'll also need to install 🤗 Transformers from source until the `v4.28` is released.
|
||||
|
||||
Now that we have fine-tuned the model for the task, we are ready to train a reward model.
|
||||
|
||||
## Reward modeling and human preferences
|
||||
|
||||
In principle, we could fine-tune the model using RLHF directly with the human annotations.
|
||||
However, this would require us to send some samples to humans for rating after each optimization iteration.
|
||||
This is expensive and slow due to the number of training samples needed for convergence and the inherent latency of human reading and annotator speed.
|
||||
|
||||
A trick that works well instead of direct feedback is training a reward model on human annotations collected before the RL loop.
|
||||
The goal of the reward model is to imitate how a human would rate a text. There are several possible strategies to build a reward model: the most straightforward way would be to predict the annotation (e.g. a rating score or a binary value for “good”/”bad”).
|
||||
In practice, what works better is to predict the ranking of two examples, where the reward model is presented with two candidates `(y_k, y_j)` for a given prompt `x` and has to predict which one would be rated higher by a human annotator.
|
||||
|
||||
With the StackExchange dataset, we can infer which of the two answers was preferred by the users based on the score.
|
||||
With that information and the loss defined above, we can then modify the `transformers.Trainer` by adding a custom loss function.
|
||||
|
||||
```python
|
||||
class RewardTrainer(Trainer):
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
rewards_j = model(input_ids=inputs["input_ids_j"], attention_mask=inputs["attention_mask_j"])[0]
|
||||
rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0]
|
||||
loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean()
|
||||
if return_outputs:
|
||||
return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k}
|
||||
return loss
|
||||
```
|
||||
|
||||
We utilize a subset of a 100,000 pair of candidates and evaluate on a held-out set of 50,000. With a modest training batch size of 4, we train the Llama model using the LoRA `peft` adapter for a single epoch using the Adam optimizer with BF16 precision. Our LoRA configuration is:
|
||||
|
||||
```python
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.SEQ_CLS,
|
||||
inference_mode=False,
|
||||
r=8,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.1,
|
||||
)
|
||||
```
|
||||
As detailed in the next section, the resulting adapter can be merged into the frozen model and saved for further downstream use.
|
||||
|
||||
## Reinforcement Learning from Human Feedback
|
||||
|
||||
With the fine-tuned language model and the reward model at hand, we are now ready to run the RL loop. It follows roughly three steps:
|
||||
|
||||
1. Generate responses from prompts,
|
||||
2. Rate the responses with the reward model,
|
||||
3. Run a reinforcement learning policy-optimization step with the ratings.
|
||||
|
||||
The Query and Response prompts are templated as follows before being tokenized and passed to the model:
|
||||
|
||||
```bash
|
||||
Question: <Query>
|
||||
|
||||
Answer: <Response>
|
||||
```
|
||||
|
||||
The same template was used for SFT, RM and RLHF stages.
|
||||
Once more, we utilize `peft` for memory-efficient training, which offers an extra advantage in the RLHF context.
|
||||
Here, the reference model and policy share the same base, the SFT model, which we load in 8-bit and freeze during training.
|
||||
We exclusively optimize the policy's LoRA weights using PPO while sharing the base model's weights.
|
||||
|
||||
```python
|
||||
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
|
||||
question_tensors = batch["input_ids"]
|
||||
|
||||
# sample from the policy and to generate responses
|
||||
response_tensors = ppo_trainer.generate(
|
||||
question_tensors,
|
||||
return_prompt=False,
|
||||
length_sampler=output_length_sampler,
|
||||
**generation_kwargs,
|
||||
)
|
||||
batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
|
||||
|
||||
# Compute sentiment score
|
||||
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
||||
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
|
||||
rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs]
|
||||
|
||||
# Run PPO step
|
||||
stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
|
||||
# Log stats to Wandb
|
||||
ppo_trainer.log_stats(stats, batch, rewards)
|
||||
```
|
||||
|
||||
For the rest of the details and evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama).
|
@ -1,14 +1,26 @@
|
||||
# vLLM Integration
|
||||
|
||||
This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood. Let's go! 🔥
|
||||
This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood.
|
||||
|
||||
> [!WARNING]
|
||||
> TRL currently only supports vLLM version `0.10.2`. Please ensure you have this version installed to avoid compatibility issues.
|
||||
|
||||
> [!TIP]
|
||||
> The following trainers currently support generation with vLLM:
|
||||
>
|
||||
> - [`GRPOTrainer`]
|
||||
> - [`OnlineDPOTrainer`]
|
||||
> - [`NashMDTrainer`]
|
||||
> - [`XPOTrainer`]
|
||||
> - [`RLOOTrainer`]
|
||||
|
||||
## 🚀 How can I use vLLM with TRL to speed up training?
|
||||
|
||||
💡 **Note**: Resources required for this specific example: a single node with 8 GPUs.
|
||||
|
||||
<Tip warning={true}>
|
||||
vLLM server and TRL trainer must use different CUDA devices to avoid conflicts.
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> When using vLLM with TRL, the **vLLM server** and the **trainer** must run on **separate CUDA devices** to prevent conflicts.
|
||||
> For guidance on configuring this properly, see [Modes of using vLLM during training](#modes-of-using-vllm-during-training).
|
||||
|
||||
First, install vLLM using the following command:
|
||||
|
||||
@ -22,12 +34,15 @@ Then run the server on specific GPUs (e.g., GPUs 0-3):
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 2 --data-parallel-size 2
|
||||
```
|
||||
|
||||
Once the server is running, you can use it to generate completions for training. In the example below, we are using the `GRPOTrainer` to train a model using the vLLM server for generation. The `--tensor-parallel-size` and `--data-parallel-size` arguments control how the model and data are sharded across GPUs.
|
||||
Once the server is running, you can use it to generate completions for training. In the example below, we are using the different supported trainers using the vLLM server for generation. The `--tensor-parallel-size` and `--data-parallel-size` arguments control how the model and data are sharded across GPUs.
|
||||
|
||||
In this example, we are sharding two copies of the model across 4 GPUs. Increasing data parallelism increases throughput, while increasing tensor parallelism allows for serving larger models. Then, run the training script on different GPUs (e.g., GPUs 4-7) by passing `use_vllm=True` in the training arguments as follows:
|
||||
|
||||
Sample of a simple `train.py` script:
|
||||
|
||||
<hfoptions id="vllm examples">
|
||||
<hfoption id="GRPO">
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
@ -55,21 +70,148 @@ trainer = GRPOTrainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="OnlineDPO">
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import OnlineDPOTrainer, OnlineDPOConfig
|
||||
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
# Dummy reward function: count the number of unique characters in the completions
|
||||
def reward_num_unique_chars(completions, **kwargs):
|
||||
return [len(set(c)) for c in completions]
|
||||
|
||||
training_args = OnlineDPOConfig(
|
||||
output_dir="my_test",
|
||||
use_vllm=True,
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
)
|
||||
|
||||
trainer = OnlineDPOTrainer(
|
||||
model="Qwen/Qwen2.5-7B",
|
||||
args=training_args,
|
||||
reward_funcs=reward_num_unique_chars,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="NashMD">
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import NashMDTrainer, NashMDConfig
|
||||
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
# Dummy reward function: count the number of unique characters in the completions
|
||||
def reward_num_unique_chars(completions, **kwargs):
|
||||
return [len(set(c)) for c in completions]
|
||||
|
||||
training_args = NashMDConfig(
|
||||
output_dir="my_test",
|
||||
use_vllm=True,
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
)
|
||||
|
||||
trainer = NashMDTrainer(
|
||||
model="Qwen/Qwen2.5-7B",
|
||||
args=training_args,
|
||||
reward_funcs=reward_num_unique_chars,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="XPO">
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import XPOTrainer, XPOConfig
|
||||
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
# Dummy reward function: count the number of unique characters in the completions
|
||||
def reward_num_unique_chars(completions, **kwargs):
|
||||
return [len(set(c)) for c in completions]
|
||||
|
||||
training_args = XPOConfig(
|
||||
output_dir="my_test",
|
||||
use_vllm=True,
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
)
|
||||
|
||||
trainer = XPOTrainer(
|
||||
model="Qwen/Qwen2.5-7B",
|
||||
args=training_args,
|
||||
reward_funcs=reward_num_unique_chars,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="RLOO">
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import RLOOTrainer, RLOOConfig
|
||||
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
# Dummy reward function: count the number of unique characters in the completions
|
||||
def reward_num_unique_chars(completions, **kwargs):
|
||||
return [len(set(c)) for c in completions]
|
||||
|
||||
training_args = RLOOConfig(
|
||||
output_dir="my_test",
|
||||
use_vllm=True,
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
)
|
||||
|
||||
trainer = RLOOTrainer(
|
||||
model="Qwen/Qwen2.5-7B",
|
||||
args=training_args,
|
||||
reward_funcs=reward_num_unique_chars,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
And the train command on separate GPUs from the server:
|
||||
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
|
||||
```
|
||||
|
||||
## 🎬 Flashback: Why do we need to use vLLM in online methods?
|
||||
## Why using vLLM?
|
||||
|
||||
### 🎬 Flashback: Why do we need to use vLLM in online methods?
|
||||
|
||||
Online methods like GRPO or Online DPO require the model to generate completions during training, which are then used to compute reward signals. However, generation can be extremely time-consuming, especially with large or reasoning models. In the default setup (without vLLM), completions are generated using the [(unwrapped) model's `generate` method](https://github.com/huggingface/trl/blob/f3e8c2304428ef16e9ae5de9e5741ed84d533b7b/trl/trainer/grpo_trainer.py#L965C39-L965C66). This approach quickly becomes a major bottleneck — generation is slow and inefficient, particularly for large batches or models. As a result, training times increase significantly, and overall efficiency drops. To address this, we turn to vLLM, which enables much faster and more scalable generation, helping eliminate this bottleneck in online methods.
|
||||
|
||||
## 🤔 How does vLLM solve the slow generation issue?
|
||||
### 🤔 How does vLLM solve the slow generation issue?
|
||||
|
||||
If you've ever done autoregressive decoder training, you know all the input tokens to the LLM produce their attention key and value tensors, and these tensors are kept in GPU memory to later generate subsequent tokens based on them. These cached key and value tensors are often referred to as the KV cache. However, storing the KV cache occupies a lot of memory, so vLLM uses a technique called **PagedAttention** to solve this problem. PagedAttention, which is inspired by the OS’s virtual memory concept, stores continuous keys and values in **non-contiguous memory space**, which is much more efficient. The details of this are beyond the scope of this document, but in short, it allows the model to store the keys and values in a more efficient way, reducing the memory footprint and speeding up the generation process. If you are interested, make sure to check out the [vLLM PagedAttention](https://blog.vllm.ai/2023/06/20/vllm.html) for more details.
|
||||
|
||||
## 🤔 What exactly happens when you run `trl vllm-serve --model <model_name>`?
|
||||
## How vLLM Works (Under the Hood) 🔍
|
||||
|
||||
### 🤔 What exactly happens when you run `trl vllm-serve --model <model_name>`?
|
||||
|
||||
When you run for example
|
||||
|
||||
@ -90,18 +232,18 @@ Each worker operates independently and processes a chunk of the incoming request
|
||||
This GPU-to-GPU communication is managed efficiently by NVIDIA’s NCCL library. The communication mainly ensures that each GPU gets its correct portion of the incoming requests — it’s lightweight and doesn’t interfere with generation itself.
|
||||
Separately, the number of completions to generate per prompt is controlled by the `num_generations` setting in the GRPO config. For instance, if you set `num_generations=2` (like in the picture above), each prompt will have 2 completions. So, with 8 prompts and `num_generations=2`, you would end up with 16 completions total — regardless of the number of GPUs or parallelism settings.
|
||||
|
||||
## 🥸 More detail on what happens under the hood when running the server
|
||||
### 🥸 More detail on what happens under the hood when running the server
|
||||
|
||||
* The vLLM server starts by running the command: `trl vllm-serve --model Qwen/Qwen2.5-7B`.
|
||||
* Once the server is running, it generates completions based on requests from the client (trainer) using `vllm_client.generate` [here](https://github.com/huggingface/trl/blob/cc044e35b285be7dc062764b3364e1e684db4c7c/trl/trainer/grpo_trainer.py#L1025-L1035).
|
||||
* The client (trainer) then requests these completions from the server.
|
||||
* These completions are used to compute the reward signal.
|
||||
* Based on the reward signal and the model’s output, the loss is computed, and the backward pass is performed to update the model’s weights.
|
||||
* **Note**: The server only handles completion generation — it doesn’t train the model. Therefore, the model’s weights aren’t updated on the server. Once the backward pass is complete, the client sends the updated weights to the server using `vllm_client.update_named_param(name, param.data)`.
|
||||
- The vLLM server starts by running the command: `trl vllm-serve --model Qwen/Qwen2.5-7B`.
|
||||
- Once the server is running, it generates completions based on requests from the client (trainer) using `vllm_client.generate` [these lines](https://github.com/huggingface/trl/blob/cc044e35b285be7dc062764b3364e1e684db4c7c/trl/trainer/grpo_trainer.py#L1025-L1035).
|
||||
- The client (trainer) then requests these completions from the server.
|
||||
- These completions are used to compute the reward signal.
|
||||
- Based on the reward signal and the model’s output, the loss is computed, and the backward pass is performed to update the model’s weights.
|
||||
- **Note**: The server only handles completion generation — it doesn’t train the model. Therefore, the model’s weights aren’t updated on the server. Once the backward pass is complete, the client sends the updated weights to the server using `vllm_client.update_named_param(name, param.data)`.
|
||||
|
||||
When using vLLM, ensure the GPUs assigned for training and generation are separate to avoid NCCL communication conflicts. If you do not set the `CUDA_VISIBLE_DEVICES` environment variable, the training script will use all available GPUs by default, which may lead to device conflicts. Starting from TRL next release after v0.19.1, the code automatically detects and prevents same-device usage, raising a error at the vllm server process:
|
||||
|
||||
```
|
||||
```log
|
||||
RuntimeError: Attempting to use the same CUDA device for multiple distinct roles/ranks within the same communicator.
|
||||
Ensure that trainer is using different devices than vLLM server.
|
||||
```
|
||||
@ -112,19 +254,21 @@ For example, if you want to use GPUs 4–7 for training while the server runs on
|
||||
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
|
||||
```
|
||||
|
||||
## 🍷 More customization options with vLLM?
|
||||
## Advanced usage
|
||||
|
||||
### 🍷 More customization options with vLLM?
|
||||
|
||||
You can customize the server configuration by passing additional arguments.
|
||||
|
||||
```
|
||||
```txt
|
||||
$ trl vllm-serve --help
|
||||
usage: trl vllm-serve [-h] --model MODEL [--revision REVISION] [--tensor_parallel_size TENSOR_PARALLEL_SIZE]
|
||||
[--data_parallel_size DATA_PARALLEL_SIZE] [--host HOST] [--port PORT]
|
||||
[--gpu_memory_utilization GPU_MEMORY_UTILIZATION] [--dtype DTYPE] [--max_model_len MAX_MODEL_LEN]
|
||||
[--enable_prefix_caching ENABLE_PREFIX_CACHING] [--enforce_eager ENFORCE_EAGER] [--log_level LOG_LEVEL]
|
||||
usage: trl vllm-serve [-h] --model MODEL [--revision REVISION] [--tensor_parallel_size TENSOR_PARALLEL_SIZE] [--data_parallel_size DATA_PARALLEL_SIZE] [--host HOST]
|
||||
[--port PORT] [--gpu_memory_utilization GPU_MEMORY_UTILIZATION] [--dtype DTYPE] [--max_model_len MAX_MODEL_LEN]
|
||||
[--enable_prefix_caching ENABLE_PREFIX_CACHING] [--enforce_eager [ENFORCE_EAGER]] [--kv_cache_dtype KV_CACHE_DTYPE]
|
||||
[--trust_remote_code [TRUST_REMOTE_CODE]] [--log_level LOG_LEVEL] [--vllm_model_impl VLLM_MODEL_IMPL]
|
||||
|
||||
options:
|
||||
-h, --help Show this help message and exit
|
||||
-h, --help show this help message and exit
|
||||
--model MODEL Model name or path to load the model from. (default: None)
|
||||
--revision REVISION Revision to use for the model. If not specified, the default branch will be used. (default: None)
|
||||
--tensor_parallel_size TENSOR_PARALLEL_SIZE, --tensor-parallel-size TENSOR_PARALLEL_SIZE
|
||||
@ -134,63 +278,222 @@ options:
|
||||
--host HOST Host address to run the server on. (default: 0.0.0.0)
|
||||
--port PORT Port to run the server on. (default: 8000)
|
||||
--gpu_memory_utilization GPU_MEMORY_UTILIZATION, --gpu-memory-utilization GPU_MEMORY_UTILIZATION
|
||||
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the device
|
||||
dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus improve the
|
||||
model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors during
|
||||
initialization. (default: 0.9)
|
||||
--dtype DTYPE Data type to use for vLLM generation. If set to 'auto', the data type will be automatically determined based on
|
||||
the model configuration. Find the supported values in the vLLM documentation. (default: auto)
|
||||
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the device dedicated to generation
|
||||
powered by vLLM. Higher values will increase the KV cache size and thus improve the model's throughput. However, if the value is too high,
|
||||
it may cause out-of-memory (OOM) errors during initialization. (default: 0.9)
|
||||
--dtype DTYPE Data type to use for vLLM generation. If set to 'auto', the data type will be automatically determined based on the model configuration.
|
||||
Find the supported values in the vLLM documentation. (default: auto)
|
||||
--max_model_len MAX_MODEL_LEN, --max-model-len MAX_MODEL_LEN
|
||||
If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced
|
||||
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model context
|
||||
size, which might be much larger than the KV cache, leading to inefficiencies. (default: None)
|
||||
If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced `vllm_gpu_memory_utilization`, leading to a
|
||||
reduced KV cache size. If not set, vLLM will use the model context size, which might be much larger than the KV cache, leading to
|
||||
inefficiencies. (default: None)
|
||||
--enable_prefix_caching ENABLE_PREFIX_CACHING, --enable-prefix-caching ENABLE_PREFIX_CACHING
|
||||
Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support this
|
||||
feature. (default: None)
|
||||
--enforce_eager ENFORCE_EAGER, --enforce-eager ENFORCE_EAGER
|
||||
Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the model
|
||||
in eager mode. If `False` (default behavior), we will use CUDA graph and eager execution in hybrid. (default:
|
||||
None)
|
||||
Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support this feature. (default: None)
|
||||
--enforce_eager [ENFORCE_EAGER], --enforce-eager [ENFORCE_EAGER]
|
||||
Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the model in eager mode. If `False`
|
||||
(default behavior), we will use CUDA graph and eager execution in hybrid. (default: False)
|
||||
--kv_cache_dtype KV_CACHE_DTYPE, --kv-cache-dtype KV_CACHE_DTYPE
|
||||
Data type to use for KV cache. If set to 'auto', the dtype will default to the model data type. (default: auto)
|
||||
--trust_remote_code [TRUST_REMOTE_CODE], --trust-remote-code [TRUST_REMOTE_CODE]
|
||||
Whether to trust remote code when loading models. Set to True to allow executing code from model repositories. This is required for some
|
||||
custom models but introduces security risks. (default: False)
|
||||
--log_level LOG_LEVEL, --log-level LOG_LEVEL
|
||||
Log level for uvicorn. Possible choices: 'critical', 'error', 'warning', 'info', 'debug', 'trace'. (default:
|
||||
info)
|
||||
Log level for uvicorn. Possible choices: 'critical', 'error', 'warning', 'info', 'debug', 'trace'. (default: info)
|
||||
--vllm_model_impl VLLM_MODEL_IMPL, --vllm-model-impl VLLM_MODEL_IMPL
|
||||
Model implementation to use for vLLM. Must be one of `transformers` or `vllm`. `transformers`: Use the `transformers` backend for model
|
||||
implementation. `vllm`: Use the `vllm` library for model implementation. (default: vllm)
|
||||
```
|
||||
|
||||
## 🥳 Okay, now that we have the server running, how can we use it to generate completions?
|
||||
### 💆🏻♀️ What's the best distributed setup?
|
||||
|
||||
Run the training script and pass `use_vllm=True` in the training arguments:
|
||||

|
||||

|
||||
|
||||
First and foremost, always remember that the optimal setup depends on:
|
||||
|
||||
- The model size
|
||||
- The number of GPUs you have
|
||||
- The GPU memory size
|
||||
- The batch size you are using
|
||||
- The number of requests you are sending to the server (prompts)
|
||||
- The `max_model_len` you are using (this is the max length of the input sequence that the model can process, a.k.a. the context window size)
|
||||
- The number of completions you are generating for each request (`num_generations`)
|
||||
|
||||
Given these factors, our experiments on the Qwen model family (3B, 7B, 14B, 32B) using 8 H100 GPUs show that:
|
||||
|
||||
- For reasonable-sized models (3B–14B) and a moderate context window (`max_len < 8k`), using full capacity for data parallelism gives better throughput. The setup `(tp=1, dp=8)` yields the best results.
|
||||
- For larger models (32B) and longer context windows (`max_len > 8k`), a smaller DP size combined with some model-side parallelism performs better. For example, `(tp=2, dp=4)` is a good setup for 32B models with a larger context window.
|
||||
|
||||
### vLLM with Transformers Backend
|
||||
|
||||
vLLM can use the **Transformers backend** for model implementations, which works for both LLMs and VLMs.
|
||||
To enable this, set `vllm_model_impl="transformers"` in your configuration or pass it via the command-line argument.
|
||||
|
||||
For more details, check out [vLLM Transformers Backend](https://blog.vllm.ai/2025/04/11/transformers-backend.html).
|
||||
|
||||
Example:
|
||||
|
||||
```sh
|
||||
CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen
|
||||
2.5-VL-3B-Instruct --tensor-parallel-size 1 --port 8000 --enforce_eager --vllm_model_impl transformers
|
||||
```
|
||||
|
||||
### Modes of Using vLLM During Training
|
||||
|
||||
TRL supports **two modes** for integrating vLLM during training: **server mode** and **colocate mode**.
|
||||
|
||||
#### Server Mode
|
||||
|
||||
In **server mode**, vLLM runs as a separate process on dedicated GPUs and communicates with the trainer via HTTP.
|
||||
This setup is ideal if you have GPUs dedicated to inference.
|
||||
|
||||
Example configuration:
|
||||
|
||||
<hfoptions id="vllm examples">
|
||||
<hfoption id="GRPO">
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(..., use_vllm=True)
|
||||
training_args = GRPOConfig(
|
||||
...,
|
||||
use_vllm=True,
|
||||
vllm_mode="server", # default value, can be omitted
|
||||
)
|
||||
```
|
||||
|
||||
## 💆🏻♀️ What's the best distributed setup?
|
||||
</hfoption>
|
||||
<hfoption id="OnlineDPO">
|
||||
|
||||

|
||||

|
||||
|
||||
First and foremost, always remember that the optimal setup depends on:
|
||||
|
||||
* The model size
|
||||
* The number of GPUs you have
|
||||
* The GPU memory size
|
||||
* The batch size you are using
|
||||
* The number of requests you are sending to the server (prompts)
|
||||
* The `max_model_len` you are using (this is the max length of the input sequence that the model can process, a.k.a. the context window size)
|
||||
* The number of completions you are generating for each request (`num_generations`)
|
||||
|
||||
Given these factors, our experiments on the Qwen model family (3B, 7B, 14B, 32B) using 8 H100 GPUs show that:
|
||||
|
||||
* For reasonable-sized models (3B–14B) and a moderate context window (`max_len < 8k`), using full capacity for data parallelism gives better throughput. The setup `(tp=1, dp=8)` yields the best results.
|
||||
* For larger models (32B) and longer context windows (`max_len > 8k`), a smaller DP size combined with some model-side parallelism performs better. For example, `(tp=2, dp=4)` is a good setup for 32B models with a larger context window.
|
||||
|
||||
## vLLM with Transformers Backend
|
||||
|
||||
vLLM now supports transformers backend for model implementations. Simply passing in `transformers` in `vllm_model_impl` in configurations or through argument parser will set use transformers backend. This works for both LLMs and VLMs. See an example below, you can get more information [here](https://blog.vllm.ai/2025/04/11/transformers-backend.html).
|
||||
```python
|
||||
from trl import OnlineDPOConfig
|
||||
|
||||
training_args = OnlineDPOConfig(
|
||||
...,
|
||||
use_vllm=True,
|
||||
vllm_mode="server", # default value, can be omitted
|
||||
)
|
||||
```
|
||||
CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen
|
||||
2.5-VL-3B-Instruct --tensor-parallel-size 1 --port 8000 --enforce_eager --vllm_model_impl transformers
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="NashMD">
|
||||
|
||||
```python
|
||||
from trl import NashMDConfig
|
||||
|
||||
training_args = NashMDConfig(
|
||||
...,
|
||||
use_vllm=True,
|
||||
vllm_mode="server", # default value, can be omitted
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="XPO">
|
||||
|
||||
```python
|
||||
from trl import XPOConfig
|
||||
|
||||
training_args = XPOConfig(
|
||||
...,
|
||||
use_vllm=True,
|
||||
vllm_mode="server", # default value, can be omitted
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="RLOO">
|
||||
|
||||
```python
|
||||
from trl import RLOOConfig
|
||||
|
||||
training_args = RLOOConfig(
|
||||
...,
|
||||
use_vllm=True,
|
||||
vllm_mode="server", # default value, can be omitted
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
#### Colocate Mode
|
||||
|
||||
In **colocate mode**, vLLM runs inside the trainer process and shares GPU memory with the training model.
|
||||
This avoids launching a separate server and can improve GPU utilization, but may lead to memory contention on the training GPUs.
|
||||
|
||||
Example configuration:
|
||||
|
||||
<hfoptions id="vllm examples">
|
||||
<hfoption id="GRPO">
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
...,
|
||||
use_vllm=True,
|
||||
vllm_mode="colocate",
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="OnlineDPO">
|
||||
|
||||
```python
|
||||
from trl import OnlineDPOConfig
|
||||
|
||||
training_args = OnlineDPOConfig(
|
||||
...,
|
||||
use_vllm=True,
|
||||
vllm_mode="colocate",
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="NashMD">
|
||||
|
||||
```python
|
||||
from trl import NashMDConfig
|
||||
|
||||
training_args = NashMDConfig(
|
||||
...,
|
||||
use_vllm=True,
|
||||
vllm_mode="colocate",
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="XPO">
|
||||
|
||||
```python
|
||||
from trl import XPOConfig
|
||||
|
||||
training_args = XPOConfig(
|
||||
...,
|
||||
use_vllm=True,
|
||||
vllm_mode="colocate",
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="RLOO">
|
||||
|
||||
```python
|
||||
from trl import RLOOConfig
|
||||
|
||||
training_args = RLOOConfig(
|
||||
...,
|
||||
use_vllm=True,
|
||||
vllm_mode="colocate",
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
> [!WARNING]
|
||||
> Check the documentation of the trainer you are using for specific details on vLLM usage and parameters.
|
||||
|
||||
> [!WARNING]
|
||||
> To reduce GPU memory usage when running vLLM, consider [enabling vLLM sleep mode](reducing_memory_usage#vllm-sleep-mode).
|
||||
|
@ -1,6 +1,6 @@
|
||||
# XPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=xpo,trl)
|
||||
[](https://huggingface.co/models?other=xpo,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
@ -57,7 +57,7 @@ To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-XPO) pe
|
||||
What is the best programming language?
|
||||
|
||||
<strong><span style="color: blue;"><trl-lib/Qwen2-0.5B-XPO>:</span></strong>
|
||||
The best programming language depends on individual preferences and familiarity with coding concepts. Some popular languages include Python, Java, C++, and JavaScript.
|
||||
The best programming language depends on individual preferences and familiarity with coding concepts. Some popular languages include Python, Java, C++, and JavaScript.
|
||||
</code></pre>
|
||||
|
||||
## Expected dataset type
|
||||
@ -84,11 +84,8 @@ Instead of a judge, you can chose to use a reward model -- see [Reward Bench](ht
|
||||
)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training.
|
||||
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> Make sure that the SFT model and reward model use the _same_ chat template and the same tokenizer. Otherwise, you may find the model completions are scored incorrectly during training.
|
||||
|
||||
### Encourage EOS token generation
|
||||
|
||||
@ -151,7 +148,6 @@ While training and evaluating we record the following reward metrics:
|
||||
* `alpha`: The weight of the XPO loss term. Typically fixed, but can be made dynamic by passing a list to [`XPOConfig`].
|
||||
* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`XPOConfig`].
|
||||
|
||||
|
||||
## XPOTrainer
|
||||
|
||||
[[autodoc]] XPOTrainer
|
||||
|
@ -1,3 +1,3 @@
|
||||
# Examples
|
||||
|
||||
Please check out https://huggingface.co/docs/trl/example_overview for documentation on our examples.
|
||||
Please check out https://huggingface.co/docs/trl/example_overview for documentation on our examples.
|
||||
|
@ -5,7 +5,7 @@ distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
fsdp_config:
|
||||
fsdp_activation_checkpointing: false
|
||||
fsdp_activation_checkpointing: true # Enable activation checkpointing for memory efficiency
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_offload_params: false
|
||||
@ -16,7 +16,7 @@ machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 2
|
||||
num_processes: 2 # Number of GPUs
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
@ -27,4 +27,4 @@ parallelism_config:
|
||||
parallelism_config_dp_replicate_size: 1
|
||||
parallelism_config_dp_shard_size: 1
|
||||
parallelism_config_tp_size: 1
|
||||
parallelism_config_cp_size: 2
|
||||
parallelism_config_cp_size: 2 # Context parallel size
|
||||
|
@ -31,7 +31,7 @@ class ScriptArguments:
|
||||
Whether to push the dataset to the Hugging Face Hub.
|
||||
repo_id (`str`, *optional*, defaults to `"trl-lib/hh-rlhf-helpful-base"`):
|
||||
Hugging Face repository ID to push the dataset to.
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
dataset_num_proc (`int`, *optional*):
|
||||
Number of workers to use for dataset processing.
|
||||
"""
|
||||
|
||||
|
@ -31,7 +31,7 @@ class ScriptArguments:
|
||||
Whether to push the dataset to the Hugging Face Hub.
|
||||
repo_id (`str`, *optional*, defaults to `"trl-lib/llava-instruct-mix"`):
|
||||
Hugging Face repository ID to push the dataset to.
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
dataset_num_proc (`int`, *optional*):
|
||||
Number of workers to use for dataset processing.
|
||||
"""
|
||||
|
||||
|
@ -30,7 +30,7 @@ class ScriptArguments:
|
||||
Whether to push the dataset to the Hugging Face Hub.
|
||||
repo_id (`str`, *optional*, defaults to `"trl-lib/lm-human-preferences-descriptiveness"`):
|
||||
Hugging Face repository ID to push the dataset to.
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
dataset_num_proc (`int`, *optional*):
|
||||
Number of workers to use for dataset processing.
|
||||
"""
|
||||
|
||||
|
@ -30,7 +30,7 @@ class ScriptArguments:
|
||||
Whether to push the dataset to the Hugging Face Hub.
|
||||
repo_id (`str`, *optional*, defaults to `"trl-lib/lm-human-preferences-sentiment"`):
|
||||
Hugging Face repository ID to push the dataset to.
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
dataset_num_proc (`int`, *optional*):
|
||||
Number of workers to use for dataset processing.
|
||||
"""
|
||||
|
||||
|
@ -32,7 +32,7 @@ class ScriptArguments:
|
||||
Whether to push the dataset to the Hugging Face Hub.
|
||||
repo_id (`str`, *optional*, defaults to `"trl-lib/math_shepherd"`):
|
||||
Hugging Face repository ID to push the dataset to.
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
dataset_num_proc (`int`, *optional*):
|
||||
Number of workers to use for dataset processing.
|
||||
"""
|
||||
|
||||
|
@ -30,7 +30,7 @@ class ScriptArguments:
|
||||
Whether to push the dataset to the Hugging Face Hub.
|
||||
repo_id (`str`, *optional*, defaults to `"trl-lib/prm800k"`):
|
||||
Hugging Face repository ID to push the dataset to.
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
dataset_num_proc (`int`, *optional*):
|
||||
Number of workers to use for dataset processing.
|
||||
"""
|
||||
|
||||
|
@ -30,7 +30,7 @@ class ScriptArguments:
|
||||
Whether to push the dataset to the Hugging Face Hub.
|
||||
repo_id (`str`, *optional*, defaults to `"trl-lib/rlaif-v"`):
|
||||
Hugging Face repository ID to push the dataset to.
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
dataset_num_proc (`int`, *optional*):
|
||||
Number of workers to use for dataset processing.
|
||||
"""
|
||||
|
||||
|
@ -30,7 +30,7 @@ class ScriptArguments:
|
||||
Whether to push the dataset to the Hugging Face Hub.
|
||||
repo_id (`str`, *optional*, defaults to `"trl-lib/tldr"`):
|
||||
Hugging Face repository ID to push the dataset to.
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
dataset_num_proc (`int`, *optional*):
|
||||
Number of workers to use for dataset processing.
|
||||
"""
|
||||
|
||||
|
@ -30,7 +30,7 @@ class ScriptArguments:
|
||||
Whether to push the dataset to the Hugging Face Hub.
|
||||
repo_id (`str`, *optional*, defaults to `"trl-lib/tldr-preference"`):
|
||||
Hugging Face repository ID to push the dataset to.
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
dataset_num_proc (`int`, *optional*):
|
||||
Number of workers to use for dataset processing.
|
||||
"""
|
||||
|
||||
|
@ -30,7 +30,7 @@ class ScriptArguments:
|
||||
Whether to push the dataset to the Hugging Face Hub.
|
||||
repo_id (`str`, *optional*, defaults to `"trl-lib/ultrafeedback-prompt"`):
|
||||
Hugging Face repository ID to push the dataset to.
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
dataset_num_proc (`int`, *optional*):
|
||||
Number of workers to use for dataset processing.
|
||||
"""
|
||||
|
||||
|
@ -34,7 +34,7 @@ class ScriptArguments:
|
||||
Whether to push the dataset to the Hugging Face Hub.
|
||||
repo_id (`str`, *optional*, defaults to `"trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness"`):
|
||||
Hugging Face repository ID to push the dataset to.
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
dataset_num_proc (`int`, *optional*):
|
||||
Number of workers to use for dataset processing.
|
||||
"""
|
||||
|
||||
|
694
examples/notebooks/grpo_qwen3_vl.ipynb
Normal file
694
examples/notebooks/grpo_qwen3_vl.ipynb
Normal file
@ -0,0 +1,694 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "-J8iGzLf4rUJ"
|
||||
},
|
||||
"source": [
|
||||
"# GRPO Qwen3-VL with QLoRA using TRL\n",
|
||||
"\n",
|
||||
"[](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_qwen3_vl.ipynb)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"With [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl), you can fine-tune cutting edge vision language models. It comes with support for quantized parameter efficient fine-tuning technique **QLoRA**, so we can use free Colab (T4 GPU) to fine-tune models like [Qwen3-VL](https://huggingface.co/collections/Qwen/qwen3-vl-68d2a7c1b8a8afce4ebd2dbe).\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n",
|
||||
"- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n",
|
||||
"- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)\n",
|
||||
"- [More Qwen3-VL Fine-tuning Examples (including TRL scripts)](https://github.com/QwenLM/Qwen3-VL/tree/main/qwen-vl-finetune/)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "NvrzGRnu48Vz"
|
||||
},
|
||||
"source": [
|
||||
"## Install dependencies\n",
|
||||
"\n",
|
||||
"We'll install **TRL** with the **PEFT** extra, which ensures all main dependencies such as **Transformers** and **PEFT** (a package for parameter-efficient fine-tuning, e.g., LoRA/QLoRA) are included. Additionally, we'll install **trackio** to log and monitor our experiments, and **bitsandbytes** to enable quantization of LLMs, reducing memory consumption for both inference and training."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "8CfZlUevmkg7"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install -Uq \"trl[peft]\" bitsandbytes trackio math_verify"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "gpzI6omi7728"
|
||||
},
|
||||
"source": [
|
||||
"### Log in to Hugging Face\n",
|
||||
"\n",
|
||||
"Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "4Ncx0wYtnYCW"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from huggingface_hub import notebook_login\n",
|
||||
"\n",
|
||||
"notebook_login()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "V_Zylc4t79-n"
|
||||
},
|
||||
"source": [
|
||||
"## Load dataset\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"We'll load the [**lmms-lab/multimodal-open-r1-8k-verified**](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset from the Hugging Face Hub using the `datasets` library.\n",
|
||||
"\n",
|
||||
"This dataset contains maths problems with the image representing the problem, along with the solution in thinking format specially tailored for VLMs. By training our model with this dataset, it'll improve its maths and thinking reasoning.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "TzXogU24F_QR"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"\n",
|
||||
"dataset_id = 'lmms-lab/multimodal-open-r1-8k-verified'\n",
|
||||
"train_dataset = load_dataset(dataset_id, split='train[:5%]')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "gVV7RoRN8zk5"
|
||||
},
|
||||
"source": [
|
||||
"In addition to the `problem` and `image` columns, we also include a custom system prompt to tell the model how we'd like the generation.\n",
|
||||
"\n",
|
||||
"The system prompt is extracted from DeepSeek R1. Refer to [this previous recipe](https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl) for more details.\n",
|
||||
"\n",
|
||||
"We convert the dataset samples into conversation samples, including the system prompt and one image and problem description per sample, since this is how the GRPO trainer expects them.\n",
|
||||
"\n",
|
||||
"We also set `padding_side=\"left\"` to ensure that generated completions during training are concatenated directly after the prompt, which is essential for GRPO to correctly compare token-level probabilities between preferred and rejected responses."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ZT1JfiiTGExB"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import AutoProcessor\n",
|
||||
"\n",
|
||||
"model_name = \"Qwen/Qwen3-VL-4B-Instruct\" # \"Qwen/Qwen3-VL-8B-Instruct\"\n",
|
||||
"processor = AutoProcessor.from_pretrained(model_name, padding_side=\"left\")\n",
|
||||
"\n",
|
||||
"SYSTEM_PROMPT = (\n",
|
||||
" \"You are a helpful AI Assistant that provides well-reasoned and detailed responses. \"\n",
|
||||
" \"You first think about the reasoning process as an internal monologue and then provide the user with the answer. \"\n",
|
||||
" \"Respond in the following format: <think>\\n...\\n</think>\\n<answer>\\n...\\n</answer>\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def make_conversation(example):\n",
|
||||
" conversation = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"system\",\n",
|
||||
" \"content\": [{\"type\": \"text\", \"text\": SYSTEM_PROMPT}],\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\"type\": \"image\", \"image\": example[\"image\"]},\n",
|
||||
" {\"type\": \"text\", \"text\": example[\"problem\"]},\n",
|
||||
" ],\n",
|
||||
" },\n",
|
||||
" ]\n",
|
||||
" prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)\n",
|
||||
" return {\n",
|
||||
" \"prompt\": prompt,\n",
|
||||
" \"image\": example[\"image\"],\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
"train_dataset = train_dataset.map(make_conversation)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "5txAuMAa8ock"
|
||||
},
|
||||
"source": [
|
||||
"Let's review one example to understand the internal structure:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "PDXQd5Jk2Bqe"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "hzSR_56wxKDA"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset = train_dataset.remove_columns(['problem', 'original_question', 'original_answer'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "T9rCkeqDODba"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "YY3uMp909Eqy"
|
||||
},
|
||||
"source": [
|
||||
"## Load model and configure LoRA/QLoRA\n",
|
||||
"\n",
|
||||
"This notebook can be used with two fine-tuning methods. By default, it is set up for **QLoRA**, which includes quantization using `BitsAndBytesConfig`. If you prefer to use standard **LoRA** without quantization, simply comment out the `BitsAndBytesConfig` configuration."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "gt05dgXgm9QR"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import Qwen3VLForConditionalGeneration, BitsAndBytesConfig\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"model = Qwen3VLForConditionalGeneration.from_pretrained(\n",
|
||||
" model_name, dtype=\"auto\",\n",
|
||||
" device_map=\"auto\",\n",
|
||||
" quantization_config=BitsAndBytesConfig(\n",
|
||||
" load_in_4bit=True,\n",
|
||||
" bnb_4bit_use_double_quant=True,\n",
|
||||
" bnb_4bit_quant_type=\"nf4\",\n",
|
||||
" bnb_4bit_compute_dtype=torch.float16\n",
|
||||
" ),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "WZGf-GF09Gsc"
|
||||
},
|
||||
"source": [
|
||||
"The following cell defines LoRA (or QLoRA if needed). When training with LoRA/QLoRA, we use a **base model** (the one selected above) and, instead of modifying its original weights, we fine-tune a **LoRA adapter** — a lightweight layer that enables efficient and memory-friendly training. The **`target_modules`** specify which parts of the model (e.g., attention or projection layers) will be adapted by LoRA during fine-tuning."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "ME1im5gh2LFg"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from peft import LoraConfig\n",
|
||||
"\n",
|
||||
"# You may need to update `target_modules` depending on the architecture of your chosen model.\n",
|
||||
"# For example, different VLMs might have different attention/projection layer names.\n",
|
||||
"peft_config = LoraConfig(\n",
|
||||
" r=8,\n",
|
||||
" lora_alpha=32,\n",
|
||||
" lora_dropout=0.1,\n",
|
||||
" target_modules=[\"q_proj\", \"v_proj\"],\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "mDq4V6dN9MGk"
|
||||
},
|
||||
"source": [
|
||||
"## Train model\n",
|
||||
"\n",
|
||||
"We'll configure **GRPO** using `GRPOConfig`, keeping the parameters minimal so the training fits on a free Colab instance. You can adjust these settings if more resources are available. For full details on all available parameters, check the [TRL GRPOConfig documentation](https://huggingface.co/docs/trl/sft_trainer#trl.GRPOConfig).\n",
|
||||
"\n",
|
||||
"First, we need to define the rewards functions that the training algorithm will use to improve the model. In this case, we'll include two reward functions.\n",
|
||||
"We'll use a format reward that will reward the model when the output includes `<think>` and `<answer>` tags and additionally a length-based reward to discourage overthinking. Both functions have been extracted from [here](https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Dqp3TfUwHUxW"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import re\n",
|
||||
"\n",
|
||||
"def format_reward(completions, **kwargs):\n",
|
||||
" \"\"\"Reward function that checks if the reasoning process is enclosed within <think> and </think> tags, while the final answer is enclosed within <answer> and </answer> tags.\"\"\"\n",
|
||||
" pattern = r\"^<think>\\n.*?\\n</think>\\n<answer>\\n.*?\\n</answer>$\"\n",
|
||||
" matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completions]\n",
|
||||
" return [1.0 if match else 0.0 for match in matches]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "rxNPUp7RBFcz"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from math_verify import LatexExtractionConfig, parse, verify\n",
|
||||
"from latex2sympy2_extended import NormalizationConfig\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def len_reward(completions, solution, **kwargs) -> float:\n",
|
||||
" \"\"\"Compute length-based rewards to discourage overthinking and promote token efficiency.\n",
|
||||
"\n",
|
||||
" Taken from the Kimi 1.5 tech report: https://huggingface.co/papers/2501.12599\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" completions: List of model completions\n",
|
||||
" solution: List of ground truth solutions\n",
|
||||
"\n",
|
||||
" Returns:\n",
|
||||
" List of rewards where:\n",
|
||||
" - For correct answers: reward = 0.5 - (len - min_len)/(max_len - min_len)\n",
|
||||
" - For incorrect answers: reward = min(0, 0.5 - (len - min_len)/(max_len - min_len))\n",
|
||||
" \"\"\"\n",
|
||||
" contents = completions\n",
|
||||
"\n",
|
||||
" # First check correctness of answers\n",
|
||||
" correctness = []\n",
|
||||
" for content, sol in zip(contents, solution):\n",
|
||||
" gold_parsed = parse(\n",
|
||||
" sol,\n",
|
||||
" extraction_mode=\"first_match\",\n",
|
||||
" extraction_config=[LatexExtractionConfig()],\n",
|
||||
" )\n",
|
||||
" if len(gold_parsed) == 0:\n",
|
||||
" # Skip unparseable examples\n",
|
||||
" correctness.append(True) # Treat as correct to avoid penalizing\n",
|
||||
" print(\"Failed to parse gold solution: \", sol)\n",
|
||||
" continue\n",
|
||||
"\n",
|
||||
" answer_parsed = parse(\n",
|
||||
" content,\n",
|
||||
" extraction_config=[\n",
|
||||
" LatexExtractionConfig(\n",
|
||||
" normalization_config=NormalizationConfig(\n",
|
||||
" nits=False,\n",
|
||||
" malformed_operators=False,\n",
|
||||
" basic_latex=True,\n",
|
||||
" equations=True,\n",
|
||||
" boxed=True,\n",
|
||||
" units=True,\n",
|
||||
" ),\n",
|
||||
" boxed_match_priority=0,\n",
|
||||
" try_extract_without_anchor=False,\n",
|
||||
" )\n",
|
||||
" ],\n",
|
||||
" extraction_mode=\"first_match\",\n",
|
||||
" )\n",
|
||||
" correctness.append(verify(answer_parsed, gold_parsed))\n",
|
||||
"\n",
|
||||
" # Calculate lengths\n",
|
||||
" lengths = [len(content) for content in contents]\n",
|
||||
" min_len = min(lengths)\n",
|
||||
" max_len = max(lengths)\n",
|
||||
"\n",
|
||||
" # If all responses have the same length, return zero rewards\n",
|
||||
" if max_len == min_len:\n",
|
||||
" return [0.0] * len(completions)\n",
|
||||
"\n",
|
||||
" rewards = []\n",
|
||||
" for length, is_correct in zip(lengths, correctness):\n",
|
||||
" lambda_val = 0.5 - (length - min_len) / (max_len - min_len)\n",
|
||||
"\n",
|
||||
" if is_correct:\n",
|
||||
" reward = lambda_val\n",
|
||||
" else:\n",
|
||||
" reward = min(0, lambda_val)\n",
|
||||
"\n",
|
||||
" rewards.append(float(reward))\n",
|
||||
"\n",
|
||||
" return rewards\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "9xBL7Rni9LZb"
|
||||
},
|
||||
"source": [
|
||||
"After defining the reward function(s), we can define the `GRPOConfig`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "OEmRM0rIHXQ4"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from trl import GRPOConfig\n",
|
||||
"\n",
|
||||
"output_dir = \"Qwen3-VL-4B-Instruct-trl-grpo\"\n",
|
||||
"\n",
|
||||
"# Configure training arguments using GRPOConfig\n",
|
||||
"training_args = GRPOConfig(\n",
|
||||
" learning_rate=2e-5,\n",
|
||||
" #num_train_epochs=1,\n",
|
||||
" max_steps=100, # Number of dataset passes. For full trainings, use `num_train_epochs` instead\n",
|
||||
"\n",
|
||||
" # Parameters that control the data preprocessing\n",
|
||||
" per_device_train_batch_size=2,\n",
|
||||
" max_completion_length=1024, # default: 256 # Max completion length produced during training\n",
|
||||
" num_generations=2, # 2, # default: 8 # Number of generations produced during trainig for comparison\n",
|
||||
" max_prompt_length=2048, # default: 512 # Max prompt lenght of the input prompt used for generation during training\n",
|
||||
"\n",
|
||||
" fp16=True,\n",
|
||||
"\n",
|
||||
" # Parameters related to reporting and saving\n",
|
||||
" output_dir=output_dir, # Where to save model checkpoints and logs\n",
|
||||
" logging_steps=1, # Log training metrics every N steps\n",
|
||||
" report_to=\"trackio\", # Experiment tracking tool\n",
|
||||
"\n",
|
||||
" # Hub integration\n",
|
||||
" push_to_hub=True,\n",
|
||||
" log_completions=True\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "O0q3myQg927v"
|
||||
},
|
||||
"source": [
|
||||
"Configure the GRPO Trainer. We pass the previously configured `training_args`. We don't use eval dataset to maintain memory usage low but you can configure it."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "z5JxkmS9HqD5",
|
||||
"outputId": "2b39338e-2194-4829-fc54-5e286566fd28"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/usr/local/lib/python3.12/dist-packages/peft/mapping_func.py:73: UserWarning: You are trying to modify a model with PEFT for a second time. If you want to reload the model with a different config, make sure to call `.unload()` before.\n",
|
||||
" warnings.warn(\n",
|
||||
"/usr/local/lib/python3.12/dist-packages/peft/tuners/tuners_utils.py:196: UserWarning: Already found a `peft_config` attribute in the model. This will lead to having multiple adapters in the model. Make sure to know what you are doing!\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from trl import GRPOTrainer\n",
|
||||
"\n",
|
||||
"trainer = GRPOTrainer(\n",
|
||||
" model=model,\n",
|
||||
" reward_funcs=[format_reward, len_reward],\n",
|
||||
" args=training_args,\n",
|
||||
" train_dataset=train_dataset,\n",
|
||||
" peft_config=peft_config,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "kQC7Q5kg95xq"
|
||||
},
|
||||
"source": [
|
||||
"Show memory stats before training"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "naG_7qlYyBP6"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gpu_stats = torch.cuda.get_device_properties(0)\n",
|
||||
"start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
|
||||
"max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n",
|
||||
"\n",
|
||||
"print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n",
|
||||
"print(f\"{start_gpu_memory} GB of memory reserved.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "YazYtLAe97Dc"
|
||||
},
|
||||
"source": [
|
||||
"And train!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "pbJXrhA0ywra"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer_stats = trainer.train()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "SmcYN5yW99IP"
|
||||
},
|
||||
"source": [
|
||||
"Show memory stats after training"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "TrrwP4ADMmrp"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n",
|
||||
"used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n",
|
||||
"used_percentage = round(used_memory / max_memory * 100, 3)\n",
|
||||
"lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n",
|
||||
"\n",
|
||||
"print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n",
|
||||
"print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n",
|
||||
"print(f\"Peak reserved memory = {used_memory} GB.\")\n",
|
||||
"print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n",
|
||||
"print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n",
|
||||
"print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "saarW87Y9_-R"
|
||||
},
|
||||
"source": [
|
||||
"## Saving fine tuned model\n",
|
||||
"\n",
|
||||
"In this step, we save the fine-tuned model both **locally** and to the **Hugging Face Hub** using the credentials from your account."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "71A8aqEyyETA"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"trainer.save_model(output_dir)\n",
|
||||
"trainer.push_to_hub(dataset_name=dataset_id)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "nfqvO0qw-OvS"
|
||||
},
|
||||
"source": [
|
||||
"## Load the fine-tuned model and run inference\n",
|
||||
"\n",
|
||||
"Now, let's test our fine-tuned model by loading the **LoRA/QLoRA adapter** and performing **inference**. We'll start by loading the **base model**, then attach the adapter to it, creating the final fine-tuned model ready for evaluation."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "R8T2uFQVyFeH"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from transformers import Qwen3VLForConditionalGeneration, AutoProcessor\n",
|
||||
"from peft import PeftModel\n",
|
||||
"\n",
|
||||
"base_model = model_name\n",
|
||||
"adapter_model = f\"{output_dir}\" # Replace with your HF username or organization\n",
|
||||
"\n",
|
||||
"model = Qwen3VLForConditionalGeneration.from_pretrained(base_model, dtype=\"auto\", device_map=\"auto\")\n",
|
||||
"model = PeftModel.from_pretrained(model, adapter_model)\n",
|
||||
"\n",
|
||||
"processor = AutoProcessor.from_pretrained(base_model)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "dPBHP0CpLa6K"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "cG5-ccGRyHgo"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from datasets import load_dataset\n",
|
||||
"\n",
|
||||
"dataset_id = 'lmms-lab/multimodal-open-r1-8k-verified'\n",
|
||||
"train_dataset = load_dataset(dataset_id, split='train[:5%]')\n",
|
||||
"\n",
|
||||
"problem = train_dataset[0]['problem']\n",
|
||||
"image = train_dataset[0]['image']\n",
|
||||
"\n",
|
||||
"messages = [\n",
|
||||
" {\n",
|
||||
" \"role\": \"system\", \"content\": [\n",
|
||||
" {\"type\": \"text\", \"text\": SYSTEM_PROMPT}\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": [\n",
|
||||
" {\"type\": \"image\", \"image\": image},\n",
|
||||
" {\"type\": \"text\", \"text\": problem},\n",
|
||||
" ],\n",
|
||||
" },\n",
|
||||
"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "r_70q_8lLgfV"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"messages"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "PX92MjqlyIwB"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"inputs = processor.apply_chat_template(\n",
|
||||
" messages,\n",
|
||||
" tokenize=True,\n",
|
||||
" add_generation_prompt=True,\n",
|
||||
" return_dict=True,\n",
|
||||
" return_tensors=\"pt\"\n",
|
||||
").to(model.device)\n",
|
||||
"\n",
|
||||
"# Inference: Generation of the output\n",
|
||||
"generated_ids = model.generate(**inputs, max_new_tokens=500)\n",
|
||||
"generated_ids_trimmed = [\n",
|
||||
" out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)\n",
|
||||
"]\n",
|
||||
"output_text = processor.batch_decode(\n",
|
||||
" generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False\n",
|
||||
")\n",
|
||||
"print(output_text)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"gpuType": "T4",
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
515
examples/notebooks/sft_qwen_vl.ipynb
Normal file
515
examples/notebooks/sft_qwen_vl.ipynb
Normal file
File diff suppressed because one or more lines are too long
@ -1,7 +0,0 @@
|
||||
# Research projects that use TRL
|
||||
|
||||
Welcome to the research projects folder! Here you can find the scripts used for some research projects that used TRL and maintained by the developers and the community (LM de-toxification, Stack-Llama, etc.). Check out the READMEs in the subfolders for more information!
|
||||
|
||||
- [De-detoxifying language models](https://github.com/huggingface/trl/tree/main/examples/research_projects/toxicity)
|
||||
- [Stack-Llama](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama)
|
||||
- [Stack-Llama-2](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama_2)
|
@ -1,15 +0,0 @@
|
||||
# LayerSkip Training Recipe
|
||||
|
||||
Implements the training recipe as described in the [LayerSkip paper](https://huggingface.co/papers/2404.16710).
|
||||
|
||||
## Run training
|
||||
```
|
||||
cd scripts
|
||||
python layer_skip_sft.py
|
||||
```
|
||||
|
||||
## Run benchmark
|
||||
```
|
||||
cd scripts
|
||||
python benchmark_layer_skip.py
|
||||
```
|
@ -1,77 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import config
|
||||
import torch
|
||||
from torch.utils import benchmark
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
def generate_tokens(model, inputs):
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
do_sample=False,
|
||||
max_new_tokens=64,
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
def generate_tokens_with_assistance(model, inputs, assistant_early_exit):
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
assistant_early_exit=assistant_early_exit,
|
||||
do_sample=False,
|
||||
max_new_tokens=64,
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ckpt = config.hub_model_id
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(ckpt, device_map="auto", dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt)
|
||||
|
||||
prompt = "### Instruction: What are my alarms for the rest of the day?\n ### Response: "
|
||||
|
||||
results = []
|
||||
label = "Generation Times"
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="generate_tokens(model, inputs)",
|
||||
setup="from __main__ import generate_tokens",
|
||||
globals={"model": model, "inputs": inputs},
|
||||
num_threads=torch.get_num_threads(),
|
||||
label=label,
|
||||
sub_label="no layer skip",
|
||||
description="generation",
|
||||
).blocked_autorange()
|
||||
)
|
||||
|
||||
for i in range(1, model.config.num_hidden_layers):
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="generate_tokens_with_assistance(model, inputs, assistant_early_exit)",
|
||||
setup="from __main__ import generate_assistant_tokens",
|
||||
globals={"model": model, "assistant_early_exit": i, "inputs": inputs},
|
||||
num_threads=torch.get_num_threads(),
|
||||
label=label,
|
||||
sub_label=f"layer skip {i}",
|
||||
description="generation",
|
||||
).blocked_autorange()
|
||||
)
|
||||
|
||||
benchmark.Compare(results).print()
|
@ -1,48 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from trl import SFTTrainer
|
||||
|
||||
|
||||
class LayerSkipSFTTrainer(SFTTrainer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.early_exit_layer = 0 # initialize with 0
|
||||
self.always_last_layer = True
|
||||
self.early_exit_loss_scale = 1.0
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
self.early_exit_layer = (
|
||||
self.early_exit_layer % (model.config.num_hidden_layers - 1)
|
||||
) + 1 # rotates between [1, num_hidden_layers-1]
|
||||
bs, seqlen = inputs.input_ids.shape
|
||||
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs, output_hidden_states=True)
|
||||
|
||||
hidden_state = outputs["hidden_states"][self.early_exit_layer].to(model.dtype)
|
||||
if self.early_exit_layer != model.config.num_hidden_layers:
|
||||
hidden_state = model.model.norm(hidden_state)
|
||||
logits = model.lm_head(hidden_state)
|
||||
loss_early = model.loss_function(logits=logits, labels=labels, vocab_size=model.vocab_size)
|
||||
|
||||
if self.always_last_layer:
|
||||
loss_last = model.loss_function(logits=outputs["logits"], labels=labels, vocab_size=model.vocab_size)
|
||||
loss = self.early_exit_loss_scale * loss_early.to(loss_last.device) + 1.0 * loss_last
|
||||
# normalize loss scales
|
||||
loss = loss / (1.0 + self.early_exit_loss_scale)
|
||||
else:
|
||||
loss = loss_early
|
||||
|
||||
return loss
|
@ -1,90 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import config
|
||||
import torch
|
||||
from custom_trainer import LayerSkipSFTTrainer
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from trl import DataCollatorForCompletionOnlyLM, SFTConfig
|
||||
|
||||
|
||||
def formatting_prompts_func(example):
|
||||
text = f"### Instruction: {example['utterance']}\n ### Response: {example['semantic_parse']}"
|
||||
|
||||
# Inject eos_token as a string before tokenization, because they are not always added
|
||||
# See: https://github.com/huggingface/transformers/issues/22794 and
|
||||
# https://github.com/huggingface/trl/issues/1623
|
||||
if tokenizer.eos_token: # usually something like "</s>" for GPT2 or "<|endoftext|>"
|
||||
text += f"{tokenizer.eos_token}"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# load the dataset
|
||||
print("[INFO] loading the dataset...")
|
||||
train_dataset = load_dataset(config.dataset_name, split="train")
|
||||
|
||||
print(f"output_root_dir: {config.output_root_dir}")
|
||||
print(f"hub_model_id: {config.hub_model_id}")
|
||||
|
||||
# load the model and tokenizer
|
||||
print("[INFO] loading the model and tokenizer...")
|
||||
model = AutoModelForCausalLM.from_pretrained(config.model_name, device_map="auto", dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, add_eos_token=True)
|
||||
|
||||
# adding pad and eos tokens if not provided in the tokenizer
|
||||
if tokenizer.pad_token is None:
|
||||
# Add '[PAD]' token if it doesn't exist
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
model.config.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
if tokenizer.eos_token is None or tokenizer.eos_token == tokenizer.bos_token:
|
||||
# Add '[EOS]' token if it doesn't exist
|
||||
tokenizer.add_special_tokens({"eos_token": "[EOS]"})
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
response_template = " ### Response:"
|
||||
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
|
||||
|
||||
args = SFTConfig(
|
||||
do_train=True,
|
||||
bf16=True,
|
||||
max_seq_length=None,
|
||||
per_device_train_batch_size=config.per_device_train_batch_size,
|
||||
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
||||
learning_rate=config.learning_rate,
|
||||
packing=False,
|
||||
num_train_epochs=1.0,
|
||||
report_to="none",
|
||||
push_to_hub=True,
|
||||
hub_model_id=config.hub_model_id,
|
||||
output_dir=config.output_dir,
|
||||
save_steps=1000,
|
||||
save_total_limit=2,
|
||||
)
|
||||
|
||||
trainer = LayerSkipSFTTrainer(
|
||||
model,
|
||||
train_dataset=train_dataset,
|
||||
args=args,
|
||||
formatting_func=formatting_prompts_func,
|
||||
data_collator=collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
@ -1,18 +0,0 @@
|
||||
# RLHF pipeline for the creation of StackLLaMa: a Stack exchange llama-7b model.
|
||||
There were three main steps to the training process:
|
||||
1. Supervised fine-tuning of the base llama-7b model to create llama-7b-se:
|
||||
- `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/supervised_finetuning.py --model_path=<LLAMA_MODEL_PATH> --streaming --learning_rate 1e-5 --max_steps 5000 --output_dir ./llama-se`
|
||||
2. Reward modeling using dialog pairs from the SE dataset using the llama-7b-se to create llama-7b-se-rm:
|
||||
- `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/reward_modeling.py --model_name=<LLAMA_SE_MODEL>`
|
||||
3. RL fine-tuning of llama-7b-se with the llama-7b-se-rm reward model:
|
||||
- `accelerate launch --multi_gpu --num_machines 1 --num_processes 8 examples/research_projects/stack_llama/scripts/rl_training.py --log_with=wandb --model_name=<LLAMA_SE_MODEL> --reward_model_name=<LLAMA_SE_RM_MODEL> --adafactor=False --tokenizer_name=<LLAMA_TOKENIZER> --save_freq=100 --output_max_length=128 --batch_size=8 --gradient_accumulation_steps=8 --batched_gen=True --ppo_epochs=4 --seed=0 --learning_rate=1.4e-5 --early_stopping=True --output_dir=llama-se-rl-finetune-128-8-8-1.4e-5_adam`
|
||||
|
||||
|
||||
LoRA layers were using at all stages to reduce memory requirements.
|
||||
At each stage the peft adapter layers were merged with the base model, using:
|
||||
```shell
|
||||
python examples/research_projects/stack_llama/scripts/merge_peft_adapter.py --adapter_model_name=XXX --base_model_name=YYY --output_name=ZZZ
|
||||
```
|
||||
Note that this script requires `peft>=0.3.0`.
|
||||
|
||||
For access to the base llama-7b model, please see Meta's [release](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) and [request form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform).
|
@ -1,60 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from peft import PeftConfig, PeftModel
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The input names representing the Adapter and Base model fine-tuned with PEFT, and the output name representing the
|
||||
merged model.
|
||||
"""
|
||||
|
||||
adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the adapter name"})
|
||||
base_model_name: Optional[str] = field(default=None, metadata={"help": "the base model name"})
|
||||
output_name: Optional[str] = field(default=None, metadata={"help": "the merged model name"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
assert script_args.adapter_model_name is not None, "please provide the name of the Adapter you would like to merge"
|
||||
assert script_args.base_model_name is not None, "please provide the name of the Base model"
|
||||
assert script_args.output_name is not None, "please provide the output name of the merged model"
|
||||
|
||||
peft_config = PeftConfig.from_pretrained(script_args.adapter_model_name)
|
||||
if peft_config.task_type == "SEQ_CLS":
|
||||
# The sequence classification task is used for the reward model in PPO
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
script_args.base_model_name, num_labels=1, dtype=torch.bfloat16
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(script_args.base_model_name, return_dict=True, dtype=torch.bfloat16)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name)
|
||||
|
||||
# Load the PEFT model
|
||||
model = PeftModel.from_pretrained(model, script_args.adapter_model_name)
|
||||
model.eval()
|
||||
|
||||
model = model.merge_and_unload()
|
||||
|
||||
model.save_pretrained(f"{script_args.output_name}")
|
||||
tokenizer.save_pretrained(f"{script_args.output_name}")
|
||||
model.push_to_hub(f"{script_args.output_name}", use_temp_dir=False)
|
@ -1,321 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig, TaskType, get_peft_model
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
|
||||
# Define and parse arguments.
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train.
|
||||
"""
|
||||
|
||||
local_rank: Optional[int] = field(default=-1, metadata={"help": "Used for multi-gpu"})
|
||||
resume_from_checkpoint: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "If you want to resume training where it left off."},
|
||||
)
|
||||
deepspeed: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Path to deepspeed config if using deepspeed. You may need this if the model that you want to train doesn't fit on a single GPU."
|
||||
},
|
||||
)
|
||||
per_device_train_batch_size: Optional[int] = field(default=4)
|
||||
per_device_eval_batch_size: Optional[int] = field(default=1)
|
||||
gradient_accumulation_steps: Optional[int] = field(default=1)
|
||||
learning_rate: Optional[float] = field(default=2e-5)
|
||||
weight_decay: Optional[float] = field(default=0.001)
|
||||
model_name: Optional[str] = field(
|
||||
default="gpt2",
|
||||
metadata={
|
||||
"help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc."
|
||||
},
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The tokenizer for your model, if left empty will use the default for your model",
|
||||
},
|
||||
)
|
||||
bf16: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU."
|
||||
},
|
||||
)
|
||||
num_train_epochs: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "The number of training epochs for the reward model."},
|
||||
)
|
||||
train_subset: Optional[int] = field(
|
||||
default=100000,
|
||||
metadata={"help": "The size of the subset of the training data to use"},
|
||||
)
|
||||
eval_subset: Optional[int] = field(
|
||||
default=50000,
|
||||
metadata={"help": "The size of the subset of the eval data to use"},
|
||||
)
|
||||
gradient_checkpointing: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Enables gradient checkpointing."},
|
||||
)
|
||||
optim: Optional[str] = field(
|
||||
default="adamw_hf",
|
||||
metadata={"help": "The optimizer to use."},
|
||||
)
|
||||
lr_scheduler_type: Optional[str] = field(
|
||||
default="linear",
|
||||
metadata={"help": "The lr scheduler"},
|
||||
)
|
||||
max_length: Optional[int] = field(default=512)
|
||||
eval_first_step: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to run eval after the first step"},
|
||||
)
|
||||
seed: Optional[int] = field(
|
||||
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
|
||||
)
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
set_seed(script_args.seed)
|
||||
# Load the human stack-exchange-paired dataset for tuning the reward model.
|
||||
train_dataset = load_dataset(
|
||||
"lvwerra/stack-exchange-paired", data_dir="data/reward", split="train", verification_mode="no_checks"
|
||||
)
|
||||
if script_args.train_subset > 0:
|
||||
train_dataset = train_dataset.select(range(script_args.train_subset))
|
||||
eval_dataset = load_dataset(
|
||||
"lvwerra/stack-exchange-paired", data_dir="data/evaluation", split="train", verification_mode="no_checks"
|
||||
)
|
||||
if script_args.eval_subset > 0:
|
||||
eval_dataset = eval_dataset.select(range(script_args.eval_subset))
|
||||
# Define the training args. Needs to be done before the model is loaded if you are using deepspeed.
|
||||
model_name_split = script_args.model_name.split("/")[-1]
|
||||
output_name = (
|
||||
f"{model_name_split}_peft_stack-exchange-paired_rmts__{script_args.train_subset}_{script_args.learning_rate}"
|
||||
)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=output_name,
|
||||
learning_rate=script_args.learning_rate,
|
||||
per_device_train_batch_size=script_args.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
|
||||
num_train_epochs=script_args.num_train_epochs,
|
||||
weight_decay=script_args.weight_decay,
|
||||
eval_strategy="steps",
|
||||
eval_steps=500,
|
||||
save_strategy="steps",
|
||||
save_steps=500,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
gradient_checkpointing=script_args.gradient_checkpointing,
|
||||
deepspeed=script_args.deepspeed,
|
||||
local_rank=script_args.local_rank,
|
||||
remove_unused_columns=False,
|
||||
label_names=[],
|
||||
bf16=script_args.bf16,
|
||||
logging_strategy="steps",
|
||||
optim=script_args.optim,
|
||||
lr_scheduler_type=script_args.lr_scheduler_type,
|
||||
seed=script_args.seed,
|
||||
)
|
||||
|
||||
|
||||
# Load the value-head model and tokenizer.
|
||||
tokenizer_name = script_args.tokenizer_name if script_args.tokenizer_name is not None else script_args.model_name
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.SEQ_CLS,
|
||||
inference_mode=False,
|
||||
r=8,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.1,
|
||||
)
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(script_args.model_name, num_labels=1, dtype=torch.bfloat16)
|
||||
model = get_peft_model(model, peft_config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
# Need to do this for gpt2, because it doesn't have an official pad token.
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model.config.pad_token_id = tokenizer.eos_token_id
|
||||
model.config.use_cache = not script_args.gradient_checkpointing
|
||||
num_proc = 24 # Can adjust to be higher if you have more processors.
|
||||
original_columns = train_dataset.column_names
|
||||
|
||||
|
||||
# Turn the dataset into pairs of post + summaries, where text_j is the preferred question + answer and text_k is the other.
|
||||
# Then tokenize the dataset.
|
||||
def preprocess_function(examples):
|
||||
new_examples = {
|
||||
"input_ids_j": [],
|
||||
"attention_mask_j": [],
|
||||
"input_ids_k": [],
|
||||
"attention_mask_k": [],
|
||||
}
|
||||
for question, response_j, response_k in zip(examples["question"], examples["response_j"], examples["response_k"]):
|
||||
tokenized_j = tokenizer("Question: " + question + "\n\nAnswer: " + response_j, truncation=True)
|
||||
tokenized_k = tokenizer("Question: " + question + "\n\nAnswer: " + response_k, truncation=True)
|
||||
|
||||
new_examples["input_ids_j"].append(tokenized_j["input_ids"])
|
||||
new_examples["attention_mask_j"].append(tokenized_j["attention_mask"])
|
||||
new_examples["input_ids_k"].append(tokenized_k["input_ids"])
|
||||
new_examples["attention_mask_k"].append(tokenized_k["attention_mask"])
|
||||
|
||||
return new_examples
|
||||
|
||||
|
||||
# preprocess the dataset and filter out QAs that are longer than script_args.max_length
|
||||
train_dataset = train_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
remove_columns=original_columns,
|
||||
)
|
||||
train_dataset = train_dataset.filter(
|
||||
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length,
|
||||
num_proc=num_proc,
|
||||
)
|
||||
|
||||
eval_dataset = eval_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
remove_columns=original_columns,
|
||||
)
|
||||
eval_dataset = eval_dataset.filter(
|
||||
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length,
|
||||
num_proc=num_proc,
|
||||
)
|
||||
|
||||
|
||||
# We need to define a special data collator that batches the data in our j vs k format.
|
||||
@dataclass
|
||||
class RewardDataCollatorWithPadding:
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
return_tensors: str = "pt"
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
features_j = []
|
||||
features_k = []
|
||||
for feature in features:
|
||||
features_j.append(
|
||||
{
|
||||
"input_ids": feature["input_ids_j"],
|
||||
"attention_mask": feature["attention_mask_j"],
|
||||
}
|
||||
)
|
||||
features_k.append(
|
||||
{
|
||||
"input_ids": feature["input_ids_k"],
|
||||
"attention_mask": feature["attention_mask_k"],
|
||||
}
|
||||
)
|
||||
batch_j = self.tokenizer.pad(
|
||||
features_j,
|
||||
padding=self.padding,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=self.return_tensors,
|
||||
)
|
||||
batch_k = self.tokenizer.pad(
|
||||
features_k,
|
||||
padding=self.padding,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=self.return_tensors,
|
||||
)
|
||||
batch = {
|
||||
"input_ids_j": batch_j["input_ids"],
|
||||
"attention_mask_j": batch_j["attention_mask"],
|
||||
"input_ids_k": batch_k["input_ids"],
|
||||
"attention_mask_k": batch_k["attention_mask"],
|
||||
"return_loss": True,
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
# Define the metric that we'll use for validation.
|
||||
accuracy = evaluate.load("accuracy")
|
||||
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
predictions, _ = eval_pred
|
||||
# Here, predictions is rewards_j and rewards_k.
|
||||
# We want to see how much of the time rewards_j > rewards_k.
|
||||
predictions = np.argmax(predictions, axis=0)
|
||||
labels = np.zeros(predictions.shape)
|
||||
return accuracy.compute(predictions=predictions, references=labels)
|
||||
|
||||
|
||||
class RewardTrainer(Trainer):
|
||||
# Define how to compute the reward loss. We use the InstructGPT pairwise logloss: https://huggingface.co/papers/2203.02155
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
rewards_j = model(input_ids=inputs["input_ids_j"], attention_mask=inputs["attention_mask_j"])[0]
|
||||
rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0]
|
||||
loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean()
|
||||
if return_outputs:
|
||||
return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k}
|
||||
return loss
|
||||
|
||||
|
||||
# Train the model, woohoo.
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
compute_metrics=compute_metrics,
|
||||
data_collator=RewardDataCollatorWithPadding(tokenizer=tokenizer),
|
||||
)
|
||||
|
||||
|
||||
if script_args.eval_first_step:
|
||||
|
||||
class EvaluateFirstStepCallback(TrainerCallback):
|
||||
def on_step_end(self, args, state, control, **kwargs):
|
||||
if state.global_step == 1:
|
||||
control.should_evaluate = True
|
||||
|
||||
trainer.add_callback(EvaluateFirstStepCallback())
|
||||
|
||||
trainer.train(script_args.resume_from_checkpoint)
|
||||
|
||||
print("Saving last checkpoint of the model")
|
||||
model.save_pretrained(output_name + "_peft_last_checkpoint")
|
@ -1,270 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import Adafactor, AutoTokenizer, HfArgumentParser, pipeline, set_seed
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
|
||||
from trl.core import LengthSampler
|
||||
|
||||
|
||||
tqdm.pandas()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The name of the Casual LM model we wish to fine-tune with PPO
|
||||
"""
|
||||
|
||||
# NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
|
||||
# models like gpt-neo* models are more suitable.
|
||||
model_name: Optional[str] = field(default="", metadata={"help": "the model name"})
|
||||
tokenizer_name: Optional[str] = field(default="", metadata={"help": "the tokenizer name"})
|
||||
reward_model_name: Optional[str] = field(default="", metadata={"help": "the reward model name"})
|
||||
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
|
||||
learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
|
||||
output_max_length: Optional[int] = field(default=128, metadata={"help": "maximum length for generation"})
|
||||
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
|
||||
batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"})
|
||||
ppo_epochs: Optional[int] = field(default=4, metadata={"help": "the number of ppo epochs"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=4, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
adafactor: Optional[bool] = field(default=False, metadata={"help": "whether to use the adafactor optimizer"})
|
||||
early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"})
|
||||
target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"})
|
||||
reward_baseline: Optional[float] = field(
|
||||
default=0.0,
|
||||
metadata={"help": "a baseline value that is subtracted from the reward"},
|
||||
)
|
||||
batched_gen: Optional[bool] = field(default=False, metadata={"help": "whether to use the batched text gen"})
|
||||
save_freq: Optional[int] = field(default=None, metadata={"help": "n steps to save the model"})
|
||||
output_dir: Optional[str] = field(default="runs/", metadata={"help": "n steps to save the model"})
|
||||
seed: Optional[int] = field(default=0, metadata={"help": "the seed"})
|
||||
steps: Optional[int] = field(default=20000, metadata={"help": "number of epochs"})
|
||||
init_kl_coef: Optional[float] = field(
|
||||
default=0.2,
|
||||
metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"},
|
||||
)
|
||||
|
||||
adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"})
|
||||
load_in_8bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 8bit"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0]
|
||||
reward_model_name = script_args.reward_model_name
|
||||
dataset_name = "lvwerra/stack-exchange-paired"
|
||||
config = PPOConfig(
|
||||
steps=script_args.steps,
|
||||
model_name=script_args.model_name,
|
||||
learning_rate=script_args.learning_rate,
|
||||
log_with=script_args.log_with,
|
||||
batch_size=script_args.batch_size,
|
||||
mini_batch_size=script_args.mini_batch_size,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
optimize_device_cache=True,
|
||||
early_stopping=script_args.early_stopping,
|
||||
target_kl=script_args.target_kl,
|
||||
ppo_epochs=script_args.ppo_epochs,
|
||||
seed=script_args.seed,
|
||||
init_kl_coef=script_args.init_kl_coef,
|
||||
adap_kl_ctrl=script_args.adap_kl_ctrl,
|
||||
)
|
||||
|
||||
train_dataset = load_dataset(
|
||||
"lvwerra/stack-exchange-paired", data_dir="data/rl", split="train", verification_mode="no_checks"
|
||||
)
|
||||
train_dataset = train_dataset.select(range(100000))
|
||||
original_columns = train_dataset.column_names
|
||||
|
||||
# We then define the arguments to pass to the sentiment analysis pipeline.
|
||||
# We set `return_all_scores` to True to get the sentiment score for each token.
|
||||
sent_kwargs = {
|
||||
"return_all_scores": True,
|
||||
"function_to_apply": "none",
|
||||
"batch_size": 16,
|
||||
"truncation": True,
|
||||
}
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name)
|
||||
# GPT-2 tokenizer has a pad token, but it is not eos_token by default. We need to set it to eos_token.
|
||||
# only for this model.
|
||||
|
||||
if getattr(tokenizer, "pad_token", None) is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
||||
# Below is an example function to build the dataset. In our case, we use the IMDB dataset
|
||||
# from the `datasets` library. One should customize this function to train the model on
|
||||
# its own dataset.
|
||||
def build_dataset(
|
||||
tokenizer,
|
||||
dataset_name="lvwerra/stack-exchange-paired",
|
||||
):
|
||||
"""
|
||||
Build dataset for training. This builds the dataset from `load_dataset`, one should
|
||||
customize this function to train the model on its own dataset.
|
||||
|
||||
Args:
|
||||
tokenizer (`transformers.PreTrainedTokenizer`):
|
||||
The tokenizer used for the model.
|
||||
dataset_name (`str`):
|
||||
The name of the dataset to be loaded.
|
||||
|
||||
Returns:
|
||||
dataloader (`torch.utils.data.DataLoader`):
|
||||
The dataloader for the dataset.
|
||||
"""
|
||||
|
||||
num_proc = 24
|
||||
|
||||
def preprocess_function(examples):
|
||||
new_examples = {
|
||||
"query": [],
|
||||
"input_ids": [],
|
||||
}
|
||||
for question in examples["question"]:
|
||||
query = "Question: " + question + "\n\nAnswer: "
|
||||
tokenized_question = tokenizer(query, truncation=True)
|
||||
new_examples["query"].append(query)
|
||||
new_examples["input_ids"].append(tokenized_question["input_ids"])
|
||||
|
||||
return new_examples
|
||||
|
||||
ds = train_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
remove_columns=original_columns,
|
||||
)
|
||||
ds = ds.filter(lambda x: len(x["input_ids"]) < 512, batched=False, num_proc=num_proc)
|
||||
|
||||
ds.set_format(type="torch")
|
||||
return ds
|
||||
|
||||
|
||||
# We retrieve the dataloader by calling the `build_dataset` function.
|
||||
dataset = build_dataset(tokenizer)
|
||||
|
||||
|
||||
def collator(data):
|
||||
return {key: [d[key] for d in data] for key in data[0]}
|
||||
|
||||
|
||||
# set seed before initializing value head for deterministic eval
|
||||
set_seed(config.seed)
|
||||
|
||||
# Now let's build the model, the reference model, and the tokenizer.
|
||||
current_device = Accelerator().local_process_index
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
config.model_name,
|
||||
load_in_8bit=script_args.load_in_8bit,
|
||||
device_map={"": current_device},
|
||||
peft_config=lora_config,
|
||||
)
|
||||
|
||||
optimizer = None
|
||||
if script_args.adafactor:
|
||||
optimizer = Adafactor(
|
||||
filter(lambda p: p.requires_grad, model.parameters()),
|
||||
scale_parameter=False,
|
||||
relative_step=False,
|
||||
warmup_init=False,
|
||||
lr=config.learning_rate,
|
||||
)
|
||||
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
|
||||
ppo_trainer = PPOTrainer(
|
||||
config,
|
||||
model,
|
||||
ref_model=None,
|
||||
tokenizer=tokenizer,
|
||||
dataset=dataset,
|
||||
data_collator=collator,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
|
||||
# We then build the sentiment analysis pipeline using our reward model, passing the
|
||||
# model name and the sentiment analysis pipeline arguments. Let's also make sure to
|
||||
# set the device to the same device as the PPOTrainer.
|
||||
device = ppo_trainer.accelerator.device
|
||||
if ppo_trainer.accelerator.num_processes == 1:
|
||||
device = 0 if torch.cuda.is_available() else "cpu" # to avoid a ` pipeline` bug
|
||||
sentiment_pipe = pipeline(
|
||||
"sentiment-analysis",
|
||||
model=reward_model_name,
|
||||
device_map={"": current_device},
|
||||
model_kwargs={"load_in_8bit": script_args.load_in_8bit},
|
||||
tokenizer=tokenizer,
|
||||
return_token_type_ids=False,
|
||||
)
|
||||
|
||||
if sentiment_pipe.model.config.pad_token_id is None:
|
||||
sentiment_pipe.model.config.pad_token_id = sentiment_pipe.model.config.eos_token_id
|
||||
# We then define the arguments to pass to the `generate` function. These arguments
|
||||
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
|
||||
# the `generate` function of the trained model.
|
||||
generation_kwargs = {
|
||||
# "min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.pad_token_id,
|
||||
"eos_token_id": 100_000,
|
||||
}
|
||||
output_min_length = 32
|
||||
output_max_length = script_args.output_max_length
|
||||
output_length_sampler = LengthSampler(output_min_length, output_max_length)
|
||||
|
||||
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
|
||||
if epoch >= config.total_ppo_epochs:
|
||||
break
|
||||
|
||||
question_tensors = batch["input_ids"]
|
||||
|
||||
response_tensors = ppo_trainer.generate(
|
||||
question_tensors,
|
||||
return_prompt=False,
|
||||
length_sampler=output_length_sampler,
|
||||
**generation_kwargs,
|
||||
)
|
||||
batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
|
||||
|
||||
# Compute reward score (using the sentiment analysis pipeline)
|
||||
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
||||
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
|
||||
rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs]
|
||||
|
||||
# Run PPO step
|
||||
stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
|
||||
ppo_trainer.log_stats(stats, batch, rewards)
|
||||
|
||||
if script_args.save_freq and epoch and epoch % script_args.save_freq == 0:
|
||||
ppo_trainer.save_pretrained(script_args.output_dir + f"step_{epoch}")
|
@ -1,222 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, logging, set_seed
|
||||
|
||||
from trl import SFTTrainer
|
||||
from trl.trainer import ConstantLengthDataset
|
||||
|
||||
|
||||
"""
|
||||
Fine-Tune Llama-7b on SE paired dataset
|
||||
"""
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_path", type=str, default="")
|
||||
parser.add_argument("--dataset_name", type=str, default="lvwerra/stack-exchange-paired")
|
||||
parser.add_argument("--subset", type=str, default="data/finetune")
|
||||
parser.add_argument("--split", type=str, default="train")
|
||||
parser.add_argument("--size_valid_set", type=int, default=4000)
|
||||
parser.add_argument("--streaming", action="store_true")
|
||||
parser.add_argument("--shuffle_buffer", type=int, default=5000)
|
||||
|
||||
parser.add_argument("--seq_length", type=int, default=1024)
|
||||
parser.add_argument("--max_steps", type=int, default=10000)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--eos_token_id", type=int, default=49152)
|
||||
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4)
|
||||
parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
|
||||
parser.add_argument("--num_warmup_steps", type=int, default=100)
|
||||
parser.add_argument("--weight_decay", type=float, default=0.05)
|
||||
|
||||
parser.add_argument("--local_rank", type=int, default=0)
|
||||
parser.add_argument("--fp16", action="store_true", default=False)
|
||||
parser.add_argument("--bf16", action="store_true", default=False)
|
||||
parser.add_argument("--gradient_checkpointing", action="store_true", default=False)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--num_workers", type=int, default=None)
|
||||
parser.add_argument("--output_dir", type=str, default="./checkpoints")
|
||||
parser.add_argument("--log_freq", default=1, type=int)
|
||||
parser.add_argument("--eval_freq", default=1000, type=int)
|
||||
parser.add_argument("--save_freq", default=1000, type=int)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
|
||||
"""
|
||||
Estimate the average number of characters per token in the dataset.
|
||||
"""
|
||||
total_characters, total_tokens = 0, 0
|
||||
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
|
||||
text = prepare_sample_text(example)
|
||||
total_characters += len(text)
|
||||
if tokenizer.is_fast:
|
||||
total_tokens += len(tokenizer(text).tokens())
|
||||
else:
|
||||
total_tokens += len(tokenizer.tokenize(text))
|
||||
|
||||
return total_characters / total_tokens
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
"""
|
||||
Prints the number of trainable parameters in the model.
|
||||
"""
|
||||
trainable_params = 0
|
||||
all_param = 0
|
||||
for _, param in model.named_parameters():
|
||||
all_param += param.numel()
|
||||
if param.requires_grad:
|
||||
trainable_params += param.numel()
|
||||
print(
|
||||
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
||||
)
|
||||
|
||||
|
||||
def prepare_sample_text(example):
|
||||
"""Prepare the text from a sample of the dataset."""
|
||||
text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}"
|
||||
return text
|
||||
|
||||
|
||||
def create_datasets(tokenizer, args):
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
data_dir=args.subset,
|
||||
split=args.split,
|
||||
use_auth_token=True,
|
||||
num_proc=args.num_workers if not args.streaming else None,
|
||||
streaming=args.streaming,
|
||||
)
|
||||
if args.streaming:
|
||||
print("Loading the dataset in streaming mode")
|
||||
valid_data = dataset.take(args.size_valid_set)
|
||||
train_data = dataset.skip(args.size_valid_set)
|
||||
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
|
||||
else:
|
||||
dataset = dataset.train_test_split(test_size=0.005, seed=args.seed)
|
||||
train_data = dataset["train"]
|
||||
valid_data = dataset["test"]
|
||||
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
|
||||
|
||||
chars_per_token = chars_token_ratio(train_data, tokenizer)
|
||||
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
|
||||
|
||||
train_dataset = ConstantLengthDataset(
|
||||
tokenizer,
|
||||
train_data,
|
||||
formatting_func=prepare_sample_text,
|
||||
infinite=True,
|
||||
seq_length=args.seq_length,
|
||||
chars_per_token=chars_per_token,
|
||||
)
|
||||
valid_dataset = ConstantLengthDataset(
|
||||
tokenizer,
|
||||
valid_data,
|
||||
formatting_func=prepare_sample_text,
|
||||
infinite=False,
|
||||
seq_length=args.seq_length,
|
||||
chars_per_token=chars_per_token,
|
||||
)
|
||||
return train_dataset, valid_dataset
|
||||
|
||||
|
||||
def run_training(args, train_data, val_data):
|
||||
print("Loading the model")
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
train_data.start_iteration = 0
|
||||
|
||||
print("Starting main loop")
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=args.output_dir,
|
||||
dataloader_drop_last=True,
|
||||
eval_strategy="steps",
|
||||
max_steps=args.max_steps,
|
||||
eval_steps=args.eval_freq,
|
||||
save_steps=args.save_freq,
|
||||
logging_steps=args.log_freq,
|
||||
per_device_train_batch_size=args.batch_size,
|
||||
per_device_eval_batch_size=args.batch_size,
|
||||
learning_rate=args.learning_rate,
|
||||
lr_scheduler_type=args.lr_scheduler_type,
|
||||
warmup_steps=args.num_warmup_steps,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
gradient_checkpointing=args.gradient_checkpointing,
|
||||
fp16=args.fp16,
|
||||
bf16=args.bf16,
|
||||
weight_decay=args.weight_decay,
|
||||
run_name="llama-7b-finetuned",
|
||||
report_to="wandb",
|
||||
ddp_find_unused_parameters=False,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_path, load_in_8bit=True, device_map={"": Accelerator().process_index}
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_data,
|
||||
eval_dataset=val_data,
|
||||
peft_config=lora_config,
|
||||
packing=True,
|
||||
)
|
||||
|
||||
print_trainable_parameters(trainer.model)
|
||||
|
||||
print("Training...")
|
||||
trainer.train()
|
||||
|
||||
print("Saving last checkpoint of the model")
|
||||
trainer.model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/"))
|
||||
|
||||
|
||||
def main(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
||||
train_dataset, eval_dataset = create_datasets(tokenizer, args)
|
||||
run_training(args, train_dataset, eval_dataset)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
assert args.model_path != "", "Please provide the llama model path"
|
||||
|
||||
set_seed(args.seed)
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
logging.set_verbosity_error()
|
||||
|
||||
main(args)
|
@ -1,75 +0,0 @@
|
||||
# DPO pipeline for the creation of StackLlaMa 2: a Stack exchange llama-v2-7b model
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Install all the dependencies in the `requirements.txt`:
|
||||
|
||||
```
|
||||
$ pip install -U -r requirements.txt
|
||||
```
|
||||
|
||||
Since we will use `accelerate` for training, make sure to run:
|
||||
```
|
||||
$ accelerate config
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
There were two main steps to the DPO training process:
|
||||
1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se:
|
||||
|
||||
```
|
||||
accelerate launch examples/research_projects/stack_llama_2/scripts/sft_llama2.py \
|
||||
--output_dir="./sft" \
|
||||
--max_steps=500 \
|
||||
--save_steps=10 \
|
||||
--per_device_train_batch_size=4 \
|
||||
--per_device_eval_batch_size=1 \
|
||||
--gradient_accumulation_steps=2 \
|
||||
--gradient_checkpointing=False \
|
||||
--group_by_length=False \
|
||||
--learning_rate=1e-4 \
|
||||
--lr_scheduler_type="cosine" \
|
||||
--warmup_steps=100 \
|
||||
--weight_decay=0.05 \
|
||||
--optim="paged_adamw_32bit" \
|
||||
--bf16=True \
|
||||
--remove_unused_columns=False \
|
||||
--run_name="sft_llama2" \
|
||||
--report_to="wandb"
|
||||
```
|
||||
1. Run the DPO trainer using the model saved by the previous step:
|
||||
```
|
||||
accelerate launch examples/research_projects/stack_llama_2/scripts/dpo_llama2.py \
|
||||
--model_name_or_path="sft/final_checkpoint" \
|
||||
--output_dir="dpo"
|
||||
```
|
||||
|
||||
|
||||
## Merging the adaptors
|
||||
|
||||
To merge the adaptors into the base model we can use the `merge_peft_adapter.py` helper script that comes with TRL:
|
||||
|
||||
```
|
||||
python examples/research_projects/stack_llama/scripts/merge_peft_adapter.py --base_model_name="meta-llama/Llama-2-7b-hf" --adapter_model_name="dpo/final_checkpoint/" --output_name="stack-llama-2"
|
||||
```
|
||||
|
||||
which will also push the model to your HuggingFace hub account.
|
||||
|
||||
## Running the model
|
||||
|
||||
We can load the DPO-trained LoRA adaptors which were saved by the DPO training step and load them via:
|
||||
|
||||
```py
|
||||
from peft import AutoPeftModelForCausalLM
|
||||
|
||||
|
||||
model = AutoPeftModelForCausalLM.from_pretrained(
|
||||
"dpo/final_checkpoint",
|
||||
low_cpu_mem_usage=True,
|
||||
dtype=torch.float16,
|
||||
load_in_4bit=True,
|
||||
)
|
||||
|
||||
model.generate(...)
|
||||
```
|
@ -1,252 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# 0. imports
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from datasets import Dataset, load_dataset
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed
|
||||
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
|
||||
# Define and parse arguments.
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The arguments for the DPO training script.
|
||||
"""
|
||||
|
||||
# data parameters
|
||||
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
|
||||
|
||||
# training parameters
|
||||
model_name_or_path: Optional[str] = field(
|
||||
default="../sft/results/final_checkpoint",
|
||||
metadata={"help": "the location of the SFT model name or path"},
|
||||
)
|
||||
learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"})
|
||||
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
|
||||
warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"})
|
||||
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
|
||||
optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})
|
||||
|
||||
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "train batch size per device"})
|
||||
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=4, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
gradient_checkpointing: Optional[bool] = field(
|
||||
default=True, metadata={"help": "whether to use gradient checkpointing"}
|
||||
)
|
||||
|
||||
gradient_checkpointing_use_reentrant: Optional[bool] = field(
|
||||
default=False, metadata={"help": "whether to use reentrant for gradient checkpointing"}
|
||||
)
|
||||
|
||||
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
|
||||
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
|
||||
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
|
||||
|
||||
max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"})
|
||||
max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"})
|
||||
max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
|
||||
logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
|
||||
save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"})
|
||||
eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"})
|
||||
|
||||
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
|
||||
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
|
||||
load_in_4bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 4bit"})
|
||||
model_dtype: Optional[str] = field(
|
||||
default="float16", metadata={"help": "model_dtype[float16, bfloat16, float] for loading."}
|
||||
)
|
||||
|
||||
# instrumentation
|
||||
report_to: Optional[str] = field(
|
||||
default="wandb",
|
||||
metadata={
|
||||
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
|
||||
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
|
||||
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
|
||||
},
|
||||
)
|
||||
# debug argument for distributed training
|
||||
ignore_bias_buffers: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
|
||||
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
|
||||
},
|
||||
)
|
||||
seed: Optional[int] = field(
|
||||
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
|
||||
)
|
||||
|
||||
|
||||
def get_stack_exchange_paired(
|
||||
data_dir: str = "data/rl",
|
||||
cache_dir: Optional[str] = None,
|
||||
num_proc=24,
|
||||
) -> Dataset:
|
||||
"""Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format.
|
||||
|
||||
The dataset is converted to a dictionary with the following structure:
|
||||
{
|
||||
'prompt': list[str],
|
||||
'chosen': list[str],
|
||||
'rejected': list[str],
|
||||
}
|
||||
|
||||
Prompts are structured as follows:
|
||||
"Question: " + <prompt> + "\n\nAnswer: "
|
||||
"""
|
||||
dataset = load_dataset(
|
||||
"lvwerra/stack-exchange-paired",
|
||||
split="train",
|
||||
cache_dir=cache_dir,
|
||||
data_dir=data_dir,
|
||||
verification_mode="no_checks",
|
||||
)
|
||||
original_columns = dataset.column_names
|
||||
|
||||
def return_prompt_and_responses(samples) -> dict[str, str]:
|
||||
return {
|
||||
"prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]],
|
||||
"chosen": samples["response_j"],
|
||||
"rejected": samples["response_k"],
|
||||
}
|
||||
|
||||
return dataset.map(
|
||||
return_prompt_and_responses,
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
remove_columns=original_columns,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
set_seed(script_args.seed)
|
||||
|
||||
# 1. load a pretrained model
|
||||
dtype = torch.float
|
||||
if script_args.model_dtype == "float16":
|
||||
dtype = torch.float16
|
||||
elif script_args.model_dtype == "bfloat16":
|
||||
dtype = torch.bfloat16
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
script_args.model_name_or_path,
|
||||
low_cpu_mem_usage=True,
|
||||
dtype=dtype,
|
||||
load_in_4bit=script_args.load_in_4bit,
|
||||
device_map={"": Accelerator().local_process_index},
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
if script_args.ignore_bias_buffers:
|
||||
# torch distributed hack
|
||||
model._ddp_params_and_buffers_to_ignore = [
|
||||
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# 2. Load the Stack-exchange paired dataset
|
||||
train_dataset = get_stack_exchange_paired(data_dir="data/rl")
|
||||
train_dataset = train_dataset.filter(
|
||||
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
|
||||
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length,
|
||||
num_proc=script_args.num_proc,
|
||||
)
|
||||
|
||||
# 3. Load evaluation dataset
|
||||
eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation")
|
||||
eval_dataset = eval_dataset.filter(
|
||||
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
|
||||
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length,
|
||||
num_proc=script_args.num_proc,
|
||||
)
|
||||
|
||||
# 4. initialize training arguments:
|
||||
training_args = DPOConfig(
|
||||
per_device_train_batch_size=script_args.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
|
||||
max_steps=script_args.max_steps,
|
||||
logging_steps=script_args.logging_steps,
|
||||
save_steps=script_args.save_steps,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
gradient_checkpointing=script_args.gradient_checkpointing,
|
||||
learning_rate=script_args.learning_rate,
|
||||
eval_strategy="steps",
|
||||
eval_steps=script_args.eval_steps,
|
||||
output_dir=script_args.output_dir,
|
||||
report_to=script_args.report_to,
|
||||
lr_scheduler_type=script_args.lr_scheduler_type,
|
||||
warmup_steps=script_args.warmup_steps,
|
||||
optim=script_args.optimizer_type,
|
||||
bf16=True,
|
||||
remove_unused_columns=False,
|
||||
run_name="dpo_llama2",
|
||||
gradient_checkpointing_kwargs=dict(use_reentrant=script_args.gradient_checkpointing_use_reentrant),
|
||||
seed=script_args.seed,
|
||||
)
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=script_args.lora_r,
|
||||
lora_alpha=script_args.lora_alpha,
|
||||
lora_dropout=script_args.lora_dropout,
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
"k_proj",
|
||||
"out_proj",
|
||||
"fc_in",
|
||||
"fc_out",
|
||||
"wte",
|
||||
],
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
# 5. initialize the DPO trainer
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
ref_model=None,
|
||||
args=training_args,
|
||||
beta=script_args.beta,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
peft_config=peft_config,
|
||||
max_prompt_length=script_args.max_prompt_length,
|
||||
max_length=script_args.max_length,
|
||||
)
|
||||
|
||||
# 6. train
|
||||
dpo_trainer.train()
|
||||
dpo_trainer.save_model(script_args.output_dir)
|
||||
|
||||
# 7. save
|
||||
output_dir = os.path.join(script_args.output_dir, "final_checkpoint")
|
||||
dpo_trainer.model.save_pretrained(output_dir)
|
@ -1,7 +0,0 @@
|
||||
transformers
|
||||
trl
|
||||
peft
|
||||
accelerate
|
||||
datasets
|
||||
bitsandbytes
|
||||
wandb
|
@ -1,212 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Fine-Tune Llama2-7b on SE paired dataset
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from peft import AutoPeftModelForCausalLM, LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
HfArgumentParser,
|
||||
is_torch_npu_available,
|
||||
is_torch_xpu_available,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
from trl.trainer import ConstantLengthDataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})
|
||||
dataset_name: Optional[str] = field(default="lvwerra/stack-exchange-paired", metadata={"help": "the dataset name"})
|
||||
subset: Optional[str] = field(default="data/finetune", metadata={"help": "the subset to use"})
|
||||
split: Optional[str] = field(default="train", metadata={"help": "the split to use"})
|
||||
size_valid_set: Optional[int] = field(default=4000, metadata={"help": "the size of the validation set"})
|
||||
streaming: Optional[bool] = field(default=True, metadata={"help": "whether to stream the dataset"})
|
||||
shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"})
|
||||
seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"})
|
||||
num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"})
|
||||
use_bnb: Optional[bool] = field(default=True, metadata={"help": "whether to use BitsAndBytes"})
|
||||
|
||||
# LoraConfig
|
||||
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
|
||||
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
|
||||
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
|
||||
|
||||
|
||||
parser = HfArgumentParser((ScriptArguments, SFTConfig))
|
||||
script_args, training_args = parser.parse_args_into_dataclasses()
|
||||
peft_config = LoraConfig(
|
||||
r=script_args.lora_r,
|
||||
lora_alpha=script_args.lora_alpha,
|
||||
lora_dropout=script_args.lora_dropout,
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
if training_args.group_by_length and training_args.packing:
|
||||
raise ValueError("Cannot use both packing and group by length")
|
||||
|
||||
# `gradient_checkpointing` was True by default until `1f3314`, but it's actually not used.
|
||||
# `gradient_checkpointing=True` will cause `Variable._execution_engine.run_backward`.
|
||||
if training_args.gradient_checkpointing:
|
||||
raise ValueError("gradient_checkpointing not supported")
|
||||
|
||||
set_seed(training_args.seed)
|
||||
|
||||
|
||||
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
|
||||
"""
|
||||
Estimate the average number of characters per token in the dataset.
|
||||
"""
|
||||
total_characters, total_tokens = 0, 0
|
||||
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
|
||||
text = prepare_sample_text(example)
|
||||
total_characters += len(text)
|
||||
if tokenizer.is_fast:
|
||||
total_tokens += len(tokenizer(text).tokens())
|
||||
else:
|
||||
total_tokens += len(tokenizer.tokenize(text))
|
||||
|
||||
return total_characters / total_tokens
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
"""
|
||||
Prints the number of trainable parameters in the model.
|
||||
"""
|
||||
trainable_params = 0
|
||||
all_param = 0
|
||||
for _, param in model.named_parameters():
|
||||
all_param += param.numel()
|
||||
if param.requires_grad:
|
||||
trainable_params += param.numel()
|
||||
print(
|
||||
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
||||
)
|
||||
|
||||
|
||||
def prepare_sample_text(example):
|
||||
"""Prepare the text from a sample of the dataset."""
|
||||
text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}"
|
||||
return text
|
||||
|
||||
|
||||
def create_datasets(tokenizer, args, seed=None):
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
data_dir=args.subset,
|
||||
split=args.split,
|
||||
use_auth_token=True,
|
||||
num_proc=args.num_workers if not args.streaming else None,
|
||||
streaming=args.streaming,
|
||||
)
|
||||
if args.streaming:
|
||||
print("Loading the dataset in streaming mode")
|
||||
valid_data = dataset.take(args.size_valid_set)
|
||||
train_data = dataset.skip(args.size_valid_set)
|
||||
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=seed)
|
||||
else:
|
||||
dataset = dataset.train_test_split(test_size=0.005, seed=seed)
|
||||
train_data = dataset["train"]
|
||||
valid_data = dataset["test"]
|
||||
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
|
||||
|
||||
chars_per_token = chars_token_ratio(train_data, tokenizer)
|
||||
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
|
||||
|
||||
train_dataset = ConstantLengthDataset(
|
||||
tokenizer,
|
||||
train_data,
|
||||
formatting_func=prepare_sample_text,
|
||||
infinite=True,
|
||||
seq_length=args.seq_length,
|
||||
chars_per_token=chars_per_token,
|
||||
)
|
||||
valid_dataset = ConstantLengthDataset(
|
||||
tokenizer,
|
||||
valid_data,
|
||||
formatting_func=prepare_sample_text,
|
||||
infinite=False,
|
||||
seq_length=args.seq_length,
|
||||
chars_per_token=chars_per_token,
|
||||
)
|
||||
return train_dataset, valid_dataset
|
||||
|
||||
|
||||
bnb_config = None
|
||||
if script_args.use_bnb:
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
script_args.model_name,
|
||||
quantization_config=bnb_config,
|
||||
device_map={"": Accelerator().local_process_index},
|
||||
trust_remote_code=True,
|
||||
use_auth_token=True,
|
||||
)
|
||||
base_model.config.use_cache = False
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training
|
||||
|
||||
train_dataset, eval_dataset = create_datasets(tokenizer, script_args, seed=training_args.seed)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=base_model,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
peft_config=peft_config,
|
||||
max_length=None,
|
||||
formatting_func=prepare_sample_text,
|
||||
processing_class=tokenizer,
|
||||
args=training_args,
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(training_args.output_dir)
|
||||
|
||||
output_dir = os.path.join(training_args.output_dir, "final_checkpoint")
|
||||
trainer.model.save_pretrained(output_dir)
|
||||
|
||||
# Free memory for merging weights
|
||||
del base_model
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", dtype=torch.bfloat16)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
output_merged_dir = os.path.join(training_args.output_dir, "final_merged_checkpoint")
|
||||
model.save_pretrained(output_merged_dir, safe_serialization=True)
|
@ -1,7 +0,0 @@
|
||||
# De-detoxifying language models
|
||||
|
||||
To run this code, do the following:
|
||||
|
||||
```shell
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file {CONFIG} examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py --log_with wandb
|
||||
```
|
@ -1,146 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, is_torch_npu_available, is_torch_xpu_available
|
||||
|
||||
|
||||
toxicity = evaluate.load("ybelkada/toxicity", "DaNLP/da-electra-hatespeech-detection", module_type="measurement")
|
||||
ds = load_dataset("OxAISH-AL-LLM/wiki_toxic", split="test")
|
||||
|
||||
parser = argparse.ArgumentParser(description="Evaluate de-toxified models")
|
||||
parser.add_argument("--model_type", default="all", type=str, help="Relative path to the source model folder")
|
||||
parser.add_argument("--output_file", default="toxicity.csv", type=str, help="Relative path to the source model folder")
|
||||
parser.add_argument("--batch_size", default=64, type=int, help="Batch size")
|
||||
parser.add_argument("--num_samples", default=400, type=int, help="Number of samples")
|
||||
parser.add_argument("--context_length", default=2000, type=int, help="Number of samples")
|
||||
parser.add_argument("--max_new_tokens", default=30, type=int, help="Max new tokens for generation")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if args.model_type == "all":
|
||||
MODELS_TO_TEST = [
|
||||
"ybelkada/gpt-neo-125m-detox",
|
||||
"EleutherAI/gpt-neo-125M",
|
||||
"EleutherAI/gpt-neo-2.7B",
|
||||
"ybelkada/gpt-neo-2.7B-detox",
|
||||
"ybelkada/gpt-j-6b-sharded-bf16",
|
||||
"ybelkada/gpt-j-6b-detoxs",
|
||||
]
|
||||
elif args.model_type == "gpt-neo":
|
||||
MODELS_TO_TEST = [
|
||||
"ybelkada/gpt-neo-125m-detox",
|
||||
"EleutherAI/gpt-neo-125M",
|
||||
"EleutherAI/gpt-neo-2.7B",
|
||||
"ybelkada/gpt-neo-2.7B-detox",
|
||||
]
|
||||
elif args.model_type == "gpt-j":
|
||||
MODELS_TO_TEST = [
|
||||
"ybelkada/gpt-j-6b-sharded-bf16",
|
||||
"ybelkada/gpt-j-6b-detox",
|
||||
]
|
||||
else:
|
||||
MODELS_TO_TEST = [args.model_type]
|
||||
NUM_SAMPLES = args.num_samples
|
||||
BATCH_SIZE = args.batch_size
|
||||
output_file = args.output_file
|
||||
max_new_tokens = args.max_new_tokens
|
||||
context_length = args.context_length
|
||||
if is_torch_xpu_available():
|
||||
device = torch.xpu.current_device()
|
||||
elif is_torch_npu_available():
|
||||
device = torch.npu.current_device()
|
||||
else:
|
||||
device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# consider only toxic prompts
|
||||
ds = ds.filter(lambda x: x["label"] == 1)
|
||||
|
||||
toxicities = {}
|
||||
|
||||
# open a csv file
|
||||
file = open(f"{output_file}", "w", newline="")
|
||||
writer = csv.writer(file)
|
||||
# add first rows
|
||||
writer.writerow(["model_id", "mean_toxicity", "std_toxicity"])
|
||||
|
||||
|
||||
for model_id in tqdm(MODELS_TO_TEST):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, device_map={"": device}, dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "left"
|
||||
input_texts = []
|
||||
|
||||
for i, example in enumerate(ds):
|
||||
# set seed
|
||||
torch.manual_seed(42)
|
||||
|
||||
input_text = example["comment_text"]
|
||||
input_texts.append(input_text[:2000])
|
||||
|
||||
if i > NUM_SAMPLES:
|
||||
break
|
||||
|
||||
if (i + 1) % BATCH_SIZE == 0:
|
||||
inputs = tokenizer(input_texts, return_tensors="pt", padding=True).to(device)
|
||||
inputs.input_ids = inputs.input_ids[:context_length]
|
||||
inputs.attention_mask = inputs.attention_mask[:context_length]
|
||||
outputs = model.generate(**inputs, do_sample=True, max_new_tokens=max_new_tokens, use_cache=True)
|
||||
generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
generated_texts = [
|
||||
generated_text.replace(input_texts[i], "") for i, generated_text in enumerate(generated_texts)
|
||||
]
|
||||
toxicity_score = toxicity.compute(predictions=generated_texts)
|
||||
input_texts = []
|
||||
|
||||
if model_id not in toxicities:
|
||||
toxicities[model_id] = []
|
||||
toxicities[model_id].extend(toxicity_score["toxicity"])
|
||||
|
||||
# last batch
|
||||
inputs = tokenizer(input_texts, return_tensors="pt", padding=True).to(device)
|
||||
outputs = model.generate(**inputs, do_sample=True, max_new_tokens=30)
|
||||
generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
generated_texts = [generated_text.replace(input_texts[i], "") for i, generated_text in enumerate(generated_texts)]
|
||||
toxicity_score = toxicity.compute(predictions=generated_texts)
|
||||
toxicities[model_id].extend(toxicity_score["toxicity"])
|
||||
|
||||
# compute mean & std using np
|
||||
mean = np.mean(toxicities[model_id])
|
||||
std = np.std(toxicities[model_id])
|
||||
|
||||
# save to file
|
||||
writer.writerow([model_id, mean, std])
|
||||
|
||||
# print
|
||||
print(f"Model: {model_id} - Mean: {mean} - Std: {std}")
|
||||
|
||||
model = None
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# close file
|
||||
file.close()
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user