mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
253 Commits
v0.15.2
...
grpo-log-e
Author | SHA1 | Date | |
---|---|---|---|
9c46200672 | |||
ac6dc65fdd | |||
d1174adc5b | |||
cd838417e4 | |||
c7e3f096a5 | |||
5c08897570 | |||
3ef9faf257 | |||
9ac614fb08 | |||
29401e790e | |||
31bf3f9244 | |||
7f32792c07 | |||
3d8727918a | |||
65245f6be8 | |||
a528b9c465 | |||
e0dd525021 | |||
64aa06499b | |||
be93a0c30c | |||
f9fbd91ea9 | |||
54d4f6b13a | |||
05bc43e960 | |||
d3dc8ff654 | |||
21738c3732 | |||
eab175d434 | |||
4da4dc9117 | |||
6b3a02385d | |||
abbbb93d6a | |||
cafa663c84 | |||
fd04a5461a | |||
56e5766205 | |||
89d44caece | |||
adfa7fd59a | |||
cf5183db7f | |||
1954c02d86 | |||
45f4c58832 | |||
cc044e35b2 | |||
999acd53ec | |||
8606b1ad09 | |||
a673da5773 | |||
00b8e311aa | |||
c163cf5081 | |||
bc9c019c43 | |||
18596cf232 | |||
280d35301b | |||
13fa8402a3 | |||
09b669fbf7 | |||
01d0be15cb | |||
3a42af1c78 | |||
aaf39604ba | |||
2bf48478e8 | |||
a8cfca6d01 | |||
1bca49515e | |||
39e96394a9 | |||
8e6ed93dfd | |||
29c5e05e3a | |||
a9b27f82d6 | |||
cd6b3de356 | |||
36685c8bba | |||
89556c8cbf | |||
f3e8c23044 | |||
9ee6c3aa56 | |||
ef05331752 | |||
05e2ba6e01 | |||
1b4f189e09 | |||
1faa7f9b36 | |||
66e6eab9bb | |||
27af0aaf4a | |||
b4ffda769e | |||
0dad4eb7ca | |||
c82f626f94 | |||
33add19161 | |||
294f35bf3c | |||
9874b3aa04 | |||
1e61f6cc5a | |||
27adc30162 | |||
df737f99c1 | |||
c04e84c454 | |||
d625c5533a | |||
6cdd24a360 | |||
8b38570258 | |||
95b1a9f612 | |||
5c1511423b | |||
5e2e9cb442 | |||
227df8271e | |||
ae1581474e | |||
47b9515fb1 | |||
c4891dcfee | |||
055cee255a | |||
73a2fb0554 | |||
982ba08092 | |||
e03e7acc5c | |||
9df19e8a75 | |||
1d7b8c4f70 | |||
7e170612a4 | |||
559724ee2c | |||
a5a46725c8 | |||
b6bcafb8bb | |||
4bfb8eb0d1 | |||
4d66bad208 | |||
e90117b3e1 | |||
31b54a6237 | |||
17e33cdaa0 | |||
5a0cebc786 | |||
65308cfd84 | |||
1755e03f6f | |||
793735a698 | |||
e70a0efeca | |||
7eaca76ed1 | |||
657f9ce6ee | |||
485852c942 | |||
9f3702f6be | |||
e751a16df5 | |||
582bc5684b | |||
c5ba70d4fc | |||
5b586da3cc | |||
488025cd87 | |||
2594cb39de | |||
2fe2337067 | |||
f6b4d6e569 | |||
26d86757a7 | |||
9771f259ed | |||
7bdedd4075 | |||
a069a2f19c | |||
ea45f513f3 | |||
a91023990a | |||
1a9387b922 | |||
1884ff1bb8 | |||
bfe2075608 | |||
6067e2a669 | |||
dee37342a8 | |||
8037f18cdf | |||
a0a53171cc | |||
23a635ed61 | |||
9b38b0b5ee | |||
0f26049ea2 | |||
7511aa4e36 | |||
f713f614e9 | |||
a34987956c | |||
0f88c179e3 | |||
beda4328cc | |||
07cfe1677e | |||
9f7755d8ed | |||
4e3f569eb8 | |||
979fda1548 | |||
f6fb6a88a9 | |||
6cbf8fbc9f | |||
5cb390cd30 | |||
b3c391e628 | |||
1b85ca6147 | |||
e7a1290b0a | |||
3822edd67b | |||
230455cab0 | |||
08f014d559 | |||
10740333bd | |||
058a733c30 | |||
3f193972d8 | |||
b575596b89 | |||
118c43f0e0 | |||
40b1c33edf | |||
1a2e74cc5a | |||
80f7dcb16d | |||
4404ccd24a | |||
39f77ca2d8 | |||
52085dd96b | |||
c7a1c95017 | |||
3003058418 | |||
a759cee2e0 | |||
0a3bad44f0 | |||
bb5b96a823 | |||
8466c7273e | |||
a871ec8e91 | |||
f7572221db | |||
8ec2e42833 | |||
218d493d11 | |||
1a9f78eb3a | |||
a10978ebdf | |||
87fbb831d3 | |||
52f39d6a24 | |||
931f7a14d2 | |||
9951105a90 | |||
5a6e23aac9 | |||
d9104c8b0d | |||
d5a5840307 | |||
f3cbd41e2c | |||
d41a32f619 | |||
fc4dae256d | |||
e4e5671e80 | |||
7c76f103da | |||
aad18ef52a | |||
b55d9f0412 | |||
4871c82b0c | |||
fd9e5a7cab | |||
5463e49a55 | |||
22759c8208 | |||
2ee6fd369f | |||
844a9c665f | |||
04f6597377 | |||
e3244d2d09 | |||
6a02c69789 | |||
a1c58aa42a | |||
3f0695a4ca | |||
a72b50b772 | |||
ea1d9be2a7 | |||
402187baab | |||
5858ceab7e | |||
7442d42c21 | |||
98de0e7c62 | |||
491921c1a4 | |||
ad6a35bdd5 | |||
7bc9858a8f | |||
b882f57d93 | |||
ac7bde5832 | |||
3d94e4e25c | |||
1a303cca8e | |||
ac327d5e84 | |||
c0854c32c9 | |||
aa18ecfde7 | |||
6849c050b9 | |||
27a6f2201b | |||
f074dcdc86 | |||
0caff61600 | |||
019fc6dbaa | |||
69ad852e56 | |||
45ccdefac4 | |||
703484a8c2 | |||
9b76d5f2e9 | |||
cbe0681ba1 | |||
4e0cf01aef | |||
5c05913196 | |||
caba04da42 | |||
be5a088337 | |||
38861475e6 | |||
f69707dab4 | |||
76f00fc394 | |||
8453017622 | |||
3608709529 | |||
21f0055893 | |||
013d360b8f | |||
e5ae703d35 | |||
a92e00e810 | |||
9b3c5bf64f | |||
15fec312d5 | |||
be1e34003c | |||
6aaf379a82 | |||
49adf74833 | |||
6c54f023ae | |||
963243a7d1 | |||
aafd8cbea5 | |||
822653824b | |||
ba036576d4 | |||
293b620950 | |||
ae3bd0d07a | |||
6d9fc11fd6 | |||
ffcb9f4aee |
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -21,8 +21,7 @@ Fixes # (issue)
|
||||
Pull Request section?
|
||||
- [ ] Was this discussed/approved via a GitHub issue? Please add a link
|
||||
to it if that's the case.
|
||||
- [ ] Did you make sure to update the documentation with your changes? Here are the
|
||||
[documentation guidelines](https://github.com/huggingface/trl/tree/main/docs).
|
||||
- [ ] Did you make sure to update the documentation with your changes?
|
||||
- [ ] Did you write any new necessary tests?
|
||||
|
||||
|
||||
|
19
.github/codeql/custom-queries.qls
vendored
Normal file
19
.github/codeql/custom-queries.qls
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
import codeql
|
||||
|
||||
from WorkflowString interpolation, Workflow workflow
|
||||
where
|
||||
interpolation.getStringValue().matches("${{ github.event.issue.title }}") or
|
||||
interpolation.getStringValue().matches("${{ github.event.issue.body }}") or
|
||||
interpolation.getStringValue().matches("${{ github.event.pull_request.title }}") or
|
||||
interpolation.getStringValue().matches("${{ github.event.pull_request.body }}") or
|
||||
interpolation.getStringValue().matches("${{ github.event.review.body }}") or
|
||||
interpolation.getStringValue().matches("${{ github.event.comment.body }}") or
|
||||
interpolation.getStringValue().matches("${{ github.event.inputs.* }}") or
|
||||
interpolation.getStringValue().matches("${{ github.event.head_commit.message }}")
|
||||
interpolation.getStringValue().matches("${{ github.event.* }}") and
|
||||
(
|
||||
step.getKey() = "run" or // Injection in run
|
||||
step.getKey() = "env" or // Injection via env
|
||||
step.getKey() = "with" // Injection via with
|
||||
)
|
||||
select workflow, "🚨 Do not use directly as input of action"
|
1
.github/workflows/build_pr_documentation.yml
vendored
1
.github/workflows/build_pr_documentation.yml
vendored
@ -9,6 +9,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
if: github.event.pull_request.draft == false
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
|
||||
with:
|
||||
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||
|
26
.github/workflows/codeQL.yml
vendored
Normal file
26
.github/workflows/codeQL.yml
vendored
Normal file
@ -0,0 +1,26 @@
|
||||
name: "CodeQL Analysis - Workflows"
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
name: "Analyze GitHub Workflows"
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
security-events: write
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
steps:
|
||||
- name: "Checkout repository"
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: "Initialize CodeQL"
|
||||
uses: github/codeql-action/init@v2
|
||||
with:
|
||||
languages: "yaml"
|
||||
queries: +security-and-quality, ./.github/codeql/custom-queries.qls
|
||||
|
||||
- name: "Perform CodeQL Analysis"
|
||||
uses: github/codeql-action/analyze@v2
|
127
.github/workflows/pr_style_bot.yml
vendored
Normal file
127
.github/workflows/pr_style_bot.yml
vendored
Normal file
@ -0,0 +1,127 @@
|
||||
name: PR Style Bot
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
run-style-bot:
|
||||
if: >
|
||||
contains(github.event.comment.body, '@bot /style') &&
|
||||
github.event.issue.pull_request != null
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Extract PR details
|
||||
id: pr_info
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const prNumber = context.payload.issue.number;
|
||||
const { data: pr } = await github.rest.pulls.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber
|
||||
});
|
||||
|
||||
// We capture both the branch ref and the "full_name" of the head repo
|
||||
// so that we can check out the correct repository & branch (including forks).
|
||||
core.setOutput("prNumber", prNumber);
|
||||
core.setOutput("headRef", pr.head.ref);
|
||||
core.setOutput("headRepoFullName", pr.head.repo.full_name);
|
||||
|
||||
- name: Check out PR branch
|
||||
uses: actions/checkout@v3
|
||||
env:
|
||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
||||
with:
|
||||
# Instead of checking out the base repo, use the contributor's repo name
|
||||
repository: ${{ env.HEADREPOFULLNAME }}
|
||||
ref: ${{ env.HEADREF }}
|
||||
# You may need fetch-depth: 0 for being able to push
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Debug
|
||||
env:
|
||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
||||
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
|
||||
run: |
|
||||
echo "PR number: ${{ env.PRNUMBER }}"
|
||||
echo "Head Ref: ${{ env.HEADREF }}"
|
||||
echo "Head Repo Full Name: ${{ env.HEADREPOFULLNAME }}"
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install ruff pre-commit
|
||||
|
||||
- name: Download Makefile from main branch
|
||||
run: |
|
||||
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/trl/main/Makefile
|
||||
|
||||
- name: Compare Makefiles
|
||||
run: |
|
||||
if ! diff -q main_Makefile Makefile; then
|
||||
echo "Error: The Makefile has changed. Please ensure it matches the main branch."
|
||||
exit 1
|
||||
fi
|
||||
echo "No changes in Makefile. Proceeding..."
|
||||
rm -rf main_Makefile
|
||||
|
||||
- name: Run make style and make quality
|
||||
run: |
|
||||
make precommit || true
|
||||
|
||||
- name: Commit and push changes
|
||||
id: commit_and_push
|
||||
env:
|
||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
||||
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
echo "HEADREPOFULLNAME: ${{ env.HEADREPOFULLNAME }}, HEADREF: ${{ env.HEADREF }}"
|
||||
# Configure git with the Actions bot user
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
# Make sure your 'origin' remote is set to the contributor's fork
|
||||
git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${{ env.HEADREPOFULLNAME }}.git"
|
||||
|
||||
# If there are changes after running style/quality, commit them
|
||||
if [ -n "$(git status --porcelain)" ]; then
|
||||
git add .
|
||||
git commit -m "Apply style fixes"
|
||||
# Push to the original contributor's forked branch
|
||||
git push origin HEAD:${{ env.HEADREF }}
|
||||
echo "changes_pushed=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "No changes to commit."
|
||||
echo "changes_pushed=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Comment on PR with workflow run link
|
||||
if: steps.commit_and_push.outputs.changes_pushed == 'true'
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const prNumber = parseInt(process.env.prNumber, 10);
|
||||
const runUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: `Style fixes have been applied. [View the workflow run here](${runUrl}).`
|
||||
});
|
||||
env:
|
||||
prNumber: ${{ steps.pr_info.outputs.prNumber }}
|
182
.github/workflows/tests.yml
vendored
182
.github/workflows/tests.yml
vendored
@ -21,11 +21,9 @@ jobs:
|
||||
check_code_quality:
|
||||
name: Check code quality
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.pull_request.draft == false
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: recursive
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
@ -38,126 +36,216 @@ jobs:
|
||||
name: Tests
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.9', '3.10', '3.11', '3.12']
|
||||
os: ['ubuntu-latest', 'windows-latest']
|
||||
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
|
||||
fail-fast: false
|
||||
runs-on: ${{ matrix.os }}
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
if: github.event.pull_request.draft == false
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
setup.py
|
||||
requirements.txt
|
||||
|
||||
- name: Install Make and Git
|
||||
run: |
|
||||
apt-get update && apt-get install -y make git curl
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Create Python virtual environment
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install --upgrade setuptools wheel
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install ".[dev]"
|
||||
source .venv/bin/activate
|
||||
uv pip install ".[dev]"
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
make test
|
||||
|
||||
- name: Post to Slack
|
||||
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: Results with Python ${{ matrix.python-version }} on ${{ matrix.os }} with lastest dependencies
|
||||
title: Results with Python ${{ matrix.python-version }} and latest dependencies
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
tests_dev:
|
||||
name: Tests with dev dependencies
|
||||
runs-on: 'ubuntu-latest'
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
if: github.event.pull_request.draft == false
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
setup.py
|
||||
requirements.txt
|
||||
|
||||
- name: Install Make and Git
|
||||
run: |
|
||||
apt-get update && apt-get install -y make git curl
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Create Python virtual environment
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install --upgrade setuptools wheel
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -U git+https://github.com/huggingface/accelerate.git
|
||||
python -m pip install -U git+https://github.com/huggingface/datasets.git
|
||||
python -m pip install -U git+https://github.com/huggingface/transformers.git
|
||||
python -m pip install ".[dev]"
|
||||
source .venv/bin/activate
|
||||
uv pip install -U git+https://github.com/huggingface/accelerate.git
|
||||
uv pip install -U git+https://github.com/huggingface/datasets.git
|
||||
uv pip install -U git+https://github.com/huggingface/transformers.git
|
||||
uv pip install ".[dev]"
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
make test
|
||||
|
||||
- name: Post to Slack
|
||||
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: Results with Python 3.12 on ubuntu-latest with dev dependencies
|
||||
title: Results with Python 3.12 and dev dependencies
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
tests_wo_optional_deps:
|
||||
name: Tests without optional dependencies
|
||||
runs-on: 'ubuntu-latest'
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
if: github.event.pull_request.draft == false
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
setup.py
|
||||
requirements.txt
|
||||
|
||||
- name: Install Make and Git
|
||||
run: |
|
||||
apt-get update && apt-get install -y make git curl
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Create Python virtual environment
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install --upgrade setuptools wheel
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install ".[test]"
|
||||
source .venv/bin/activate
|
||||
uv pip install ".[test]"
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
make test
|
||||
|
||||
- name: Post to Slack
|
||||
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: Results with Python 3.12 on ubuntu-latest without optional dependencies
|
||||
title: Results with Python 3.12 without optional dependencies
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
tests_min_versions:
|
||||
name: Tests with minimum versions
|
||||
runs-on: 'ubuntu-latest'
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
if: github.event.pull_request.draft == false
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
setup.py
|
||||
requirements.txt
|
||||
|
||||
- name: Install Make and Git
|
||||
run: |
|
||||
apt-get update && apt-get install -y make git curl
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Create Python virtual environment
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install --upgrade setuptools wheel
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install accelerate==0.34.0
|
||||
python -m pip install datasets==2.21.0
|
||||
python -m pip install transformers==4.46.0
|
||||
python -m pip install ".[dev]"
|
||||
source .venv/bin/activate
|
||||
uv pip install accelerate==0.34.0
|
||||
uv pip install datasets==3.0.0
|
||||
uv pip install transformers==4.46.0
|
||||
uv pip install ".[dev]"
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
make test
|
||||
|
||||
- name: Post to Slack
|
||||
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: Results with Python 3.12 on ubuntu-latest with minimum versions
|
||||
title: Results with Python 3.12 and minimum dependencies versions
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
45
.github/workflows/tests_latest.yml
vendored
45
.github/workflows/tests_latest.yml
vendored
@ -13,33 +13,54 @@ env:
|
||||
jobs:
|
||||
tests:
|
||||
name: Tests latest TRL release with dev dependencies
|
||||
runs-on: 'ubuntu-latest'
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
with: { ref: v0.15-release }
|
||||
with: { ref: v0.17-release }
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
setup.py
|
||||
requirements.txt
|
||||
|
||||
- name: Install Make and Git
|
||||
run: |
|
||||
apt-get update && apt-get install -y make git curl
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Create Python virtual environment
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install --upgrade setuptools wheel
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -U git+https://github.com/huggingface/accelerate.git
|
||||
python -m pip install -U git+https://github.com/huggingface/datasets.git
|
||||
python -m pip install -U git+https://github.com/huggingface/transformers.git
|
||||
python -m pip install ".[dev]"
|
||||
source .venv/bin/activate
|
||||
uv pip install -U git+https://github.com/huggingface/accelerate.git
|
||||
uv pip install -U git+https://github.com/huggingface/datasets.git
|
||||
uv pip install -U git+https://github.com/huggingface/transformers.git
|
||||
uv pip install ".[dev]"
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
make test
|
||||
|
||||
- name: Post to Slack
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: Results of latest TRL with Python 3.12 on ubuntu-latest with dev dependencies
|
||||
title: Results of latest TRL with Python 3.12 and dev dependencies
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
5
.github/workflows/trufflehog.yml
vendored
5
.github/workflows/trufflehog.yml
vendored
@ -12,4 +12,7 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@main
|
||||
uses: trufflesecurity/trufflehog@853e1e8d249fd1e29d0fcc7280d29b03df3d643d
|
||||
with:
|
||||
# exclude buggy postgres detector that is causing false positives and not relevant to our codebase
|
||||
extra_args: --results=verified,unknown --exclude-detectors=postgres
|
||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -142,4 +142,4 @@ checklink/cookies.txt
|
||||
# wandb files
|
||||
nbs/wandb/
|
||||
examples/notebooks/wandb/
|
||||
wandb/
|
||||
wandb/
|
@ -1,8 +1,8 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.6.3
|
||||
rev: v0.11.10
|
||||
hooks:
|
||||
- id: ruff
|
||||
- id: ruff-check
|
||||
types_or: [ python, pyi ]
|
||||
args: [ --fix ]
|
||||
- id: ruff-format
|
||||
|
@ -31,4 +31,4 @@ keywords:
|
||||
- pytorch
|
||||
- transformers
|
||||
license: Apache-2.0
|
||||
version: 0.15
|
||||
version: 0.17
|
||||
|
193
CONTRIBUTING.md
193
CONTRIBUTING.md
@ -171,8 +171,7 @@ Follow these steps to start contributing:
|
||||
$ pytest tests/<TEST_TO_RUN>.py
|
||||
```
|
||||
|
||||
> For the following commands leveraging the `make` utility, we recommend using the WSL system when running on
|
||||
> Windows. More information [here](https://docs.microsoft.com/en-us/windows/wsl/about).
|
||||
> For the following commands leveraging the `make` utility.
|
||||
|
||||
You can also run the full suite with the following command.
|
||||
|
||||
@ -457,3 +456,193 @@ Warnings play a critical role in guiding users toward resolving potential issues
|
||||
```
|
||||
|
||||
By following this classification, you ensure that warnings, information, and exceptions are used appropriately, providing clear guidance to the user without cluttering the system with unnecessary messages.
|
||||
|
||||
|
||||
## Making a release
|
||||
|
||||
> [!NOTE]
|
||||
> VERSION needs to be formatted following the `v{major}.{minor}.{patch}` convention. We need to follow this convention to be able to retrieve versioned scripts.
|
||||
|
||||
To create the package for PyPI.
|
||||
|
||||
#### 0. Prerequisites
|
||||
|
||||
- Dependencies:
|
||||
- twine: `pip install build twine`
|
||||
- Create an account in (and join the `trl` project):
|
||||
- PyPI: https://pypi.org/
|
||||
- Test PyPI: https://test.pypi.org/
|
||||
|
||||
#### 1. Ensure your local repository is up to date with the upstream repository
|
||||
|
||||
```bash
|
||||
git checkout main
|
||||
git pull origin main
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Do not merge other pull requests into `main` until the release is done. This is to ensure that the release is stable and does not include any untested changes. Announce internally (#trl-internal) to other maintainers that you are doing a release and that they must not merge PRs until the release is done.
|
||||
|
||||
#### 2. Create a release branch from main
|
||||
|
||||
```bash
|
||||
git checkout -b release-v{major}.{minor}
|
||||
```
|
||||
|
||||
#### 3. Change the version in the following files
|
||||
|
||||
- `.github/workflows/tests_latest.yml`:
|
||||
```diff
|
||||
- with: { ref: v{major}.{minor-1}-release }
|
||||
+ with: { ref: v{major}.{minor}-release }
|
||||
```
|
||||
- `CITATION.cff`
|
||||
```diff
|
||||
- version: {major}.{minor-1}
|
||||
+ version: {major}.{minor}
|
||||
```
|
||||
- `__init__.py`
|
||||
```diff
|
||||
- __version__ = "{major}.{minor}.0.dev0"
|
||||
+ __version__ = "{major}.{minor}.0"
|
||||
```
|
||||
- `setup.cfg`
|
||||
```diff
|
||||
- version = {major}.{minor}.0.dev0
|
||||
+ version = {major}.{minor}.0
|
||||
```
|
||||
|
||||
#### 4. Commit and push these changes
|
||||
|
||||
```shell
|
||||
git commit -m 'Release: {major}.{minor}'
|
||||
git push origin release-v{major}.{minor}
|
||||
```
|
||||
|
||||
#### 5. Create a pull request
|
||||
|
||||
from `release-v{major}.{minor}` to `main`, named `Release: v{major}.{minor}`, wait for tests to pass, and request a review.
|
||||
|
||||
#### 6. Once the pull request is approved, merge it into `main`
|
||||
|
||||
#### 7. Add a tag in git to mark the release
|
||||
|
||||
```shell
|
||||
git checkout main
|
||||
git pull origin main
|
||||
git tag -a v{major}.{minor}.0 -m 'Adds tag v{major}.{minor}.0 for PyPI'
|
||||
git push origin v{major}.{minor}.0
|
||||
```
|
||||
|
||||
#### 8. Create a branch `v{major}.{minor}-release` for future patch releases.
|
||||
|
||||
```shell
|
||||
git checkout -b v{major}.{minor}-release
|
||||
git push origin v{major}.{minor}-release
|
||||
```
|
||||
|
||||
This ensures that future patch releases (`v{major}.{minor}.1`, `v{major}.{minor}.2`, etc.) can be made separately from `main`.
|
||||
|
||||
#### 9. Create the wheels for your release
|
||||
|
||||
These are the artifacts that will be uploaded to PyPI and installed by users via `pip install trl`.
|
||||
|
||||
Clean previous builds:
|
||||
|
||||
```shell
|
||||
rm -rf build dist
|
||||
```
|
||||
|
||||
At the root of your repo, run
|
||||
|
||||
```bash
|
||||
python -m build .
|
||||
```
|
||||
|
||||
This will create a folders named `dist` with the new versions of your package.
|
||||
|
||||
#### 10. Upload the package to PyPI Test
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Do not skip this step. It is important to test the package before uploading it to the main PyPI server.
|
||||
|
||||
```shell
|
||||
twine upload dist/* -r testpypi
|
||||
```
|
||||
|
||||
Then in a fresh environment containing all dependencies you need, try to install your new package from the PyPI test server.
|
||||
|
||||
```bash
|
||||
pip install -i https://test.pypi.org/simple/ trl
|
||||
```
|
||||
|
||||
You might get errors for missing dependencies since the PyPI test server does not contain all packages like PyPI does. To make sure you have everything you can do:
|
||||
|
||||
```bash
|
||||
pip install trl
|
||||
pip uninstall trl
|
||||
```
|
||||
|
||||
(the second line will remove trl but keep all its dependencies).
|
||||
|
||||
Also make sure you can actually use the package! Run the following line:
|
||||
|
||||
```bash
|
||||
python -c "from trl import *"
|
||||
```
|
||||
|
||||
along with anything that tests:
|
||||
|
||||
- the core feature of your package
|
||||
- the new features you’re adding in the release
|
||||
|
||||
#### 11. Publish on PyPI
|
||||
|
||||
> [!WARNING]
|
||||
> This can't be reverted. Make sure you have tested everything before doing this step.
|
||||
|
||||
```shell
|
||||
twine upload dist/*
|
||||
```
|
||||
|
||||
#### 12. Create a GitHub Release
|
||||
|
||||
1. Go to the repo’s [releases section](https://github.com/huggingface/trl/releases) on GitHub.
|
||||
2. Click **Draft a new release**.
|
||||
3. Select the `v{major}.{minor}.0` tag you just created in step 7.
|
||||
4. Add a title (`v{major}.{minor}.0`) and a short description of what’s new.
|
||||
5. Click **Publish Release**.
|
||||
|
||||
#### 13. Bump to dev version
|
||||
|
||||
1. Create a branch `bump-dev-version-{major}.{minor+1}` from `main` and checkout to it.
|
||||
|
||||
```shell
|
||||
git checkout -b bump-dev-version-{major}.{minor+1}
|
||||
```
|
||||
|
||||
2. Change the version in the following files:
|
||||
1. `__init__.py`
|
||||
```diff
|
||||
- __version__ = "{major}.{minor}.0"
|
||||
+ __version__ = "{major}.{minor+1}.0.dev0"
|
||||
```
|
||||
2. `setup.cfg`
|
||||
```diff
|
||||
- version = {major}.{minor}.0
|
||||
+ version = {major}.{minor+1}.0.dev0
|
||||
```
|
||||
|
||||
3. Commit and push these changes
|
||||
|
||||
```shell
|
||||
git add trl/__init__.py setup.cfg
|
||||
git commit -m '⬆️ Bump dev version'
|
||||
git push origin bump-dev-version-{major}.{minor+1}
|
||||
```
|
||||
|
||||
4. Create a pull request from `bump-dev-version-{major}.{minor+1}` to `main`, named `⬆️ Bump dev version`, and request urgent review.
|
||||
|
||||
5. Once the pull request is approved, merge it into `main`.
|
||||
|
||||
6. The codebase is now ready for the next development cycle, inform the team in the #trl-internal channel.
|
||||
|
2
LICENSE
2
LICENSE
@ -186,7 +186,7 @@
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
Copyright 2020-2025 The HuggingFace Team
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
@ -1,6 +1,6 @@
|
||||
include settings.ini
|
||||
include LICENSE
|
||||
include CONTRIBUTING.md
|
||||
include README.md
|
||||
recursive-exclude * __pycache__
|
||||
include trl/templates/*.md
|
||||
include trl/templates/*.md
|
||||
include trl/accelerate_configs/*.yaml
|
9
Makefile
9
Makefile
@ -6,17 +6,14 @@ ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
|
||||
COMMAND_FILES_PATH = `pwd`/commands
|
||||
|
||||
test:
|
||||
python -m pytest -n auto --dist=loadfile -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' ./tests/
|
||||
pytest -n auto -m "not slow and not low-priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' tests/
|
||||
|
||||
precommit:
|
||||
pre-commit run --all-files
|
||||
python scripts/add_copyrights.py
|
||||
|
||||
tests_gpu:
|
||||
python -m pytest tests/test_* $(if $(IS_GITHUB_CI),--report-log "common_tests.log",)
|
||||
pre-commit run --all-files
|
||||
|
||||
slow_tests:
|
||||
python -m pytest tests/slow/test_* $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
|
||||
pytest -m "slow" tests/ $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
|
||||
|
||||
test_examples:
|
||||
touch temp_results_sft_tests.txt
|
||||
|
145
README.md
145
README.md
@ -12,8 +12,9 @@
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/huggingface/trl/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/huggingface/trl.svg?color=blue"></a>
|
||||
<a href="https://huggingface.co/docs/trl/index"><img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/trl/index.svg?down_color=red&down_message=offline&up_color=blue&up_message=online"></a>
|
||||
<a href="https://huggingface.co/docs/trl/index"><img alt="Documentation" src="https://img.shields.io/website?label=documentation&url=https%3A%2F%2Fhuggingface.co%2Fdocs%2Ftrl%2Findex&down_color=red&down_message=offline&up_color=blue&up_message=online"></a>
|
||||
<a href="https://github.com/huggingface/trl/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/trl.svg"></a>
|
||||
<a href="https://huggingface.co/trl-lib"><img alt="Hugging Face Hub" src="https://img.shields.io/badge/🤗%20Hub-trl--lib-yellow"></a>
|
||||
</p>
|
||||
|
||||
## Overview
|
||||
@ -22,16 +23,14 @@ TRL is a cutting-edge library designed for post-training foundation models using
|
||||
|
||||
## Highlights
|
||||
|
||||
- **Trainers**: Various fine-tuning methods are easily accessible via trainers like [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer) and more.
|
||||
|
||||
- **Efficient and scalable**:
|
||||
- Leverages [🤗 Accelerate](https://github.com/huggingface/accelerate) to scale from single GPU to multi-node clusters using methods like DDP and DeepSpeed.
|
||||
- Full integration with [`PEFT`](https://github.com/huggingface/peft) enables training on large models with modest hardware via quantization and LoRA/QLoRA.
|
||||
- Integrates [Unsloth](https://github.com/unslothai/unsloth) for accelerating training using optimized kernels.
|
||||
- Leverages [🤗 Accelerate](https://github.com/huggingface/accelerate) to scale from single GPU to multi-node clusters using methods like [DDP](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) and [DeepSpeed](https://github.com/deepspeedai/DeepSpeed).
|
||||
- Full integration with [🤗 PEFT](https://github.com/huggingface/peft) enables training on large models with modest hardware via quantization and LoRA/QLoRA.
|
||||
- Integrates [🦥 Unsloth](https://github.com/unslothai/unsloth) for accelerating training using optimized kernels.
|
||||
|
||||
- **Command Line Interface (CLI)**: A simple interface lets you fine-tune and interact with models without needing to write code.
|
||||
|
||||
- **Trainers**: Various fine-tuning methods are easily accessible via trainers like [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`ORPOTrainer`](https://huggingface.co/docs/trl/orpo_trainer) and more.
|
||||
|
||||
- **AutoModels**: Use pre-defined model classes like [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) to simplify reinforcement learning (RL) with LLMs.
|
||||
- **Command Line Interface (CLI)**: A simple interface lets you fine-tune with models without needing to write code.
|
||||
|
||||
## Installation
|
||||
|
||||
@ -59,60 +58,75 @@ If you want to use the examples you can clone the repository with the following
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
```
|
||||
|
||||
## Command Line Interface (CLI)
|
||||
## Quick Start
|
||||
|
||||
You can use the TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT) and Direct Preference Optimization (DPO), or vibe check your model with the chat CLI:
|
||||
|
||||
**SFT:**
|
||||
|
||||
```bash
|
||||
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name trl-lib/Capybara \
|
||||
--output_dir Qwen2.5-0.5B-SFT
|
||||
```
|
||||
|
||||
**DPO:**
|
||||
|
||||
```bash
|
||||
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--dataset_name argilla/Capybara-Preferences \
|
||||
--output_dir Qwen2.5-0.5B-DPO
|
||||
```
|
||||
|
||||
**Chat:**
|
||||
|
||||
```bash
|
||||
trl chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
|
||||
```
|
||||
|
||||
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.
|
||||
|
||||
## How to use
|
||||
|
||||
For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.
|
||||
|
||||
### `SFTTrainer`
|
||||
|
||||
Here is a basic example of how to use the `SFTTrainer`:
|
||||
Here is a basic example of how to use the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer):
|
||||
|
||||
```python
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
from trl import SFTTrainer
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
|
||||
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
|
||||
trainer = SFTTrainer(
|
||||
args=training_args,
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### `GRPOTrainer`
|
||||
|
||||
[`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer) implements the [Group Relative Policy Optimization (GRPO) algorithm](https://huggingface.co/papers/2402.03300) that is more memory-efficient than PPO and was used to train [Deepseek AI's R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOTrainer
|
||||
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
# Dummy reward function: count the number of unique characters in the completions
|
||||
def reward_num_unique_chars(completions, **kwargs):
|
||||
return [len(set(c)) for c in completions]
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=reward_num_unique_chars,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### `DPOTrainer`
|
||||
|
||||
[`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer) implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train [Llama 3](https://huggingface.co/papers/2407.21783) and many other models. Here is a basic example of how to use the `DPOTrainer`:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### `RewardTrainer`
|
||||
|
||||
Here is a basic example of how to use the `RewardTrainer`:
|
||||
Here is a basic example of how to use the [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer):
|
||||
|
||||
```python
|
||||
from trl import RewardConfig, RewardTrainer
|
||||
@ -137,47 +151,28 @@ trainer = RewardTrainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### `GRPOTrainer`
|
||||
## Command Line Interface (CLI)
|
||||
|
||||
`GRPOTrainer` implements the [Group Relative Policy Optimization (GRPO) algorithm](https://huggingface.co/papers/2402.03300) that is more memory-efficient than PPO and was used to train [Deepseek AI's R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).
|
||||
You can use the TRL Command Line Interface (CLI) to quickly get started with post-training methods like Supervised Fine-Tuning (SFT) or Direct Preference Optimization (DPO):
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
**SFT:**
|
||||
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
# Dummy reward function: rewards completions that are close to 20 characters
|
||||
def reward_len(completions, **kwargs):
|
||||
return [-abs(20 - len(completion)) for completion in completions]
|
||||
|
||||
training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=10)
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=reward_len,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
trainer.train()
|
||||
```bash
|
||||
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name trl-lib/Capybara \
|
||||
--output_dir Qwen2.5-0.5B-SFT
|
||||
```
|
||||
|
||||
### `DPOTrainer`
|
||||
**DPO:**
|
||||
|
||||
`DPOTrainer` implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train Llama 3 and many other models. Here is a basic example of how to use the `DPOTrainer`:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
||||
trainer = DPOTrainer(model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer)
|
||||
trainer.train()
|
||||
```bash
|
||||
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--dataset_name argilla/Capybara-Preferences \
|
||||
--output_dir Qwen2.5-0.5B-DPO
|
||||
```
|
||||
|
||||
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.
|
||||
|
||||
## Development
|
||||
|
||||
If you want to contribute to `trl` or customize it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:
|
||||
|
@ -42,7 +42,7 @@ accelerate launch $EXTRA_ACCELERATE_ARGS \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_steps $MAX_STEPS \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--max_seq_length $SEQ_LEN \
|
||||
--max_length $SEQ_LEN \
|
||||
$EXTRA_TRAINING_ARGS
|
||||
"""
|
||||
|
||||
|
@ -23,6 +23,8 @@
|
||||
title: Reducing Memory Usage
|
||||
- local: speeding_up_training
|
||||
title: Speeding Up Training
|
||||
- local: distributing_training
|
||||
title: Distributing Training
|
||||
- local: use_model
|
||||
title: Using Trained Models
|
||||
title: How-to guides
|
||||
@ -35,6 +37,8 @@
|
||||
title: PEFT
|
||||
- local: unsloth_integration
|
||||
title: Unsloth
|
||||
- local: vllm_integration
|
||||
title: vLLM
|
||||
title: Integrations
|
||||
- sections:
|
||||
- local: example_overview
|
||||
@ -47,10 +51,10 @@
|
||||
title: Training StackLlama
|
||||
- local: detoxifying_a_lm
|
||||
title: Detoxifying a Language Model
|
||||
- local: learning_tools
|
||||
title: Learning to Use Tools
|
||||
- local: multi_adapter_rl
|
||||
title: Multi Adapter RLHF
|
||||
- local: training_vlm_sft
|
||||
title: Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset)
|
||||
title: Examples
|
||||
- sections:
|
||||
- sections: # Sorted alphabetically
|
||||
@ -93,6 +97,8 @@
|
||||
title: Trainers
|
||||
- local: models
|
||||
title: Model Classes
|
||||
- local: model_utils
|
||||
title: Model Utilities
|
||||
- local: best_of_n
|
||||
title: Best of N Sampling
|
||||
- local: judges
|
||||
@ -101,8 +107,10 @@
|
||||
title: Callbacks
|
||||
- local: data_utils
|
||||
title: Data Utilities
|
||||
- local: text_environments
|
||||
title: Text Environments
|
||||
- local: rewards
|
||||
title: Reward Functions
|
||||
- local: script_utils
|
||||
title: Script Utilities
|
||||
- local: others
|
||||
title: Others
|
||||
title: API
|
||||
|
@ -1,105 +1,231 @@
|
||||
# Command Line Interfaces (CLIs)
|
||||
|
||||
You can use TRL to fine-tune your Language Model with Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) or even chat with your model using the TRL CLIs.
|
||||
TRL provides a powerful command-line interface (CLI) to fine-tune large language models (LLMs) using methods like Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and more. The CLI abstracts away much of the boilerplate, letting you launch training jobs quickly and reproducibly.
|
||||
|
||||
Currently supported CLIs are:
|
||||
Currently supported commands are:
|
||||
|
||||
#### Training commands
|
||||
#### Training Commands
|
||||
|
||||
- `trl dpo`: fine-tune a LLM with DPO
|
||||
- `trl grpo`: fine-tune a LLM with GRPO
|
||||
- `trl kto`: fine-tune a LLM with KTO
|
||||
- `trl sft`: fine-tune a LLM with SFT
|
||||
|
||||
#### Other commands
|
||||
#### Other Commands
|
||||
|
||||
- `trl chat`: quickly spin up an LLM fine-tuned for chatting
|
||||
- `trl env`: get the system information
|
||||
- `trl vllm-serve`: serve a model with vLLM
|
||||
|
||||
## Fine-tuning with the CLI
|
||||
## Fine-Tuning with the TRL CLI
|
||||
|
||||
Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task.
|
||||
### Basic Usage
|
||||
|
||||
You can launch training directly from the CLI by specifying required arguments like the model and dataset:
|
||||
|
||||
<hfoptions id="command_line">
|
||||
<hfoption id="SFT">
|
||||
|
||||
Before using the `sft` or `dpo` commands make sure to run:
|
||||
```bash
|
||||
accelerate config
|
||||
trl sft \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name stanfordnlp/imdb
|
||||
```
|
||||
and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command.
|
||||
|
||||
We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command.
|
||||
</hfoption>
|
||||
<hfoption id="DPO">
|
||||
|
||||
```bash
|
||||
trl dpo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name anthropic/hh-rlhf
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Using Configuration Files
|
||||
|
||||
To keep your CLI commands clean and reproducible, you can define all training arguments in a YAML configuration file:
|
||||
|
||||
<hfoptions id="config_file">
|
||||
<hfoption id="SFT">
|
||||
|
||||
```yaml
|
||||
model_name_or_path:
|
||||
Qwen/Qwen2.5-0.5B
|
||||
dataset_name:
|
||||
stanfordnlp/imdb
|
||||
report_to:
|
||||
none
|
||||
learning_rate:
|
||||
0.0001
|
||||
lr_scheduler_type:
|
||||
cosine
|
||||
# sft_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: stanfordnlp/imdb
|
||||
```
|
||||
|
||||
Save that config in a `.yaml` and get started immediately! An example CLI config is available as `examples/cli_configs/example_config.yaml`. Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g. from the root folder:
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl sft --config examples/cli_configs/example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts
|
||||
trl sft --config sft_config.yaml
|
||||
```
|
||||
|
||||
Will force-use `cosine_with_restarts` for `lr_scheduler_type`.
|
||||
</hfoption>
|
||||
<hfoption id="DPO">
|
||||
|
||||
### Supported Arguments
|
||||
```yaml
|
||||
# dpo_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: anthropic/hh-rlhf
|
||||
```
|
||||
|
||||
We do support all arguments from `transformers.TrainingArguments`, for loading your model, we support all arguments from `~trl.ModelConfig`:
|
||||
|
||||
[[autodoc]] ModelConfig
|
||||
|
||||
You can pass any of these arguments either to the CLI or the YAML file.
|
||||
|
||||
### Supervised Fine-tuning (SFT)
|
||||
|
||||
Follow the basic instructions above and run `trl sft --output_dir <output_dir> <*args>`:
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl sft --model_name_or_path facebook/opt-125m --dataset_name stanfordnlp/imdb --output_dir opt-sft-imdb
|
||||
trl dpo --config dpo_config.yaml
|
||||
```
|
||||
|
||||
The SFT CLI is based on the `trl/scripts/sft.py` script.
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Direct Policy Optimization (DPO)
|
||||
### Scaling Up with Accelerate
|
||||
|
||||
To use the DPO CLI, you need to have a dataset in the TRL format such as
|
||||
TRL CLI natively supports [🤗 Accelerate](https://huggingface.co/docs/accelerate), making it easy to scale training across multiple GPUs, machines, or use advanced setups like DeepSpeed — all from the same CLI.
|
||||
|
||||
* TRL's Anthropic HH dataset: https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-helpful-base-trl-style
|
||||
* TRL's OpenAI TL;DR summarization dataset: https://huggingface.co/datasets/trl-internal-testing/tldr-preference-trl-style
|
||||
You can pass any `accelerate launch` arguments directly to `trl`, such as `--num_processes`. For more information see [Using accelerate launch](https://huggingface.co/docs/accelerate/en/basic_tutorials/launch#using-accelerate-launch).
|
||||
|
||||
These datasets always have at least three columns `prompt, chosen, rejected`:
|
||||
|
||||
* `prompt` is a list of strings.
|
||||
* `chosen` is the chosen response in [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
|
||||
* `rejected` is the rejected response [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
|
||||
|
||||
|
||||
To do a quick start, you can run the following command:
|
||||
<hfoptions id="launch_args">
|
||||
<hfoption id="SFT inline">
|
||||
|
||||
```bash
|
||||
trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style
|
||||
trl sft \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name stanfordnlp/imdb \
|
||||
--num_processes 4
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="SFT w/ config file">
|
||||
|
||||
The DPO CLI is based on the `trl/scripts/dpo.py` script.
|
||||
```yaml
|
||||
# sft_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: stanfordnlp/imdb
|
||||
num_processes: 4
|
||||
```
|
||||
|
||||
|
||||
#### Custom preference dataset
|
||||
|
||||
Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`):
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
|
||||
trl sft --config sft_config.yaml
|
||||
```
|
||||
|
||||
## Chat interface
|
||||
</hfoption>
|
||||
<hfoption id="DPO inline">
|
||||
|
||||
```bash
|
||||
trl dpo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name anthropic/hh-rlhf \
|
||||
--num_processes 4
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPO w/ config file">
|
||||
|
||||
```yaml
|
||||
# dpo_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: anthropic/hh-rlhf
|
||||
num_processes: 4
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl dpo --config dpo_config.yaml
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Using `--accelerate_config` for Accelerate Configuration
|
||||
|
||||
The `--accelerate_config` flag lets you easily configure distributed training with [🤗 Accelerate](https://github.com/huggingface/accelerate). This flag accepts either:
|
||||
|
||||
* the name of a predefined config profile (built into TRL), or
|
||||
* a path to a custom Accelerate YAML config file.
|
||||
|
||||
#### Predefined Config Profiles
|
||||
|
||||
TRL provides several ready-to-use Accelerate configs to simplify common training setups:
|
||||
|
||||
| Name | Description |
|
||||
| ------------ | ----------------------------------- |
|
||||
| `fsdp1` | Fully Sharded Data Parallel Stage 1 |
|
||||
| `fsdp2` | Fully Sharded Data Parallel Stage 2 |
|
||||
| `zero1` | DeepSpeed ZeRO Stage 1 |
|
||||
| `zero2` | DeepSpeed ZeRO Stage 2 |
|
||||
| `zero3` | DeepSpeed ZeRO Stage 3 |
|
||||
| `multi_gpu` | Multi-GPU training |
|
||||
| `single_gpu` | Single-GPU training |
|
||||
|
||||
To use one of these, just pass the name to `--accelerate_config`. TRL will automatically load the corresponding config file from `trl/accelerate_config/`.
|
||||
|
||||
#### Example Usage
|
||||
|
||||
<hfoptions id="accelerate_config">
|
||||
<hfoption id="SFT inline">
|
||||
|
||||
```bash
|
||||
trl sft \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name stanfordnlp/imdb \
|
||||
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="SFT w/ config file">
|
||||
|
||||
```yaml
|
||||
# sft_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: stanfordnlp/imdb
|
||||
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl sft --config sft_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPO inline">
|
||||
|
||||
```bash
|
||||
trl dpo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name anthropic/hh-rlhf \
|
||||
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPO w/ config file">
|
||||
|
||||
```yaml
|
||||
# dpo_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: anthropic/hh-rlhf
|
||||
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl dpo --config dpo_config.yaml
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Chat Interface
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
The chat interface is deprecated and will be removed in TRL 0.19. Use `transformers-cli chat` instead. For more information, see the [Transformers documentation, chat with text generation models](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
</Tip>
|
||||
|
||||
The chat CLI lets you quickly load the model and talk to it. Simply run the following:
|
||||
|
||||
@ -124,7 +250,7 @@ Besides talking to the model there are a few commands you can use:
|
||||
- `save` or `save {SAVE_NAME}`: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
|
||||
- `exit`: closes the interface
|
||||
|
||||
## Getting the system information
|
||||
## Getting the System Information
|
||||
|
||||
You can get the system information by running the following command:
|
||||
|
||||
@ -132,7 +258,7 @@ You can get the system information by running the following command:
|
||||
trl env
|
||||
```
|
||||
|
||||
This will print out the system information including the GPU information, the CUDA version, the PyTorch version, the transformers version, and the TRL version, and any optional dependencies that are installed.
|
||||
This will print out the system information, including the GPU information, the CUDA version, the PyTorch version, the transformers version, the TRL version, and any optional dependencies that are installed.
|
||||
|
||||
```txt
|
||||
Copy-paste the following information when reporting an issue:
|
||||
@ -140,7 +266,7 @@ Copy-paste the following information when reporting an issue:
|
||||
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
|
||||
- Python version: 3.11.9
|
||||
- PyTorch version: 2.4.1
|
||||
- CUDA device: NVIDIA H100 80GB HBM3
|
||||
- accelerator(s): NVIDIA H100 80GB HBM3
|
||||
- Transformers version: 4.45.0.dev0
|
||||
- Accelerate version: 0.34.2
|
||||
- Accelerate config:
|
||||
@ -171,6 +297,7 @@ Copy-paste the following information when reporting an issue:
|
||||
- LLM-Blender version: 0.0.2
|
||||
- OpenAI version: 1.46.0
|
||||
- PEFT version: 0.12.0
|
||||
- vLLM version: not installed
|
||||
```
|
||||
|
||||
This information are required when reporting an issue.
|
||||
This information is required when reporting an issue.
|
||||
|
@ -2,48 +2,6 @@
|
||||
|
||||
TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. Note: Although these examples use the DPOTrainer, the customization applies to most (if not all) trainers.
|
||||
|
||||
## Train on multiple GPUs / nodes
|
||||
|
||||
The trainers in TRL use 🤗 Accelerate to enable distributed training across multiple GPUs or nodes. To do so, first create an 🤗 Accelerate config file by running
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
and answering the questions according to your multi-gpu / multi-node setup. You can then launch distributed training by running:
|
||||
|
||||
```bash
|
||||
accelerate launch your_script.py
|
||||
```
|
||||
|
||||
We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.:
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
|
||||
```
|
||||
|
||||
Refer to the [examples page](https://github.com/huggingface/trl/tree/main/examples) for more details.
|
||||
|
||||
### Distributed training with DeepSpeed
|
||||
|
||||
All of the trainers in TRL can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run:
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_your_script.py --all_arguments_of_the_script
|
||||
```
|
||||
|
||||
Note that for ZeRO-3, a small tweak is needed to initialize your reward model on the correct device via the `zero3_init_context_manager()` context manager. In particular, this is needed to avoid DeepSpeed hanging after a fixed number of training steps. Here is a snippet of what is involved from the [`sentiment_tuning`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) example:
|
||||
|
||||
```python
|
||||
ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin
|
||||
if ds_plugin is not None and ds_plugin.is_zero3_init_enabled():
|
||||
with ds_plugin.zero3_init_context_manager(enable=False):
|
||||
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
|
||||
else:
|
||||
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
|
||||
```
|
||||
|
||||
Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin.
|
||||
|
||||
|
||||
## Use different optimizers and schedulers
|
||||
|
@ -12,6 +12,10 @@
|
||||
|
||||
[[autodoc]] maybe_apply_chat_template
|
||||
|
||||
## maybe_convert_to_chatml
|
||||
|
||||
[[autodoc]] maybe_convert_to_chatml
|
||||
|
||||
## extract_prompt
|
||||
|
||||
[[autodoc]] extract_prompt
|
||||
@ -31,3 +35,11 @@
|
||||
## pack_examples
|
||||
|
||||
[[autodoc]] pack_examples
|
||||
|
||||
## pack_dataset
|
||||
|
||||
[[autodoc]] pack_dataset
|
||||
|
||||
## truncate_dataset
|
||||
|
||||
[[autodoc]] truncate_dataset
|
||||
|
@ -279,7 +279,7 @@ Choosing the right dataset type depends on the task you are working on and the s
|
||||
| [`PPOTrainer`] | Tokenized language modeling |
|
||||
| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
|
||||
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
|
||||
| [`SFTTrainer`] | [Language modeling](#language-modeling) |
|
||||
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
|
||||
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
|
||||
<Tip>
|
||||
|
@ -4,4 +4,36 @@
|
||||
|
||||
Section under construction. Feel free to contribute!
|
||||
|
||||
</Tip>
|
||||
</Tip>
|
||||
|
||||
TRL supports training with DeepSpeed, a library that implements advanced training optimization techniques. These include optimizer state partitioning, offloading, gradient partitioning, and more.
|
||||
|
||||
DeepSpeed integrates the [Zero Redundancy Optimizer (ZeRO)](https://huggingface.co/papers/1910.02054), which allows to scale the model size proportional to the number of devices with sustained high efficiency.
|
||||
|
||||

|
||||
|
||||
## Installation
|
||||
|
||||
To use DeepSpeed with TRL, install it using the following command:
|
||||
|
||||
```bash
|
||||
pip install deepspeed
|
||||
```
|
||||
|
||||
## Running Training Scripts with DeepSpeed
|
||||
|
||||
No modifications to your training script are required. Simply run it with the DeepSpeed configuration file:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file <ACCELERATE_WITH_DEEPSPEED_CONFIG_FILE.yaml> train.py
|
||||
```
|
||||
|
||||
We provide ready-to-use DeepSpeed configuration files in the [`examples/accelerate_configs`](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) directory. For example, to run training with ZeRO Stage 2, use the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml train.py
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin.
|
||||
|
60
docs/source/distributing_training.md
Normal file
60
docs/source/distributing_training.md
Normal file
@ -0,0 +1,60 @@
|
||||
# Distributing Training
|
||||
|
||||
<Tip warning={true}>
|
||||
Section under construction. Feel free to contribute!
|
||||
</Tip>
|
||||
|
||||
## Multi-GPU Training with TRL
|
||||
|
||||
The trainers in TRL use [🤗 Accelerate](https://github.com/huggingface/accelerate) to enable distributed training across multiple GPUs or nodes. To do so, first create an [🤗 Accelerate](https://github.com/huggingface/accelerate) config file by running
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
and answering the questions according to your multi-GPU / multi-node setup. You can then launch distributed training by running:
|
||||
|
||||
```bash
|
||||
accelerate launch train.py
|
||||
```
|
||||
|
||||
We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.:
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml train.py <SCRIPT_ARGS>
|
||||
```
|
||||
|
||||
This automatically distributes the workload across all available GPUs.
|
||||
|
||||
Under the hood, [🤗 Accelerate](https://github.com/huggingface/accelerate) creates one model per GPU. Each process:
|
||||
- Processes its own batch of data
|
||||
- Computes the loss and gradients for that batch
|
||||
- Shares gradient updates across all GPUs
|
||||
|
||||

|
||||
|
||||
The effective batch size is calculated as:
|
||||
|
||||
$$
|
||||
\text{Batch Size} = \text{per\_device\_train\_batch\_size} \times \text{num\_devices} \times \text{gradient\_accumulation\_steps}
|
||||
$$
|
||||
|
||||
To maintain a consistent batch size when scaling to multiple GPUs, make sure to update `per_device_train_batch_size` and `gradient_accumulation_steps` accordingly.
|
||||
|
||||
Example, these configurations are equivalent, and should yield the same results:
|
||||
|
||||
| Number of GPUs | Per device batch size | Gradient accumulation steps | Comments |
|
||||
| --- | --- | --- | --- |
|
||||
| 1 | 32 | 1 | Possibly high memory usage, but faster training |
|
||||
| 1 | 4 | 8 | Lower memory usage, slower training |
|
||||
| 8 | 4 | 1 | Multi-GPU to get the best of both worlds |
|
||||
|
||||
<Tip>
|
||||
|
||||
Having one model per GPU can lead to high memory usage, which may not be feasible for large models or low-memory GPUs. In such cases, you can leverage [DeepSpeed](https://github.com/deepspeedai/DeepSpeed), which provides optimizations like model sharding, Zero Redundancy Optimizer, mixed precision training, and offloading to CPU or NVMe. Check out our [DeepSpeed Integration](deepspeed_integration.md) guide for more details.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Multi-Nodes Training
|
||||
|
||||
We're working on a guide for multi-node training. Stay tuned! 🚀
|
@ -61,9 +61,9 @@ Distributed across 8 GPUs, the training takes approximately 3 minutes. You can v
|
||||
|
||||

|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-DPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-DPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-DPO
|
||||
<pre><code>$ transformers-cli chat --model_name_or_path trl-lib/Qwen2-0.5B-DPO
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
|
@ -33,26 +33,37 @@ Then, it is encouraged to launch jobs with `accelerate launch`!
|
||||
|
||||
Scripts can be used as examples of how to use TRL trainers. They are located in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) directory. Additionally, we provide examples in the [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directory. These examples are maintained and tested regularly.
|
||||
|
||||
| File | Description |
|
||||
| ----------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| [`examples/scripts/alignprop.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/alignprop.py) | This script shows how to use the [`AlignPropTrainer`] to fine-tune a diffusion model. |
|
||||
| [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. |
|
||||
| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`examples/scripts/ddpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ddpo.py) | This script shows how to use the [`DDPOTrainer`] to fine-tune a stable diffusion model using reinforcement learning. |
|
||||
| [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. |
|
||||
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language |
|
||||
| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
|
||||
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a reward model on your own dataset. |
|
||||
| [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models so users may see unexpected behaviour in other model architectures. |
|
||||
| File | Description |
|
||||
| --- | --- |
|
||||
| [`examples/scripts/alignprop.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/alignprop.py) | This script shows how to use the [`AlignPropTrainer`] to fine-tune a diffusion model. |
|
||||
| [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. |
|
||||
| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`examples/scripts/ddpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ddpo.py) | This script shows how to use the [`DDPOTrainer`] to fine-tune a stable diffusion model using reinforcement learning. |
|
||||
| [`examples/scripts/dpo_online.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_online.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a model. |
|
||||
| [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. |
|
||||
| [`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py) | This script shows how to use the [`GKDTrainer`] to fine-tune a model. |
|
||||
| [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`NashMDTrainer`] to fine-tune a model. |
|
||||
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language |
|
||||
| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
|
||||
| [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) | This script shows how to use the [`PRMTrainer`] to fine-tune a Process-supervised Reward Model (PRM). |
|
||||
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a Outcome Reward Model (ORM) on your own dataset. |
|
||||
| [`examples/scripts/rloo/rloo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo/rloo.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language |
|
||||
| [`examples/scripts/rloo/rloo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo/rloo_tldr.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
|
||||
| [`examples/scripts/sft_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model. |
|
||||
| [`examples/scripts/sft_video_llm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_video_llm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Video Language Model. |
|
||||
| [`examples/scripts/sft_vlm_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model on vision to text tasks. |
|
||||
| [`examples/scripts/sft_vlm_smol_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_smol_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a SmolVLM model. |
|
||||
| [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models so users may see unexpected behaviour in other model architectures. |
|
||||
| [`examples/scripts/xpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/xpo.py) | This script shows how to use the [`XPOTrainer`] to fine-tune a model. |
|
||||
|
||||
Here are also some easier-to-run colab notebooks that you can use to get started with TRL:
|
||||
|
||||
| File | Description |
|
||||
| --------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------- |
|
||||
| [`examples/notebooks/best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb) | This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. |
|
||||
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
|
||||
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
|
||||
| File | Description |
|
||||
| --- | --- |
|
||||
| [`examples/notebooks/best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb) | This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. |
|
||||
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
|
||||
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
|
||||
|
||||
|
||||
We also have some other examples that are less maintained but can be used as a reference:
|
||||
|
@ -68,11 +68,17 @@ At each training step, we sample a batch of prompts and generate a set of \\( G
|
||||
|
||||
### Computing the advantage
|
||||
|
||||
For each of the \\( G \\) sequences, we compute the reward using a reward model. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows:
|
||||
For each of the \\( G \\) sequences, we compute the reward using a reward model. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows:
|
||||
|
||||
$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$
|
||||
$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$
|
||||
|
||||
This approach gives the method its name: **Group Relative Policy Optimization (GRPO)**.
|
||||
This approach gives the method its name: **Group Relative Policy Optimization (GRPO)**.
|
||||
|
||||
<Tip>
|
||||
|
||||
It was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that scaling by \\( \text{std}(\mathbf{r}) \\) may cause a question-level difficulty bias. You can disable this scaling by setting `scale_rewards=False` in [`GRPOConfig`].
|
||||
|
||||
</Tip>
|
||||
|
||||
### Estimating the KL divergence
|
||||
|
||||
@ -83,46 +89,215 @@ $$
|
||||
|
||||
### Computing the loss
|
||||
|
||||
The objective is to maximize the advantage while ensuring that the model remains close to the reference policy. Consequently, the loss is defined as follows:
|
||||
The objective is to maximize the advantage while ensuring that the model remains close to the reference policy. Consequently, the loss is defined as follows:
|
||||
|
||||
$$
|
||||
\mathcal{L}_{\text{GRPO}}(\theta) = -\frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
|
||||
\mathcal{L}_{\text{GRPO}}(\theta) = -\frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
|
||||
$$
|
||||
|
||||
where the first term represents the scaled advantage and the second term penalizes deviations from the reference policy through KL divergence.
|
||||
where the first term represents the scaled advantage and the second term penalizes deviations from the reference policy through KL divergence.
|
||||
|
||||
In the original paper, this formulation is generalized to account for multiple updates after each generation by leveraging the **clipped surrogate objective**:
|
||||
<Tip>
|
||||
|
||||
Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we don't scale by \\( \frac{1}{|o_i|} \\) because it was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that this introduces a response-level length bias. More details in [loss types](#loss-types).
|
||||
|
||||
</Tip>
|
||||
|
||||
In the original paper, this formulation is generalized to account for multiple updates after each generation (denoted \\( \mu \\), can be set with `num_iterations` in [`GRPOConfig`]) by leveraging the **clipped surrogate objective**:
|
||||
|
||||
$$
|
||||
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
|
||||
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
|
||||
$$
|
||||
|
||||
where \\(\text{clip}(\cdot, 1 - \epsilon, 1 + \epsilon) \\) ensures that updates do not deviate excessively from the reference policy by bounding the policy ratio between \\( 1 - \epsilon \\) and \\( 1 + \epsilon \\).
|
||||
In TRL though, as in the original paper, we only do one update per generation, so we can simplify the loss to the first form.
|
||||
When \\( \mu = 1 \\) (default in TRL), the clipped surrogate objective simplifies to the original objective.
|
||||
|
||||
#### Loss Types
|
||||
|
||||
Several formulations of the objective have been proposed in the literature. Initially, the objective of GRPO was defined as follows:
|
||||
|
||||
$$
|
||||
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} l_{i,t},
|
||||
$$
|
||||
|
||||
where
|
||||
|
||||
$$
|
||||
l_{i,t} = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right].
|
||||
$$
|
||||
|
||||
The DAPO paper highlights the limitations of the GRPO algorithm’s sample-level loss in long-CoT scenarios, where longer responses are under-penalized, leading to poorer quality outputs. The proposed solution is a token-level normalization, which better handles longer sequences by assigning more balanced rewards to individual tokens, regardless of response length:
|
||||
|
||||
$$
|
||||
\mathcal{L}_{\text{DAPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},
|
||||
$$
|
||||
|
||||
|
||||
Furthermore, it was demonstrated in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that the initial GRPO formulation introduces a response length bias. They show that while the DAPO formulation reduces this bias, it does not eliminate it completely. To fully remove this bias, they propose dividing by a constant instead of the sequence length, resulting in the following formulation:
|
||||
|
||||
$$
|
||||
\mathcal{L}_{\text{Dr. GRPO}}(\theta) = - \frac{1}{LG} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},
|
||||
$$
|
||||
|
||||
This constant is recommended to be the maximum completion length. To use this formulation, set `loss_type="dr_grpo"` in the [`GRPOConfig`].
|
||||
|
||||
## Logged metrics
|
||||
|
||||
The GRPO Trainer logs the following metrics:
|
||||
|
||||
- `completion_length`: The average completion length.
|
||||
- `reward/{reward_func_name}`: The reward computed by each reward function.
|
||||
- `reward`: The average reward.
|
||||
- `reward_std` : The average standard deviation within reward groups.
|
||||
- `kl` : The average KL divergence between the model and the reference model calculated on completions.
|
||||
- `num_tokens`: The total number of tokens processed so far, including both prompts and completions.
|
||||
- `completions/mean_length`: The average length of generated completions.
|
||||
- `completions/min_length`: The minimun length of generated completions.
|
||||
- `completions/max_length`: The maximum length of generated completions.
|
||||
- `completions/mean_terminated_length`: The average length of generated completions that terminate with EOS.
|
||||
- `completions/min_terminated_length`: The minimun length of generated completions that terminate with EOS.
|
||||
- `completions/max_terminated_length`: The maximum length of generated completions that terminate with EOS.
|
||||
- `completions/clipped_ratio` : The ratio of truncated (clipped) completions.
|
||||
- `reward/{reward_func_name}/mean`: The average reward from a specific reward function.
|
||||
- `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function.
|
||||
- `reward`: The overall average reward after applying reward weights.
|
||||
- `reward_std`: The standard deviation of the overall reward within each batch after applying reward weights.
|
||||
- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero.
|
||||
- `clip_ratio/region_mean`: The ratio of token probabilities where the GRPO objective is clipped to stay within the trust region:
|
||||
$$
|
||||
\text{clip}\left( r_{i,t}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \qquad r_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}\,.
|
||||
$$
|
||||
A higher value means more tokens are clipped, which constrains how much the policy $\pi_\theta$ can change.
|
||||
- `clip_ratio/low_mean`: The average ratio of token probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
|
||||
- `clip_ratio/low_min`: The minimum ratio of token probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
|
||||
- `clip_ratio/high_mean`: The average ratio of token probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\)
|
||||
- `clip_ratio/high_max`: The maximum ratio of token probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\).
|
||||
|
||||
## Customization
|
||||
|
||||
## Speed up training with vLLM-powered generation
|
||||
### Speed up training with vLLM-powered generation
|
||||
|
||||
Generation is often the main bottleneck that makes training slow with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation. To enable it, pass `use_vllm=True` in the training arguments.
|
||||
Generation is often the main bottleneck when training with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a high-throughput, low-latency inference engine for LLMs. To enable it, first install the package with
|
||||
```shell
|
||||
pip install trl[vllm]
|
||||
```
|
||||
|
||||
We support two ways of using vLLM during training: **server mode** and **colocate mode**.
|
||||
|
||||
#### 🔌 Option 1: Server mode
|
||||
|
||||
In this mode, vLLM runs in a separate process (and using separate GPUs) and communicates with the trainer via HTTP. This is ideal if you have dedicated GPUs for inference.
|
||||
|
||||
1. **Start the vLLM server**:
|
||||
```bash
|
||||
trl vllm-serve --model <model_name>
|
||||
```
|
||||
|
||||
2. **Enable server mode in your training script**:
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
...,
|
||||
use_vllm=True,
|
||||
vllm_mode="server", # default value, can be omitted
|
||||
)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure that the server is using different GPUs than the trainer, otherwise you may run into NCCL errors. You can specify the GPUs to use with the `CUDA_VISIBLE_DEVICES` environment variable.
|
||||
|
||||
</Tip>
|
||||
|
||||
#### 🧩 Option 2: Colocate mode
|
||||
|
||||
In this mode, vLLM runs inside the trainer process and shares GPU memory with the training model. This avoids launching a separate server and can improve GPU utilization, but may lead to memory contention on the training GPUs.
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(..., use_vllm=True)
|
||||
```
|
||||
training_args = GRPOConfig(
|
||||
...,
|
||||
use_vllm=True,
|
||||
vllm_mode="colocate",
|
||||
)
|
||||
```
|
||||
|
||||
For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).
|
||||
<Tip>
|
||||
|
||||
Depending on the model size and the overall GPU memory requirements for training, you may need to adjust the `vllm_gpu_memory_utilization` parameter in [`GRPOConfig`] to avoid underutilization or out-of-memory errors.
|
||||
|
||||
</Tip>
|
||||
|
||||
For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).
|
||||
|
||||
### GRPO at scale: train a 70B+ Model on multiple nodes
|
||||
|
||||
When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include:
|
||||
|
||||
- **DeepSpeed ZeRO Stage 3**: ZeRO leverages data parallelism to distribute model states (weights, gradients, optimizer states) across multiple GPUs and CPUs, reducing memory and compute requirements on each device. Since large models cannot fit on a single GPU, using ZeRO Stage 3 is required for training such model. For more details, see [DeepSpeed Integration](deepspeed_integration).
|
||||
- **Accelerate**: Accelerate is a library that simplifies distributed training across multiple GPUs and nodes. It provides a simple API to launch distributed training and handles the complexities of distributed training, such as data parallelism, gradient accumulation, and distributed data loading. For more details, see [Distributing Training](distributing_training).
|
||||
- **vLLM**: See the previous section on how to use vLLM to speed up generation.
|
||||
|
||||
Below is an example SLURM script to train a 70B model with GRPO on multiple nodes. This script trains a model on 4 nodes and uses the 5th node for vLLM-powered generation.
|
||||
|
||||
```sh
|
||||
#!/bin/bash
|
||||
#SBATCH --nodes=5
|
||||
#SBATCH --gres=gpu:8
|
||||
|
||||
# Get the list of allocated nodes
|
||||
NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
|
||||
|
||||
# Assign the first 4 nodes for training and the 5th node for vLLM
|
||||
TRAIN_NODES="${NODELIST[@]:0:4}" # Nodes 0, 1, 2, 3 for training
|
||||
VLLM_NODE="${NODELIST[4]}" # Node 4 for vLLM
|
||||
|
||||
# Run training on the first 4 nodes (Group 1)
|
||||
srun --nodes=4 --ntasks=4 --nodelist="${NODELIST[@]:0:4}" accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
--num_processes 32 \
|
||||
--num_machines 4 \
|
||||
--main_process_ip ${NODELIST[0]} \
|
||||
--machine_rank $SLURM_PROCID \
|
||||
--rdzv_backend c10d \
|
||||
train_grpo.py \
|
||||
--server_ip $VLLM_NODE &
|
||||
|
||||
# Run vLLM server on the 5th node (Group 2)
|
||||
srun --nodes=1 --ntasks=1 --nodelist="${NODELIST[4]}" trl vllm-serve --model Qwen/Qwen2.5-72B --tensor_parallel_size 8 &
|
||||
|
||||
wait
|
||||
```
|
||||
|
||||
```python
|
||||
import argparse
|
||||
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--vllm_server_host", type=str, default="", help="The server IP")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Example dataset from TLDR
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
# Dummy reward function: count the number of unique characters in the completions
|
||||
def reward_num_unique_chars(completions, **kwargs):
|
||||
return [len(set(c)) for c in completions]
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir="Qwen2.5-72B-GRPO",
|
||||
per_device_train_batch_size=4,
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
logging_steps=10,
|
||||
use_vllm=True,
|
||||
vllm_server_host=args.vllm_server_host.replace("ip-", "").replace("-", "."), # from ip-X-X-X-X to X.X.X.X
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(model="Qwen/Qwen2.5-72B", args=training_args, reward_funcs=reward_num_unique_chars, train_dataset=dataset)
|
||||
trainer.train()
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
```
|
||||
|
||||
### Using a custom reward function
|
||||
|
||||
@ -132,6 +307,7 @@ The [`GRPOTrainer`] supports using custom reward functions instead of dense rewa
|
||||
- The function must accept the following as keyword arguments:
|
||||
- `prompts` (contains the prompts),
|
||||
- `completions` (contains the generated completions),
|
||||
- `completions_ids` (contains the tokenized completions),
|
||||
- All columns names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument.
|
||||
|
||||
The easiest way to comply with this requirement is to use `**kwargs` in the function signature.
|
||||
@ -145,9 +321,29 @@ The [`GRPOTrainer`] supports using custom reward functions instead of dense rewa
|
||||
|
||||
Below is an example of a reward function for a standard format that rewards longer completions:
|
||||
|
||||
```python
|
||||
def reward_func(completions_ids, **kwargs):
|
||||
"""Reward function that assigns higher scores to longer completions (in terms of token count)."""
|
||||
return [float(len(ids)) for ids in completions_ids]
|
||||
```
|
||||
|
||||
You can test it as follows:
|
||||
|
||||
```python
|
||||
>>> prompts = ["The sky is", "The sun is"] # not used in the reward function, but the trainer will pass it
|
||||
>>> completions = [" blue.", " in the sky."] # not used in the reward function, but the trainer will pass it
|
||||
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]]
|
||||
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
|
||||
[2.0, 4.0]
|
||||
```
|
||||
|
||||
#### Example 1.1: Reward longer completions (based in the number of characters)
|
||||
|
||||
Same as the previous example, but this time the reward function is based on the number of characters instead of tokens.
|
||||
|
||||
```python
|
||||
def reward_func(completions, **kwargs):
|
||||
"""Reward function that gives higher scores to longer completions."""
|
||||
"""Reward function that assigns higher scores to longer completions (in terms of character count)."""
|
||||
return [float(len(completion)) for completion in completions]
|
||||
```
|
||||
|
||||
@ -156,7 +352,8 @@ You can test it as follows:
|
||||
```python
|
||||
>>> prompts = ["The sky is", "The sun is"]
|
||||
>>> completions = [" blue.", " in the sky."]
|
||||
>>> print(reward_func(prompts=prompts, completions=completions))
|
||||
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]] # not used in the reward function, but the trainer will pass it
|
||||
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
|
||||
[6.0, 12.0]
|
||||
```
|
||||
|
||||
@ -216,6 +413,67 @@ You can test this function as follows:
|
||||
>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth)
|
||||
[1.0, 0.0]
|
||||
```
|
||||
#### Example 4: Multi-task reward functions
|
||||
|
||||
Below is an example of using multiple reward functions in the [`GRPOTrainer`]. In this example, we define two task-specific reward functions: `math_reward_func` and `coding_reward_func`. The `math_reward_func` rewards math problems based on their correctness, while the `coding_reward_func` rewards coding problems based on whether the solution works.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import GRPOTrainer
|
||||
|
||||
# Define a dataset that contains both math and coding problems
|
||||
dataset = Dataset.from_list(
|
||||
[
|
||||
{"prompt": "What is 2+2?", "task": "math"},
|
||||
{"prompt": "Write a function that returns the sum of two numbers.", "task": "code"},
|
||||
{"prompt": "What is 3*4?", "task": "math"},
|
||||
{"prompt": "Write a function that returns the product of two numbers.", "task": "code"},
|
||||
]
|
||||
)
|
||||
|
||||
# Math-specific reward function
|
||||
def math_reward_func(prompts, completions, task, **kwargs):
|
||||
rewards = []
|
||||
for prompt, completion, t in zip(prompts, completions, task):
|
||||
if t == "math":
|
||||
# Calculate math-specific reward
|
||||
correct = check_math_solution(prompt, completion)
|
||||
reward = 1.0 if correct else -1.0
|
||||
rewards.append(reward)
|
||||
else:
|
||||
# Return None for non-math tasks
|
||||
rewards.append(None)
|
||||
return rewards
|
||||
|
||||
# Coding-specific reward function
|
||||
def coding_reward_func(prompts, completions, task, **kwargs):
|
||||
rewards = []
|
||||
for prompt, completion, t in zip(prompts, completions, task):
|
||||
if t == "coding":
|
||||
# Calculate coding-specific reward
|
||||
works = test_code_solution(prompt, completion)
|
||||
reward = 1.0 if works else -1.0
|
||||
rewards.append(reward)
|
||||
else:
|
||||
# Return None for non-coding tasks
|
||||
rewards.append(None)
|
||||
return rewards
|
||||
|
||||
# Use both task-specific reward functions
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=[math_reward_func, coding_reward_func],
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
In this example, the `math_reward_func` and `coding_reward_func` are designed to work with a mixed dataset that contains both math and coding problems. The `task` column in the dataset is used to determine which reward function to apply to each problem. If there is no relevant reward function for a sample in the dataset, the reward function will return `None` and the [`GRPOTrainer`] will continue with the valid functions and tasks. This allows the [`GRPOTrainer`] to handle multiple reward functions with different applicability.
|
||||
|
||||
Note that the [`GRPOTrainer`] will ignore the `None` rewards returned by the reward functions and only consider the rewards returned by the relevant functions. This ensures that the model is trained on the relevant tasks and ignores the tasks for which there is no relevant reward function.
|
||||
|
||||
|
||||
|
||||
#### Passing the reward function to the trainer
|
||||
|
||||
|
@ -4,37 +4,35 @@
|
||||
|
||||
# TRL - Transformer Reinforcement Learning
|
||||
|
||||
TRL is a full stack library where we provide a set of tools to train transformer language models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step.
|
||||
TRL is a full stack library where we provide a set of tools to train transformer language models with methods like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO), Reward Modeling, and more.
|
||||
The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers).
|
||||
|
||||
You can also explore TRL-related models, datasets, and demos in the [TRL Hugging Face organization](https://huggingface.co/trl-lib).
|
||||
|
||||
## Learn
|
||||
|
||||
Learn post-training with TRL and other libraries in 🤗 [smol course](https://github.com/huggingface/smol-course).
|
||||
|
||||
## API documentation
|
||||
## Contents
|
||||
|
||||
- [Model Classes](models): *A brief overview of what each public model class does.*
|
||||
- [`SFTTrainer`](sft_trainer): *Supervise Fine-tune your model easily with `SFTTrainer`*
|
||||
- [`RewardTrainer`](reward_trainer): *Train easily your reward model using `RewardTrainer`.*
|
||||
- [`PPOTrainer`](ppo_trainer): *Further fine-tune the supervised fine-tuned model using PPO algorithm*
|
||||
- [Best-of-N Sampling](best-of-n): *Use best of n sampling as an alternative way to sample predictions from your active model*
|
||||
- [`DPOTrainer`](dpo_trainer): *Direct Preference Optimization training using `DPOTrainer`.*
|
||||
- [`TextEnvironment`](text_environments): *Text environment to train your model using tools with RL.*
|
||||
|
||||
## Examples
|
||||
|
||||
- [Sentiment Tuning](sentiment_tuning): *Fine tune your model to generate positive movie contents*
|
||||
- [Training with PEFT](lora_tuning_peft): *Memory efficient RLHF training using adapters with PEFT*
|
||||
- [Detoxifying LLMs](detoxifying_a_lm): *Detoxify your language model through RLHF*
|
||||
- [StackLlama](using_llama_models): *End-to-end RLHF training of a Llama model on Stack exchange dataset*
|
||||
- [Learning with Tools](learning_tools): *Walkthrough of using `TextEnvironments`*
|
||||
- [Multi-Adapter Training](multi_adapter_rl): *Use a single base model and multiple adapters for memory efficient end-to-end training*
|
||||
The documentation is organized into the following sections:
|
||||
|
||||
- **Getting Started**: installation and quickstart guide.
|
||||
- **Conceptual Guides**: dataset formats, training FAQ, and understanding logs.
|
||||
- **How-to Guides**: reducing memory usage, speeding up training, distributing training, etc.
|
||||
- **Integrations**: DeepSpeed, Liger Kernel, PEFT, etc.
|
||||
- **Examples**: example overview, community tutorials, etc.
|
||||
- **API**: trainers, utils, etc.
|
||||
|
||||
## Blog posts
|
||||
|
||||
<div class="mt-10">
|
||||
<div class="w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5">
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/open-r1">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/open-r1/thumbnails.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on January 28, 2025</p>
|
||||
<p class="text-gray-700">Open-R1: a fully open reproduction of DeepSeek-R1</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo_vlm">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/dpo_vlm/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on July 10, 2024</p>
|
||||
|
@ -2,56 +2,138 @@
|
||||
|
||||
[](https://huggingface.co/models?other=iterative-sft,trl)
|
||||
|
||||
|
||||
Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.
|
||||
|
||||
## Usage
|
||||
## Quickstart
|
||||
|
||||
To get started quickly, instantiate an instance a model, and a tokenizer.
|
||||
To get started quickly, you can either pass a model identifier or a pre-instantiated model to the trainer:
|
||||
|
||||
```python
|
||||
from trl import IterativeSFTConfig, IterativeSFTTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
# Using a model identifier
|
||||
trainer = IterativeSFTTrainer(
|
||||
"facebook/opt-350m",
|
||||
args=IterativeSFTConfig(
|
||||
max_length=512,
|
||||
output_dir="./output",
|
||||
),
|
||||
)
|
||||
|
||||
# Or using a pre-instantiated model
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
|
||||
trainer = IterativeSFTTrainer(
|
||||
model,
|
||||
tokenizer
|
||||
args=IterativeSFTConfig(
|
||||
max_length=512,
|
||||
output_dir="./output",
|
||||
),
|
||||
processing_class=tokenizer,
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
You have the choice to either provide a list of strings or a list of tensors to the step function.
|
||||
## Usage
|
||||
|
||||
#### Using a list of tensors as input:
|
||||
The [`IterativeSFTTrainer`] supports two ways of providing input data to the `step` function:
|
||||
|
||||
### Using a list of tensors as input:
|
||||
|
||||
```python
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
trainer.step(**inputs)
|
||||
|
||||
```
|
||||
|
||||
#### Using a list of strings as input:
|
||||
### Using a list of strings as input:
|
||||
|
||||
```python
|
||||
|
||||
inputs = {
|
||||
"texts": texts
|
||||
"texts": texts,
|
||||
"texts_labels": texts_labels, # Optional, defaults to texts
|
||||
}
|
||||
|
||||
trainer.step(**inputs)
|
||||
|
||||
```
|
||||
|
||||
For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels.
|
||||
For causal language models, labels will automatically be created from `input_ids` or from `texts`. When using sequence to sequence models you will have to provide your own labels or `text_labels`.
|
||||
|
||||
## IterativeTrainer
|
||||
## Configuration
|
||||
|
||||
The [`IterativeSFTConfig`] class provides several parameters to customize the training:
|
||||
|
||||
```python
|
||||
from trl import IterativeSFTConfig
|
||||
|
||||
config = IterativeSFTConfig(
|
||||
# Model initialization parameters
|
||||
model_init_kwargs={"torch_dtype": "bfloat16"},
|
||||
|
||||
# Data preprocessing parameters
|
||||
max_length=512,
|
||||
truncation_mode="keep_end",
|
||||
|
||||
# Training parameters
|
||||
output_dir="./output",
|
||||
learning_rate=2e-5,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
max_steps=1000,
|
||||
logging_steps=10,
|
||||
save_steps=100,
|
||||
optim="adamw_torch",
|
||||
report_to="wandb",
|
||||
)
|
||||
```
|
||||
|
||||
### Model Initialization
|
||||
|
||||
You can control how the model is initialized by passing keyword arguments to `model_init_kwargs`:
|
||||
|
||||
```python
|
||||
config = IterativeSFTConfig(
|
||||
model_init_kwargs={
|
||||
"torch_dtype": "bfloat16",
|
||||
"device_map": "auto",
|
||||
"trust_remote_code": True,
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### Data Preprocessing
|
||||
|
||||
The trainer supports two truncation modes:
|
||||
|
||||
- `keep_end`: Truncates from the start of the sequence
|
||||
- `keep_start`: Truncates from the end of the sequence
|
||||
|
||||
```python
|
||||
config = IterativeSFTConfig(
|
||||
max_length=512,
|
||||
truncation_mode="keep_end", # or "keep_start"
|
||||
)
|
||||
```
|
||||
|
||||
### Training Optimization
|
||||
|
||||
You can optimize CUDA cache usage for more memory-efficient training:
|
||||
|
||||
```python
|
||||
config = IterativeSFTConfig(
|
||||
optimize_device_cache=True,
|
||||
)
|
||||
```
|
||||
|
||||
## IterativeSFTTrainer
|
||||
|
||||
[[autodoc]] IterativeSFTTrainer
|
||||
|
||||
## IterativeSFTConfig
|
||||
|
||||
[[autodoc]] IterativeSFTConfig
|
||||
|
@ -53,9 +53,9 @@ Distributed across 8 x H100 GPUs, the training takes approximately 30 minutes. Y
|
||||
|
||||

|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-KTO
|
||||
<pre><code>$ transformers-cli chat --model_name_or_path trl-lib/Qwen2-0.5B-KTO
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
@ -121,14 +121,14 @@ By default, they are both 1. However, if you have more of one or the other, then
|
||||
|
||||
While training and evaluating we record the following reward metrics:
|
||||
|
||||
- `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
|
||||
- `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
|
||||
- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
|
||||
- `logps/chosen`: the mean log probabilities of the chosen completions
|
||||
- `logps/rejected`: the mean log probabilities of the rejected completions
|
||||
- `logits/chosen`: the mean logits of the chosen completions
|
||||
- `logits/rejected`: the mean logits of the rejected completions
|
||||
- `kl`: the KL divergence between the policy model and the reference model
|
||||
- `rewards/chosen_sum`: the sum of log probabilities of the policy model for the chosen responses scaled by beta
|
||||
- `rewards/rejected_sum`: the sum of log probabilities of the policy model for the rejected responses scaled by beta
|
||||
- `logps/chosen_sum`: the sum of log probabilities of the chosen completions
|
||||
- `logps/rejected_sum`: the sum of log probabilities of the rejected completions
|
||||
- `logits/chosen_sum`: the sum of logits of the chosen completions
|
||||
- `logits/rejected_sum`: the sum of logits of the rejected completions
|
||||
- `count/chosen`: the count of chosen samples in a batch
|
||||
- `count/rejected`: the count of rejected samples in a batch
|
||||
|
||||
## KTOTrainer
|
||||
|
||||
|
@ -1,233 +0,0 @@
|
||||
# Learning Tools (Experimental 🧪)
|
||||
|
||||
Using Large Language Models (LLMs) with tools has been a popular topic recently with awesome works such as [ToolFormer](https://huggingface.co/papers/2302.04761) and [ToolBench](https://huggingface.co/papers/2305.16504). In TRL, we provide a simple example of how to teach LLM to use tools with reinforcement learning.
|
||||
|
||||
|
||||
Here's an overview of the scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples/research_projects/tools):
|
||||
|
||||
| File | Description |
|
||||
|---|---|
|
||||
| [`calculator.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/calculator.py) | Script to train LLM to use a calculator with reinforcement learning. |
|
||||
| [`triviaqa.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/triviaqa.py) | Script to train LLM to use a wiki tool to answer questions. |
|
||||
| [`python_interpreter.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/python_interpreter.py) | Script to train LLM to use python interpreter to solve math puzzles. |
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Note that the scripts above rely heavily on the `TextEnvironment` API which is still under active development. The API may change in the future. Please see [`TextEnvironment`](text_environment) for the related docs.
|
||||
</Tip>
|
||||
|
||||
|
||||
## Learning to Use a Calculator
|
||||
|
||||
|
||||
The rough idea is as follows:
|
||||
|
||||
1. Load a tool such as [ybelkada/simple-calculator](https://huggingface.co/spaces/ybelkada/simple-calculator) that parse a text calculation like `"14 + 34"` and return the calculated number:
|
||||
```python
|
||||
from transformers import AutoTokenizer, load_tool
|
||||
tool = load_tool("ybelkada/simple-calculator")
|
||||
tool_fn = lambda text: str(round(float(tool(text)), 2)) # rounding to 2 decimal places
|
||||
```
|
||||
1. Define a reward function that returns a positive reward if the tool returns the correct answer. In the script we create a dummy reward function like `reward_fn = lambda x: 1`, but we override the rewards directly later.
|
||||
1. Create a prompt on how to use the tools
|
||||
```python
|
||||
# system prompt
|
||||
prompt = """\
|
||||
What is 13.1-3?
|
||||
|
||||
<request><SimpleCalculatorTool>13.1-3<call>10.1<response>
|
||||
|
||||
Result=10.1<submit>
|
||||
|
||||
What is 4*3?
|
||||
|
||||
<request><SimpleCalculatorTool>4*3<call>12<response>
|
||||
|
||||
Result=12<submit>
|
||||
|
||||
What is 12.1+1?
|
||||
|
||||
<request><SimpleCalculatorTool>12.1+1<call>13.1<response>
|
||||
|
||||
Result=13.1<submit>
|
||||
|
||||
What is 12.1-20?
|
||||
|
||||
<request><SimpleCalculatorTool>12.1-20<call>-7.9<response>
|
||||
|
||||
Result=-7.9<submit>"""
|
||||
```
|
||||
3. Create a `trl.TextEnvironment` with the model
|
||||
```python
|
||||
env = TextEnvironment(
|
||||
model,
|
||||
tokenizer,
|
||||
{"SimpleCalculatorTool": tool_fn},
|
||||
reward_fn,
|
||||
prompt,
|
||||
generation_kwargs=generation_kwargs,
|
||||
)
|
||||
```
|
||||
4. Then generate some data such as `tasks = ["\n\nWhat is 13.1-3?", "\n\nWhat is 4*3?"]` and run the environment with `queries, responses, masks, rewards, histories = env.run(tasks)`. The environment will look for the `<call>` token in the prompt and append the tool output to the response; it will also return the mask associated with the response. You can further use the `histories` to visualize the interaction between the model and the tool; `histories[0].show_text()` will show the text with color-coded tool output and `histories[0].show_tokens(tokenizer)` will show visualize the tokens.
|
||||

|
||||
1. Finally, we can train the model with `train_stats = ppo_trainer.step(queries, responses, rewards, masks)`. The trainer will use the mask to ignore the tool output when computing the loss, make sure to pass that argument to `step`.
|
||||
|
||||
## Experiment results
|
||||
|
||||
We trained a model with the above script for 10 random seeds. You can reproduce the run with the following command. Feel free to remove the `--slurm-*` arguments if you don't have access to a slurm cluster.
|
||||
|
||||
```
|
||||
WANDB_TAGS="calculator_final" python benchmark/benchmark.py \
|
||||
--command "python examples/research_projects/tools/calculator.py" \
|
||||
--num-seeds 10 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 8 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
```
|
||||
|
||||
We can then use [`openrlbenchmark`](https://github.com/openrlbenchmark/openrlbenchmark) which generates the following plot.
|
||||
```
|
||||
# pip install openrlbenchmark==0.2.1a5
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=openrlbenchmark&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.tracker_project_name&cen=trl_ppo_trainer_config.value.log_with&metrics=env/reward_mean&metrics=objective/kl' \
|
||||
'wandb?tag=calculator_final&cl=calculator_mask' \
|
||||
--env-ids trl \
|
||||
--check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename static/0compare \
|
||||
--scan-history
|
||||
```
|
||||
|
||||

|
||||
|
||||
As we can see, while 1-2 experiments crashed for some reason, most of the runs obtained near perfect proficiency in the calculator task.
|
||||
|
||||
|
||||
## (Early Experiments 🧪): learning to use a wiki tool for question answering
|
||||
|
||||
In the [ToolFormer](https://huggingface.co/papers/2302.04761) paper, it shows an interesting use case that utilizes a Wikipedia Search tool to help answer questions. In this section, we attempt to perform similar experiments but uses RL instead to teach the model to use a wiki tool on the [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) dataset.
|
||||
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
**Note that many settings are different so the results are not directly comparable.**
|
||||
</Tip>
|
||||
|
||||
|
||||
|
||||
|
||||
### Building a search index
|
||||
|
||||
Since [ToolFormer](https://huggingface.co/papers/2302.04761) did not open source, we needed to first replicate the search index. It is mentioned in their paper that the authors built the search index using a BM25 retriever that indexes the Wikipedia dump from [KILT](https://github.com/facebookresearch/KILT)
|
||||
|
||||
Fortunately, [`pyserini`](https://github.com/castorini/pyserini) already implements the BM25 retriever and provides a prebuilt index for the KILT Wikipedia dump. We can use the following code to search the index.
|
||||
|
||||
```python
|
||||
from pyserini.search.lucene import LuceneSearcher
|
||||
import json
|
||||
searcher = LuceneSearcher.from_prebuilt_index('wikipedia-kilt-doc')
|
||||
def search(query):
|
||||
hits = searcher.search(query, k=1)
|
||||
hit = hits[0]
|
||||
contents = json.loads(hit.raw)['contents']
|
||||
return contents
|
||||
print(search("tennis racket"))
|
||||
```
|
||||
```
|
||||
Racket (sports equipment)
|
||||
A racket or racquet is a sports implement consisting of a handled frame with an open hoop across which a network of strings or catgut is stretched tightly. It is used for striking a ball or shuttlecock in games such as squash, tennis, racquetball, and badminton. Collectively, these games are known as racket sports. Racket design and manufacturing has changed considerably over the centuries.
|
||||
|
||||
The frame of rackets for all sports was traditionally made of solid wood (later laminated wood) and the strings of animal intestine known as catgut. The traditional racket size was limited by the strength and weight of the wooden frame which had to be strong enough to hold the strings and stiff enough to hit the ball or shuttle. Manufacturers started adding non-wood laminates to wood rackets to improve stiffness. Non-wood rackets were made first of steel, then of aluminum, and then carbon fiber composites. Wood is still used for real tennis, rackets, and xare. Most rackets are now made of composite materials including carbon fiber or fiberglass, metals such as titanium alloys, or ceramics.
|
||||
...
|
||||
```
|
||||
|
||||
We then basically deployed this snippet as a Hugging Face space [here](https://huggingface.co/spaces/vwxyzjn/pyserini-wikipedia-kilt-doc), so that we can use the space as a `transformers.Tool` later.
|
||||
|
||||

|
||||
|
||||
### Experiment settings
|
||||
|
||||
We use the following settings:
|
||||
|
||||
* use the `bigcode/starcoderbase` model as the base model
|
||||
* use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragraphs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool.
|
||||
* test if the response contain the answer string, if so, give a reward of 1, otherwise, give a reward of 0.
|
||||
* notice this is a simplified evaluation criteria. In [ToolFormer](https://huggingface.co/papers/2302.04761), the authors checks if the first 20 words of the response contain the correct answer.
|
||||
* used the following prompt that demonstrates the usage of the wiki tool.
|
||||
```python
|
||||
prompt = """\
|
||||
Answer the following question:
|
||||
|
||||
Q: In which branch of the arts is Patricia Neary famous?
|
||||
A: Ballets
|
||||
A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
|
||||
Result=Ballets<submit>
|
||||
|
||||
Q: Who won Super Bowl XX?
|
||||
A: Chicago Bears
|
||||
A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
|
||||
Result=Chicago Bears<submit>
|
||||
|
||||
Q: """
|
||||
```
|
||||
|
||||
|
||||
### Result and Discussion
|
||||
|
||||
|
||||
Our experiments show that the agent can learn to use the wiki tool to answer questions. The learning curves would go up mostly, but one of the experiment did crash.
|
||||
|
||||

|
||||
|
||||
Wandb report is [here](https://wandb.ai/costa-huang/cleanRL/reports/TriviaQA-Final-Experiments--Vmlldzo1MjY0ODk5) for further inspection.
|
||||
|
||||
|
||||
Note that the correct rate of the trained model is on the low end, which could be due to the following reasons:
|
||||
|
||||
* **incorrect searches:** When given the question `"What is Bruce Willis' real first name?"` if the model searches for `Bruce Willis`, our wiki tool returns "Patrick Poivey (born 18 February 1948) is a French actor. He is especially known for his voice: he is the French dub voice of Bruce Willis since 1988.` But a correct search should be `Walter Bruce Willis (born March 19, 1955) is an American former actor. He achieved fame with a leading role on the comedy-drama series Moonlighting (1985–1989) and appeared in over a hundred films, gaining recognition as an action hero after his portrayal of John McClane in the Die Hard franchise (1988–2013) and other roles.[1][2]"
|
||||
|
||||
|
||||

|
||||
|
||||
* **unnecessarily long response**: The wiki tool by default sometimes output very long sequences. E.g., when the wiki tool searches for "Brown Act"
|
||||
* Our wiki tool returns "The Ralph M. Brown Act, located at California Government Code 54950 "et seq.", is an act of the California State Legislature, authored by Assemblymember Ralph M. Brown and passed in 1953, that guarantees the public's right to attend and participate in meetings of local legislative bodies."
|
||||
* [ToolFormer](https://huggingface.co/papers/2302.04761)'s wiki tool returns "The Ralph M. Brown Act is an act of the California State Legislature that guarantees the public's right to attend and participate in meetings of local legislative bodies." which is more succinct.
|
||||
|
||||

|
||||
|
||||
|
||||
## (Early Experiments 🧪): solving math puzzles with python interpreter
|
||||
|
||||
In this section, we attempt to teach the model to use a python interpreter to solve math puzzles. The rough idea is to give the agent a prompt like the following:
|
||||
|
||||
```python
|
||||
prompt = """\
|
||||
Example of using a Python API to solve math questions.
|
||||
|
||||
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
|
||||
|
||||
<request><PythonInterpreter>
|
||||
def solution():
|
||||
money_initial = 23
|
||||
bagels = 5
|
||||
bagel_cost = 3
|
||||
money_spent = bagels * bagel_cost
|
||||
money_left = money_initial - money_spent
|
||||
result = money_left
|
||||
return result
|
||||
print(solution())
|
||||
<call>8<response>
|
||||
|
||||
Result = 8 <submit>
|
||||
|
||||
Q: """
|
||||
```
|
||||
|
||||
|
||||
Training experiment can be found at https://wandb.ai/lvwerra/trl-gsm8k/runs/a5odv01y
|
||||
|
||||

|
@ -1,74 +1,99 @@
|
||||
# Logging
|
||||
|
||||
As reinforcement learning algorithms are historically challenging to debug, it's important to pay careful attention to logging.
|
||||
By default, the TRL [`PPOTrainer`] saves a lot of relevant information to wandb or tensorboard.
|
||||
By default, TRL trainers like [`PPOTrainer`] and [`GRPOTrainer`] save a lot of relevant information to supported experiment trackers like Weights & Biases (wandb) or TensorBoard.
|
||||
|
||||
Upon initialization, pass one of these two options to the [`PPOConfig`]:
|
||||
Upon initialization, pass the `report_to` argument to the respective configuration object (e.g., [`PPOConfig`] for `PPOTrainer`, or [`GRPOConfig`] for `GRPOTrainer`):
|
||||
|
||||
```
|
||||
training_args = PPOConfig(..., report_to="wandb") # or "tensorboard"
|
||||
```python
|
||||
# For PPOTrainer
|
||||
ppo_config = PPOConfig(
|
||||
# ...,
|
||||
report_to="wandb" # or "tensorboard"
|
||||
)
|
||||
|
||||
# For GRPOTrainer
|
||||
grpc_config = GRPOConfig(
|
||||
# ...,
|
||||
report_to="wandb" # or "tensorboard"
|
||||
)
|
||||
```
|
||||
|
||||
If you want to log with tensorboard, add the kwarg `project_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig.
|
||||
If you want to log with TensorBoard, you might also need to specify logging directories, for example, by adding `logging_dir=PATH_TO_LOGS` to the configuration object (e.g., `PPOConfig` or `GRPOConfig`).
|
||||
|
||||
## PPO Logging
|
||||
|
||||
Here's a brief explanation for the logged metrics provided in the data:
|
||||
|
||||
Key metrics to monitor. We want to maximize the reward, maintain a low KL divergence, and maximize entropy:
|
||||
1. `env/reward_mean`: The average reward obtained from the environment. Alias `ppo/mean_scores`, which is used to specifically monitor the reward model.
|
||||
1. `env/reward_std`: The standard deviation of the reward obtained from the environment. Alias ``ppo/std_scores`, which is used to specifically monitor the reward model.
|
||||
1. `env/reward_dist`: The histogram distribution of the reward obtained from the environment.
|
||||
1. `objective/kl`: The mean Kullback-Leibler (KL) divergence between the old and new policies. It measures how much the new policy deviates from the old policy. The KL divergence is used to compute the KL penalty in the objective function.
|
||||
1. `objective/kl_dist`: The histogram distribution of the `objective/kl`.
|
||||
1. `objective/kl_coef`: The coefficient for Kullback-Leibler (KL) divergence in the objective function.
|
||||
1. `ppo/mean_non_score_reward`: The **KL penalty** calculated by `objective/kl * objective/kl_coef` as the total reward for optimization to prevent the new policy from deviating too far from the old policy.
|
||||
1. `objective/entropy`: The entropy of the model's policy, calculated by `-logprobs.sum(-1).mean()`. High entropy means the model's actions are more random, which can be beneficial for exploration.
|
||||
|
||||
Training stats:
|
||||
1. `ppo/learning_rate`: The learning rate for the PPO algorithm.
|
||||
1. `ppo/policy/entropy`: The entropy of the model's policy, calculated by `pd = torch.nn.functional.softmax(logits, dim=-1); entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)`. It measures the randomness of the policy.
|
||||
1. `ppo/policy/clipfrac`: The fraction of probability ratios (old policy / new policy) that fell outside the clipping range in the PPO objective. This can be used to monitor the optimization process.
|
||||
1. `ppo/policy/approxkl`: The approximate KL divergence between the old and new policies, measured by `0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)`, corresponding to the `k2` estimator in http://joschu.net/blog/kl-approx.html
|
||||
1. `ppo/policy/policykl`: Similar to `ppo/policy/approxkl`, but measured by `masked_mean(old_logprobs - logprobs, mask)`, corresponding to the `k1` estimator in http://joschu.net/blog/kl-approx.html
|
||||
1. `ppo/policy/ratio`: The histogram distribution of the ratio between the new and old policies, used to compute the PPO objective.
|
||||
1. `ppo/policy/advantages_mean`: The average of the GAE (Generalized Advantage Estimation) advantage estimates. The advantage function measures how much better an action is compared to the average action at a state.
|
||||
1. `ppo/policy/advantages`: The histogram distribution of `ppo/policy/advantages_mean`.
|
||||
1. `ppo/returns/mean`: The mean of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance. See https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ for more details.
|
||||
1. `ppo/returns/var`: The variance of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance.
|
||||
1. `ppo/val/mean`: The mean of the values, used to monitor the value function's performance.
|
||||
1. `ppo/val/var` : The variance of the values, used to monitor the value function's performance.
|
||||
1. `ppo/val/var_explained`: The explained variance for the value function, used to monitor the value function's performance.
|
||||
1. `ppo/val/clipfrac`: The fraction of the value function's predicted values that are clipped.
|
||||
1. `ppo/val/vpred`: The predicted values from the value function.
|
||||
1. `ppo/val/error`: The mean squared error between the `ppo/val/vpred` and returns, used to monitor the value function's performance.
|
||||
1. `ppo/loss/policy`: The policy loss for the Proximal Policy Optimization (PPO) algorithm.
|
||||
1. `ppo/loss/value`: The loss for the value function in the PPO algorithm. This value quantifies how well the function estimates the expected future rewards.
|
||||
1. `ppo/loss/total`: The total loss for the PPO algorithm. It is the sum of the policy loss and the value function loss.
|
||||
|
||||
|
||||
Stats on queries, responses, and logprobs:
|
||||
1. `tokens/queries_len_mean`: The average length of the queries tokens.
|
||||
1. `tokens/queries_len_std`: The standard deviation of the length of the queries tokens.
|
||||
1. `tokens/queries_dist`: The histogram distribution of the length of the queries tokens.
|
||||
1. `tokens/responses_len_mean`: The average length of the responses tokens.
|
||||
1. `tokens/responses_len_std`: The standard deviation of the length of the responses tokens.
|
||||
1. `tokens/responses_dist`: The histogram distribution of the length of the responses tokens. (Costa: inconsistent naming, should be `tokens/responses_len_dist`)
|
||||
1. `objective/logprobs`: The histogram distribution of the log probabilities of the actions taken by the model.
|
||||
1. `objective/ref_logprobs`: The histogram distribution of the log probabilities of the actions taken by the reference model.
|
||||
|
||||
|
||||
* `eps`: Tracks the number of episodes per second.
|
||||
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
|
||||
* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
|
||||
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
|
||||
* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
|
||||
* `objective/scores`: The mean scores returned by the reward model / environment.
|
||||
* `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
|
||||
* `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
|
||||
* `loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
|
||||
* `loss/value_avg`: The average value loss, indicating the difference between the predicted value and the actual reward.
|
||||
* `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to `policy/clipfrac_avg` but for the value function.
|
||||
* `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
|
||||
* `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
|
||||
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
|
||||
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
|
||||
* `lr`: The current learning rate used by the optimizer.
|
||||
* `episode`: The current episode count in the training process.
|
||||
|
||||
### Crucial values
|
||||
During training, many values are logged, here are the most important ones:
|
||||
|
||||
1. `env/reward_mean`,`env/reward_std`, `env/reward_dist`: the properties of the reward distribution from the "environment" / reward model
|
||||
1. `ppo/mean_non_score_reward`: The mean negated KL penalty during training (shows the delta between the reference model and the new policy over the batch in the step)
|
||||
1. `objective/scores`: The mean scores returned by the reward model / environment.
|
||||
1. `objective/rlhf_reward`: The mean RLHF reward. This is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
|
||||
1. `objective/non_score_reward`: The mean reward from non-score-related sources (e.g., KL penalty).
|
||||
|
||||
Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables):
|
||||
|
||||
1. `ppo/loss/value`: it will spike / NaN when not going well.
|
||||
1. `ppo/policy/ratio`: `ratio` being 1 is a baseline value, meaning that the probability of sampling a token is the same under the new and old policy. If the ratio is too high like 200, it means the probability of sampling a token is 200 times higher under the new policy than the old policy. This is a sign that the new policy is too different from the old policy, which will likely cause overoptimization and collapse training later on.
|
||||
1. `ppo/policy/clipfrac` and `ppo/policy/approxkl`: if `ratio` is too high, the `ratio` is going to get clipped, resulting in high `clipfrac` and high `approxkl` as well.
|
||||
1. `objective/kl`: it should stay positive so that the policy is not too far away from the reference policy.
|
||||
1. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities.
|
||||
1. `loss/value_avg`: The average value loss. It will spike / NaN when not going well.
|
||||
1. `val/ratio`: The mean ratio of the current policy probability to the old policy probability. This number should float around 1.0. If this `ratio` is too high (e.g., 2.0 or 1000.0) or too small (e.g., 0.1), it means the updates between consecutive policies are too drastic.
|
||||
1. `policy/clipfrac_avg` and `policy/approxkl_avg`: If `val/ratio` is too high, the `ratio` is going to get clipped, resulting in high `policy/clipfrac_avg` and high `policy/approxkl_avg` as well.
|
||||
1. `objective/kl`: The mean KL divergence. It should stay positive and ideally not too large, so that the policy is not too far away from the reference policy.
|
||||
|
||||
## GRPO Logging
|
||||
|
||||
Here's a brief explanation for the logged metrics provided in the data for the GRPO trainer:
|
||||
|
||||
* `num_tokens`: Total number of input tokens processed during training so far.
|
||||
|
||||
**Completions:**
|
||||
* `completions/mean_length`: Mean length of all generated completions (including those not ending with an EOS token).
|
||||
* `completions/min_length`: Minimum length among all generated completions.
|
||||
* `completions/max_length`: Maximum length among all generated completions.
|
||||
* `completions/clipped_ratio`: The ratio of completions that did not end with an EOS token before reaching the maximum generation length (i.e., they were truncated).
|
||||
* `completions/mean_terminated_length`: Mean length of only those completions that successfully ended with an EOS token.
|
||||
* `completions/min_terminated_length`: Minimum length among completions that ended with an EOS token.
|
||||
* `completions/max_terminated_length`: Maximum length among completions that ended with an EOS token.
|
||||
|
||||
**Rewards:**
|
||||
* `rewards/{reward_func_name}/mean`: The mean reward obtained from a specific, named reward function (e.g., `rewards/my_custom_reward/mean`). This is logged for each reward function used.
|
||||
* `rewards/{reward_func_name}/std`: The standard deviation of rewards from a specific, named reward function.
|
||||
* `reward`: The overall mean of the (potentially weighted and, if `args.scale_rewards` is true, normalized) rewards, after group-wise normalization (advantages).
|
||||
* `reward_std`: The standard deviation of the (potentially weighted) rewards *before* group-wise normalization for advantages.
|
||||
|
||||
**Policy and Loss Metrics:**
|
||||
* `kl`: The mean Kullback-Leibler (KL) divergence between the current policy and the reference policy. This is logged only if `beta` (the KL coefficient in `GRPOConfig`) is non-zero.
|
||||
* If Liger GRPOLoss is used (`use_liger_loss: True` in `GRPOConfig`):
|
||||
* `clip_ratio`: The fraction of policy updates where the probability ratio was clipped according to the GRPO loss's epsilon bounds.
|
||||
* If standard GRPOLoss is used (`use_liger_loss: False`):
|
||||
* `clip_ratio/low_mean`: The mean fraction of instances where the probability ratio `r_t(θ)` was clipped at the lower bound `1 - epsilon_low` (occurs when advantage is negative and ratio is below the bound).
|
||||
* `clip_ratio/low_min`: The minimum observed fraction for `clip_ratio/low_mean` across batches/processes.
|
||||
* `clip_ratio/high_mean`: The mean fraction of instances where the probability ratio `r_t(θ)` was clipped at the upper bound `1 + epsilon_high` (occurs when advantage is positive and ratio is above the bound).
|
||||
* `clip_ratio/high_max`: The maximum observed fraction for `clip_ratio/high_mean` across batches/processes.
|
||||
* `clip_ratio/region_mean`: The mean fraction of instances where the probability ratio was clipped at either the lower or upper bound.
|
||||
|
||||
### Crucial GRPO values
|
||||
During GRPO training, monitor these values for insights into performance and stability:
|
||||
|
||||
1. `reward`: This is the primary objective. It reflects the (group-wise normalized) rewards the policy is achieving. It should generally increase during successful training.
|
||||
1. `kl`: If `beta > 0`, this tracks the divergence from the reference model. Keep an eye on it to ensure the policy doesn't stray too far, which can lead to instability.
|
||||
1. `clip_ratio/*` (either `clip_ratio` for Liger loss or the more detailed `clip_ratio/...` metrics for standard loss): These indicate how often the policy updates are being constrained by the GRPO clipping mechanism. Very high values might suggest that the policy is trying to change too drastically (potentially due to large advantages or a learning rate that's too high) or that the epsilon clipping range is too restrictive.
|
||||
1. `completions/clipped_ratio`: A high ratio here indicates that the model is frequently generating completions that are cut off by `max_completion_length` rather than naturally ending with an EOS token. This might suggest issues with learning sequence termination or that `max_completion_length` is too short.
|
||||
1. `rewards/{reward_func_name}/mean`: Monitoring the mean of individual reward functions can help diagnose which aspects of the desired behavior the model is learning or struggling with, especially when using multiple reward sources.
|
||||
|
5
docs/source/model_utils.md
Normal file
5
docs/source/model_utils.md
Normal file
@ -0,0 +1,5 @@
|
||||
# Model Utilities
|
||||
|
||||
## get_act_offloading_ctx_manager
|
||||
|
||||
[[autodoc]] models.get_act_offloading_ctx_manager
|
@ -51,9 +51,9 @@ accelerate launch train_nash_md.py
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 3 hours.
|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-NashMD) performs, you can use the [TRL Chat CLI](clis#chat-interface).
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-NashMD) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-NashMD
|
||||
<pre><code>$ transformers-cli chat --model_name_or_path trl-lib/Qwen2-0.5B-NashMD
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
|
@ -53,9 +53,9 @@ Distributed across 8 GPUs, the training takes approximately 1 hour. You can veri
|
||||
|
||||

|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-OnlineDPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-OnlineDPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-OnlineDPO
|
||||
<pre><code>$ transformers-cli chat --model_name_or_path trl-lib/Qwen2-0.5B-OnlineDPO
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
|
@ -56,9 +56,9 @@ Distributed across 8 GPUs, the training takes approximately 30 minutes. You can
|
||||
|
||||

|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-ORPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-ORPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-ORPO
|
||||
<pre><code>$ transformers-cli chat --model_name_or_path trl-lib/Qwen2-0.5B-ORPO
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
|
9
docs/source/others.md
Normal file
9
docs/source/others.md
Normal file
@ -0,0 +1,9 @@
|
||||
# Other
|
||||
|
||||
## profiling_decorator
|
||||
|
||||
[[autodoc]] extras.profiling.profiling_decorator
|
||||
|
||||
## profiling_context
|
||||
|
||||
[[autodoc]] extras.profiling.profiling_context
|
@ -52,7 +52,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an
|
||||
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
|
||||
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
|
||||
* `lr`: lr: The current learning rate used by the optimizer.
|
||||
* `episode`: episode: The current global step or episode count in the training process.
|
||||
* `episode`: episode: The current episode count in the training process.
|
||||
|
||||
|
||||
## Cookbook
|
||||
|
@ -14,7 +14,7 @@ Sequence lengths in the dataset can vary widely. When data is batched, sequences
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/why_you_should_truncate.png" alt="Truncation prompt completion" width="600"/>
|
||||
</div>
|
||||
|
||||
To reduce memory usage, it’s important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case.
|
||||
To reduce memory usage, it's important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case.
|
||||
|
||||
<hfoptions id="dpo">
|
||||
<hfoption id="DPO">
|
||||
@ -44,7 +44,7 @@ training_args = DPOConfig(..., max_completion_length=...)
|
||||
</hfoption>
|
||||
<hfoption id="SFT">
|
||||
|
||||
SFT truncation is applied to the input sequence via the `max_seq_length` parameter.
|
||||
SFT truncation is applied to the input sequence via the `max_length` parameter.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/truncation_input_ids.png" alt="Truncation input ids" width="600"/>
|
||||
@ -55,7 +55,7 @@ To set the truncation parameter, use the following code snippet:
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(..., max_seq_length=...)
|
||||
training_args = SFTConfig(..., max_length=...)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
@ -85,7 +85,7 @@ Packing eliminates padding, preserves all sequence information, and allows for f
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(..., packing=True, max_seq_length=512)
|
||||
training_args = SFTConfig(..., packing=True, max_length=512)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
@ -94,6 +94,77 @@ Packing may cause batch contamination, where adjacent sequences influence one an
|
||||
|
||||
</Tip>
|
||||
|
||||
## Padding-free
|
||||
|
||||
Padding-free batching is an alternative approach for reducing memory usage. In this method, a batch is first sampled and then flattened into a single sequence, avoiding padding. Unlike packing, which can result in incomplete sequences by combining parts of different samples, padding-free batching ensures that all sequences remain complete and intact.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/padding-free.png" alt="Padding-free batching" width="600"/>
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
It's highly recommended to use padding-free batching with **Flash Attention 2**. Otherwise, you may encounter batch contamination issues.
|
||||
|
||||
</Tip>
|
||||
|
||||
<hfoptions id="padding-free">
|
||||
<hfoption id="DPO">
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"})
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="SFT">
|
||||
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"})
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Activation offloading
|
||||
|
||||
Activation offloading is a memory efficiency technique that reduces GPU VRAM usage by temporarily moving activation tensors to CPU RAM during the forward pass and bringing them back only when needed for the backward pass. This significantly reduces peak memory usage at the cost of slightly increased training time.
|
||||
|
||||
To enable activation offloading in your SFT training configuration:
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="SFT">
|
||||
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(..., activation_offloading=True)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
When using activation offloading with models that use Liger kernels, you must disable Liger cross entropy due to compatibility issues. The issue occurs specifically with `use_liger_kernel=True` because Liger cross entropy performs in-place operations which conflict with activation offloading. The default setting (`use_liger_kernel=False`) works:
|
||||
|
||||
```python
|
||||
# When using activation offloading with a model that uses Liger kernels:
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(
|
||||
activation_offloading=True,
|
||||
use_liger_kernel=False, # Disable Liger cross entropy
|
||||
# Other parameters...
|
||||
)
|
||||
```
|
||||
</Tip>
|
||||
|
||||
Under the hood, activation offloading implements PyTorch's [`saved_tensors_hooks`](https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html#hooks-for-autograd-saved-tensors) to intercept activations during the forward pass. It intelligently manages which tensors to offload based on size and context, avoiding offloading output tensors which would be inefficient. For performance optimization, it can optionally use CUDA streams to overlap computation with CPU-GPU transfers.
|
||||
|
||||
## Disabling model gathering for generation in online methods
|
||||
|
||||
When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Online methods involve generating completions from the model as part of the training process. During this step, the model weights are temporarily gathered on a single GPU for generation. For very large models, this gathering can lead to out-of-memory (OOM) errors, as described in this issue: [#2250](https://github.com/huggingface/trl/issues/2250#issue-2598304204).
|
||||
@ -101,6 +172,15 @@ When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Onl
|
||||
If you encounter this issue, you can disable the gathering of model weights for generation by setting the following parameter:
|
||||
|
||||
<hfoptions id="ds3_gather_for_generation">
|
||||
<hfoption id="GRPO">
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(..., ds3_gather_for_generation=False)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Online DPO">
|
||||
|
||||
```python
|
||||
|
9
docs/source/rewards.md
Normal file
9
docs/source/rewards.md
Normal file
@ -0,0 +1,9 @@
|
||||
# Reward Functions
|
||||
|
||||
This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`].
|
||||
|
||||
## Format rewards
|
||||
|
||||
### think_format_reward
|
||||
|
||||
[[autodoc]] rewards.think_format_reward
|
@ -2,10 +2,7 @@
|
||||
|
||||
[](https://huggingface.co/models?other=sft,trl) [](https://github.com/huggingface/smol-course/tree/main/1_instruction_tuning)
|
||||
|
||||
Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset.
|
||||
|
||||
Check out a complete flexible example at [`trl/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/trl/scripts/sft.py).
|
||||
Experimental support for Vision Language Models is also included in the example [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft_vlm.py).
|
||||
Supervised fine-tuning (SFT) is the most common step in post-training foundation models, and also one of the most effective. In TRL, we provide a simple API to train models with SFT in a few lines of code; for a complete training script, check out [`trl/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/trl/scripts/sft.py). Experimental support for Vision Language Models is also included in [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft_vlm.py).
|
||||
|
||||
## Quickstart
|
||||
|
||||
@ -19,7 +16,7 @@ from trl import SFTConfig, SFTTrainer
|
||||
dataset = load_dataset("stanfordnlp/imdb", split="train")
|
||||
|
||||
training_args = SFTConfig(
|
||||
max_seq_length=512,
|
||||
max_length=512,
|
||||
output_dir="/tmp",
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
@ -29,7 +26,7 @@ trainer = SFTTrainer(
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
Make sure to pass the correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
|
||||
Make sure to pass the correct value for `max_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
|
||||
|
||||
You can also construct a model outside of the trainer and pass it as follows:
|
||||
|
||||
@ -53,123 +50,24 @@ trainer = SFTTrainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
The above snippets will use the default training arguments from the [`SFTConfig`] class. If you want to modify the defaults pass in your modification to the `SFTConfig` constructor and pass them to the trainer via the `args` argument.
|
||||
The above snippets will use the default training arguments from the [`SFTConfig`] class. If you want to modify the defaults, pass in your modification to the `SFTConfig` constructor and pass it to the trainer via the `args` argument.
|
||||
|
||||
## Advanced usage
|
||||
|
||||
### Train on completions only
|
||||
|
||||
You can use the `DataCollatorForCompletionOnlyLM` to train your model on the generated prompts only. Note that this works only in the case when `packing=False`.
|
||||
To instantiate that collator for instruction data, pass a response template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on completions only on the CodeAlpaca dataset:
|
||||
To train on completions only, simply use a [prompt-completion](dataset_formats#prompt-completion) dataset. In this mode, loss is computed solely on the completion part.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from datasets import load_dataset
|
||||
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
|
||||
|
||||
dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
|
||||
def formatting_prompts_func(example):
|
||||
output_texts = []
|
||||
for i in range(len(example['instruction'])):
|
||||
text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
|
||||
output_texts.append(text)
|
||||
return output_texts
|
||||
|
||||
response_template = " ### Answer:"
|
||||
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
train_dataset=dataset,
|
||||
args=SFTConfig(output_dir="/tmp"),
|
||||
formatting_func=formatting_prompts_func,
|
||||
data_collator=collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
To instantiate that collator for assistant style conversation data, pass a response template, an instruction template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on assistant completions only on the Open Assistant Guanaco dataset:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from datasets import load_dataset
|
||||
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
|
||||
|
||||
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
|
||||
instruction_template = "### Human:"
|
||||
response_template = "### Assistant:"
|
||||
collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template, response_template=response_template, tokenizer=tokenizer, mlm=False)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
args=SFTConfig(output_dir="/tmp"),
|
||||
train_dataset=dataset,
|
||||
data_collator=collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Make sure to have a `pad_token_id` which is different from `eos_token_id` which can result in the model not properly predicting EOS (End of Sentence) tokens during generation.
|
||||
|
||||
#### Using token_ids directly for `response_template`
|
||||
|
||||
Some tokenizers like Llama 2 (`meta-llama/Llama-2-XXb-hf`) tokenize sequences differently depending on whether they have context or not. For example:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
|
||||
def print_tokens_with_ids(txt):
|
||||
tokens = tokenizer.tokenize(txt, add_special_tokens=False)
|
||||
token_ids = tokenizer.encode(txt, add_special_tokens=False)
|
||||
print(list(zip(tokens, token_ids)))
|
||||
|
||||
prompt = """### User: Hello\n\n### Assistant: Hi, how can I help you?"""
|
||||
print_tokens_with_ids(prompt) # [..., ('▁Hello', 15043), ('<0x0A>', 13), ('<0x0A>', 13), ('##', 2277), ('#', 29937), ('▁Ass', 4007), ('istant', 22137), (':', 29901), ...]
|
||||
|
||||
response_template = "### Assistant:"
|
||||
print_tokens_with_ids(response_template) # [('▁###', 835), ('▁Ass', 4007), ('istant', 22137), (':', 29901)]
|
||||
```
|
||||
|
||||
In this case, and due to lack of context in `response_template`, the same string ("### Assistant:") is tokenized differently:
|
||||
|
||||
- Text (with context): `[2277, 29937, 4007, 22137, 29901]`
|
||||
- `response_template` (without context): `[835, 4007, 22137, 29901]`
|
||||
|
||||
This will lead to an error when the `DataCollatorForCompletionOnlyLM` does not find the `response_template` in the dataset example text:
|
||||
|
||||
```
|
||||
RuntimeError: Could not find response key [835, 4007, 22137, 29901] in token IDs tensor([ 1, 835, ...])
|
||||
```
|
||||
|
||||
|
||||
To solve this, you can tokenize the `response_template` with the same context as in the dataset, truncate it as needed and pass the `token_ids` directly to the `response_template` argument of the `DataCollatorForCompletionOnlyLM` class. For example:
|
||||
|
||||
```python
|
||||
response_template_with_context = "\n### Assistant:" # We added context here: "\n". This is enough for this tokenizer
|
||||
response_template_ids = tokenizer.encode(response_template_with_context, add_special_tokens=False)[2:] # Now we have it like in the dataset texts: `[2277, 29937, 4007, 22137, 29901]`
|
||||
|
||||
data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)
|
||||
```
|
||||
If you’d like to compute loss on both the prompt **and** the completion while still using a prompt-completion dataset, set `completion_only_loss=False` in the [`SFTConfig`]. This is equivalent to [converting the dataset to a language modeling](dataset_formats#from-prompt-completion-to-language-modeling-dataset) format.
|
||||
|
||||
### Add Special Tokens for Chat Format
|
||||
|
||||
Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment.
|
||||
Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system, and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment.
|
||||
The [`setup_chat_format`] function in `trl` easily sets up a model and tokenizer for conversational AI tasks. This function:
|
||||
- Adds special tokens to the tokenizer, e.g. `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
|
||||
- Adds special tokens to the tokenizer, e.g., `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
|
||||
- Resizes the model’s embedding layer to accommodate the new tokens.
|
||||
- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI.
|
||||
- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g. 64. If you want to see more formats being supported in the future, please open a GitHub issue on [trl](https://github.com/huggingface/trl)
|
||||
- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g., `64`. If you want to see more formats being supported in the future, please open a GitHub issue on [trl](https://github.com/huggingface/trl)
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
@ -179,11 +77,13 @@ from trl import setup_chat_format
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
|
||||
# Set up the chat format with default 'chatml' format
|
||||
# Set up the chat format with the default 'chatml' format
|
||||
model, tokenizer = setup_chat_format(model, tokenizer)
|
||||
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Some base models, like those from Qwen, have a predefined chat template in the model's tokenizer. In these cases, it is not necessary to apply `setup_chat_format()`, as the tokenizer already handles the formatting. However, it is necessary to align the EOS token with the chat template to ensure the model's responses terminate correctly. In these cases, specify `eos_token` in `SFTConfig`; for example, for `Qwen/Qwen2.5-1.5B`, one should set `eos_token="<|im_end|>"`.
|
||||
|
||||
With our model and tokenizer set up, we can now fine-tune our model on a conversational dataset. Below is an example of how a dataset can be formatted for fine-tuning.
|
||||
|
||||
### Dataset format support
|
||||
@ -226,7 +126,7 @@ trainer = SFTTrainer(
|
||||
)
|
||||
```
|
||||
|
||||
If the dataset is not in one of those format you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look.
|
||||
If the dataset is not in one of those formats, you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look.
|
||||
|
||||
|
||||
### Format your input prompts
|
||||
@ -246,26 +146,23 @@ Let us assume your dataset has two fields, `question` and `answer`. Therefore yo
|
||||
```python
|
||||
...
|
||||
def formatting_prompts_func(example):
|
||||
output_texts = []
|
||||
for i in range(len(example['question'])):
|
||||
text = f"### Question: {example['question'][i]}\n ### Answer: {example['answer'][i]}"
|
||||
output_texts.append(text)
|
||||
return output_texts
|
||||
return f"### Question: {example['question']}\n ### Answer: {example['answer']}"
|
||||
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
formatting_func=formatting_prompts_func,
|
||||
formatting_func=formatting_prompt_func,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
To properly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example of how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
|
||||
To properly format your input, make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example of how to use SFTTrainer on the alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
|
||||
|
||||
### Packing dataset ([`ConstantLengthDataset`])
|
||||
### Packing dataset
|
||||
|
||||
[`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. This is done with the [`ConstantLengthDataset`] utility class that returns constant length chunks of tokens from a stream of examples. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTConfig`] constructor.
|
||||
[`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTConfig`] constructor.
|
||||
|
||||
```python
|
||||
...
|
||||
@ -280,12 +177,12 @@ trainer = SFTTrainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Note that if you use a packed dataset and if you pass `max_steps` in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing.
|
||||
Note that if you use a packed dataset and if you pass `max_steps` in the training arguments, you will probably train your models for more than a few epochs, depending on the way you have configured the packed dataset and the training protocol. Double-check that you know and understand what you are doing.
|
||||
If you don't want to pack your `eval_dataset`, you can pass `eval_packing=False` to the `SFTConfig` init method.
|
||||
|
||||
#### Customize your prompts using packed dataset
|
||||
|
||||
If your dataset has several fields that you want to combine, for example if the dataset has `question` and `answer` fields and you want to combine them, you can pass a formatting function to the trainer that will take care of that. For example:
|
||||
If your dataset has several fields that you want to combine, for example, if the dataset has `question` and `answer` fields and you want to combine them, you can pass a formatting function to the trainer that will take care of that. For example:
|
||||
|
||||
```python
|
||||
def formatting_func(example):
|
||||
@ -302,7 +199,6 @@ trainer = SFTTrainer(
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
You can also customize the [`ConstantLengthDataset`] much more by directly passing the arguments to the [`SFTConfig`] constructor. Please refer to that class' signature for more information.
|
||||
|
||||
### Control over the pretrained model
|
||||
|
||||
@ -360,7 +256,7 @@ trainer.train()
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> If the chat template contains special tokens like `<|im_start|>` (ChatML) or `<|eot_id|>` (Llama), the embedding layer and LM head must be included in the trainable parameters via the `modules_to_save` argument. Without this, the fine-tuned model will produce unbounded or nonsense generations. If the chat template doesn't contain special tokens (e.g. Alpaca), then the `modules_to_save` argument can be ignored or set to `None`.
|
||||
> If the chat template contains special tokens like `<|im_start|>` (ChatML) or `<|eot_id|>` (Llama), the embedding layer and LM head must be included in the trainable parameters via the `modules_to_save` argument. Without this, the fine-tuned model will produce unbounded or nonsensical generations. If the chat template doesn't contain special tokens (e.g., Alpaca), then the `modules_to_save` argument can be ignored or set to `None`.
|
||||
|
||||
|
||||
You can also continue training your `PeftModel`. For that, first load a `PeftModel` outside `SFTTrainer` and pass it directly to the trainer without the `peft_config` argument being passed.
|
||||
@ -425,15 +321,15 @@ Once you have loaded your model, wrap the `trainer.train()` call under the `with
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Note that you cannot train your model using Flash Attention 1 on an arbitrary dataset as `torch.scaled_dot_product_attention` does not support training with padding tokens if you use Flash Attention kernels. Therefore you can only use that feature with `packing=True`. If your dataset contains padding tokens, consider switching to Flash Attention 2 integration.
|
||||
Note that you cannot train your model using Flash Attention 1 on an arbitrary dataset as `torch.scaled_dot_product_attention` does not support training with padding tokens if you use Flash Attention kernels. Therefore, you can only use that feature with `packing=True`. If your dataset contains padding tokens, consider switching to Flash Attention 2 integration.
|
||||
|
||||
Below are some numbers you can get in terms of speedup and memory efficiency, using Flash Attention 1, on a single NVIDIA-T4 16GB.
|
||||
|
||||
| use_flash_attn_1 | model_name | max_seq_len | batch_size | time per training step |
|
||||
| ---------------- | ----------------- | ----------- | ---------- | ---------------------- |
|
||||
| x | facebook/opt-350m | 2048 | 8 | ~59.1s |
|
||||
| ✓ | facebook/opt-350m | 2048 | 8 | ~59.1s |
|
||||
| | facebook/opt-350m | 2048 | 8 | **OOM** |
|
||||
| x | facebook/opt-350m | 2048 | 4 | ~30.3s |
|
||||
| ✓ | facebook/opt-350m | 2048 | 4 | ~30.3s |
|
||||
| | facebook/opt-350m | 2048 | 4 | ~148.9s |
|
||||
|
||||
### Using Flash Attention-2
|
||||
@ -455,12 +351,12 @@ model = AutoModelForCausalLM.from_pretrained(
|
||||
```
|
||||
|
||||
If you don't use quantization, make sure your model is loaded in half-precision and dispatch your model on a supported GPU device.
|
||||
After loading your model, you can either train it as it is, or attach adapters and train adapters on it in case your model is quantized.
|
||||
After loading your model, you can either train it as it is or attach adapters and train adapters on it in case your model is quantized.
|
||||
|
||||
In contrast to Flash Attention 1, the integration makes it possible to train your model on an arbitrary dataset that also includes padding tokens.
|
||||
|
||||
|
||||
### Using model creation utility
|
||||
### Using the model creation utility
|
||||
|
||||
We included a utility function to create your model.
|
||||
|
||||
@ -495,17 +391,17 @@ trainer = SFTTrainer(
|
||||
)
|
||||
```
|
||||
|
||||
### Enhance the model's performances using NEFTune
|
||||
### Enhance the model's performance using NEFTune
|
||||
|
||||
NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://huggingface.co/papers/2310.05914) from Jain et al. it consists of adding noise to the embedding vectors during training. According to the abstract of the paper:
|
||||
NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://huggingface.co/papers/2310.05914) from Jain et al. It consists of adding noise to the embedding vectors during training. According to the abstract of the paper:
|
||||
|
||||
> Standard finetuning of LLaMA-2-7B using Alpaca achieves 29.79% on AlpacaEval, which rises to 64.69% using noisy embeddings. NEFTune also improves over strong baselines on modern instruction datasets. Models trained with Evol-Instruct see a 10% improvement, with ShareGPT an 8% improvement, and with OpenPlatypus an 8% improvement. Even powerful models further refined with RLHF such as LLaMA-2-Chat benefit from additional training with NEFTune.
|
||||
> Standard finetuning of LLaMA-2-7B using Alpaca achieves 29.79% on AlpacaEval, which rises to 64.69% using noisy embeddings. NEFTune also improves over strong baselines on modern instruction datasets. Models trained with Evol-Instruct see a 10% improvement, with ShareGPT an 8% improvement, and with OpenPlatypus an 8% improvement. Even powerful models further refined with RLHF, such as LLaMA-2-Chat, benefit from additional training with NEFTune.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/neft-screenshot.png">
|
||||
</div>
|
||||
|
||||
To use it in `SFTTrainer` simply pass `neftune_noise_alpha` when creating your `SFTConfig` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to retrieve back the original behaviour of the embedding layer.
|
||||
To use it in `SFTTrainer`, simply pass `neftune_noise_alpha` when creating your `SFTConfig` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to revert to the original behaviour of the embedding layer.
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
@ -534,7 +430,7 @@ Note however, that the amount of performance gain is _dataset dependent_ and in
|
||||
|
||||
### Accelerate fine-tuning 2x using `unsloth`
|
||||
|
||||
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks on 1x A100 listed below:
|
||||
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently, `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek, etc) and Mistral architectures. Some benchmarks on 1x A100 listed below:
|
||||
|
||||
| 1 A100 40GB | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
|
||||
| --------------- | --------- | --- | --------------------- | --------- | ------------ |
|
||||
@ -543,19 +439,19 @@ You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [
|
||||
| Mistral 7b | Slim Orca | 1x | 1.17x | **1.88x** | -65.9% |
|
||||
| Tiny Llama 1.1b | Alpaca | 1x | 1.55x | **2.74x** | -57.8% |
|
||||
|
||||
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
|
||||
First, install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number
|
||||
max_length = 2048 # Supports automatic RoPE Scaling, so choose any number
|
||||
|
||||
# Load model
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="unsloth/mistral-7b",
|
||||
max_seq_length=max_seq_length,
|
||||
max_seq_length=max_length,
|
||||
dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
||||
load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False
|
||||
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
|
||||
@ -581,7 +477,7 @@ model = FastLanguageModel.get_peft_model(
|
||||
random_state=3407,
|
||||
)
|
||||
|
||||
training_args = SFTConfig(output_dir="./output", max_seq_length=max_seq_length)
|
||||
training_args = SFTConfig(output_dir="./output", max_length=max_length)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
@ -593,28 +489,29 @@ trainer.train()
|
||||
|
||||
The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).
|
||||
|
||||
## Liger-Kernel: Increase 20% throughput and reduces 60% memory for multi-GPU training
|
||||
## Liger-Kernel: Increase 20% throughput and reduce 60% memory for multi-GPU training
|
||||
|
||||
[Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduces memory usage by 60%. That way, we can **4x** our context length, as described in the benchmark below. They have implemented Hugging Face Compatible `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed).
|
||||
[Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduce memory usage by 60%. That way, we can **4x** our context length, as described in the benchmark below. They have implemented Hugging Face Compatible `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed).
|
||||
|
||||
With great memory reduction, you can potentially turn off cpu_offloading or gradient checkpointing to further boost the performance.
|
||||
With this memory reduction, you can potentially turn off `cpu_offloading` or gradient checkpointing to further boost the performance.
|
||||
|
||||
| Speed Up | Memory Reduction |
|
||||
|--------------------------|-------------------------|
|
||||
|  |  |
|
||||
|
||||
|
||||
1. To use Liger-Kernel in `SFTTrainer`, first install by
|
||||
1. To use Liger-Kernel in [`SFTTrainer`], first install it by:
|
||||
|
||||
```bash
|
||||
pip install liger-kernel
|
||||
```
|
||||
|
||||
2. Once installed, set `use_liger` in [`SFTConfig`]. No other changes are needed!
|
||||
2. Once installed, set `use_liger_kernel` in [`SFTConfig`]. No other changes are needed!
|
||||
|
||||
```python
|
||||
training_args = SFTConfig(
|
||||
use_liger=True
|
||||
use_liger_kernel=True,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
@ -624,14 +521,14 @@ To learn more about Liger-Kernel, visit their [official repository](https://gith
|
||||
|
||||
Pay attention to the following best practices when training a model with that trainer:
|
||||
|
||||
- [`SFTTrainer`] always truncates by default the sequences to the `max_seq_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training.
|
||||
- [`SFTTrainer`] always truncates by default the sequences to the `max_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training.
|
||||
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
|
||||
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
|
||||
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
|
||||
- If you create a model outside the trainer, make sure not to pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
|
||||
|
||||
## Multi-GPU Training
|
||||
|
||||
Trainer (and thus SFTTrainer) supports multi-GPU training. If you run your script with `python script.py` it will default to using DP as the strategy, which may be [slower than expected](https://github.com/huggingface/trl/issues/1303). To use DDP (which is generally recommended, see [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many?select-gpu=Accelerate#data-parallelism) for more info) you must launch the script with `python -m torch.distributed.launch script.py` or `accelerate launch script.py`. For DDP to work you must also check the following:
|
||||
Trainer (and thus SFTTrainer) supports multi-GPU training. If you run your script with `python script.py` it will default to using DP as the strategy, which may be [slower than expected](https://github.com/huggingface/trl/issues/1303). To use DDP (which is generally recommended, see [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many?select-gpu=Accelerate#data-parallelism) for more info) you must launch the script with `python -m torch.distributed.launch script.py` or `accelerate launch script.py`. For DDP to work, you must also check the following:
|
||||
- If you're using gradient_checkpointing, add the following to the TrainingArguments: `gradient_checkpointing_kwargs={'use_reentrant':False}` (more info [here](https://github.com/huggingface/transformers/issues/26969)
|
||||
- Ensure that the model is placed on the correct device:
|
||||
```python
|
||||
@ -649,7 +546,7 @@ You may experience some issues with GPTQ Quantization after completing training.
|
||||
|
||||
## Extending `SFTTrainer` for Vision Language Models
|
||||
|
||||
`SFTTrainer` does not inherently support vision-language data. However, we provide a guide on how to tweak the trainer to support vision-language data. Specifically, you need to use a custom data collator that is compatible with vision-language data. This guide outlines the steps to make these adjustments. For a concrete example, refer to the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) which demonstrates how to fine-tune the LLaVA 1.5 model on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset.
|
||||
`SFTTrainer` does not inherently support vision-language data. However, we provide a guide on how to tweak the trainer to support vision-language data. Specifically, you need to use a custom data collator that is compatible with vision-language data. This guide outlines the steps to make these adjustments. For a concrete example, refer to the script [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py), which demonstrates how to fine-tune the LLaVA 1.5 model on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset.
|
||||
|
||||
### Preparing the Data
|
||||
|
||||
@ -768,10 +665,6 @@ A full example of training LLaVa 1.5 on the [HuggingFaceH4/llava-instruct-mix-vs
|
||||
|
||||
## Datasets
|
||||
|
||||
In the SFTTrainer we smartly support `datasets.IterableDataset` in addition to other style datasets. This is useful if you are using large corpora that you do not want to save all to disk. The data will be tokenized and processed on the fly, even when packing is enabled.
|
||||
In the SFTTrainer, we smartly support `datasets.IterableDataset` in addition to other style datasets. This is useful if you are using large corpora that you do not want to save all to disk. The data will be tokenized and processed on the fly, even when packing is enabled.
|
||||
|
||||
Additionally, in the SFTTrainer, we support pre-tokenized datasets if they are `datasets.Dataset` or `datasets.IterableDataset`. In other words, if such a dataset has a column of `input_ids`, no further processing (tokenization or packing) will be done, and the dataset will be used as-is. This can be useful if you have pretokenized your dataset outside of this script and want to re-use it directly.
|
||||
|
||||
### ConstantLengthDataset
|
||||
|
||||
[[autodoc]] trainer.ConstantLengthDataset
|
||||
Additionally, in the SFTTrainer, we support pre-tokenized datasets if they are `datasets.Dataset` or `datasets.IterableDataset`. In other words, if such a dataset has a column of `input_ids`, no further processing (tokenization or packing) will be done, and the dataset will be used as-is. This can be useful if you have pretokenized your dataset outside of this script and want to reuse it directly.
|
||||
|
@ -37,7 +37,13 @@ training_args = OnlineDPOConfig(..., use_vllm=True)
|
||||
</hfoption>
|
||||
<hfoption id="GRPO">
|
||||
|
||||
Then, enable it by passing `use_vllm=True` in the training arguments.
|
||||
First, start a vLLM server by running:
|
||||
|
||||
```bash
|
||||
trl vllm-serve --model <model_name>
|
||||
```
|
||||
|
||||
Then, run the training script and pass `use_vllm=True` in the training arguments.
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
@ -45,31 +51,23 @@ from trl import GRPOConfig
|
||||
training_args = GRPOConfig(..., use_vllm=True)
|
||||
```
|
||||
|
||||
The strategy here is to use a dedicated GPU for generation powered by vLLM, while using the remainder for training.
|
||||
You can customize the server configuration by passing additional arguments. For more information, see [vLLM integration](vllm_integration).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
When using vLLM, an additional GPU is required exclusively for generation. This means you need at least two available GPUs and must ensure that one remains unused by the trainer. To achieve this, run the training with `--num_processes <NUMBER_OF_GPUs - 1>`.
|
||||
When using vLLM, ensure that the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation using `CUDA_VISIBLE_DEVICES`.
|
||||
|
||||
For example, if you have 4 GPUs, set `--num_processes 3` to allocate three GPUs for training while reserving one for generation.
|
||||
```bash
|
||||
accelerate launch --multi_gpu --num_processes 3 train_grpo.py
|
||||
```
|
||||
Set GPUs **0-3** for vLLM generation:
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model <model_name>
|
||||
```
|
||||
|
||||

|
||||
And GPUs **4-7** for training:
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
You can further tune the vLLM configuration by setting a specific `vllm_device` and `vllm_gpu_memory_utilization` in the [`GRPOConfig`].
|
||||
|
||||
```python
|
||||
training_args = GRPOConfig(
|
||||
...,
|
||||
use_vllm=True,
|
||||
vllm_device="cuda:4",
|
||||
vllm_gpu_memory_utilization=0.7,
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
@ -1,197 +0,0 @@
|
||||
# Text Environments
|
||||
|
||||
Text environments provide a learning ground for language agents. It allows a language model to use tools to accomplish a task such as using a Python interpreter to answer math questions or using a search index for trivia questions. Having access to tools allows language models to solve tasks that would be very hard for the models itself but can be trivial for the appropriate tools. A good example is arithmetics of large numbers that become a simple copy-paste task once you have access to a calculator.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/textenv.png">
|
||||
</div>
|
||||
|
||||
Let's dive into how text environments work and start with tools!
|
||||
|
||||
## Tools
|
||||
|
||||
One of the core building blocks of text environments are tools that the model can use to solve tasks. In general tools can be any Python function that takes a string as input and returns string. The `TextEnvironment` offers two options for tools: either go with predefined tools from `transformers.Tool` or define your own function or class with `__call__` method. Let's have a look at both!
|
||||
|
||||
### `transformers.Tool`
|
||||
|
||||
Text environments fully support tools of the class `transformers.Tool`. The advantage of building tools in that framework is that they can easily be shared
|
||||
|
||||
```Python
|
||||
from transformers import load_tool
|
||||
|
||||
# simple calculator tool that runs +-/* operations
|
||||
calc_tool = load_tool("ybelkada/simple-calculator")
|
||||
|
||||
# python interpreter that executes program and returns outputs
|
||||
py_tool = load_tool("lvwerra/python-interpreter")
|
||||
|
||||
# wikipedia search index that returns best search match
|
||||
wiki_tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
|
||||
```
|
||||
|
||||
These tools are either loaded from the hub or from a local folder. Using the tool is as simple as calling them with a text query:
|
||||
|
||||
```Python
|
||||
calc_tool("1/2")
|
||||
>>> "0.5"
|
||||
```
|
||||
|
||||
Note that both input and return values are strings to enable easy usage with a language model.
|
||||
|
||||
### Custom Tools
|
||||
|
||||
The following is an example of a tool that adds two integers:
|
||||
|
||||
```Python
|
||||
def add(text):
|
||||
int_1, int_2 = text.split("+")
|
||||
result = int(int_1) + int(int_2)
|
||||
return str(result)
|
||||
|
||||
print(add("1+1"))
|
||||
>>> "2"
|
||||
```
|
||||
|
||||
We looked at basic examples such as a calculator but the principle holds for more complex tools as well such as a web search tool where you input the query and get the search results in return. Now let's look at how the model can use the tools with the call syntax.
|
||||
|
||||
### Call syntax
|
||||
|
||||
In order to have a unified way for the model to call a tool we created a simple syntax that looks as follows:
|
||||
|
||||
```python
|
||||
"<request><TOOL_NAME>QUERY<call>TOOL_RESPONSE<response>"
|
||||
```
|
||||
|
||||
There are a few special tokens involved so let's decompose it: First the model can signal that it wants to use a tool by emitting the `<request>` token. After that we want to know the name of the tool to call which is done by enclosing the tool name with `<>` brackets. Once we know which tool to call the tool query follows which is in free text form. The `<call>` tokens signifies the end of the query and stops the model generation. At this point the model output is parsed and the query sent to the tool. The environment appends the tool response to the string followed by the `<response>` token to show the end the tool output.
|
||||
|
||||
Let's look at the concrete example of the calculator and assume its name is `Calculator` (more on how the name of a tool is inferred later):
|
||||
|
||||
```python
|
||||
"<request><Calculator>1/2<call>0.5<response>"
|
||||
```
|
||||
|
||||
Finally, the episode is ended and generation stops when the model generates `<submit>` which marks the interaction as completed.
|
||||
|
||||
Now let's have a look how we can create a new text environment!
|
||||
|
||||
## Create a `TextEnvironment`
|
||||
|
||||
|
||||
```python
|
||||
prompt = """\
|
||||
What is 13-3?
|
||||
<request><SimpleCalculatorTool>13-3<call>10.0<response>
|
||||
Result=10<submit>
|
||||
"""
|
||||
|
||||
def reward_fn(result, answer):
|
||||
"""Simplified reward function returning 1 if result matches answer and 0 otherwise."""
|
||||
result_parsed = result.split("=")[1].split("<")[0]
|
||||
return int(result_parsed==answer)
|
||||
|
||||
text_env = TextEnvironemnt(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
tools= {"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")},
|
||||
reward_fn=exact_match_reward,
|
||||
prompt=prompt,
|
||||
max_turns=1
|
||||
max_tool_response=100
|
||||
generation_kwargs={"do_sample": "true"}
|
||||
)
|
||||
```
|
||||
|
||||
Let's decompose the settings:
|
||||
|
||||
| Argument | Description |
|
||||
|:-------------------|:----------------|
|
||||
| `model` | Language model to interact with the environment and generate requests. |
|
||||
| `tokenizer` | Tokenizer of language model handling tokenization of strings. |
|
||||
| `tools` | `list` of `dict` of tools. If former the name of the tool is inferred from class name and otherwise it's the keys of the dictionary.|
|
||||
| `reward_fn` | A function that takes a string as input and returns. Can have extra arguments that are passed to `.run()` such as ground truth.|
|
||||
| `prompt` | Prompt to prepend to every task. Usually a few examples to demonstrate to the model how to use the tools in a few-shot fashion. |
|
||||
| `max_turns` | Maximum number of interactions between model and tools before episode ends.|
|
||||
| `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.|
|
||||
| `max_length` | The maximum number of tokens to allow in an episode. |
|
||||
| `generation_kwargs`| Generation settings used by the language model. |
|
||||
|
||||
You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools!
|
||||
|
||||
|
||||
## Run an Episode
|
||||
|
||||
To run a set of queries through the text environment one can simply use the `run` method.
|
||||
|
||||
```python
|
||||
queries = ["What is 1/2?"]
|
||||
answers = ["0.5"]
|
||||
|
||||
queries, responses, masks, rewards, histories = text_env.run(queries, answers=answers)
|
||||
```
|
||||
|
||||
This will execute the model/tool feedback loop for each query until either no tool is called anymore, the maximum number of turns is reached or to maximum number of tokens in an episode is exceeded. The extra `kwargs` (e.g. `answers=answers` above) passed to `run` will be passed on to the reward function.
|
||||
|
||||
There are five objects that are returned by `run`:
|
||||
|
||||
- `queries`: a list of the tokenized queries
|
||||
- `responses`: all tokens that have been generated withing the environment including model and tool tokens
|
||||
- `masks`: mask that indicates which tokens have been generated by the model and which tokens are generated by the tool
|
||||
- `rewards`: a list of reward for each query/response
|
||||
- `histories`: list of `TextHistory` objects, which are useful objects containing all the above and also the text equivalents
|
||||
|
||||
The masks are crucial for training as we don't want to optimize tokens that the model has not generated which are tokens produced by the tools.
|
||||
|
||||
Next, we'll train a PPO step with the generated responses!
|
||||
|
||||
|
||||
### Train
|
||||
Training on episodes from the `TextEnvironment` is straight forward and simply requires forwarding all the returned variables except the `TextHistory` objects to the `step` method:
|
||||
|
||||
```python
|
||||
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
|
||||
```
|
||||
|
||||
## `TextHistory`
|
||||
|
||||
The `TextHistory` object stores the interactions between the model and the text environment. It stores tokens and text generated in each turn and their source in each turn (model or system) as well as rewards. Let's go through the class attributes and methods.
|
||||
|
||||
### Attributes
|
||||
|
||||
The following table summarises the available attributes of the `TextEnvironment` class:
|
||||
|
||||
| Attribute | Description |
|
||||
|:-------------------|:----------------|
|
||||
| `text` | The full string of the text generated in the text environment with both model and system generated text. |
|
||||
| `text_spans` | A list of tuples with the spans for each model or system generated text segment. |
|
||||
| `system_spans` | A list of boolean values indicating if the segment is model or system generated. |
|
||||
| `tokens` | All tokens generated in text environment with both model and system generated tokens. |
|
||||
| `token_spans` | Similar to `text_spans` the `token_spans` indicate the boundaries of model andsystem generated tokens. |
|
||||
| `token_masks` | The token masks can be used to ignore system generated tokens by masking them. |
|
||||
| `completed` | Indicates if the interaction with the environment has completed. |
|
||||
| `truncated` | Indicates if the interaction with the environment has completed because max length was reached. |
|
||||
|
||||
With these attributes you can reconstruct every interaction of the model with the `TextEnvironment`. The `TextHistory` also lets you visualize the text history. Let's have a look!
|
||||
|
||||
### Visualization
|
||||
|
||||
When the model interacts inside the `TextEnvironment` it can be useful to visualize and separate which parts of the text outputs were generated by the model and which parts come from the system and tools. For that purpose there are the two methods [`TextHistory.show_text`] and [`TextHistory.show_tokens`]. They print the text and tokens respectively and highlight the various segments using the [`rich` library](https://github.com/Textualize/rich) (make sure to install it before using these methods).
|
||||
|
||||
You can see that the prompt is highlighted in gray, whereas system segments such as query and tool responses are highlighted in green. All segments generated by the model are highlighted in blue and in addition to the pure text output the reward is displayed as additional text in plum. Here an example of `show_text`:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/textenv_show_text.png" width=600>
|
||||
</div>
|
||||
|
||||
Sometimes there can be tricky tokenization related issues that are hidden when showing the decoded text. Thus `TextHistory` also offers an option to display the same highlighting on the tokens directly with `show_tokens`:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/textenv_show_tokens.png" width=800>
|
||||
</div>
|
||||
|
||||
Note that you can turn on the colour legend by passing `show_legend=True`.
|
||||
|
||||
## API Documentation
|
||||
|
||||
[[autodoc]] TextEnvironment
|
||||
|
||||
[[autodoc]] TextHistory
|
381
docs/source/training_vlm_sft.md
Normal file
381
docs/source/training_vlm_sft.md
Normal file
@ -0,0 +1,381 @@
|
||||
# Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset)
|
||||
|
||||

|
||||
|
||||
## Overview
|
||||
|
||||
This guide walks you through the process of fine-tuning a multimodal language model (e.g., **Gemma 3**) using **Supervised Fine-Tuning (SFT)**. We cover two cases:
|
||||
|
||||
- **Single Image + Text**
|
||||
- **Multi-Image + Text**
|
||||
|
||||
This guide serves as a **detailed walkthrough** and complements the existing [VLM SFT script](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py). If you're already familiar with the concepts, you can use the script directly.
|
||||
|
||||
We demonstrate the fine-tuning process using two datasets, but these principles extend to other **Vision-Language Models (VLMs)** and datasets.
|
||||
|
||||
## Understanding the Datasets
|
||||
|
||||
To address both **Single Image + Text** and **Multi-Image + Text** scenarios, we use two datasets that are well-suited for this task.
|
||||
|
||||
### HuggingFaceH4/llava-instruct-mix-vsft Dataset (Image + Text)
|
||||
|
||||
This dataset is a reformatted version of [LLaVA Instruct Mix](https://huggingface.co/datasets/theblackcat102/llava-instruct-mix). It consists of conversations where a user provides both **text** and a **single image** as input.
|
||||
|
||||
The model (referred to as the **"assistant"**) responds based on both the **visual and textual information** shared by the user. This dataset is particularly useful for training multimodal models to **understand and generate responses based on images and text**.
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft/embed/viewer/default/train"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
### FanqingM/MMIU-Benchmark Dataset (Multi-Image + Text)
|
||||
|
||||
The **FanqingM/MMIU-Benchmark** dataset consists of:
|
||||
|
||||
- **Context:** Included in the system prompt.
|
||||
- **Question:** Provided as part of the user's input.
|
||||
- **Series of Images:** Multiple images related to the question.
|
||||
- **Answer:** The model's expected response.
|
||||
|
||||
This dataset is designed for tasks where the model must reason over multiple images to generate an informed response based on both visual and textual inputs.
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/FanqingM/MMIU-Benchmark/embed/viewer/default/test"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
## Developing a Fine-Tuning Script for Multimodal SFT
|
||||
|
||||
In this section, we build the script needed to fine-tune a multimodal model for both **Single Image + Text** and **Multi-Image + Text** use cases.
|
||||
|
||||
### Setting Up the Environment
|
||||
|
||||
Before fine-tuning, we need to install the required dependencies. Let's start by setting up the environment:
|
||||
|
||||
```bash
|
||||
# Install the required libraries. Futher details: https://huggingface.co/docs/trl/installation
|
||||
pip install -U -q trl bitsandbytes peft hf_xet tensorboard
|
||||
```
|
||||
|
||||
Once all dependencies are installed, we need to log in to the **Hugging Face Hub**. Since **Gemma 3** is a gated model, access permissions are required.
|
||||
|
||||
If you haven’t requested access yet, visit the [Model Card](https://huggingface.co/google/gemma-3-4b-it) and request it.
|
||||
|
||||
To log in, you’ll need to generate an [access token](https://huggingface.co/settings/tokens) from your Hugging Face account.
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
### **Loading the Data**
|
||||
|
||||
As mentioned earlier, we will cover two possible use cases. While the specific procedure may vary based on the dataset, the core principles remain consistent.
|
||||
|
||||
This guide supports both use cases, so refer to the **Single Image + Text** or **Multi-Image + Text** sections depending on your specific scenario.
|
||||
|
||||
#### **Single Image + Text**
|
||||
|
||||

|
||||
|
||||
In this case, each sample in a batch consists of a **single image paired with text**. Since the dataset is already formatted for supervised fine-tuning (SFT), we can directly load it using `load_dataset`.
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset_name = "HuggingFaceH4/llava-instruct-mix-vsft"
|
||||
|
||||
# Load Dataset
|
||||
dataset = load_dataset(dataset_name)
|
||||
```
|
||||
|
||||
#### **Multi-Image + Text (or Interleaving)**
|
||||
|
||||

|
||||
|
||||
Gemma 3 also supports **Multi-Image + Text** scenarios, where:
|
||||
|
||||
- The model receives a **list of images** alongside a user message.
|
||||
- The model processes **interleaved images and text** within a conversation.
|
||||
|
||||
For this dataset, some preprocessing is required before training.
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset_name = "FanqingM/MMIU-Benchmark"
|
||||
|
||||
# Load Dataset
|
||||
dataset = load_dataset(dataset_name)
|
||||
```
|
||||
|
||||
After loading the dataset, we need to preprocess and format it into a conversational structure. Here’s an example of how the data might look:
|
||||
|
||||
```python
|
||||
{"role": "system", "content": [{"type": "text", "text": "You are a judge in a photography competition, and now you are given the four images. Please examine the details and tell which one of them is most likely to be a real photograph.\nSelect from the following choices.\nA: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]},
|
||||
{"role": "user", "content": images_list + [{"type": "text", "text": "Which image is most likely to be a real photograph?"}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "A: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]},
|
||||
```
|
||||
|
||||
Here, `images_list` is a list of images:
|
||||
|
||||
```python
|
||||
images_list = [
|
||||
{"type": "image", "image": <class 'PIL.Image.Image'>},
|
||||
{"type": "image", "image": <class 'PIL.Image.Image'>},
|
||||
{"type": "image", "image": <class 'PIL.Image.Image'>},
|
||||
{"type": "image", "image": <class 'PIL.Image.Image'>},
|
||||
{"type": "image", "image": <class 'PIL.Image.Image'>},
|
||||
]
|
||||
```
|
||||
|
||||
This structure can be translated into code like this:
|
||||
|
||||
```python
|
||||
import os
|
||||
import zipfile
|
||||
import io
|
||||
from datasets import DatasetDict
|
||||
from huggingface_hub import hf_hub_download, list_repo_files
|
||||
from PIL import Image
|
||||
|
||||
dataset_train_split = "test"
|
||||
|
||||
def format_data(samples: dict[str, any]) -> dict[str, list]:
|
||||
formatted_samples = {"messages": []}
|
||||
for cont in range(len(samples["question"])):
|
||||
images = []
|
||||
for img_path in samples["input_image_path"][cont]:
|
||||
try:
|
||||
with open(img_path, "rb") as f:
|
||||
img_bytes = f.read()
|
||||
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
||||
images.append({"type": "image", "image": image})
|
||||
except Exception as e:
|
||||
print(f"Error processing image {img_path}: {e}")
|
||||
continue
|
||||
|
||||
formatted_samples["messages"].append(
|
||||
[
|
||||
{"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]},
|
||||
{"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]},
|
||||
]
|
||||
)
|
||||
return formatted_samples
|
||||
|
||||
# For multi-image example
|
||||
def prepare_dataset(dataset: DatasetDict, dataset_name: str, dataset_train_split: str) -> DatasetDict:
|
||||
all_files = list_repo_files(dataset_name, repo_type="dataset")
|
||||
zip_files = [f for f in all_files if f.endswith(".zip")]
|
||||
|
||||
for zip_filename in zip_files:
|
||||
zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset")
|
||||
extract_folder = zip_filename.replace(".zip", "")
|
||||
os.makedirs(extract_folder, exist_ok=True)
|
||||
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(extract_folder)
|
||||
|
||||
dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16)
|
||||
return dataset
|
||||
|
||||
dataset = prepare_dataset(dataset, dataset_name, dataset_train_split)
|
||||
```
|
||||
|
||||
With this, your **Multi-Image + Text** dataset is now prepared for training.
|
||||
|
||||
### **Preparing for Training**
|
||||
|
||||
We start by loading the model and processor. In this example, we use `google/gemma-3-4b-it`, but the same process applies to its other variants and similar models.
|
||||
|
||||
To optimize memory usage, we configure `BitsAndBytes` to load the quantized version of the model.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig
|
||||
|
||||
model_id = "google/gemma-3-4b-it"
|
||||
|
||||
# BitsAndBytesConfig int-4 config
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_quant_storage=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Load model and tokenizer
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="eager", # Important (Ref: https://github.com/huggingface/transformers/blob/c15a7adb283fa984a40558c7fe7bed30ae975cdd/src/transformers/models/gemma3/modeling_gemma3.py#L934)
|
||||
quantization_config=bnb_config
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
processor.tokenizer.padding_side = "right"
|
||||
```
|
||||
|
||||
Next, we set up [Quantized Low-Rank Adaptation (QLoRA)](https://huggingface.co/papers/2305.14314), an efficient fine-tuning technique for Large Language Models (LLMs) and Vision-Language Models (VLMs).
|
||||
|
||||
```python
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
# Configure QLoRA
|
||||
peft_config = LoraConfig(
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.05,
|
||||
r=16,
|
||||
bias="none",
|
||||
target_modules="all-linear",
|
||||
task_type="CAUSAL_LM",
|
||||
modules_to_save=[
|
||||
"lm_head",
|
||||
"embed_tokens",
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
With QLoRA now set up, we need to define the training arguments for SFT. The [`SFTConfig`] class simplifies this process, providing an easy way to adjust parameters based on our specific needs.
|
||||
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(
|
||||
output_dir="gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft", # Directory to save the model and push to the Hub. Use a specific repository id (e.g., gemma-3-4b-it-trl-sft-MMIU-Benchmark for multi-image datasets).
|
||||
num_train_epochs=1, # Set the number of epochs to train the model.
|
||||
per_device_train_batch_size=8, # Batch size for each device (e.g., GPU) during training. multi-image -> per_device_train_batch_size=1
|
||||
gradient_accumulation_steps=4, # Number of steps before performing a backward/update pass to accumulate gradients. multi-image -> gradient_accumulation_steps=1
|
||||
gradient_checkpointing=True, # Enable gradient checkpointing to reduce memory usage during training.
|
||||
optim="adamw_torch_fused", # Use the fused AdamW optimizer for better performance.
|
||||
logging_steps=10, # Frequency of logging training progress (log every 10 steps).
|
||||
save_strategy="epoch", # Save checkpoints at the end of each epoch.
|
||||
learning_rate=2e-05, # Learning rate for training.
|
||||
bf16=True, # Enable bfloat16 precision for training to save memory and speed up computations.
|
||||
push_to_hub=True, # Automatically push the fine-tuned model to Hugging Face Hub after training.
|
||||
report_to="tensorboard", # Automatically report metrics to tensorboard.
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}, # Set gradient checkpointing to non-reentrant to avoid issues.
|
||||
dataset_kwargs={"skip_prepare_dataset": True}, # Skip dataset preparation to handle preprocessing manually.
|
||||
remove_unused_columns=False, # Ensure unused columns are not removed in the collator (important for batch processing).
|
||||
)
|
||||
```
|
||||
|
||||
The `collate_fn` is responsible for processing and preparing individual examples to form a batch.
|
||||
|
||||
Each example in the batch undergoes the following steps:
|
||||
|
||||
1. The **chat template** is applied to the text.
|
||||
2. The **processor tokenizes** both `texts` and `images`, encoding them into tensors.
|
||||
3. The **labels** for training are set as the `input_ids` of the example.
|
||||
4. Certain **special tokens** are **masked (ignored)** during loss computation:
|
||||
- `pad_token_id`
|
||||
- `<image_token_id>`
|
||||
- `<image_soft_token>` (corresponding to ID `262144`)
|
||||
|
||||
This process is similar across different dataset types, with a minor variation in how images are handled:
|
||||
|
||||
- **Single Image + Text** → A **list of images** is directly processed.
|
||||
- **Multi-Image + Text** → A **list of lists of images** is used, where each batch element contains multiple images.
|
||||
|
||||
```python
|
||||
from PIL import Image
|
||||
|
||||
# For multi-image cases
|
||||
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
|
||||
image_inputs = []
|
||||
for msg in messages:
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
content = [content]
|
||||
|
||||
for element in content:
|
||||
if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
|
||||
if "image" in element:
|
||||
image = element["image"]
|
||||
else:
|
||||
image = element
|
||||
if image is not None:
|
||||
image = Image.open(io.BytesIO(image["bytes"]))
|
||||
image_inputs.append(image.convert("RGB"))
|
||||
return image_inputs
|
||||
|
||||
def collate_fn(examples):
|
||||
texts = [processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False).strip() for example in examples]
|
||||
if "images" in examples[0]: # single-image
|
||||
images = [
|
||||
[img.convert("RGB") for img in example["images"]]
|
||||
for example in examples
|
||||
]
|
||||
else: # multi-image
|
||||
images = [process_vision_info(example["messages"]) for example in examples]
|
||||
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(
|
||||
text=texts, images=images, return_tensors="pt", padding=True
|
||||
) # Encode texts and images into tensors
|
||||
|
||||
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||
labels = batch["input_ids"].clone() # Clone input IDs for labels
|
||||
# Mask image tokens
|
||||
image_token_id = [
|
||||
processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"])
|
||||
]
|
||||
# Mask tokens for not being used in the loss computation
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
labels[labels == image_token_id] = -100
|
||||
labels[labels == 262144] = -100
|
||||
|
||||
batch["labels"] = labels
|
||||
return batch # Return the prepared batch
|
||||
```
|
||||
|
||||
### **Training the Model**
|
||||
|
||||
With all the components set up, we now configure the `SFTTrainer` using the previously defined settings and start the training process.
|
||||
|
||||
``` python
|
||||
# Training
|
||||
from trl import SFTTrainer
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
data_collator=collate_fn,
|
||||
train_dataset=dataset["train"], # multi-image -> train_dataset=dataset["test"],
|
||||
processing_class=processor,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save the final model
|
||||
trainer.save_model()
|
||||
```
|
||||
|
||||
We save the fine-tuned model to the Hub, making it easily accessible for future use. Additionally, TRL automatically logs the training results to **Weights & Biases (Wandb)** or **TensorBoard**, depending on the chosen configuration.
|
||||
|
||||
<!-- Add Wandb training results -->
|
||||
### Results
|
||||
|
||||
During and after trainig, we can inspect the results using **Weights & Biases (Wandb)** or **TensorBoard**. For example:
|
||||
|
||||
* [**gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft (Single Image+Text)**](https://huggingface.co/sergiopaniego/gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft)
|
||||
|
||||
* [**gemma-3-4b-it-trl-sft-MMIU-Benchmark (Multi-Images+Text or Interleaving)**](https://huggingface.co/sergiopaniego/gemma-3-4b-it-trl-sft-MMIU-Benchmark)
|
||||
|
||||
## Limitations
|
||||
|
||||
Currently, fine-tuning Gemma has some [known limitations](https://github.com/huggingface/trl/issues/3121). We recommend following the procedure outlined in this guide to ensure the best results.
|
||||
|
||||
## References
|
||||
|
||||
For further reading and complementary resources, check out the following:
|
||||
|
||||
- [Fine-Tuning Vision-Language Models with QLoRA](https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora)
|
||||
- [Fine-Tuning a Vision Language Model (Qwen2-VL-7B) with the Hugging Face Ecosystem (TRL)](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl)
|
||||
|
@ -43,7 +43,6 @@ To use the data efficiently, we use a technique called packing: instead of havin
|
||||
With this approach the training is much more efficient as each token that is passed through the model is also trained in contrast to padding tokens which are usually masked from the loss.
|
||||
If you don't have much data and are more concerned about occasionally cutting off some tokens that are overflowing the context you can also use a classical data loader.
|
||||
|
||||
The packing is handled by the `ConstantLengthDataset` and we can then use the `Trainer` after loading the model with `peft`. First, we load the model in int8, prepare it for training, and then add the LoRA adapters.
|
||||
|
||||
```python
|
||||
# load model in 8bit
|
||||
|
185
docs/source/vllm_integration.md
Normal file
185
docs/source/vllm_integration.md
Normal file
@ -0,0 +1,185 @@
|
||||
# vLLM Integration
|
||||
|
||||
This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood. Let's go! 🔥
|
||||
|
||||
## 🚀 How can I use vLLM with TRL to speed up training?
|
||||
|
||||
💡 **Note**: Resources required for this specific example: a single node with 8 GPUs.
|
||||
|
||||
First, install vLLM using the following command:
|
||||
|
||||
```bash
|
||||
pip install "trl[vllm]"
|
||||
```
|
||||
|
||||
Then run the server:
|
||||
|
||||
```sh
|
||||
trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 2 --data-parallel-size 2
|
||||
```
|
||||
|
||||
Once the server is running, you can use it to generate completions for training. In the example below, we are using the `GRPOTrainer` to train a model using the vLLM server for generation. The `--tensor-parallel-size` and `--data-parallel-size` arguments control how the model and data are sharded across GPUs.
|
||||
|
||||
In this example, we are sharding two copies of the model across 4 GPUs. Increasing data parallelism increases throughput, while increasing tensor parallelism allows for serving larger models. Then, run the training script by passing `use_vllm=True` in the training arguments as follows:
|
||||
|
||||
Sample of a simple `train.py` script:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
# Dummy reward function: count the number of unique characters in the completions
|
||||
def reward_num_unique_chars(completions, **kwargs):
|
||||
return [len(set(c)) for c in completions]
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir="my_test",
|
||||
use_vllm=True,
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
logging_steps=10,
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2.5-7B",
|
||||
args=training_args,
|
||||
reward_funcs=reward_num_unique_chars,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
And the train command:
|
||||
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
|
||||
```
|
||||
|
||||
## 🎬 Flashback: Why do we need to use vLLM in online methods?
|
||||
|
||||
Online methods like GRPO or Online DPO require the model to generate completions during training, which are then used to compute reward signals. However, generation can be extremely time-consuming, especially with large or reasoning models. In the default setup (without vLLM), completions are generated using the [(unwrapped) model's `generate` method](https://github.com/huggingface/trl/blob/f3e8c2304428ef16e9ae5de9e5741ed84d533b7b/trl/trainer/grpo_trainer.py#L965C39-L965C66). This approach quickly becomes a major bottleneck — generation is slow and inefficient, particularly for large batches or models. As a result, training times increase significantly, and overall efficiency drops. To address this, we turn to vLLM, which enables much faster and more scalable generation, helping eliminate this bottleneck in online methods.
|
||||
|
||||
## 🤔 How does vLLM solve the slow generation issue?
|
||||
|
||||
If you've ever done autoregressive decoder training, you know all the input tokens to the LLM produce their attention key and value tensors, and these tensors are kept in GPU memory to later generate subsequent tokens based on them. These cached key and value tensors are often referred to as the KV cache. However, storing the KV cache occupies a lot of memory, so vLLM uses a technique called **PagedAttention** to solve this problem. PagedAttention, which is inspired by the OS’s virtual memory concept, stores continuous keys and values in **non-contiguous memory space**, which is much more efficient. The details of this are beyond the scope of this document, but in short, it allows the model to store the keys and values in a more efficient way, reducing the memory footprint and speeding up the generation process. If you are interested, make sure to check out the [vLLM PagedAttention](https://blog.vllm.ai/2023/06/20/vllm.html) for more details.
|
||||
|
||||
## 🤔 What exactly happens when you run `trl vllm-serve --model <model_name>`?
|
||||
|
||||
When you run for example
|
||||
|
||||
```sh
|
||||
trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 1 --data-parallel-size 4
|
||||
```
|
||||
|
||||
the following happens:
|
||||
|
||||

|
||||
|
||||
1. vLLM first spawns multiple workers to handle incoming requests in parallel. The number of workers is determined by multiplying the `--tensor-parallel-size` and `--data-parallel-size` values. In this example, it spawns 4 workers (1 × 4).
|
||||
Each worker operates independently and processes a chunk of the incoming requests — which are basically the prompts sent to the server for generation. A key point to understand is that these 4 workers are running in parallel, and each one is responsible for handling a subset of the total incoming load.
|
||||
|
||||
2. Once the incoming requests (prompts) are distributed across the workers, the model starts generating completions. Internally, the model’s weights are split across multiple GPUs based on the `--tensor-parallel-size` argument — this is how tensor parallelism is handled. Meanwhile, data parallelism (controlled by `--data-parallel-size`) ensures that different sets of requests are processed independently across the workers. In short: tensor parallelism splits the model across GPUs, and data parallelism splits the batch of requests across different model replicas.
|
||||
|
||||
3. Although the GPUs process requests independently and in parallel, they still need to communicate with each other. Remember that each GPU handles only a slice of the incoming prompts (for example, with 4 GPUs and 8 prompts using `--data-parallel-size=4`, each GPU processes 2 prompts).
|
||||
This GPU-to-GPU communication is managed efficiently by NVIDIA’s NCCL library. The communication mainly ensures that each GPU gets its correct portion of the incoming requests — it’s lightweight and doesn’t interfere with generation itself.
|
||||
Separately, the number of completions to generate per prompt is controlled by the `num_generations` setting in the GRPO config. For instance, if you set `num_generations=2` (like in the picture above), each prompt will have 2 completions. So, with 8 prompts and `num_generations=2`, you would end up with 16 completions total — regardless of the number of GPUs or parallelism settings.
|
||||
|
||||
## 🥸 More detail on what happens under the hood when running the server
|
||||
|
||||
* The vLLM server starts by running the command: `trl vllm-serve --model Qwen/Qwen2.5-7B`.
|
||||
* Once the server is running, it generates completions based on requests from the client (trainer) using `vllm_client.generate` [here](https://github.com/huggingface/trl/blob/cc044e35b285be7dc062764b3364e1e684db4c7c/trl/trainer/grpo_trainer.py#L1025-L1035).
|
||||
* The client (trainer) then requests these completions from the server.
|
||||
* These completions are used to compute the reward signal.
|
||||
* Based on the reward signal and the model’s output, the loss is computed, and the backward pass is performed to update the model’s weights.
|
||||
* **Note**: The server only handles completion generation — it doesn’t train the model. Therefore, the model’s weights aren’t updated on the server. Once the backward pass is complete, the client sends the updated weights to the server using `vllm_client.update_named_param(name, param.data)`.
|
||||
|
||||
When using vLLM, ensure the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation for training using `CUDA_VISIBLE_DEVICES`. See the example below:
|
||||
|
||||
* **Set GPUs *0–3* for vLLM generation:** Assume `CUDA_VISIBLE_DEVICES=0,1,2,3` are allocated for vLLM generation.
|
||||
|
||||
```sh
|
||||
trl vllm-serve --model <model_name> --tensor-parallel-size 1 --data-parallel-size 4
|
||||
```
|
||||
|
||||
* **And GPUs *4–7* for training:** If you do not set the `CUDA_VISIBLE_DEVICES` environment variable, the training script will use all available GPUs by default, which may lead to resource conflicts. To avoid this, you can specify which GPUs to use for training. For example, if you want to use GPUs 4–7 for training, set the environment variable as follows:
|
||||
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
|
||||
```
|
||||
|
||||
## 🍷 More customization options with vLLM?
|
||||
|
||||
You can customize the server configuration by passing additional arguments.
|
||||
|
||||
```
|
||||
$ trl vllm-serve --help
|
||||
usage: trl vllm-serve [-h] --model MODEL [--revision REVISION] [--tensor_parallel_size TENSOR_PARALLEL_SIZE]
|
||||
[--data_parallel_size DATA_PARALLEL_SIZE] [--host HOST] [--port PORT]
|
||||
[--gpu_memory_utilization GPU_MEMORY_UTILIZATION] [--dtype DTYPE] [--max_model_len MAX_MODEL_LEN]
|
||||
[--enable_prefix_caching ENABLE_PREFIX_CACHING] [--enforce_eager ENFORCE_EAGER] [--log_level LOG_LEVEL]
|
||||
|
||||
options:
|
||||
-h, --help Show this help message and exit
|
||||
--model MODEL Model name or path to load the model from. (default: None)
|
||||
--revision REVISION Revision to use for the model. If not specified, the default branch will be used. (default: None)
|
||||
--tensor_parallel_size TENSOR_PARALLEL_SIZE, --tensor-parallel-size TENSOR_PARALLEL_SIZE
|
||||
Number of tensor parallel workers to use. (default: 1)
|
||||
--data_parallel_size DATA_PARALLEL_SIZE, --data-parallel-size DATA_PARALLEL_SIZE
|
||||
Number of data parallel workers to use. (default: 1)
|
||||
--host HOST Host address to run the server on. (default: 0.0.0.0)
|
||||
--port PORT Port to run the server on. (default: 8000)
|
||||
--gpu_memory_utilization GPU_MEMORY_UTILIZATION, --gpu-memory-utilization GPU_MEMORY_UTILIZATION
|
||||
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the device
|
||||
dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus improve the
|
||||
model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors during
|
||||
initialization. (default: 0.9)
|
||||
--dtype DTYPE Data type to use for vLLM generation. If set to 'auto', the data type will be automatically determined based on
|
||||
the model configuration. Find the supported values in the vLLM documentation. (default: auto)
|
||||
--max_model_len MAX_MODEL_LEN, --max-model-len MAX_MODEL_LEN
|
||||
If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced
|
||||
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model context
|
||||
size, which might be much larger than the KV cache, leading to inefficiencies. (default: None)
|
||||
--enable_prefix_caching ENABLE_PREFIX_CACHING, --enable-prefix-caching ENABLE_PREFIX_CACHING
|
||||
Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support this
|
||||
feature. (default: None)
|
||||
--enforce_eager ENFORCE_EAGER, --enforce-eager ENFORCE_EAGER
|
||||
Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the model
|
||||
in eager mode. If `False` (default behavior), we will use CUDA graph and eager execution in hybrid. (default:
|
||||
None)
|
||||
--log_level LOG_LEVEL, --log-level LOG_LEVEL
|
||||
Log level for uvicorn. Possible choices: 'critical', 'error', 'warning', 'info', 'debug', 'trace'. (default:
|
||||
info)
|
||||
```
|
||||
|
||||
## 🥳 Okay, now that we have the server running, how can we use it to generate completions?
|
||||
|
||||
Run the training script and pass `use_vllm=True` in the training arguments:
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(..., use_vllm=True)
|
||||
```
|
||||
|
||||
## 💆🏻♀️ What's the best distributed setup?
|
||||
|
||||

|
||||

|
||||
|
||||
First and foremost, always remember that the optimal setup depends on:
|
||||
|
||||
* The model size
|
||||
* The number of GPUs you have
|
||||
* The GPU memory size
|
||||
* The batch size you are using
|
||||
* The number of requests you are sending to the server (prompts)
|
||||
* The `max_model_len` you are using (this is the max length of the input sequence that the model can process, a.k.a. the context window size)
|
||||
* The number of completions you are generating for each request (`num_generations`)
|
||||
|
||||
Given these factors, our experiments on the Qwen model family (3B, 7B, 14B, 32B) using 8 H100 GPUs show that:
|
||||
|
||||
* For reasonable-sized models (3B–14B) and a moderate context window (`max_len < 8k`), using full capacity for data parallelism gives better throughput. The setup `(tp=1, dp=8)` yields the best results.
|
||||
* For larger models (32B) and longer context windows (`max_len > 8k`), a smaller DP size combined with some model-side parallelism performs better. For example, `(tp=2, dp=4)` is a good setup for 32B models with a larger context window.
|
@ -50,9 +50,9 @@ accelerate launch train_xpo.py
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 1 hour.
|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-XPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-XPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-XPO
|
||||
<pre><code>$ transformers-cli chat --model_name_or_path trl-lib/Qwen2-0.5B-XPO
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
|
28
examples/accelerate_configs/fsdp1.yaml
Normal file
28
examples/accelerate_configs/fsdp1.yaml
Normal file
@ -0,0 +1,28 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
fsdp_config:
|
||||
fsdp_activation_checkpointing: false
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_forward_prefetch: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: FULL_SHARD
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_use_orig_params: true
|
||||
fsdp_version: 1
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
25
examples/accelerate_configs/fsdp2.yaml
Normal file
25
examples/accelerate_configs/fsdp2.yaml
Normal file
@ -0,0 +1,25 @@
|
||||
# Requires accelerate 1.7.0 or higher
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
fsdp_config:
|
||||
fsdp_activation_checkpointing: false
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_version: 2
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
@ -1,25 +0,0 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_forward_prefetch: false
|
||||
fsdp_offload_params: true
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_use_orig_params: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -4,4 +4,4 @@ This directory contains a collection of Jupyter notebooks that demonstrate how t
|
||||
|
||||
- [`best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb): This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO.
|
||||
- [`gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb): This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook.
|
||||
- [`gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook.
|
||||
- [`gpt2-sentiment-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook.
|
||||
|
15
examples/research_projects/layer_skip/README.md
Normal file
15
examples/research_projects/layer_skip/README.md
Normal file
@ -0,0 +1,15 @@
|
||||
# LayerSkip Training Recipe
|
||||
|
||||
Implements the training recipe as described in the [LayerSkip paper](https://huggingface.co/papers/2404.16710).
|
||||
|
||||
## Run training
|
||||
```
|
||||
cd scripts
|
||||
python layer_skip_sft.py
|
||||
```
|
||||
|
||||
## Run benchmark
|
||||
```
|
||||
cd scripts
|
||||
python benchmark_layer_skip.py
|
||||
```
|
@ -0,0 +1,77 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import config
|
||||
import torch
|
||||
from torch.utils import benchmark
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
def generate_tokens(model, inputs):
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
do_sample=False,
|
||||
max_new_tokens=64,
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
def generate_tokens_with_assistance(model, inputs, assistant_early_exit):
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
assistant_early_exit=assistant_early_exit,
|
||||
do_sample=False,
|
||||
max_new_tokens=64,
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ckpt = config.hub_model_id
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(ckpt, device_map="auto", torch_dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt)
|
||||
|
||||
prompt = "### Instruction: What are my alarms for the rest of the day?\n ### Response: "
|
||||
|
||||
results = []
|
||||
label = "Generation Times"
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="generate_tokens(model, inputs)",
|
||||
setup="from __main__ import generate_tokens",
|
||||
globals={"model": model, "inputs": inputs},
|
||||
num_threads=torch.get_num_threads(),
|
||||
label=label,
|
||||
sub_label="no layer skip",
|
||||
description="generation",
|
||||
).blocked_autorange()
|
||||
)
|
||||
|
||||
for i in range(1, model.config.num_hidden_layers):
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="generate_tokens_with_assistance(model, inputs, assistant_early_exit)",
|
||||
setup="from __main__ import generate_assistant_tokens",
|
||||
globals={"model": model, "assistant_early_exit": i, "inputs": inputs},
|
||||
num_threads=torch.get_num_threads(),
|
||||
label=label,
|
||||
sub_label=f"layer skip {i}",
|
||||
description="generation",
|
||||
).blocked_autorange()
|
||||
)
|
||||
|
||||
benchmark.Compare(results).print()
|
28
examples/research_projects/layer_skip/scripts/config.py
Normal file
28
examples/research_projects/layer_skip/scripts/config.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from huggingface_hub import whoami
|
||||
|
||||
|
||||
model_name = "unsloth/Llama-3.2-3B"
|
||||
tokenizer_name = "unsloth/Llama-3.2-3B"
|
||||
dataset_name = "WillHeld/top_v2"
|
||||
|
||||
output_root_dir = "./checkpoints/"
|
||||
hub_model_id = f"{whoami()['name']}/layerskip-{model_name.split('/')[1]}-{dataset_name.split('/')[1]}"
|
||||
output_dir = f"{output_root_dir}/{hub_model_id}"
|
||||
|
||||
per_device_train_batch_size = 8
|
||||
gradient_accumulation_steps = 1
|
||||
learning_rate = 2e-5
|
@ -0,0 +1,48 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from trl import SFTTrainer
|
||||
|
||||
|
||||
class LayerSkipSFTTrainer(SFTTrainer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.early_exit_layer = 0 # initialize with 0
|
||||
self.always_last_layer = True
|
||||
self.early_exit_loss_scale = 1.0
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
self.early_exit_layer = (
|
||||
self.early_exit_layer % (model.config.num_hidden_layers - 1)
|
||||
) + 1 # rotates between [1, num_hidden_layers-1]
|
||||
bs, seqlen = inputs.input_ids.shape
|
||||
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs, output_hidden_states=True)
|
||||
|
||||
hidden_state = outputs["hidden_states"][self.early_exit_layer].to(model.dtype)
|
||||
if self.early_exit_layer != model.config.num_hidden_layers:
|
||||
hidden_state = model.model.norm(hidden_state)
|
||||
logits = model.lm_head(hidden_state)
|
||||
loss_early = model.loss_function(logits=logits, labels=labels, vocab_size=model.vocab_size)
|
||||
|
||||
if self.always_last_layer:
|
||||
loss_last = model.loss_function(logits=outputs["logits"], labels=labels, vocab_size=model.vocab_size)
|
||||
loss = self.early_exit_loss_scale * loss_early.to(loss_last.device) + 1.0 * loss_last
|
||||
# normalize loss scales
|
||||
loss = loss / (1.0 + self.early_exit_loss_scale)
|
||||
else:
|
||||
loss = loss_early
|
||||
|
||||
return loss
|
@ -0,0 +1,91 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import config
|
||||
import torch
|
||||
from custom_trainer import LayerSkipSFTTrainer
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from trl import DataCollatorForCompletionOnlyLM, SFTConfig
|
||||
|
||||
|
||||
def formatting_prompts_func(example):
|
||||
text = f"### Instruction: {example['utterance']}\n ### Response: {example['semantic_parse']}"
|
||||
|
||||
# Inject eos_token as a string before tokenization, because they are not always added
|
||||
# See: https://github.com/huggingface/transformers/issues/22794 and
|
||||
# https://github.com/huggingface/trl/issues/1623
|
||||
if tokenizer.eos_token: # usually something like "</s>" for GPT2 or "<|endoftext|>"
|
||||
text += f"{tokenizer.eos_token}"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# load the dataset
|
||||
print("[INFO] loading the dataset...")
|
||||
train_dataset = load_dataset(config.dataset_name, split="train")
|
||||
|
||||
print(f"output_root_dir: {config.output_root_dir}")
|
||||
print(f"hub_model_id: {config.hub_model_id}")
|
||||
|
||||
# load the model and tokenizer
|
||||
print("[INFO] loading the model and tokenizer...")
|
||||
model = AutoModelForCausalLM.from_pretrained(config.model_name, device_map="auto", torch_dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, add_eos_token=True)
|
||||
|
||||
# adding pad and eos tokens if not provided in the tokenizer
|
||||
if tokenizer.pad_token is None:
|
||||
# Add '[PAD]' token if it doesn't exist
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
model.config.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
if tokenizer.eos_token is None or tokenizer.eos_token == tokenizer.bos_token:
|
||||
# Add '[EOS]' token if it doesn't exist
|
||||
tokenizer.add_special_tokens({"eos_token": "[EOS]"})
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
response_template = " ### Response:"
|
||||
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
|
||||
|
||||
args = SFTConfig(
|
||||
do_train=True,
|
||||
bf16=True,
|
||||
max_seq_length=None,
|
||||
per_device_train_batch_size=config.per_device_train_batch_size,
|
||||
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
||||
learning_rate=config.learning_rate,
|
||||
packing=False,
|
||||
num_train_epochs=1.0,
|
||||
report_to="none",
|
||||
push_to_hub=True,
|
||||
hub_model_id=config.hub_model_id,
|
||||
output_dir=config.output_dir,
|
||||
logging_steps=500,
|
||||
save_steps=1000,
|
||||
save_total_limit=2,
|
||||
)
|
||||
|
||||
trainer = LayerSkipSFTTrainer(
|
||||
model,
|
||||
train_dataset=train_dataset,
|
||||
args=args,
|
||||
formatting_func=formatting_prompts_func,
|
||||
data_collator=collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -185,7 +185,7 @@ trainer = SFTTrainer(
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
peft_config=peft_config,
|
||||
max_seq_length=None,
|
||||
max_length=None,
|
||||
formatting_func=prepare_sample_text,
|
||||
processing_class=tokenizer,
|
||||
args=training_args,
|
||||
|
@ -1,118 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, load_tool
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment
|
||||
|
||||
|
||||
def generate_data(n):
|
||||
"""Generate random arithmetic tasks and answers."""
|
||||
tasks, answers = [], []
|
||||
for _ in range(n):
|
||||
a = np.random.randint(0, 50)
|
||||
b = np.random.randint(0, 50)
|
||||
op = np.random.choice(["-", "+", "*"])
|
||||
tasks.append(f"\n\nWhat is {a} {op} {b}?")
|
||||
if op == "-":
|
||||
answers.append(a - b)
|
||||
elif op == "+":
|
||||
answers.append(a + b)
|
||||
else:
|
||||
answers.append(a * b)
|
||||
return tasks, answers
|
||||
|
||||
|
||||
def exact_match_reward(responses, answers=None):
|
||||
"""Reward if generated response contains correct answer."""
|
||||
rewards = []
|
||||
pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*<submit>" # generated by chatGPT
|
||||
for response, answer in zip(responses, answers):
|
||||
reward = 0.0
|
||||
predicted_number = None
|
||||
match_pattern = re.findall(pattern, response)
|
||||
if match_pattern:
|
||||
predicted_number = float(match_pattern[0])
|
||||
if predicted_number is not None:
|
||||
if np.abs(predicted_number - answer) < 0.01:
|
||||
reward += 1.0
|
||||
rewards.append(torch.tensor(reward))
|
||||
return rewards
|
||||
|
||||
|
||||
# set up models
|
||||
model_id = "gpt2"
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id)
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# system prompt
|
||||
prompt = """\
|
||||
What is 13-3?
|
||||
|
||||
<request><SimpleCalculatorTool>13-3<call>10.0<response>
|
||||
|
||||
Result=10<submit>
|
||||
|
||||
What is 4*3?
|
||||
|
||||
<request><SimpleCalculatorTool>4*3<call>12.0<response>
|
||||
|
||||
Result=12<submit>"""
|
||||
|
||||
generation_kwargs = {
|
||||
"min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.eos_token_id,
|
||||
"eos_token_id": -1,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
# trainer
|
||||
ppo_config = PPOConfig(
|
||||
batch_size=256,
|
||||
learning_rate=1.41e-5,
|
||||
mini_batch_size=64,
|
||||
log_with="wandb",
|
||||
)
|
||||
ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer)
|
||||
|
||||
# text env
|
||||
text_env = TextEnvironment(
|
||||
model,
|
||||
tokenizer,
|
||||
{"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")},
|
||||
exact_match_reward,
|
||||
prompt,
|
||||
generation_kwargs=generation_kwargs,
|
||||
)
|
||||
|
||||
# main training loop
|
||||
for _step in range(100):
|
||||
tasks, answers = generate_data(ppo_config.batch_size)
|
||||
queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers)
|
||||
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
|
||||
|
||||
response_texts = [tokenizer.decode(response) for response in responses]
|
||||
query_texts = [tokenizer.decode(query) for query in queries]
|
||||
texts = {"query": [qt.split("<submit>")[-1].strip() for qt in query_texts], "response": response_texts}
|
||||
ppo_trainer.log_stats(train_stats, texts, rewards, columns_to_log=["query", "response", "answer"])
|
||||
ppo_trainer.save_pretrained(model_id + "-calculator")
|
@ -1,193 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoTokenizer, HfArgumentParser, load_tool
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment
|
||||
|
||||
|
||||
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name: Optional[str] = field(default="bigcode/starcoderbase", metadata={"help": "the model name"})
|
||||
learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
|
||||
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
|
||||
batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=16, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
max_new_tokens: Optional[int] = field(default=256, metadata={"help": "max number of generated tokens per turn"})
|
||||
ppo_epochs: Optional[int] = field(default=1, metadata={"help": "max number of ppo epochs"})
|
||||
n_epochs: Optional[int] = field(default=32, metadata={"help": "max number of ppo epochs"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
|
||||
def exact_match_reward(responses, answers=None):
|
||||
"""Reward if generated response contains correct answer."""
|
||||
rewards = []
|
||||
pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*<submit>" # generated by chatGPT
|
||||
for response, answer in zip(responses, answers):
|
||||
reward = 0.0
|
||||
try:
|
||||
predicted_number = None
|
||||
match_pattern = re.findall(pattern, response)
|
||||
if match_pattern:
|
||||
predicted_number = float(match_pattern[0])
|
||||
if predicted_number is not None:
|
||||
if np.abs(predicted_number - float(answer)) < 0.1:
|
||||
reward += 1.0
|
||||
except Exception:
|
||||
pass
|
||||
rewards.append(torch.tensor(reward))
|
||||
return rewards
|
||||
|
||||
|
||||
def evaluate(test_dataloader, text_env, ppo_trainer):
|
||||
test_rewards = []
|
||||
for test_batch in test_dataloader:
|
||||
_, _, _, rewards, _ = text_env.run(test_batch["query"], answers=test_batch["answer"])
|
||||
test_rewards.extend(rewards)
|
||||
test_rewards = ppo_trainer.accelerator.gather_for_metrics(
|
||||
torch.stack(test_rewards).to(ppo_trainer.accelerator.device)
|
||||
)
|
||||
return test_rewards.mean()
|
||||
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
target_modules=["c_proj", "c_attn", "q_attn"],
|
||||
)
|
||||
|
||||
# set up models
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
script_args.model_name,
|
||||
use_auth_token=True,
|
||||
load_in_4bit=True,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, use_auth_token=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
ds = load_dataset("openai/gsm8k", "main", split="train")
|
||||
ds = ds.rename_columns({"question": "query"})
|
||||
ds = ds.map(lambda x: {"answer": x["answer"].split("#### ")[1]})
|
||||
ds = ds.select(range(1, len(ds))) # skip the first sample which is used in prompt
|
||||
|
||||
ds_test = load_dataset("openai/gsm8k", "main", split="test")
|
||||
ds_test = ds_test.rename_columns({"question": "query"})
|
||||
ds_test = ds_test.map(lambda x: {"answer": x["answer"].split("#### ")[1]})
|
||||
|
||||
test_dataloader = torch.utils.data.DataLoader(ds_test, batch_size=script_args.batch_size)
|
||||
|
||||
# prompt
|
||||
prompt = """\
|
||||
Example of using a Python API to solve math questions.
|
||||
|
||||
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
|
||||
|
||||
<request><PythonInterpreter>
|
||||
def solution():
|
||||
money_initial = 23
|
||||
bagels = 5
|
||||
bagel_cost = 3
|
||||
money_spent = bagels * bagel_cost
|
||||
money_left = money_initial - money_spent
|
||||
result = money_left
|
||||
return result
|
||||
print(solution())
|
||||
<call>72<response>
|
||||
|
||||
Result = 72 <submit>
|
||||
|
||||
Q: """
|
||||
|
||||
generation_kwargs = {
|
||||
"min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.eos_token_id,
|
||||
"eos_token_id": -1,
|
||||
"max_new_tokens": script_args.max_new_tokens,
|
||||
}
|
||||
|
||||
# trainer
|
||||
ppo_config = PPOConfig(
|
||||
batch_size=script_args.batch_size,
|
||||
learning_rate=script_args.learning_rate,
|
||||
mini_batch_size=script_args.mini_batch_size,
|
||||
ppo_epochs=script_args.ppo_epochs,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
log_with="wandb",
|
||||
tracker_project_name="trl-gsm8k",
|
||||
remove_unused_columns=False,
|
||||
optimize_cuda_cache=True,
|
||||
)
|
||||
|
||||
ppo_trainer = PPOTrainer(args=ppo_config, model=model, tokenizer=tokenizer, dataset=ds)
|
||||
test_dataloader = ppo_trainer.accelerator.prepare(test_dataloader)
|
||||
|
||||
# text env
|
||||
text_env = TextEnvironment(
|
||||
model,
|
||||
tokenizer,
|
||||
[load_tool("lvwerra/python-interpreter")],
|
||||
exact_match_reward,
|
||||
prompt,
|
||||
max_turns=2,
|
||||
generation_kwargs=generation_kwargs,
|
||||
)
|
||||
|
||||
# main training loop
|
||||
for epoch in range(script_args.n_epochs):
|
||||
for step, batch in enumerate(ppo_trainer.dataloader):
|
||||
if (step == 0) and (epoch % 4 == 0): # evaluate every 4 epochs
|
||||
reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer)
|
||||
else:
|
||||
reward_mean_test = None
|
||||
|
||||
queries, responses, masks, rewards, histories = text_env.run(batch["query"], answers=batch["answer"])
|
||||
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
|
||||
|
||||
# logging
|
||||
if reward_mean_test is not None:
|
||||
train_stats["env/reward_mean_test"] = reward_mean_test
|
||||
texts = {
|
||||
"query": batch["query"],
|
||||
"response": [tokenizer.decode(response) for response in responses],
|
||||
"answer": batch["answer"],
|
||||
}
|
||||
ppo_trainer.log_stats(train_stats, texts, rewards, columns_to_log=["query", "response", "answer"])
|
||||
|
||||
reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer)
|
||||
ppo_trainer.save_pretrained(f"model/{script_args.model_name}-gsm8k")
|
@ -1,192 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoTokenizer, HfArgumentParser, load_tool
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment
|
||||
|
||||
|
||||
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name: Optional[str] = field(default="bigcode/starcoderbase", metadata={"help": "the model name"})
|
||||
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
|
||||
learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
|
||||
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
|
||||
batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=16, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
max_new_tokens: Optional[int] = field(default=256, metadata={"help": "max number of generated tokens per turn"})
|
||||
ppo_epochs: Optional[int] = field(default=1, metadata={"help": "max number of ppo epochs"})
|
||||
iterations: Optional[int] = field(default=1000, metadata={"help": "the number of iterations"})
|
||||
seed: Optional[int] = field(default=0, metadata={"help": "the random seed"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
target_modules=["c_proj", "c_attn", "q_attn"],
|
||||
)
|
||||
|
||||
# set up models
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
script_args.model_name,
|
||||
use_auth_token=True,
|
||||
trust_remote_code=True,
|
||||
load_in_4bit=True,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, use_auth_token=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# system prompt
|
||||
prompt = """\
|
||||
Answer the following question:
|
||||
|
||||
Q: In which branch of the arts is Patricia Neary famous?
|
||||
A: Ballets
|
||||
A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
|
||||
Result=Ballets<submit>
|
||||
|
||||
Q: Who won Super Bowl XX?
|
||||
A: Chicago Bears
|
||||
A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
|
||||
Result=Chicago Bears<submit>
|
||||
|
||||
Q: """
|
||||
|
||||
generation_kwargs = {
|
||||
"min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.eos_token_id,
|
||||
"eos_token_id": -1,
|
||||
"max_new_tokens": script_args.max_new_tokens,
|
||||
}
|
||||
|
||||
# trainer
|
||||
config = PPOConfig(
|
||||
batch_size=script_args.batch_size,
|
||||
model_name=script_args.model_name,
|
||||
learning_rate=script_args.learning_rate,
|
||||
log_with=script_args.log_with,
|
||||
mini_batch_size=script_args.mini_batch_size,
|
||||
ppo_epochs=script_args.ppo_epochs,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
seed=script_args.seed,
|
||||
optimize_cuda_cache=True,
|
||||
)
|
||||
ppo_trainer = PPOTrainer(args=config, model=model, tokenizer=tokenizer)
|
||||
dataset = load_dataset("mandarjoshi/trivia_qa", "rc", split="train")
|
||||
local_seed = script_args.seed + ppo_trainer.accelerator.process_index * 100003 # Prime
|
||||
dataset = dataset.shuffle(local_seed)
|
||||
|
||||
|
||||
def data_generator():
|
||||
for i in range(len(dataset)):
|
||||
yield dataset[i]["question"], list(dataset[i]["answer"]["normalized_aliases"])
|
||||
|
||||
|
||||
gen = data_generator()
|
||||
gen = iter(gen)
|
||||
|
||||
|
||||
def generate_data(n):
|
||||
tasks, answers = [], []
|
||||
for _i in range(n):
|
||||
q, a = next(gen)
|
||||
tasks.append(q)
|
||||
answers.append(a)
|
||||
return tasks, answers
|
||||
|
||||
|
||||
def exact_match_reward(responses, answers=None):
|
||||
"""Reward if generated response contains correct answer."""
|
||||
rewards = []
|
||||
for response, answer in zip(responses, answers):
|
||||
reward = 0.0
|
||||
for a in answer:
|
||||
if a.lower() in response.lower():
|
||||
reward += 1.0
|
||||
break
|
||||
rewards.append(torch.tensor(reward))
|
||||
return rewards
|
||||
|
||||
|
||||
def tool_fn(x):
|
||||
# limit the amount of tokens
|
||||
return tool(x).split("\n")[1][:600]
|
||||
|
||||
|
||||
# text env
|
||||
tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
|
||||
|
||||
text_env = TextEnvironment(
|
||||
model,
|
||||
tokenizer,
|
||||
{"Wiki": tool_fn},
|
||||
exact_match_reward,
|
||||
prompt,
|
||||
generation_kwargs=generation_kwargs,
|
||||
max_tool_reponse=400,
|
||||
)
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
trainable_params = 0
|
||||
all_param = 0
|
||||
for _, param in model.named_parameters():
|
||||
all_param += param.numel()
|
||||
if param.requires_grad:
|
||||
trainable_params += param.numel()
|
||||
print(
|
||||
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
||||
)
|
||||
|
||||
|
||||
print_trainable_parameters(model)
|
||||
# main training loop
|
||||
for i in range(script_args.iterations):
|
||||
tasks, answers = generate_data(config.batch_size)
|
||||
queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers)
|
||||
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
|
||||
response_texts = [tokenizer.decode(response) for response in responses]
|
||||
query_texts = [tokenizer.decode(query) for query in queries]
|
||||
texts = {
|
||||
"query": [qt.split("<submit>")[-1].strip() for qt in query_texts],
|
||||
"response": response_texts,
|
||||
"answer": [", ".join(item) for item in answers],
|
||||
}
|
||||
all_rewards = ppo_trainer.accelerator.gather(torch.tensor(rewards, device=ppo_trainer.accelerator.device))
|
||||
ppo_trainer.log_stats(train_stats, texts, list(all_rewards), columns_to_log=["query", "response", "answer"])
|
||||
if i % 100 == 0:
|
||||
ppo_trainer.save_pretrained(f"models/{script_args.model_name}_{script_args.seed}_{i}_triviaqa")
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -99,4 +99,4 @@ else:
|
||||
completions = [[c0, c1] for c0, c1 in zip(reference_completions, model_completions)]
|
||||
best_idxs = judge.judge(prompts, completions)
|
||||
model_win_rate = best_idxs.count(1) / len(best_idxs)
|
||||
print(f"Model win rate: {model_win_rate*100:.2f}%")
|
||||
print(f"Model win rate: {model_win_rate * 100:.2f}%")
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -45,7 +45,6 @@ python examples/scripts/gkd.py \
|
||||
--lora_alpha 16
|
||||
"""
|
||||
|
||||
from accelerate import PartialState
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer, GenerationConfig
|
||||
|
||||
@ -106,14 +105,6 @@ if __name__ == "__main__":
|
||||
################
|
||||
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
||||
|
||||
with PartialState().local_main_process_first():
|
||||
dataset = dataset.map(
|
||||
lambda x: {
|
||||
"prompt": tokenizer.apply_chat_template(x["prompt"], tokenize=False, add_generation_prompt=True)
|
||||
},
|
||||
num_proc=training_args.dataset_num_proc,
|
||||
)
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
62
examples/scripts/sft_gemma3.py
Normal file
62
examples/scripts/sft_gemma3.py
Normal file
@ -0,0 +1,62 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Train Gemma-3 on the Codeforces COTS dataset.
|
||||
|
||||
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml examples/scripts/sft_gemma3.py
|
||||
"""
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForImageTextToText
|
||||
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
|
||||
def main():
|
||||
# Load dataset
|
||||
train_dataset = load_dataset("open-r1/codeforces-cots", split="train")
|
||||
train_dataset = train_dataset.remove_columns("prompt")
|
||||
|
||||
# Load model
|
||||
model_id = "google/gemma-3-12b-it"
|
||||
model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager")
|
||||
|
||||
# Train model
|
||||
training_args = SFTConfig(
|
||||
output_dir=f"{model_id}-codeforces-SFT",
|
||||
logging_steps=10,
|
||||
bf16=True,
|
||||
use_liger_kernel=True,
|
||||
gradient_checkpointing=True,
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False},
|
||||
max_length=8192,
|
||||
per_device_train_batch_size=1,
|
||||
gradient_accumulation_steps=8,
|
||||
dataset_num_proc=32,
|
||||
num_train_epochs=1,
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
args=training_args,
|
||||
model=model,
|
||||
train_dataset=train_dataset,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
# Push to hub
|
||||
trainer.push_to_hub(dataset_name="open-r1/codeforces-cots")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
223
examples/scripts/sft_vlm_gemma3.py
Normal file
223
examples/scripts/sft_vlm_gemma3.py
Normal file
@ -0,0 +1,223 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Train Gemma-3 on the HuggingFaceH4/llava-instruct-mix-vsft dataset (single-image).
|
||||
|
||||
accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
examples/scripts/sft_vlm_gemma3.py \
|
||||
--dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
|
||||
--model_name_or_path google/gemma-3-4b-it \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--output_dir gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--use_peft \
|
||||
--lora_target_modules all-linear \
|
||||
--attn_implementation eager
|
||||
|
||||
Train Gemma-3 on the FanqingM/MMIU-Benchmark dataset (multi-image).
|
||||
|
||||
accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
examples/scripts/sft_vlm_gemma3.py \
|
||||
--dataset_name FanqingM/MMIU-Benchmark \
|
||||
--dataset_train_split test \
|
||||
--model_name_or_path google/gemma-3-4b-it \
|
||||
--per_device_train_batch_size 1 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--output_dir gemma-3-4b-it-trl-sft-MMIU-Benchmark \
|
||||
--bf16 \
|
||||
--torch_dtype bfloat16 \
|
||||
--use_peft \
|
||||
--lora_target_modules all-linear
|
||||
--attn_implementation eager
|
||||
"""
|
||||
|
||||
import io
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
import torch
|
||||
from datasets import DatasetDict, load_dataset
|
||||
from huggingface_hub import hf_hub_download, list_repo_files
|
||||
from PIL import Image
|
||||
from transformers import AutoModelForImageTextToText, AutoProcessor
|
||||
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
ScriptArguments,
|
||||
SFTConfig,
|
||||
SFTTrainer,
|
||||
TrlParser,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
|
||||
|
||||
# For multi-image example
|
||||
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
|
||||
image_inputs = []
|
||||
for msg in messages:
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
content = [content]
|
||||
|
||||
for element in content:
|
||||
if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
|
||||
if "image" in element:
|
||||
image = element["image"]
|
||||
else:
|
||||
image = element
|
||||
if image is not None:
|
||||
image = Image.open(io.BytesIO(image["bytes"]))
|
||||
image_inputs.append(image.convert("RGB"))
|
||||
return image_inputs
|
||||
|
||||
|
||||
def format_data(samples: dict[str, any]) -> dict[str, list]:
|
||||
formatted_samples = {"messages": []}
|
||||
for cont in range(len(samples["question"])):
|
||||
images = []
|
||||
for img_path in samples["input_image_path"][cont]:
|
||||
try:
|
||||
with open(img_path, "rb") as f:
|
||||
img_bytes = f.read()
|
||||
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
||||
images.append({"type": "image", "image": image})
|
||||
except Exception as e:
|
||||
print(f"Error processing image {img_path}: {e}")
|
||||
continue
|
||||
|
||||
formatted_samples["messages"].append(
|
||||
[
|
||||
{"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]},
|
||||
{"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]},
|
||||
]
|
||||
)
|
||||
return formatted_samples
|
||||
|
||||
|
||||
# For multi-image example
|
||||
def prepare_dataset(dataset: DatasetDict, dataset_name: str, dataset_train_split: str) -> DatasetDict:
|
||||
all_files = list_repo_files(dataset_name, repo_type="dataset")
|
||||
zip_files = [f for f in all_files if f.endswith(".zip")]
|
||||
|
||||
for zip_filename in zip_files:
|
||||
zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset")
|
||||
extract_folder = zip_filename.replace(".zip", "")
|
||||
os.makedirs(extract_folder, exist_ok=True)
|
||||
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(extract_folder)
|
||||
|
||||
dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16)
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
|
||||
script_args, training_args, model_args = parser.parse_args_and_config()
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
training_args.remove_unused_columns = False
|
||||
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||
|
||||
################
|
||||
# Model, Tokenizer & Processor
|
||||
################
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
|
||||
)
|
||||
processor.tokenizer.padding_side = "right"
|
||||
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
|
||||
)
|
||||
|
||||
def collate_fn(examples):
|
||||
texts = [
|
||||
processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False).strip()
|
||||
for example in examples
|
||||
]
|
||||
if "images" in examples[0]: # single-image
|
||||
images = [[img.convert("RGB") for img in example["images"]] for example in examples]
|
||||
else: # multi-image
|
||||
images = [process_vision_info(example["messages"]) for example in examples]
|
||||
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(
|
||||
text=texts, images=images, return_tensors="pt", padding=True
|
||||
) # Encode texts and images into tensors
|
||||
|
||||
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||
labels = batch["input_ids"].clone() # Clone input IDs for labels
|
||||
# Mask image tokens
|
||||
image_token_id = [
|
||||
processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"])
|
||||
]
|
||||
# Mask tokens for not being used in the loss computation
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
labels[labels == image_token_id] = -100
|
||||
labels[labels == 262144] = -100
|
||||
|
||||
batch["labels"] = labels
|
||||
return batch # Return the prepared batch
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
||||
if script_args.dataset_name == "FanqingM/MMIU-Benchmark":
|
||||
dataset = prepare_dataset(dataset, script_args.dataset_name, script_args.dataset_train_split)
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
data_collator=collate_fn,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
||||
processing_class=processor.tokenizer,
|
||||
peft_config=get_peft_config(model_args),
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save and push to hub
|
||||
trainer.save_model(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||
if trainer.accelerator.is_main_process:
|
||||
processor.push_to_hub(training_args.hub_model_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user