mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
443 Commits
Author | SHA1 | Date | |
---|---|---|---|
0f9d948f7e | |||
cdb3f42332 | |||
fb11366ef0 | |||
ab46f07f9b | |||
64aa06499b | |||
be93a0c30c | |||
f9fbd91ea9 | |||
ff920d36d8 | |||
ddda41661b | |||
a108b9def9 | |||
54d4f6b13a | |||
05bc43e960 | |||
d3dc8ff654 | |||
21738c3732 | |||
eab175d434 | |||
4da4dc9117 | |||
6b3a02385d | |||
abbbb93d6a | |||
cafa663c84 | |||
fd04a5461a | |||
56e5766205 | |||
89d44caece | |||
adfa7fd59a | |||
cf5183db7f | |||
1954c02d86 | |||
45f4c58832 | |||
cc044e35b2 | |||
999acd53ec | |||
8606b1ad09 | |||
a673da5773 | |||
00b8e311aa | |||
c163cf5081 | |||
bc9c019c43 | |||
18596cf232 | |||
280d35301b | |||
13fa8402a3 | |||
09b669fbf7 | |||
01d0be15cb | |||
3a42af1c78 | |||
aaf39604ba | |||
2bf48478e8 | |||
a8cfca6d01 | |||
1bca49515e | |||
39e96394a9 | |||
8e6ed93dfd | |||
29c5e05e3a | |||
a9b27f82d6 | |||
cd6b3de356 | |||
36685c8bba | |||
89556c8cbf | |||
f3e8c23044 | |||
9ee6c3aa56 | |||
ef05331752 | |||
05e2ba6e01 | |||
1b4f189e09 | |||
1faa7f9b36 | |||
66e6eab9bb | |||
27af0aaf4a | |||
b4ffda769e | |||
0dad4eb7ca | |||
c82f626f94 | |||
33add19161 | |||
294f35bf3c | |||
9874b3aa04 | |||
1e61f6cc5a | |||
27adc30162 | |||
df737f99c1 | |||
c04e84c454 | |||
d625c5533a | |||
6cdd24a360 | |||
8b38570258 | |||
95b1a9f612 | |||
5c1511423b | |||
5e2e9cb442 | |||
227df8271e | |||
ae1581474e | |||
47b9515fb1 | |||
c4891dcfee | |||
055cee255a | |||
73a2fb0554 | |||
982ba08092 | |||
e03e7acc5c | |||
9df19e8a75 | |||
1d7b8c4f70 | |||
7e170612a4 | |||
559724ee2c | |||
a5a46725c8 | |||
b6bcafb8bb | |||
4bfb8eb0d1 | |||
4d66bad208 | |||
e90117b3e1 | |||
31b54a6237 | |||
17e33cdaa0 | |||
5a0cebc786 | |||
65308cfd84 | |||
1755e03f6f | |||
793735a698 | |||
e70a0efeca | |||
7eaca76ed1 | |||
657f9ce6ee | |||
485852c942 | |||
9f3702f6be | |||
e751a16df5 | |||
582bc5684b | |||
c5ba70d4fc | |||
5b586da3cc | |||
488025cd87 | |||
2594cb39de | |||
2fe2337067 | |||
f6b4d6e569 | |||
26d86757a7 | |||
9771f259ed | |||
7bdedd4075 | |||
a069a2f19c | |||
ea45f513f3 | |||
a91023990a | |||
1a9387b922 | |||
1884ff1bb8 | |||
bfe2075608 | |||
6067e2a669 | |||
dee37342a8 | |||
8037f18cdf | |||
a0a53171cc | |||
23a635ed61 | |||
9b38b0b5ee | |||
0f26049ea2 | |||
7511aa4e36 | |||
f713f614e9 | |||
a34987956c | |||
0f88c179e3 | |||
beda4328cc | |||
07cfe1677e | |||
9f7755d8ed | |||
4e3f569eb8 | |||
979fda1548 | |||
f6fb6a88a9 | |||
6cbf8fbc9f | |||
5cb390cd30 | |||
b3c391e628 | |||
1b85ca6147 | |||
e7a1290b0a | |||
3822edd67b | |||
230455cab0 | |||
08f014d559 | |||
10740333bd | |||
058a733c30 | |||
3f193972d8 | |||
b575596b89 | |||
118c43f0e0 | |||
40b1c33edf | |||
1a2e74cc5a | |||
80f7dcb16d | |||
4404ccd24a | |||
39f77ca2d8 | |||
52085dd96b | |||
c7a1c95017 | |||
3003058418 | |||
a759cee2e0 | |||
0a3bad44f0 | |||
bb5b96a823 | |||
8466c7273e | |||
a871ec8e91 | |||
f7572221db | |||
8ec2e42833 | |||
218d493d11 | |||
1a9f78eb3a | |||
a10978ebdf | |||
87fbb831d3 | |||
52f39d6a24 | |||
931f7a14d2 | |||
9951105a90 | |||
5a6e23aac9 | |||
d9104c8b0d | |||
d5a5840307 | |||
f3cbd41e2c | |||
d41a32f619 | |||
fc4dae256d | |||
e4e5671e80 | |||
7c76f103da | |||
aad18ef52a | |||
b55d9f0412 | |||
4871c82b0c | |||
fd9e5a7cab | |||
5463e49a55 | |||
22759c8208 | |||
2ee6fd369f | |||
844a9c665f | |||
04f6597377 | |||
e3244d2d09 | |||
6a02c69789 | |||
a1c58aa42a | |||
3f0695a4ca | |||
a72b50b772 | |||
ea1d9be2a7 | |||
402187baab | |||
5858ceab7e | |||
7442d42c21 | |||
98de0e7c62 | |||
491921c1a4 | |||
ad6a35bdd5 | |||
7bc9858a8f | |||
b882f57d93 | |||
ac7bde5832 | |||
3d94e4e25c | |||
1a303cca8e | |||
ac327d5e84 | |||
c0854c32c9 | |||
aa18ecfde7 | |||
6849c050b9 | |||
27a6f2201b | |||
f074dcdc86 | |||
0caff61600 | |||
019fc6dbaa | |||
69ad852e56 | |||
45ccdefac4 | |||
703484a8c2 | |||
9b76d5f2e9 | |||
cbe0681ba1 | |||
4e0cf01aef | |||
5c05913196 | |||
caba04da42 | |||
be5a088337 | |||
38861475e6 | |||
f69707dab4 | |||
76f00fc394 | |||
8453017622 | |||
3608709529 | |||
21f0055893 | |||
013d360b8f | |||
e5ae703d35 | |||
a92e00e810 | |||
9b3c5bf64f | |||
15fec312d5 | |||
be1e34003c | |||
6aaf379a82 | |||
49adf74833 | |||
6c54f023ae | |||
963243a7d1 | |||
aafd8cbea5 | |||
822653824b | |||
ba036576d4 | |||
293b620950 | |||
ae3bd0d07a | |||
6d9fc11fd6 | |||
ffcb9f4aee | |||
00e5889380 | |||
5c9cf2003d | |||
8830786a23 | |||
b0f513c13d | |||
81221661c6 | |||
7347c292c3 | |||
2106b31298 | |||
9b67eea473 | |||
e752fc6c2e | |||
674bb75f59 | |||
b9df81045b | |||
55e680e142 | |||
09eefa73ab | |||
7fdb69aa7d | |||
5b9236d1e8 | |||
82d12eb751 | |||
84d73fd00b | |||
2241f17914 | |||
cf97133d51 | |||
724acb9716 | |||
7134a1e73f | |||
bf6e7edea5 | |||
e95f9fb74a | |||
a85768f120 | |||
78c5ce23fd | |||
af4ad47035 | |||
b2ae99925d | |||
bd946f93c1 | |||
f42e34e613 | |||
338fbd546b | |||
32f8fa8aad | |||
1a2276402f | |||
1f344c9377 | |||
85121fc300 | |||
bbdd6db17c | |||
6e088d165c | |||
a325a0eec5 | |||
0ec1ccd990 | |||
1c35a48b50 | |||
2ce36ae889 | |||
bf6919117e | |||
265663af6a | |||
5ab15d3fef | |||
fecaa991de | |||
ab30a01baf | |||
6dc278a042 | |||
67441bb432 | |||
62685fbf20 | |||
4197956395 | |||
9ac8d9773b | |||
094d51b599 | |||
df8f619ec5 | |||
56880ba73d | |||
801582ec24 | |||
ed14ed9043 | |||
4659ad916f | |||
1123bd0f51 | |||
55a329e9f0 | |||
4720656654 | |||
807046b7d7 | |||
317d2d477b | |||
aeb03cf1a9 | |||
2578e95023 | |||
6f99f42f72 | |||
d14f7f3eb2 | |||
8e65825d4c | |||
5e4d7be0e1 | |||
f34b70a32e | |||
0e216f7411 | |||
59c201433c | |||
40c238395e | |||
a1d2955116 | |||
887c1f3fa3 | |||
949db2357e | |||
fe4b5efe4e | |||
a9b54a852e | |||
d4222a1e08 | |||
a5c88d6c75 | |||
b6a084c46e | |||
d9f056862f | |||
3d2c1e49b1 | |||
5fd78367ae | |||
0f5ffad26e | |||
88514d51e3 | |||
76837e82b9 | |||
35553930da | |||
fd4b283b82 | |||
1b1140aa69 | |||
4c7eb6fe29 | |||
564fc86759 | |||
3215a1c586 | |||
cdc16f3ac6 | |||
2ecd53ad77 | |||
5877786b5a | |||
57d9a97394 | |||
751fb1d84b | |||
edabe0a2d8 | |||
abfffc510b | |||
ed7de87dc7 | |||
beb892bfe0 | |||
f2d42fa0c2 | |||
d6a7e9d6f5 | |||
451677203d | |||
2f25f54ab9 | |||
a50124dd3a | |||
1d23ecc36f | |||
52d213173f | |||
d9ee2fd202 | |||
763738f457 | |||
aed5da580e | |||
99451b421a | |||
5239b9462d | |||
8fb267ff1e | |||
2e1adbb6ff | |||
b668048fe1 | |||
8c49ea39ec | |||
88ad1a099c | |||
9908dda6d9 | |||
5e204e1eaa | |||
82cfeb8930 | |||
0fe73a8af5 | |||
33fb9efc43 | |||
f68d11f9f9 | |||
aeca63774f | |||
117c6d4b52 | |||
6d4ed070f1 | |||
cd7156fb34 | |||
ca850be0a2 | |||
179ba53671 | |||
e3e171a26b | |||
b3aff441ff | |||
efc687db62 | |||
f2e362656c | |||
c9c4f18039 | |||
460e780265 | |||
7ba118a229 | |||
6a05feff02 | |||
2f72f47191 | |||
9410874787 | |||
9c5388b69e | |||
b02189aaa5 | |||
52201d3c18 | |||
9ff79a65e3 | |||
9001a8682c | |||
f6f42651e2 | |||
148b592313 | |||
d6a8f2c2f6 | |||
8d9cfaafeb | |||
94e4135a17 | |||
ac267781ec | |||
2c6e0d9705 | |||
e1d781353b | |||
a34e9bf84f | |||
c10cc8995b | |||
9368dccef6 | |||
43df3a485a | |||
baee06f2e8 | |||
bbd8cbb720 | |||
4f937c7629 | |||
16fa13ce72 | |||
453db5cd79 | |||
ee3cbe1946 | |||
17e8060984 | |||
163695e85c | |||
672c96546d | |||
bdeb117320 | |||
6578fdc101 | |||
a0066f47f8 | |||
5626806aef | |||
bb0afc2459 | |||
066fc37bd3 | |||
b80c1a6fb8 | |||
b5eabbeb07 | |||
cbf9abcd07 | |||
6f8fe59aeb | |||
1293f37c5f | |||
e7870dd5d6 | |||
21d5baf338 | |||
76dbb1a576 | |||
b8c9d9c7bc | |||
623963126b | |||
2d24d35013 | |||
dde20b23cf | |||
015321e135 | |||
454f36d951 | |||
9b7f9f3519 | |||
518e29ca9c | |||
ac7b6cfdfa | |||
0238d96c6f | |||
c86b51cd12 | |||
ac77c09223 | |||
7f2ccbe3a2 | |||
74e20cbbbc | |||
27b9e3a93f | |||
dc2b8b9e90 | |||
5e90682836 | |||
3b439967f4 | |||
2f34a161cd |
76
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
76
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -7,36 +7,7 @@ body:
|
||||
value: |
|
||||
Thanks for taking the time to fill out this bug report! 🤗
|
||||
|
||||
Before you submit your bug report:
|
||||
|
||||
- If it is your first time submitting, be sure to check our [bug report guidelines](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#did-you-find-a-bug)
|
||||
|
||||
- type: textarea
|
||||
id: system-info
|
||||
attributes:
|
||||
label: System Info
|
||||
description: Please share your system info with us. You can run the command `trl env` and copy-paste its output below.
|
||||
placeholder: trl version, transformers version, platform, python version, ...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: checkboxes
|
||||
id: information-scripts-examples
|
||||
attributes:
|
||||
label: Information
|
||||
description: 'The problem arises when using:'
|
||||
options:
|
||||
- label: "The official example scripts"
|
||||
- label: "My own modified scripts"
|
||||
|
||||
- type: checkboxes
|
||||
id: information-tasks
|
||||
attributes:
|
||||
label: Tasks
|
||||
description: "The tasks I am working on are:"
|
||||
options:
|
||||
- label: "An officially supported task in the `examples` folder"
|
||||
- label: "My own task or dataset (give details below)"
|
||||
🚩 If it is your first time submitting, be sure to check our [bug report guidelines](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#did-you-find-a-bug)
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
@ -50,18 +21,47 @@ body:
|
||||
Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
|
||||
Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
|
||||
|
||||
placeholder: |
|
||||
Steps to reproduce the behavior:
|
||||
value: |
|
||||
```python
|
||||
from trl import ...
|
||||
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
```
|
||||
|
||||
outputs:
|
||||
|
||||
```
|
||||
Traceback (most recent call last):
|
||||
File "example.py", line 42, in <module>
|
||||
...
|
||||
```
|
||||
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
id: system-info
|
||||
attributes:
|
||||
label: System Info
|
||||
description: |
|
||||
Please provide information about your system: platform, Python version, PyTorch version, Transformers version, devices, TRL version, ...
|
||||
You can get this information by running `trl env` in your terminal.
|
||||
|
||||
placeholder: Copy-paste the output of `trl env`
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: checkboxes
|
||||
id: terms
|
||||
attributes:
|
||||
label: Expected behavior
|
||||
description: "A clear and concise description of what you would expect to happen."
|
||||
label: Checklist
|
||||
description: |
|
||||
Before submitting, please confirm that you've completed each of the following.
|
||||
If an item doesn't apply to your issue, check it anyway to show you've reviewed it.
|
||||
options:
|
||||
- label: "I have checked that my issue isn't already filed (see [open issues](https://github.com/huggingface/trl/issues?q=is%3Aissue))"
|
||||
required: true
|
||||
- label: "I have included my system information"
|
||||
required: true
|
||||
- label: "Any code provided is minimal, complete, and reproducible ([more on MREs](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))"
|
||||
required: true
|
||||
- label: "Any code provided is properly formatted in code blocks, (no screenshot, [more on code blocks](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))"
|
||||
required: true
|
||||
- label: "Any traceback provided is complete"
|
||||
required: true
|
||||
|
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
3
.github/PULL_REQUEST_TEMPLATE.md
vendored
@ -21,8 +21,7 @@ Fixes # (issue)
|
||||
Pull Request section?
|
||||
- [ ] Was this discussed/approved via a GitHub issue? Please add a link
|
||||
to it if that's the case.
|
||||
- [ ] Did you make sure to update the documentation with your changes? Here are the
|
||||
[documentation guidelines](https://github.com/huggingface/trl/tree/main/docs).
|
||||
- [ ] Did you make sure to update the documentation with your changes?
|
||||
- [ ] Did you write any new necessary tests?
|
||||
|
||||
|
||||
|
19
.github/codeql/custom-queries.qls
vendored
Normal file
19
.github/codeql/custom-queries.qls
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
import codeql
|
||||
|
||||
from WorkflowString interpolation, Workflow workflow
|
||||
where
|
||||
interpolation.getStringValue().matches("${{ github.event.issue.title }}") or
|
||||
interpolation.getStringValue().matches("${{ github.event.issue.body }}") or
|
||||
interpolation.getStringValue().matches("${{ github.event.pull_request.title }}") or
|
||||
interpolation.getStringValue().matches("${{ github.event.pull_request.body }}") or
|
||||
interpolation.getStringValue().matches("${{ github.event.review.body }}") or
|
||||
interpolation.getStringValue().matches("${{ github.event.comment.body }}") or
|
||||
interpolation.getStringValue().matches("${{ github.event.inputs.* }}") or
|
||||
interpolation.getStringValue().matches("${{ github.event.head_commit.message }}")
|
||||
interpolation.getStringValue().matches("${{ github.event.* }}") and
|
||||
(
|
||||
step.getKey() = "run" or // Injection in run
|
||||
step.getKey() = "env" or // Injection via env
|
||||
step.getKey() = "with" // Injection via with
|
||||
)
|
||||
select workflow, "🚨 Do not use directly as input of action"
|
1
.github/workflows/build_pr_documentation.yml
vendored
1
.github/workflows/build_pr_documentation.yml
vendored
@ -9,6 +9,7 @@ concurrency:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
if: github.event.pull_request.draft == false
|
||||
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
|
||||
with:
|
||||
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||
|
26
.github/workflows/codeQL.yml
vendored
Normal file
26
.github/workflows/codeQL.yml
vendored
Normal file
@ -0,0 +1,26 @@
|
||||
name: "CodeQL Analysis - Workflows"
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
analyze:
|
||||
name: "Analyze GitHub Workflows"
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
security-events: write
|
||||
actions: read
|
||||
contents: read
|
||||
|
||||
steps:
|
||||
- name: "Checkout repository"
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: "Initialize CodeQL"
|
||||
uses: github/codeql-action/init@v2
|
||||
with:
|
||||
languages: "yaml"
|
||||
queries: +security-and-quality, ./.github/codeql/custom-queries.qls
|
||||
|
||||
- name: "Perform CodeQL Analysis"
|
||||
uses: github/codeql-action/analyze@v2
|
15
.github/workflows/issue_auto_labeller.yml
vendored
Normal file
15
.github/workflows/issue_auto_labeller.yml
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
name: "Hugging Face Issue Labeler"
|
||||
on:
|
||||
issues:
|
||||
types: opened
|
||||
|
||||
jobs:
|
||||
triage:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: August-murr/auto-labeler@main
|
||||
with:
|
||||
hf-api-key: ${{ secrets.CI_HF_API_TOKEN }}
|
127
.github/workflows/pr_style_bot.yml
vendored
Normal file
127
.github/workflows/pr_style_bot.yml
vendored
Normal file
@ -0,0 +1,127 @@
|
||||
name: PR Style Bot
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
run-style-bot:
|
||||
if: >
|
||||
contains(github.event.comment.body, '@bot /style') &&
|
||||
github.event.issue.pull_request != null
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Extract PR details
|
||||
id: pr_info
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const prNumber = context.payload.issue.number;
|
||||
const { data: pr } = await github.rest.pulls.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: prNumber
|
||||
});
|
||||
|
||||
// We capture both the branch ref and the "full_name" of the head repo
|
||||
// so that we can check out the correct repository & branch (including forks).
|
||||
core.setOutput("prNumber", prNumber);
|
||||
core.setOutput("headRef", pr.head.ref);
|
||||
core.setOutput("headRepoFullName", pr.head.repo.full_name);
|
||||
|
||||
- name: Check out PR branch
|
||||
uses: actions/checkout@v3
|
||||
env:
|
||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
||||
with:
|
||||
# Instead of checking out the base repo, use the contributor's repo name
|
||||
repository: ${{ env.HEADREPOFULLNAME }}
|
||||
ref: ${{ env.HEADREF }}
|
||||
# You may need fetch-depth: 0 for being able to push
|
||||
fetch-depth: 0
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Debug
|
||||
env:
|
||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
||||
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
|
||||
run: |
|
||||
echo "PR number: ${{ env.PRNUMBER }}"
|
||||
echo "Head Ref: ${{ env.HEADREF }}"
|
||||
echo "Head Repo Full Name: ${{ env.HEADREPOFULLNAME }}"
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install ruff pre-commit
|
||||
|
||||
- name: Download Makefile from main branch
|
||||
run: |
|
||||
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/trl/main/Makefile
|
||||
|
||||
- name: Compare Makefiles
|
||||
run: |
|
||||
if ! diff -q main_Makefile Makefile; then
|
||||
echo "Error: The Makefile has changed. Please ensure it matches the main branch."
|
||||
exit 1
|
||||
fi
|
||||
echo "No changes in Makefile. Proceeding..."
|
||||
rm -rf main_Makefile
|
||||
|
||||
- name: Run make style and make quality
|
||||
run: |
|
||||
make precommit || true
|
||||
|
||||
- name: Commit and push changes
|
||||
id: commit_and_push
|
||||
env:
|
||||
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
|
||||
HEADREF: ${{ steps.pr_info.outputs.headRef }}
|
||||
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
echo "HEADREPOFULLNAME: ${{ env.HEADREPOFULLNAME }}, HEADREF: ${{ env.HEADREF }}"
|
||||
# Configure git with the Actions bot user
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
# Make sure your 'origin' remote is set to the contributor's fork
|
||||
git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${{ env.HEADREPOFULLNAME }}.git"
|
||||
|
||||
# If there are changes after running style/quality, commit them
|
||||
if [ -n "$(git status --porcelain)" ]; then
|
||||
git add .
|
||||
git commit -m "Apply style fixes"
|
||||
# Push to the original contributor's forked branch
|
||||
git push origin HEAD:${{ env.HEADREF }}
|
||||
echo "changes_pushed=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "No changes to commit."
|
||||
echo "changes_pushed=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Comment on PR with workflow run link
|
||||
if: steps.commit_and_push.outputs.changes_pushed == 'true'
|
||||
uses: actions/github-script@v6
|
||||
with:
|
||||
script: |
|
||||
const prNumber = parseInt(process.env.prNumber, 10);
|
||||
const runUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`
|
||||
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: prNumber,
|
||||
body: `Style fixes have been applied. [View the workflow run here](${runUrl}).`
|
||||
});
|
||||
env:
|
||||
prNumber: ${{ steps.pr_info.outputs.prNumber }}
|
182
.github/workflows/tests.yml
vendored
182
.github/workflows/tests.yml
vendored
@ -21,11 +21,9 @@ jobs:
|
||||
check_code_quality:
|
||||
name: Check code quality
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.pull_request.draft == false
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: recursive
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
@ -38,126 +36,216 @@ jobs:
|
||||
name: Tests
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.9', '3.10', '3.11', '3.12']
|
||||
os: ['ubuntu-latest', 'windows-latest']
|
||||
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13']
|
||||
fail-fast: false
|
||||
runs-on: ${{ matrix.os }}
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
if: github.event.pull_request.draft == false
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
setup.py
|
||||
requirements.txt
|
||||
|
||||
- name: Install Make and Git
|
||||
run: |
|
||||
apt-get update && apt-get install -y make git curl
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Create Python virtual environment
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install --upgrade setuptools wheel
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install ".[dev]"
|
||||
source .venv/bin/activate
|
||||
uv pip install ".[dev]"
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
make test
|
||||
|
||||
- name: Post to Slack
|
||||
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: Results with ${{ matrix.python-version }} on ${{ matrix.os }} with lastest dependencies
|
||||
title: Results with Python ${{ matrix.python-version }} and latest dependencies
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
tests_dev:
|
||||
name: Tests with dev dependencies
|
||||
runs-on: 'ubuntu-latest'
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
if: github.event.pull_request.draft == false
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
setup.py
|
||||
requirements.txt
|
||||
|
||||
- name: Install Make and Git
|
||||
run: |
|
||||
apt-get update && apt-get install -y make git curl
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Create Python virtual environment
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install --upgrade setuptools wheel
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -U git+https://github.com/huggingface/accelerate.git
|
||||
python -m pip install -U git+https://github.com/huggingface/datasets.git
|
||||
python -m pip install -U git+https://github.com/huggingface/transformers.git
|
||||
python -m pip install ".[dev]"
|
||||
source .venv/bin/activate
|
||||
uv pip install -U git+https://github.com/huggingface/accelerate.git
|
||||
uv pip install -U git+https://github.com/huggingface/datasets.git
|
||||
uv pip install -U git+https://github.com/huggingface/transformers.git
|
||||
uv pip install ".[dev]"
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
make test
|
||||
|
||||
- name: Post to Slack
|
||||
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: Results with ${{ matrix.python-version }} on ${{ matrix.os }} with dev dependencies
|
||||
title: Results with Python 3.12 and dev dependencies
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
tests_wo_optional_deps:
|
||||
name: Tests without optional dependencies
|
||||
runs-on: 'ubuntu-latest'
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
if: github.event.pull_request.draft == false
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
setup.py
|
||||
requirements.txt
|
||||
|
||||
- name: Install Make and Git
|
||||
run: |
|
||||
apt-get update && apt-get install -y make git curl
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Create Python virtual environment
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install --upgrade setuptools wheel
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install ".[test]"
|
||||
source .venv/bin/activate
|
||||
uv pip install ".[test]"
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
make test
|
||||
|
||||
- name: Post to Slack
|
||||
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: Results with ${{ matrix.python-version }} on ${{ matrix.os }} without optional dependencies
|
||||
title: Results with Python 3.12 without optional dependencies
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
tests_min_versions:
|
||||
name: Tests with minimum versions
|
||||
runs-on: 'ubuntu-latest'
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
if: github.event.pull_request.draft == false
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
setup.py
|
||||
requirements.txt
|
||||
|
||||
- name: Install Make and Git
|
||||
run: |
|
||||
apt-get update && apt-get install -y make git curl
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Create Python virtual environment
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install --upgrade setuptools wheel
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install accelerate==0.34.0
|
||||
python -m pip install datasets==2.21.0
|
||||
python -m pip install transformers==4.46.0
|
||||
python -m pip install ".[dev]"
|
||||
source .venv/bin/activate
|
||||
uv pip install accelerate==0.34.0
|
||||
uv pip install datasets==3.0.0
|
||||
uv pip install transformers==4.46.0
|
||||
uv pip install ".[dev]"
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
make test
|
||||
|
||||
- name: Post to Slack
|
||||
if: github.ref == 'refs/heads/main' && always() # Check if the branch is main
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: Results with ${{ matrix.python-version }} on ${{ matrix.os }} with minimum versions
|
||||
title: Results with Python 3.12 and minimum dependencies versions
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
66
.github/workflows/tests_latest.yml
vendored
Normal file
66
.github/workflows/tests_latest.yml
vendored
Normal file
@ -0,0 +1,66 @@
|
||||
name: Tests latest TRL release with dev dependencies
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # Runs daily at midnight UTC
|
||||
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
TQDM_DISABLE: 1
|
||||
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
name: Tests latest TRL release with dev dependencies
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
with: { ref: v0.17-release }
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install Make and Git
|
||||
run: |
|
||||
apt-get update && apt-get install -y make git curl
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
- name: Create Python virtual environment
|
||||
run: |
|
||||
uv venv
|
||||
uv pip install --upgrade setuptools wheel
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
uv pip install -U git+https://github.com/huggingface/accelerate.git
|
||||
uv pip install -U git+https://github.com/huggingface/datasets.git
|
||||
uv pip install -U git+https://github.com/huggingface/transformers.git
|
||||
uv pip install ".[dev]"
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
make test
|
||||
|
||||
- name: Post to Slack
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: Results of latest TRL with Python 3.12 and dev dependencies
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
5
.github/workflows/trufflehog.yml
vendored
5
.github/workflows/trufflehog.yml
vendored
@ -12,4 +12,7 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@main
|
||||
uses: trufflesecurity/trufflehog@853e1e8d249fd1e29d0fcc7280d29b03df3d643d
|
||||
with:
|
||||
# exclude buggy postgres detector that is causing false positives and not relevant to our codebase
|
||||
extra_args: --results=verified,unknown --exclude-detectors=postgres
|
||||
|
5
.gitignore
vendored
5
.gitignore
vendored
@ -142,7 +142,4 @@ checklink/cookies.txt
|
||||
# wandb files
|
||||
nbs/wandb/
|
||||
examples/notebooks/wandb/
|
||||
wandb/
|
||||
|
||||
# cli scripts that are symlinked from `examples/scripts`
|
||||
trl/commands/scripts/
|
||||
wandb/
|
@ -1,6 +1,6 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.6.3
|
||||
rev: v0.11.3
|
||||
hooks:
|
||||
- id: ruff
|
||||
types_or: [ python, pyi ]
|
||||
|
@ -31,4 +31,4 @@ keywords:
|
||||
- pytorch
|
||||
- transformers
|
||||
license: Apache-2.0
|
||||
version: 0.11.1
|
||||
version: 0.17
|
||||
|
402
CONTRIBUTING.md
402
CONTRIBUTING.md
@ -23,7 +23,7 @@ There are several ways you can contribute to TRL:
|
||||
* Contribute to the examples or the documentation.
|
||||
|
||||
If you don't know where to start, there is a special [Good First
|
||||
Issue](https://github.com/huggingface/trl/contribute) listing. It will give you a list of
|
||||
Issue](https://github.com/huggingface/trl/labels/%F0%9F%91%B6%20good%20first%20issue) listing. It will give you a list of
|
||||
open issues that are beginner-friendly and help you start contributing to open-source. The best way to do that is to open a Pull Request and link it to the issue that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over.
|
||||
|
||||
For something slightly more challenging, you can also take a look at the [Good Second Issue](https://github.com/huggingface/trl/labels/Good%20Second%20Issue) list. In general though, if you feel like you know what you're doing, go for it and we'll help you get there! 🚀
|
||||
@ -33,12 +33,12 @@ For something slightly more challenging, you can also take a look at the [Good S
|
||||
Before you start contributing make sure you have installed all the dev tools:
|
||||
|
||||
```bash
|
||||
make dev
|
||||
pip install -e .[dev]
|
||||
```
|
||||
|
||||
## Fixing outstanding issues
|
||||
|
||||
If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](#create-a-pull-request) and open a Pull Request!
|
||||
If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](#submitting-a-pull-request-pr) and open a Pull Request!
|
||||
|
||||
## Submitting a bug-related issue or feature request
|
||||
|
||||
@ -152,7 +152,7 @@ Follow these steps to start contributing:
|
||||
4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library:
|
||||
|
||||
```bash
|
||||
$ make dev
|
||||
$ pip install -e .[dev]
|
||||
```
|
||||
|
||||
(If TRL was already installed in the virtual environment, remove
|
||||
@ -171,8 +171,7 @@ Follow these steps to start contributing:
|
||||
$ pytest tests/<TEST_TO_RUN>.py
|
||||
```
|
||||
|
||||
> For the following commands leveraging the `make` utility, we recommend using the WSL system when running on
|
||||
> Windows. More information [here](https://docs.microsoft.com/en-us/windows/wsl/about).
|
||||
> For the following commands leveraging the `make` utility.
|
||||
|
||||
You can also run the full suite with the following command.
|
||||
|
||||
@ -256,3 +255,394 @@ That's how `make test` is implemented (without the `pip install` line)!
|
||||
|
||||
You can specify a smaller set of tests to test only the feature
|
||||
you're working on.
|
||||
|
||||
### Default values guidelines
|
||||
|
||||
1. **Use defaults when appropriate**:
|
||||
|
||||
Provide default values unless the parameter's value varies significantly by use case. For example, datasets or models should not have defaults, but parameters like `learning_rate` should.
|
||||
|
||||
2. **Prioritize proven defaults**:
|
||||
|
||||
Default values should align with those recommended in the original paper or method. Alternatives require strong evidence of superior performance in most cases.
|
||||
|
||||
3. **Ensure safety and predictability**:
|
||||
|
||||
Defaults must be safe, expected and reliable. Avoid settings that could lead to surprising outcomes, such as excessive memory usage or poor performance in edge cases.
|
||||
|
||||
4. **Balance consistency and flexibility**:
|
||||
|
||||
Aim for consistent defaults across similar functions or methods. However, consistency should not be preferred to point 2 or 3.
|
||||
|
||||
5. **Opt-in for new features**:
|
||||
|
||||
Do not enable new features or improvements (e.g., novel loss functions) by default. Users should explicitly opt-in to use these.
|
||||
|
||||
### Writing documentation
|
||||
|
||||
High-quality documentation is crucial for maintaining a project that is easy to use, understand, and extend. When adding new features, ensure they are thoroughly documented to maintain consistency and clarity throughout the project.
|
||||
|
||||
To illustrate what good documentation looks like, here’s an example of a well-documented function:
|
||||
|
||||
````python
|
||||
def replicate_str(string: str, n: int, sep: str = " ") -> str:
|
||||
r"""
|
||||
Replicate a string `n` times with a separator.
|
||||
|
||||
Args:
|
||||
string (`str`):
|
||||
String to replicate.
|
||||
n (`int`):
|
||||
Number of times to replicate the string.
|
||||
sep (`str`, *optional*, defaults to `" "`):
|
||||
Separator to use between each replication.
|
||||
|
||||
Returns:
|
||||
`str`: The replicated string.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
>>> replicate_str("hello", 3)
|
||||
"hello hello hello"
|
||||
>>> replicate_str("hello", 3, sep=", ")
|
||||
"hello, hello, hello"
|
||||
```
|
||||
"""
|
||||
return sep.join([string] * n)
|
||||
````
|
||||
|
||||
* **Line Wrapping:** Applied a consistent line wrap at column 120 to improve readability.
|
||||
* **Definite Articles:** Removed definite articles where possible to streamline language. (Eg: Changed "The string to replicate" to "String to replicate")
|
||||
* **Type Annotations:**
|
||||
* Always include type definitions, indicating if a parameter is optional and specifying the default value.
|
||||
* Note that `Optional` means that the value can be `None`, and `*optional*` means that it is not required for the user to pass a value.
|
||||
E.g., for arguments that can't be `None` and aren't required:
|
||||
|
||||
```python
|
||||
foo (`int`, *optional*, defaults to `4`):
|
||||
```
|
||||
|
||||
For arguments that can be `None` and are required:
|
||||
|
||||
```python
|
||||
foo (`Optional[int]`):
|
||||
```
|
||||
|
||||
for arguments that can be `None` and aren't required:
|
||||
|
||||
```python
|
||||
foo (`Optional[int]`, *optional*, defaults to `None`):
|
||||
```
|
||||
|
||||
* **String Defaults:**
|
||||
* Ensured that default string values are wrapped in double quotes:
|
||||
|
||||
```python
|
||||
defaults to `"foo"`
|
||||
```
|
||||
|
||||
* **Dictionary Typing:**
|
||||
* Replaced generic `dict` type hints with more explicit `dict[str, Any]` to clarify expected key-value pairs.
|
||||
* **Default Value Formatting:**
|
||||
* Consistently surrounded default values with backticks for improved formatting:
|
||||
|
||||
```python
|
||||
defaults to `4`
|
||||
```
|
||||
|
||||
* **Sub-sectioning:** When the number of arguments is large, consider breaking them into sub-sections for better readability.
|
||||
|
||||
```python
|
||||
def calculate_statistics(data: list[float], precision: int = 2, include_variance: bool = False) -> dict[str, float]:
|
||||
r"""
|
||||
Calculates basic statistics for a given dataset.
|
||||
|
||||
Args:
|
||||
> Data inputs
|
||||
|
||||
data (`list[float]`):
|
||||
A list of numerical values to analyze.
|
||||
|
||||
> Configuration parameters
|
||||
|
||||
precision (`int`, *optional*, defaults to `2`):
|
||||
Number of decimal places to round the results.
|
||||
include_variance (`bool`, *optional*, defaults to `False`):
|
||||
Whether to include the variance of the dataset in the results.
|
||||
|
||||
Returns:
|
||||
`dict[str, float]`:
|
||||
A dictionary containing calculated statistics such as mean, median, and optionally variance.
|
||||
"""
|
||||
...
|
||||
```
|
||||
|
||||
### Deprecation and backward compatibility
|
||||
|
||||
Our approach to deprecation and backward compatibility is flexible and based on the feature’s usage and impact. Each deprecation is carefully evaluated, aiming to balance innovation with user needs.
|
||||
|
||||
When a feature or component is marked for deprecation, its use will emit a warning message. This warning will include:
|
||||
|
||||
- **Transition Guidance**: Instructions on how to migrate to the alternative solution or replacement.
|
||||
- **Removal Version**: The target version when the feature will be removed, providing users with a clear timeframe to transition.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
warnings.warn(
|
||||
"The `Trainer.foo` method is deprecated and will be removed in version 0.14.0. "
|
||||
"Please use the `Trainer.bar` class instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
```
|
||||
|
||||
The deprecation and removal schedule is based on each feature's usage and impact, with examples at two extremes:
|
||||
|
||||
- **Experimental or Low-Use Features**: For a feature that is experimental or has limited usage, backward compatibility may not be maintained between releases. Users should therefore anticipate potential breaking changes from one version to the next.
|
||||
|
||||
- **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning.
|
||||
|
||||
These examples represent the two ends of a continuum. The specific timeline for each feature will be determined individually, balancing innovation with user stability needs.
|
||||
|
||||
### Working with warnings
|
||||
|
||||
Warnings play a critical role in guiding users toward resolving potential issues, but they should be used thoughtfully to avoid unnecessary noise. Unlike logging, which provides informational context or operational details, warnings signal conditions that require attention and action. Overusing warnings can dilute their importance, leading users to ignore them entirely.
|
||||
|
||||
#### Definitions
|
||||
|
||||
- **Correct**: An operation is correct if it is valid, follows the intended approach, and aligns with the current best practices or guidelines within the codebase. This is the recommended or intended way to perform the operation.
|
||||
- **Supported**: An operation is supported if it is technically valid and works within the current codebase, but it may not be the most efficient, optimal, or recommended way to perform the task. This includes deprecated features or legacy approaches that still work but may be phased out in the future.
|
||||
|
||||
#### Choosing the right message
|
||||
|
||||
- **Correct → No warning**:
|
||||
If the operation is fully valid and expected, no message should be issued. The system is working as intended, so no warning is necessary.
|
||||
|
||||
- **Correct but deserves attention → No warning, possibly a log message**:
|
||||
When an operation is correct but uncommon or requires special attention, providing an informational message can be helpful. This keeps users informed without implying any issue. If available, use the logger to output this message. Example:
|
||||
|
||||
```python
|
||||
logger.info("This is an informational message about a rare but correct operation.")
|
||||
```
|
||||
|
||||
- **Correct but very likely a mistake → Warning with option to disable**:
|
||||
In rare cases, you may want to issue a warning for a correct operation that’s very likely a mistake. In such cases, you must provide an option to suppress the warning. This can be done with a flag in the function. Example:
|
||||
|
||||
```python
|
||||
def my_function(foo, bar, _warn=True):
|
||||
if foo == bar:
|
||||
if _warn:
|
||||
warnings.warn("foo and bar are the same, this is likely a mistake. Ignore this warning by setting `_warn=False`.")
|
||||
# Do something
|
||||
```
|
||||
|
||||
- **Supported but not correct → Warning**:
|
||||
If the operation is technically supported but is deprecated, suboptimal, or could cause future issues (e.g., conflicting arguments), a warning should be raised. This message should be actionable, meaning it must explain how to resolve the issue. Example:
|
||||
|
||||
```python
|
||||
def my_function(foo, bar):
|
||||
if foo and bar:
|
||||
warnings.warn("Both `foo` and `bar` were provided, but only one is allowed. Ignoring `foo`. Please pass only one of these arguments.")
|
||||
# Do something
|
||||
```
|
||||
|
||||
- **Not supported → Exception**:
|
||||
If the operation is invalid or unsupported, raise an exception. This indicates that the operation cannot be performed and requires immediate attention. Example:
|
||||
|
||||
```python
|
||||
def my_function(foo, bar):
|
||||
if foo and bar:
|
||||
raise ValueError("Both `foo` and `bar` were provided, but only one is allowed. Please pass only one of these arguments.")
|
||||
```
|
||||
|
||||
By following this classification, you ensure that warnings, information, and exceptions are used appropriately, providing clear guidance to the user without cluttering the system with unnecessary messages.
|
||||
|
||||
|
||||
## Making a release
|
||||
|
||||
> [!NOTE]
|
||||
> VERSION needs to be formatted following the `v{major}.{minor}.{patch}` convention. We need to follow this convention to be able to retrieve versioned scripts.
|
||||
|
||||
To create the package for PyPI.
|
||||
|
||||
#### 0. Prerequisites
|
||||
|
||||
- Dependencies:
|
||||
- twine: `pip install build twine`
|
||||
- Create an account in (and join the `trl` project):
|
||||
- PyPI: https://pypi.org/
|
||||
- Test PyPI: https://test.pypi.org/
|
||||
|
||||
#### 1. Ensure your local repository is up to date with the upstream repository
|
||||
|
||||
```bash
|
||||
git checkout main
|
||||
git pull origin main
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Do not merge other pull requests into `main` until the release is done. This is to ensure that the release is stable and does not include any untested changes. Announce internally (#trl-internal) to other maintainers that you are doing a release and that they must not merge PRs until the release is done.
|
||||
|
||||
#### 2. Create a release branch from main
|
||||
|
||||
```bash
|
||||
git checkout -b release-v{major}.{minor}
|
||||
```
|
||||
|
||||
#### 3. Change the version in the following files
|
||||
|
||||
- `.github/workflows/tests_latest.yml`:
|
||||
```diff
|
||||
- with: { ref: v{major}.{minor-1}-release }
|
||||
+ with: { ref: v{major}.{minor}-release }
|
||||
```
|
||||
- `CITATION.cff`
|
||||
```diff
|
||||
- version: {major}.{minor-1}
|
||||
+ version: {major}.{minor}
|
||||
```
|
||||
- `__init__.py`
|
||||
```diff
|
||||
- __version__ = "{major}.{minor}.0.dev0"
|
||||
+ __version__ = "{major}.{minor}.0"
|
||||
```
|
||||
- `setup.cfg`
|
||||
```diff
|
||||
- version = {major}.{minor}.0.dev0
|
||||
+ version = {major}.{minor}.0
|
||||
```
|
||||
|
||||
#### 4. Commit and push these changes
|
||||
|
||||
```shell
|
||||
git commit -m 'Release: {major}.{minor}'
|
||||
git push origin release-v{major}.{minor}
|
||||
```
|
||||
|
||||
#### 5. Create a pull request
|
||||
|
||||
from `release-v{major}.{minor}` to `main`, named `Release: v{major}.{minor}`, wait for tests to pass, and request a review.
|
||||
|
||||
#### 6. Once the pull request is approved, merge it into `main`
|
||||
|
||||
#### 7. Add a tag in git to mark the release
|
||||
|
||||
```shell
|
||||
git checkout main
|
||||
git pull origin main
|
||||
git tag -a v{major}.{minor}.0 -m 'Adds tag v{major}.{minor}.0 for PyPI'
|
||||
git push origin v{major}.{minor}.0
|
||||
```
|
||||
|
||||
#### 8. Create a branch `v{major}.{minor}-release` for future patch releases.
|
||||
|
||||
```shell
|
||||
git checkout -b v{major}.{minor}-release
|
||||
git push origin v{major}.{minor}-release
|
||||
```
|
||||
|
||||
This ensures that future patch releases (`v{major}.{minor}.1`, `v{major}.{minor}.2`, etc.) can be made separately from `main`.
|
||||
|
||||
#### 9. Create the wheels for your release
|
||||
|
||||
These are the artifacts that will be uploaded to PyPI and installed by users via `pip install trl`.
|
||||
|
||||
Clean previous builds:
|
||||
|
||||
```shell
|
||||
rm -rf build dist
|
||||
```
|
||||
|
||||
At the root of your repo, run
|
||||
|
||||
```bash
|
||||
python -m build .
|
||||
```
|
||||
|
||||
This will create a folders named `dist` with the new versions of your package.
|
||||
|
||||
#### 10. Upload the package to PyPI Test
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Do not skip this step. It is important to test the package before uploading it to the main PyPI server.
|
||||
|
||||
```shell
|
||||
twine upload dist/* -r testpypi
|
||||
```
|
||||
|
||||
Then in a fresh environment containing all dependencies you need, try to install your new package from the PyPI test server.
|
||||
|
||||
```bash
|
||||
pip install -i https://test.pypi.org/simple/ trl
|
||||
```
|
||||
|
||||
You might get errors for missing dependencies since the PyPI test server does not contain all packages like PyPI does. To make sure you have everything you can do:
|
||||
|
||||
```bash
|
||||
pip install trl
|
||||
pip uninstall trl
|
||||
```
|
||||
|
||||
(the second line will remove trl but keep all its dependencies).
|
||||
|
||||
Also make sure you can actually use the package! Run the following line:
|
||||
|
||||
```bash
|
||||
python -c "from trl import *"
|
||||
```
|
||||
|
||||
along with anything that tests:
|
||||
|
||||
- the core feature of your package
|
||||
- the new features you’re adding in the release
|
||||
|
||||
#### 11. Publish on PyPI
|
||||
|
||||
> [!WARNING]
|
||||
> This can't be reverted. Make sure you have tested everything before doing this step.
|
||||
|
||||
```shell
|
||||
twine upload dist/*
|
||||
```
|
||||
|
||||
#### 12. Create a GitHub Release
|
||||
|
||||
1. Go to the repo’s [releases section](https://github.com/huggingface/trl/releases) on GitHub.
|
||||
2. Click **Draft a new release**.
|
||||
3. Select the `v{major}.{minor}.0` tag you just created in step 7.
|
||||
4. Add a title (`v{major}.{minor}.0`) and a short description of what’s new.
|
||||
5. Click **Publish Release**.
|
||||
|
||||
#### 13. Bump to dev version
|
||||
|
||||
1. Create a branch `bump-dev-version-{major}.{minor+1}` from `main` and checkout to it.
|
||||
|
||||
```shell
|
||||
git checkout -b bump-dev-version-{major}.{minor+1}
|
||||
```
|
||||
|
||||
2. Change the version in the following files:
|
||||
1. `__init__.py`
|
||||
```diff
|
||||
- __version__ = "{major}.{minor}.0"
|
||||
+ __version__ = "{major}.{minor+1}.0.dev0"
|
||||
```
|
||||
2. `setup.cfg`
|
||||
```diff
|
||||
- version = {major}.{minor}.0
|
||||
+ version = {major}.{minor+1}.0.dev0
|
||||
```
|
||||
|
||||
3. Commit and push these changes
|
||||
|
||||
```shell
|
||||
git add trl/__init__.py setup.cfg
|
||||
git commit -m '⬆️ Bump dev version'
|
||||
git push origin bump-dev-version-{major}.{minor+1}
|
||||
```
|
||||
|
||||
4. Create a pull request from `bump-dev-version-{major}.{minor+1}` to `main`, named `⬆️ Bump dev version`, and request urgent review.
|
||||
|
||||
5. Once the pull request is approved, merge it into `main`.
|
||||
|
||||
6. The codebase is now ready for the next development cycle, inform the team in the #trl-internal channel.
|
||||
|
2
LICENSE
2
LICENSE
@ -186,7 +186,7 @@
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
Copyright 2020-2025 The HuggingFace Team
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
@ -1,6 +1,6 @@
|
||||
include settings.ini
|
||||
include LICENSE
|
||||
include CONTRIBUTING.md
|
||||
include README.md
|
||||
recursive-exclude * __pycache__
|
||||
include trl/templates/*.md
|
||||
include trl/templates/*.md
|
||||
include trl/accelerate_configs/*.yaml
|
15
Makefile
15
Makefile
@ -5,24 +5,15 @@ check_dirs := examples tests trl
|
||||
ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
|
||||
COMMAND_FILES_PATH = `pwd`/commands
|
||||
|
||||
|
||||
dev:
|
||||
[ -L "$(pwd)/trl/commands/scripts" ] && unlink "$(pwd)/trl/commands/scripts" || true
|
||||
pip install -e ".[dev]"
|
||||
ln -s `pwd`/examples/scripts/ `pwd`/trl/commands
|
||||
|
||||
test:
|
||||
python -m pytest -n auto --dist=loadfile -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' ./tests/
|
||||
pytest -n auto -m "not slow and not low-priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' tests/
|
||||
|
||||
precommit:
|
||||
pre-commit run --all-files
|
||||
python scripts/add_copyrights.py
|
||||
|
||||
tests_gpu:
|
||||
python -m pytest tests/test_* $(if $(IS_GITHUB_CI),--report-log "common_tests.log",)
|
||||
pre-commit run --all-files
|
||||
|
||||
slow_tests:
|
||||
python -m pytest tests/slow/test_* $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
|
||||
pytest -m "slow" tests/ $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
|
||||
|
||||
test_examples:
|
||||
touch temp_results_sft_tests.txt
|
||||
|
162
README.md
162
README.md
@ -1,7 +1,7 @@
|
||||
# TRL - Transformer Reinforcement Learning
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_banner_dark.png" alt="TRL Banner">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png" alt="TRL Banner">
|
||||
</div>
|
||||
|
||||
<hr> <br>
|
||||
@ -12,8 +12,9 @@
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/huggingface/trl/blob/main/LICENSE"><img alt="License" src="https://img.shields.io/github/license/huggingface/trl.svg?color=blue"></a>
|
||||
<a href="https://huggingface.co/docs/trl/index"><img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/trl/index.svg?down_color=red&down_message=offline&up_color=blue&up_message=online"></a>
|
||||
<a href="https://huggingface.co/docs/trl/index"><img alt="Documentation" src="https://img.shields.io/website?label=documentation&url=https%3A%2F%2Fhuggingface.co%2Fdocs%2Ftrl%2Findex&down_color=red&down_message=offline&up_color=blue&up_message=online"></a>
|
||||
<a href="https://github.com/huggingface/trl/releases"><img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/trl.svg"></a>
|
||||
<a href="https://huggingface.co/trl-lib"><img alt="Hugging Face Hub" src="https://img.shields.io/badge/🤗%20Hub-trl--lib-yellow"></a>
|
||||
</p>
|
||||
|
||||
## Overview
|
||||
@ -22,16 +23,14 @@ TRL is a cutting-edge library designed for post-training foundation models using
|
||||
|
||||
## Highlights
|
||||
|
||||
- **Trainers**: Various fine-tuning methods are easily accessible via trainers like [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer) and more.
|
||||
|
||||
- **Efficient and scalable**:
|
||||
- Leverages [🤗 Accelerate](https://github.com/huggingface/accelerate) to scale from single GPU to multi-node clusters using methods like DDP and DeepSpeed.
|
||||
- Full integration with [`PEFT`](https://github.com/huggingface/peft) enables training on large models with modest hardware via quantization and LoRA/QLoRA.
|
||||
- Integrates [Unsloth](https://github.com/unslothai/unsloth) for accelerating training using optimized kernels.
|
||||
- Leverages [🤗 Accelerate](https://github.com/huggingface/accelerate) to scale from single GPU to multi-node clusters using methods like [DDP](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) and [DeepSpeed](https://github.com/deepspeedai/DeepSpeed).
|
||||
- Full integration with [🤗 PEFT](https://github.com/huggingface/peft) enables training on large models with modest hardware via quantization and LoRA/QLoRA.
|
||||
- Integrates [🦥 Unsloth](https://github.com/unslothai/unsloth) for accelerating training using optimized kernels.
|
||||
|
||||
- **Command Line Interface (CLI)**: A simple interface lets you fine-tune and interact with models without needing to write code.
|
||||
|
||||
- **Trainers**: Various fine-tuning methods are easily accessible via trainers like [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`ORPOTrainer`](https://huggingface.co/docs/trl/orpo_trainer) and more.
|
||||
|
||||
- **AutoModels**: Use pre-defined model classes like [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) to simplify reinforcement learning (RL) with LLMs.
|
||||
- **Command Line Interface (CLI)**: A simple interface lets you fine-tune with models without needing to write code.
|
||||
|
||||
## Installation
|
||||
|
||||
@ -59,60 +58,75 @@ If you want to use the examples you can clone the repository with the following
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
```
|
||||
|
||||
## Command Line Interface (CLI)
|
||||
## Quick Start
|
||||
|
||||
You can use the TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT) and Direct Preference Optimization (DPO), or vibe check your model with the chat CLI:
|
||||
|
||||
**SFT:**
|
||||
|
||||
```bash
|
||||
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name trl-lib/Capybara \
|
||||
--output_dir Qwen2.5-0.5B-SFT
|
||||
```
|
||||
|
||||
**DPO:**
|
||||
|
||||
```bash
|
||||
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--dataset_name argilla/Capybara-Preferences \
|
||||
--output_dir Qwen2.5-0.5B-DPO
|
||||
```
|
||||
|
||||
**Chat:**
|
||||
|
||||
```bash
|
||||
trl chat --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct
|
||||
```
|
||||
|
||||
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.
|
||||
|
||||
## How to use
|
||||
|
||||
For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.
|
||||
|
||||
### `SFTTrainer`
|
||||
|
||||
Here is a basic example of how to use the `SFTTrainer`:
|
||||
Here is a basic example of how to use the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer):
|
||||
|
||||
```python
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
from trl import SFTTrainer
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
|
||||
training_args = SFTConfig(output_dir="Qwen/Qwen2.5-0.5B-SFT")
|
||||
trainer = SFTTrainer(
|
||||
args=training_args,
|
||||
model="Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### `GRPOTrainer`
|
||||
|
||||
[`GRPOTrainer`](https://huggingface.co/docs/trl/grpo_trainer) implements the [Group Relative Policy Optimization (GRPO) algorithm](https://huggingface.co/papers/2402.03300) that is more memory-efficient than PPO and was used to train [Deepseek AI's R1](https://huggingface.co/deepseek-ai/DeepSeek-R1).
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOTrainer
|
||||
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
# Dummy reward function: count the number of unique characters in the completions
|
||||
def reward_num_unique_chars(completions, **kwargs):
|
||||
return [len(set(c)) for c in completions]
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=reward_num_unique_chars,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### `DPOTrainer`
|
||||
|
||||
[`DPOTrainer`](https://huggingface.co/docs/trl/dpo_trainer) implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train [Llama 3](https://huggingface.co/papers/2407.21783) and many other models. Here is a basic example of how to use the `DPOTrainer`:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
processing_class=tokenizer
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### `RewardTrainer`
|
||||
|
||||
Here is a basic example of how to use the `RewardTrainer`:
|
||||
Here is a basic example of how to use the [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer):
|
||||
|
||||
```python
|
||||
from trl import RewardConfig, RewardTrainer
|
||||
@ -137,60 +151,28 @@ trainer = RewardTrainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### `RLOOTrainer`
|
||||
## Command Line Interface (CLI)
|
||||
|
||||
`RLOOTrainer` implements a [REINFORCE-style optimization](https://huggingface.co/papers/2402.14740) for RLHF that is more performant and memory-efficient than PPO. Here is a basic example of how to use the `RLOOTrainer`:
|
||||
You can use the TRL Command Line Interface (CLI) to quickly get started with post-training methods like Supervised Fine-Tuning (SFT) or Direct Preference Optimization (DPO):
|
||||
|
||||
```python
|
||||
from trl import RLOOConfig, RLOOTrainer, apply_chat_template
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
)
|
||||
**SFT:**
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
reward_model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
|
||||
)
|
||||
ref_policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
policy = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
|
||||
dataset = load_dataset("trl-lib/ultrafeedback-prompt")
|
||||
dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
|
||||
dataset = dataset.map(lambda x: tokenizer(x["prompt"]), remove_columns="prompt")
|
||||
|
||||
training_args = RLOOConfig(output_dir="Qwen2.5-0.5B-RL")
|
||||
trainer = RLOOTrainer(
|
||||
config=training_args,
|
||||
processing_class=tokenizer,
|
||||
policy=policy,
|
||||
ref_policy=ref_policy,
|
||||
reward_model=reward_model,
|
||||
train_dataset=dataset["train"],
|
||||
eval_dataset=dataset["test"],
|
||||
)
|
||||
trainer.train()
|
||||
```bash
|
||||
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name trl-lib/Capybara \
|
||||
--output_dir Qwen2.5-0.5B-SFT
|
||||
```
|
||||
|
||||
### `DPOTrainer`
|
||||
**DPO:**
|
||||
|
||||
`DPOTrainer` implements the popular [Direct Preference Optimization (DPO) algorithm](https://huggingface.co/papers/2305.18290) that was used to post-train Llama 3 and many other models. Here is a basic example of how to use the `DPOTrainer`:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
||||
trainer = DPOTrainer(model=model, args=training_args, train_dataset=dataset, processing_class=tokenizer)
|
||||
trainer.train()
|
||||
```bash
|
||||
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
--dataset_name argilla/Capybara-Preferences \
|
||||
--output_dir Qwen2.5-0.5B-DPO
|
||||
```
|
||||
|
||||
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.
|
||||
|
||||
## Development
|
||||
|
||||
If you want to contribute to `trl` or customize it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:
|
||||
@ -198,7 +180,7 @@ If you want to contribute to `trl` or customize it to your needs make sure to re
|
||||
```bash
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
cd trl/
|
||||
make dev
|
||||
pip install -e .[dev]
|
||||
```
|
||||
|
||||
## Citation
|
||||
|
@ -2,7 +2,7 @@
|
||||
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
|
||||
# but defaults to QLoRA + PEFT
|
||||
OUTPUT_DIR="test_dpo/"
|
||||
MODEL_NAME="trl-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
DATASET_NAME="trl-internal-testing/hh-rlhf-helpful-base-trl-style"
|
||||
MAX_STEPS=5
|
||||
BATCH_SIZE=2
|
||||
@ -35,7 +35,7 @@ CMD="""
|
||||
accelerate launch $EXTRA_ACCELERATE_ARGS \
|
||||
--num_processes $NUM_GPUS \
|
||||
--mixed_precision 'fp16' \
|
||||
`pwd`/examples/scripts/dpo.py \
|
||||
`pwd`/trl/scripts/dpo.py \
|
||||
--model_name_or_path $MODEL_NAME \
|
||||
--dataset_name $DATASET_NAME \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
|
@ -2,7 +2,7 @@
|
||||
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
|
||||
# but defaults to QLoRA + PEFT
|
||||
OUTPUT_DIR="test_sft/"
|
||||
MODEL_NAME="trl-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
DATASET_NAME="stanfordnlp/imdb"
|
||||
MAX_STEPS=5
|
||||
BATCH_SIZE=2
|
||||
@ -36,13 +36,13 @@ CMD="""
|
||||
accelerate launch $EXTRA_ACCELERATE_ARGS \
|
||||
--num_processes $NUM_GPUS \
|
||||
--mixed_precision 'fp16' \
|
||||
`pwd`/examples/scripts/sft.py \
|
||||
`pwd`/trl/scripts/sft.py \
|
||||
--model_name $MODEL_NAME \
|
||||
--dataset_name $DATASET_NAME \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_steps $MAX_STEPS \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--max_seq_length $SEQ_LEN \
|
||||
--max_length $SEQ_LEN \
|
||||
$EXTRA_TRAINING_ARGS
|
||||
"""
|
||||
|
||||
|
@ -5,21 +5,59 @@
|
||||
title: Installation
|
||||
- local: quickstart
|
||||
title: Quickstart
|
||||
- local: clis
|
||||
title: Get started with Command Line Interfaces (CLIs)
|
||||
title: Getting started
|
||||
- sections:
|
||||
- local: dataset_formats
|
||||
title: Dataset Formats
|
||||
- local: how_to_train
|
||||
title: PPO Training FAQ
|
||||
- local: use_model
|
||||
title: Use Trained Models
|
||||
- local: customization
|
||||
title: Customize the Training
|
||||
title: Training FAQ
|
||||
- local: logging
|
||||
title: Understanding Logs
|
||||
title: Get started
|
||||
title: Conceptual Guides
|
||||
- sections:
|
||||
- sections: # Sort alphabetically
|
||||
- local: clis
|
||||
title: Command Line Interface (CLI)
|
||||
- local: customization
|
||||
title: Customizing the Training
|
||||
- local: reducing_memory_usage
|
||||
title: Reducing Memory Usage
|
||||
- local: speeding_up_training
|
||||
title: Speeding Up Training
|
||||
- local: distributing_training
|
||||
title: Distributing Training
|
||||
- local: use_model
|
||||
title: Using Trained Models
|
||||
title: How-to guides
|
||||
- sections:
|
||||
- local: deepspeed_integration
|
||||
title: DeepSpeed
|
||||
- local: liger_kernel_integration
|
||||
title: Liger Kernel
|
||||
- local: peft_integration
|
||||
title: PEFT
|
||||
- local: unsloth_integration
|
||||
title: Unsloth
|
||||
- local: vllm_integration
|
||||
title: vLLM
|
||||
title: Integrations
|
||||
- sections:
|
||||
- local: example_overview
|
||||
title: Example Overview
|
||||
- local: community_tutorials
|
||||
title: Community Tutorials
|
||||
- local: sentiment_tuning
|
||||
title: Sentiment Tuning
|
||||
- local: using_llama_models
|
||||
title: Training StackLlama
|
||||
- local: detoxifying_a_lm
|
||||
title: Detoxifying a Language Model
|
||||
- local: multi_adapter_rl
|
||||
title: Multi Adapter RLHF
|
||||
- local: training_vlm_sft
|
||||
title: Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset)
|
||||
title: Examples
|
||||
- sections:
|
||||
- sections: # Sorted alphabetically
|
||||
- local: alignprop_trainer
|
||||
title: AlignProp
|
||||
- local: bco_trainer
|
||||
@ -34,6 +72,8 @@
|
||||
title: Online DPO
|
||||
- local: gkd_trainer
|
||||
title: GKD
|
||||
- local: grpo_trainer
|
||||
title: GRPO
|
||||
- local: kto_trainer
|
||||
title: KTO
|
||||
- local: nash_md_trainer
|
||||
@ -42,6 +82,8 @@
|
||||
title: ORPO
|
||||
- local: ppo_trainer
|
||||
title: PPO
|
||||
- local: prm_trainer
|
||||
title: PRM
|
||||
- local: reward_trainer
|
||||
title: Reward
|
||||
- local: rloo_trainer
|
||||
@ -55,6 +97,8 @@
|
||||
title: Trainers
|
||||
- local: models
|
||||
title: Model Classes
|
||||
- local: model_utils
|
||||
title: Model Utilities
|
||||
- local: best_of_n
|
||||
title: Best of N Sampling
|
||||
- local: judges
|
||||
@ -63,22 +107,10 @@
|
||||
title: Callbacks
|
||||
- local: data_utils
|
||||
title: Data Utilities
|
||||
- local: text_environments
|
||||
title: Text Environments
|
||||
- local: rewards
|
||||
title: Reward Functions
|
||||
- local: script_utils
|
||||
title: Script Utilities
|
||||
- local: others
|
||||
title: Others
|
||||
title: API
|
||||
- sections:
|
||||
- local: example_overview
|
||||
title: Example Overview
|
||||
- local: sentiment_tuning
|
||||
title: Sentiment Tuning
|
||||
- local: lora_tuning_peft
|
||||
title: Training with PEFT
|
||||
- local: detoxifying_a_lm
|
||||
title: Detoxifying a Language Model
|
||||
- local: using_llama_models
|
||||
title: Training StackLlama
|
||||
- local: learning_tools
|
||||
title: Learning to Use Tools
|
||||
- local: multi_adapter_rl
|
||||
title: Multi Adapter RLHF
|
||||
title: Examples
|
||||
|
@ -7,7 +7,7 @@
|
||||
If your reward function is differentiable, directly backpropagating gradients from the reward models to the diffusion model is significantly more sample and compute efficient (25x) than doing policy gradient algorithm like DDPO.
|
||||
AlignProp does full backpropagation through time, which allows updating the earlier steps of denoising via reward backpropagation.
|
||||
|
||||
<div style="text-align: center"><img src="https://align-prop.github.io/reward_tuning.png"/></div>
|
||||
<div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/reward_tuning.png"/></div>
|
||||
|
||||
|
||||
## Getting started with `examples/scripts/alignprop.py`
|
||||
@ -16,7 +16,7 @@ The `alignprop.py` script is a working example of using the `AlignProp` trainer
|
||||
|
||||
**Note:** one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1.
|
||||
|
||||
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running
|
||||
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post-finetuning to HuggingFace hub. The following bash command is to be entered to get things running
|
||||
|
||||
```batch
|
||||
python alignprop.py --hf_user_access_token <token>
|
||||
@ -26,7 +26,7 @@ To obtain the documentation of `stable_diffusion_tuning.py`, please run `python
|
||||
|
||||
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
|
||||
|
||||
- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater to 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps)
|
||||
- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater than 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps)
|
||||
- The configurable truncation backprop absolute step (`--alignprop_config.truncated_backprop_timestep=49`) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False
|
||||
|
||||
## Setting up the image logging hook function
|
@ -62,7 +62,7 @@ embedding_model = Accelerator().prepare_model(self.embedding_model)
|
||||
embedding_func = partial(embed_prompt, model=embedding_model)
|
||||
```
|
||||
|
||||
Set `prompt_sample_size` to defined how many prompts are selected to train the UDM classifier and start the training with the provided embedding function:
|
||||
Set `prompt_sample_size` to define how many prompts are selected to train the UDM classifier and start the training with the provided embedding function:
|
||||
|
||||
```py
|
||||
training_args = BCOConfig(
|
||||
@ -97,4 +97,4 @@ To scale how much the auxiliary loss contributes to the total loss, use the hype
|
||||
|
||||
## BCOConfig
|
||||
|
||||
[[autodoc]] BCOConfig
|
||||
[[autodoc]] BCOConfig
|
@ -67,6 +67,6 @@ best_of_n.generate(query_tensors, device=device)
|
||||
|
||||
```
|
||||
|
||||
Furthermore, at the time of initialization you can set the seed to control repeatability of the generation process and the number of samples to generate for each query
|
||||
Furthermore, at the time of initialization you can set the seed to control the repeatability of the generation process and the number of samples to generate for each query
|
||||
|
||||
|
@ -14,4 +14,8 @@
|
||||
|
||||
## LogCompletionsCallback
|
||||
|
||||
[[autodoc]] LogCompletionsCallback
|
||||
[[autodoc]] LogCompletionsCallback
|
||||
|
||||
## MergeModelCallback
|
||||
|
||||
[[autodoc]] MergeModelCallback
|
303
docs/source/clis.md
Normal file
303
docs/source/clis.md
Normal file
@ -0,0 +1,303 @@
|
||||
# Command Line Interfaces (CLIs)
|
||||
|
||||
TRL provides a powerful command-line interface (CLI) to fine-tune large language models (LLMs) using methods like Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and more. The CLI abstracts away much of the boilerplate, letting you launch training jobs quickly and reproducibly.
|
||||
|
||||
Currently supported commands are:
|
||||
|
||||
#### Training Commands
|
||||
|
||||
- `trl dpo`: fine-tune a LLM with DPO
|
||||
- `trl grpo`: fine-tune a LLM with GRPO
|
||||
- `trl kto`: fine-tune a LLM with KTO
|
||||
- `trl sft`: fine-tune a LLM with SFT
|
||||
|
||||
#### Other Commands
|
||||
|
||||
- `trl env`: get the system information
|
||||
- `trl vllm-serve`: serve a model with vLLM
|
||||
|
||||
## Fine-Tuning with the TRL CLI
|
||||
|
||||
### Basic Usage
|
||||
|
||||
You can launch training directly from the CLI by specifying required arguments like the model and dataset:
|
||||
|
||||
<hfoptions id="command_line">
|
||||
<hfoption id="SFT">
|
||||
|
||||
```bash
|
||||
trl sft \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name stanfordnlp/imdb
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPO">
|
||||
|
||||
```bash
|
||||
trl dpo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name anthropic/hh-rlhf
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Using Configuration Files
|
||||
|
||||
To keep your CLI commands clean and reproducible, you can define all training arguments in a YAML configuration file:
|
||||
|
||||
<hfoptions id="config_file">
|
||||
<hfoption id="SFT">
|
||||
|
||||
```yaml
|
||||
# sft_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: stanfordnlp/imdb
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl sft --config sft_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPO">
|
||||
|
||||
```yaml
|
||||
# dpo_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: anthropic/hh-rlhf
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl dpo --config dpo_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Scaling Up with Accelerate
|
||||
|
||||
TRL CLI natively supports [🤗 Accelerate](https://huggingface.co/docs/accelerate), making it easy to scale training across multiple GPUs, machines, or use advanced setups like DeepSpeed — all from the same CLI.
|
||||
|
||||
You can pass any `accelerate launch` arguments directly to `trl`, such as `--num_processes`. For more information see [Using accelerate launch](https://huggingface.co/docs/accelerate/en/basic_tutorials/launch#using-accelerate-launch).
|
||||
|
||||
<hfoptions id="launch_args">
|
||||
<hfoption id="SFT inline">
|
||||
|
||||
```bash
|
||||
trl sft \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name stanfordnlp/imdb \
|
||||
--num_processes 4
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="SFT w/ config file">
|
||||
|
||||
```yaml
|
||||
# sft_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: stanfordnlp/imdb
|
||||
num_processes: 4
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl sft --config sft_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPO inline">
|
||||
|
||||
```bash
|
||||
trl dpo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name anthropic/hh-rlhf \
|
||||
--num_processes 4
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPO w/ config file">
|
||||
|
||||
```yaml
|
||||
# dpo_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: anthropic/hh-rlhf
|
||||
num_processes: 4
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl dpo --config dpo_config.yaml
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Using `--accelerate_config` for Accelerate Configuration
|
||||
|
||||
The `--accelerate_config` flag lets you easily configure distributed training with [🤗 Accelerate](https://github.com/huggingface/accelerate). This flag accepts either:
|
||||
|
||||
* the name of a predefined config profile (built into TRL), or
|
||||
* a path to a custom Accelerate YAML config file.
|
||||
|
||||
#### Predefined Config Profiles
|
||||
|
||||
TRL provides several ready-to-use Accelerate configs to simplify common training setups:
|
||||
|
||||
| Name | Description |
|
||||
| ------------ | ----------------------------------- |
|
||||
| `fsdp1` | Fully Sharded Data Parallel Stage 1 |
|
||||
| `fsdp2` | Fully Sharded Data Parallel Stage 2 |
|
||||
| `zero1` | DeepSpeed ZeRO Stage 1 |
|
||||
| `zero2` | DeepSpeed ZeRO Stage 2 |
|
||||
| `zero3` | DeepSpeed ZeRO Stage 3 |
|
||||
| `multi_gpu` | Multi-GPU training |
|
||||
| `single_gpu` | Single-GPU training |
|
||||
|
||||
To use one of these, just pass the name to `--accelerate_config`. TRL will automatically load the corresponding config file from `trl/accelerate_config/`.
|
||||
|
||||
#### Example Usage
|
||||
|
||||
<hfoptions id="accelerate_config">
|
||||
<hfoption id="SFT inline">
|
||||
|
||||
```bash
|
||||
trl sft \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name stanfordnlp/imdb \
|
||||
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="SFT w/ config file">
|
||||
|
||||
```yaml
|
||||
# sft_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: stanfordnlp/imdb
|
||||
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl sft --config sft_config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPO inline">
|
||||
|
||||
```bash
|
||||
trl dpo \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name anthropic/hh-rlhf \
|
||||
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="DPO w/ config file">
|
||||
|
||||
```yaml
|
||||
# dpo_config.yaml
|
||||
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||
dataset_name: anthropic/hh-rlhf
|
||||
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
|
||||
```
|
||||
|
||||
Launch with:
|
||||
|
||||
```bash
|
||||
trl dpo --config dpo_config.yaml
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Chat Interface
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
The chat interface is deprecated and will be removed in TRL 0.19. Use `transformers-cli chat` instead. For more information, see the [Transformers documentation, chat with text generation models](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
</Tip>
|
||||
|
||||
The chat CLI lets you quickly load the model and talk to it. Simply run the following:
|
||||
|
||||
<pre><code>$ trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
<strong><span style="color: blue;"><Qwen/Qwen1.5-0.5B-Chat>:</span></strong>
|
||||
There isn't a "best" programming language, as everyone has different style preferences, needs, and preferences. However, some people commonly use
|
||||
languages like Python, Java, C++, and JavaScript, which are popular among developers for a variety of reasons, including readability, flexibility,
|
||||
and scalability. Ultimately, it depends on personal preference, needs, and goals.
|
||||
</code></pre>
|
||||
|
||||
Note that the chat interface relies on the tokenizer's [chat template](https://huggingface.co/docs/transformers/chat_templating) to format the inputs for the model. Make sure your tokenizer has a chat template defined.
|
||||
|
||||
Besides talking to the model there are a few commands you can use:
|
||||
|
||||
- `clear`: clears the current conversation and start a new one
|
||||
- `example {NAME}`: load example named `{NAME}` from the config and use it as the user input
|
||||
- `set {SETTING_NAME}={SETTING_VALUE};`: change the system prompt or generation settings (multiple settings are separated by a `;`).
|
||||
- `reset`: same as clear but also resets the generation configs to defaults if they have been changed by `set`
|
||||
- `save` or `save {SAVE_NAME}`: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
|
||||
- `exit`: closes the interface
|
||||
|
||||
## Getting the System Information
|
||||
|
||||
You can get the system information by running the following command:
|
||||
|
||||
```bash
|
||||
trl env
|
||||
```
|
||||
|
||||
This will print out the system information, including the GPU information, the CUDA version, the PyTorch version, the transformers version, the TRL version, and any optional dependencies that are installed.
|
||||
|
||||
```txt
|
||||
Copy-paste the following information when reporting an issue:
|
||||
|
||||
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
|
||||
- Python version: 3.11.9
|
||||
- PyTorch version: 2.4.1
|
||||
- accelerator(s): NVIDIA H100 80GB HBM3
|
||||
- Transformers version: 4.45.0.dev0
|
||||
- Accelerate version: 0.34.2
|
||||
- Accelerate config:
|
||||
- compute_environment: LOCAL_MACHINE
|
||||
- distributed_type: DEEPSPEED
|
||||
- mixed_precision: no
|
||||
- use_cpu: False
|
||||
- debug: False
|
||||
- num_processes: 4
|
||||
- machine_rank: 0
|
||||
- num_machines: 1
|
||||
- rdzv_backend: static
|
||||
- same_network: True
|
||||
- main_training_function: main
|
||||
- enable_cpu_affinity: False
|
||||
- deepspeed_config: {'gradient_accumulation_steps': 4, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': False, 'zero_stage': 2}
|
||||
- downcast_bf16: no
|
||||
- tpu_use_cluster: False
|
||||
- tpu_use_sudo: False
|
||||
- tpu_env: []
|
||||
- Datasets version: 3.0.0
|
||||
- HF Hub version: 0.24.7
|
||||
- TRL version: 0.12.0.dev0+acb4d70
|
||||
- bitsandbytes version: 0.41.1
|
||||
- DeepSpeed version: 0.15.1
|
||||
- Diffusers version: 0.30.3
|
||||
- Liger-Kernel version: 0.3.0
|
||||
- LLM-Blender version: 0.0.2
|
||||
- OpenAI version: 1.46.0
|
||||
- PEFT version: 0.12.0
|
||||
- vLLM version: not installed
|
||||
```
|
||||
|
||||
This information is required when reporting an issue.
|
@ -1,171 +0,0 @@
|
||||
# Command Line Interfaces (CLIs)
|
||||
|
||||
You can use TRL to fine-tune your Language Model with Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) or even chat with your model using the TRL CLIs.
|
||||
|
||||
Currently supported CLIs are:
|
||||
|
||||
- `trl sft`: fine-tune a LLM on a text/instruction dataset
|
||||
- `trl dpo`: fine-tune a LLM with DPO on a preference dataset
|
||||
- `trl chat`: quickly spin up a LLM fine-tuned for chatting
|
||||
- `trl env`: get the system information
|
||||
|
||||
## Fine-tuning with the CLI
|
||||
|
||||
Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task.
|
||||
|
||||
Before using the `sft` or `dpo` commands make sure to run:
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command.
|
||||
|
||||
We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command.
|
||||
|
||||
```yaml
|
||||
model_name_or_path:
|
||||
trl-internal-testing/tiny-random-LlamaForCausalLM
|
||||
dataset_name:
|
||||
stanfordnlp/imdb
|
||||
report_to:
|
||||
none
|
||||
learning_rate:
|
||||
0.0001
|
||||
lr_scheduler_type:
|
||||
cosine
|
||||
```
|
||||
|
||||
Save that config in a `.yaml` and get started immediately! An example CLI config is available as `examples/cli_configs/example_config.yaml`. Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g. from the root folder:
|
||||
|
||||
```bash
|
||||
trl sft --config examples/cli_configs/example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts
|
||||
```
|
||||
|
||||
Will force-use `cosine_with_restarts` for `lr_scheduler_type`.
|
||||
|
||||
### Supported Arguments
|
||||
|
||||
We do support all arguments from `transformers.TrainingArguments`, for loading your model, we support all arguments from `~trl.ModelConfig`:
|
||||
|
||||
[[autodoc]] ModelConfig
|
||||
|
||||
You can pass any of these arguments either to the CLI or the YAML file.
|
||||
|
||||
### Supervised Fine-tuning (SFT)
|
||||
|
||||
Follow the basic instructions above and run `trl sft --output_dir <output_dir> <*args>`:
|
||||
|
||||
```bash
|
||||
trl sft --model_name_or_path facebook/opt-125m --dataset_name stanfordnlp/imdb --output_dir opt-sft-imdb
|
||||
```
|
||||
|
||||
The SFT CLI is based on the `examples/scripts/sft.py` script.
|
||||
|
||||
### Direct Policy Optimization (DPO)
|
||||
|
||||
To use the DPO CLI, you need to have a dataset in the TRL format such as
|
||||
|
||||
* TRL's Anthropic HH dataset: https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-helpful-base-trl-style
|
||||
* TRL's OpenAI TL;DR summarization dataset: https://huggingface.co/datasets/trl-internal-testing/tldr-preference-trl-style
|
||||
|
||||
These datasets always have at least three columns `prompt, chosen, rejected`:
|
||||
|
||||
* `prompt` is a list of strings.
|
||||
* `chosen` is the chosen response in [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
|
||||
* `rejected` is the rejected response [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
|
||||
|
||||
|
||||
To do a quick start, you can run the following command:
|
||||
|
||||
```bash
|
||||
trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style
|
||||
```
|
||||
|
||||
|
||||
The DPO CLI is based on the `examples/scripts/dpo.py` script.
|
||||
|
||||
|
||||
#### Custom preference dataset
|
||||
|
||||
Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`):
|
||||
|
||||
```bash
|
||||
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
|
||||
```
|
||||
|
||||
## Chat interface
|
||||
|
||||
The chat CLI lets you quickly load the model and talk to it. Simply run the following:
|
||||
|
||||
<pre><code>$ trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
<strong><span style="color: blue;"><Qwen/Qwen1.5-0.5B-Chat>:</span></strong>
|
||||
There isn't a "best" programming language, as everyone has different style preferences, needs, and preferences. However, some people commonly use
|
||||
languages like Python, Java, C++, and JavaScript, which are popular among developers for a variety of reasons, including readability, flexibility,
|
||||
and scalability. Ultimately, it depends on personal preference, needs, and goals.
|
||||
</code></pre>
|
||||
|
||||
Note that the chat interface relies on the tokenizer's [chat template](https://huggingface.co/docs/transformers/chat_templating) to format the inputs for the model. Make sure your tokenizer has a chat template defined.
|
||||
|
||||
Besides talking to the model there are a few commands you can use:
|
||||
|
||||
- `clear`: clears the current conversation and start a new one
|
||||
- `example {NAME}`: load example named `{NAME}` from the config and use it as the user input
|
||||
- `set {SETTING_NAME}={SETTING_VALUE};`: change the system prompt or generation settings (multiple settings are separated by a `;`).
|
||||
- `reset`: same as clear but also resets the generation configs to defaults if they have been changed by `set`
|
||||
- `save` or `save {SAVE_NAME}`: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
|
||||
- `exit`: closes the interface
|
||||
|
||||
The default examples are defined in `examples/scripts/config/default_chat_config.yaml` but you can pass your own with `--config CONFIG_FILE` where you can also specify the default generation parameters.
|
||||
|
||||
## Getting the system information
|
||||
|
||||
You can get the system information by running the following command:
|
||||
|
||||
```bash
|
||||
trl env
|
||||
```
|
||||
|
||||
This will print out the system information including the GPU information, the CUDA version, the PyTorch version, the transformers version, and the TRL version, and any optional dependencies that are installed.
|
||||
|
||||
```txt
|
||||
Copy-paste the following information when reporting an issue:
|
||||
|
||||
- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
|
||||
- Python version: 3.11.9
|
||||
- PyTorch version: 2.4.1
|
||||
- CUDA device: NVIDIA H100 80GB HBM3
|
||||
- Transformers version: 4.45.0.dev0
|
||||
- Accelerate version: 0.34.2
|
||||
- Accelerate config:
|
||||
- compute_environment: LOCAL_MACHINE
|
||||
- distributed_type: DEEPSPEED
|
||||
- mixed_precision: no
|
||||
- use_cpu: False
|
||||
- debug: False
|
||||
- num_processes: 4
|
||||
- machine_rank: 0
|
||||
- num_machines: 1
|
||||
- rdzv_backend: static
|
||||
- same_network: True
|
||||
- main_training_function: main
|
||||
- enable_cpu_affinity: False
|
||||
- deepspeed_config: {'gradient_accumulation_steps': 4, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': False, 'zero_stage': 2}
|
||||
- downcast_bf16: no
|
||||
- tpu_use_cluster: False
|
||||
- tpu_use_sudo: False
|
||||
- tpu_env: []
|
||||
- Datasets version: 3.0.0
|
||||
- HF Hub version: 0.24.7
|
||||
- TRL version: 0.12.0.dev0+acb4d70
|
||||
- bitsandbytes version: 0.41.1
|
||||
- DeepSpeed version: 0.15.1
|
||||
- Diffusers version: 0.30.3
|
||||
- Liger-Kernel version: 0.3.0
|
||||
- LLM-Blender version: 0.0.2
|
||||
- OpenAI version: 1.46.0
|
||||
- PEFT version: 0.12.0
|
||||
```
|
||||
|
||||
This information are required when reporting an issue.
|
31
docs/source/community_tutorials.md
Normal file
31
docs/source/community_tutorials.md
Normal file
@ -0,0 +1,31 @@
|
||||
# Community Tutorials
|
||||
|
||||
Community tutorials are made by active members of the Hugging Face community who want to share their knowledge and expertise with others. They are a great way to learn about the library and its features, and to get started with core classes and modalities.
|
||||
|
||||
# Language Models
|
||||
|
||||
| Task | Class | Description | Author | Tutorial | Colab |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| Reinforcement Learning | [`GRPOTrainer`] | Post training an LLM for reasoning with GRPO in TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb) |
|
||||
| Reinforcement Learning | [`GRPOTrainer`] | Mini-R1: Reproduce Deepseek R1 „aha moment“ a RL tutorial | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/mini-deepseek-r1) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/mini-deepseek-r1-aha-grpo.ipynb) |
|
||||
| Instruction tuning | [`SFTTrainer`] | Fine-tuning Google Gemma LLMs using ChatML format with QLoRA | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-google-gemma) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/gemma-lora-example.ipynb) |
|
||||
| Structured Generation | [`SFTTrainer`] | Fine-tuning Llama-2-7B to generate Persian product catalogs in JSON using QLoRA and PEFT | [Mohammadreza Esmaeilian](https://huggingface.co/Mohammadreza) | [Link](https://huggingface.co/learn/cookbook/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_to_generate_persian_product_catalogs_in_json_format.ipynb) |
|
||||
| Preference Optimization | [`DPOTrainer`] | Align Mistral-7b using Direct Preference Optimization for human preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/Fine_tune_Mistral_7b_with_DPO.html) | [](https://colab.research.google.com/github/mlabonne/llm-course/blob/main/Fine_tune_a_Mistral_7b_model_with_DPO.ipynb) |
|
||||
| Preference Optimization | [`ORPOTrainer`] | Fine-tuning Llama 3 with ORPO combining instruction tuning and preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html) | [](https://colab.research.google.com/drive/1eHNWg9gnaXErdAa8_mcvjMupbSS6rDvi) |
|
||||
| Instruction tuning | [`SFTTrainer`] | How to fine-tune open LLMs in 2025 with Hugging Face | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-llms-in-2025) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-llms-in-2025.ipynb) |
|
||||
|
||||
<Youtube id="cnGyyM0vOes" />
|
||||
|
||||
# Vision Language Models
|
||||
|
||||
| Task | Class | Description | Author | Tutorial | Colab |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| Visual QA | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for visual question answering on ChartQA dataset | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_trl.ipynb) |
|
||||
| Visual QA | [`SFTTrainer`] | Fine-tuning SmolVLM with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_smol_vlm_sft_trl) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_smol_vlm_sft_trl.ipynb) |
|
||||
| SEO Description | [`SFTTrainer`] | Fine-tuning Qwen2-VL-7B for generating SEO-friendly descriptions from images | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-multimodal-llms-with-trl) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-multimodal-llms-with-trl.ipynb) |
|
||||
| Visual QA | [`DPOTrainer`] | PaliGemma 🤝 Direct Preference Optimization | [Merve Noyan](https://huggingface.co/merve) | [Link](https://github.com/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) | [](https://colab.research.google.com/github/merveenoyan/smol-vision/blob/main/PaliGemma_DPO.ipynb) |
|
||||
| Visual QA | [`DPOTrainer`] | Fine-tuning SmolVLM using direct preference optimization (DPO) with TRL on a consumer GPU | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_vlm_dpo_smolvlm_instruct) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_vlm_dpo_smolvlm_instruct.ipynb) |
|
||||
|
||||
## Contributing
|
||||
|
||||
If you have a tutorial that you would like to add to this list, please open a PR to add it. We will review it and merge it if it is relevant to the community.
|
@ -4,7 +4,7 @@
|
||||
|
||||
## Overview
|
||||
|
||||
Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by [Haoran Xu](https://huggingface.co/haoranxu), [Amr Sharaf](https://huggingface.co/amrsharaf), [Yunmo Chen](https://huggingface.co/yunmochen), Weiting Tan, Lingfeng Shen, Benjamin Van Durme, [Kenton Murray](https://huggingface.co/Kenton), and [Young Jin Kim](https://huggingface.co/ykim362). At a high-level, CPO trains models to avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation to the DPO loss and can be applied to other domains like chat.
|
||||
Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by [Haoran Xu](https://huggingface.co/haoranxu), [Amr Sharaf](https://huggingface.co/amrsharaf), [Yunmo Chen](https://huggingface.co/yunmochen), Weiting Tan, Lingfeng Shen, Benjamin Van Durme, [Kenton Murray](https://huggingface.co/Kenton), and [Young Jin Kim](https://huggingface.co/ykim362). At a high-level, CPO trains models to avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation of the DPO loss and can be applied to other domains, such as chat.
|
||||
|
||||
CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.
|
||||
|
||||
@ -75,7 +75,7 @@ While training and evaluating we record the following reward metrics:
|
||||
|
||||
### Simple Preference Optimization (SimPO)
|
||||
|
||||
The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0` in the [`CPOConfig`].
|
||||
The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`].
|
||||
|
||||
### CPO-SimPO
|
||||
|
||||
@ -105,4 +105,4 @@ To scale how much the auxiliary loss contributes to the total loss, use the hype
|
||||
|
||||
## CPOConfig
|
||||
|
||||
[[autodoc]] CPOConfig
|
||||
[[autodoc]] CPOConfig
|
@ -2,48 +2,6 @@
|
||||
|
||||
TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. Note: Although these examples use the DPOTrainer, the customization applies to most (if not all) trainers.
|
||||
|
||||
## Train on multiple GPUs / nodes
|
||||
|
||||
The trainers in TRL use 🤗 Accelerate to enable distributed training across multiple GPUs or nodes. To do so, first create an 🤗 Accelerate config file by running
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
and answering the questions according to your multi-gpu / multi-node setup. You can then launch distributed training by running:
|
||||
|
||||
```bash
|
||||
accelerate launch your_script.py
|
||||
```
|
||||
|
||||
We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.:
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
|
||||
```
|
||||
|
||||
Refer to the [examples page](https://github.com/huggingface/trl/tree/main/examples) for more details.
|
||||
|
||||
### Distributed training with DeepSpeed
|
||||
|
||||
All of the trainers in TRL can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run:
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_your_script.py --all_arguments_of_the_script
|
||||
```
|
||||
|
||||
Note that for ZeRO-3, a small tweak is needed to initialize your reward model on the correct device via the `zero3_init_context_manager()` context manager. In particular, this is needed to avoid DeepSpeed hanging after a fixed number of training steps. Here is a snippet of what is involved from the [`sentiment_tuning`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) example:
|
||||
|
||||
```python
|
||||
ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin
|
||||
if ds_plugin is not None and ds_plugin.is_zero3_init_enabled():
|
||||
with ds_plugin.zero3_init_context_manager(enable=False):
|
||||
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
|
||||
else:
|
||||
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
|
||||
```
|
||||
|
||||
Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin.
|
||||
|
||||
|
||||
## Use different optimizers and schedulers
|
||||
@ -134,7 +92,7 @@ Read more about 8-bit model loading in `transformers` [here](https://huggingface
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
45
docs/source/data_utils.md
Normal file
45
docs/source/data_utils.md
Normal file
@ -0,0 +1,45 @@
|
||||
# Data Utilities
|
||||
|
||||
## is_conversational
|
||||
|
||||
[[autodoc]] is_conversational
|
||||
|
||||
## apply_chat_template
|
||||
|
||||
[[autodoc]] apply_chat_template
|
||||
|
||||
## maybe_apply_chat_template
|
||||
|
||||
[[autodoc]] maybe_apply_chat_template
|
||||
|
||||
## maybe_convert_to_chatml
|
||||
|
||||
[[autodoc]] maybe_convert_to_chatml
|
||||
|
||||
## extract_prompt
|
||||
|
||||
[[autodoc]] extract_prompt
|
||||
|
||||
## maybe_extract_prompt
|
||||
|
||||
[[autodoc]] maybe_extract_prompt
|
||||
|
||||
## unpair_preference_dataset
|
||||
|
||||
[[autodoc]] unpair_preference_dataset
|
||||
|
||||
## maybe_unpair_preference_dataset
|
||||
|
||||
[[autodoc]] maybe_unpair_preference_dataset
|
||||
|
||||
## pack_examples
|
||||
|
||||
[[autodoc]] pack_examples
|
||||
|
||||
## pack_dataset
|
||||
|
||||
[[autodoc]] pack_dataset
|
||||
|
||||
## truncate_dataset
|
||||
|
||||
[[autodoc]] truncate_dataset
|
@ -1,15 +0,0 @@
|
||||
## Data Utilities
|
||||
|
||||
[[autodoc]] is_conversational
|
||||
|
||||
[[autodoc]] apply_chat_template
|
||||
|
||||
[[autodoc]] maybe_apply_chat_template
|
||||
|
||||
[[autodoc]] extract_prompt
|
||||
|
||||
[[autodoc]] maybe_extract_prompt
|
||||
|
||||
[[autodoc]] unpair_preference_dataset
|
||||
|
||||
[[autodoc]] maybe_unpair_preference_dataset
|
@ -77,6 +77,18 @@ This guide provides an overview of the dataset formats and types supported by ea
|
||||
"label": False}</code></pre>
|
||||
</td>
|
||||
</tr>
|
||||
</tr>
|
||||
<td>Stepwise supervision</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": "Which number is larger, 9.8 or 9.11?",
|
||||
"completions": ["The fractional part of 9.8 is 0.8.",
|
||||
"The fractional part of 9.11 is 0.11.",
|
||||
"0.11 is greater than 0.8.",
|
||||
"Hence, 9.11 > 9.8."],
|
||||
"labels": [True, True, False, False]}</code></pre>
|
||||
</td>
|
||||
<td></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### Formats
|
||||
@ -87,9 +99,11 @@ The standard dataset format typically consists of plain text strings. The column
|
||||
|
||||
```python
|
||||
# Language modeling
|
||||
example = {"text": "The sky is blue."}
|
||||
language_modeling_example = {"text": "The sky is blue."}
|
||||
# Preference
|
||||
example = {"chosen": "The sky is blue.", "rejected": "The sky is green."}
|
||||
preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}
|
||||
# Unpaired preference
|
||||
unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True}
|
||||
```
|
||||
|
||||
#### Conversational
|
||||
@ -104,18 +118,17 @@ messages = [
|
||||
]
|
||||
```
|
||||
|
||||
Just like standard datasets, the columns in conversational datasets vary depending on the task. For instance, a preference dataset would include columns like `"chosen"` and `"rejected"` to compare responses:
|
||||
Just like standard datasets, the columns in conversational datasets vary depending on the task. Below are examples of conversational dataset formats for different tasks:
|
||||
|
||||
```python
|
||||
example = {
|
||||
"chosen": [
|
||||
{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."},
|
||||
],
|
||||
"rejected": [
|
||||
{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is green."},
|
||||
],
|
||||
# Prompt-completion
|
||||
prompt_completion_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"completion": [{"role": "assistant", "content": "It is blue."}]}
|
||||
# Preference
|
||||
preference_example = {
|
||||
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"chosen": [{"role": "assistant", "content": "It is blue."}],
|
||||
"rejected": [{"role": "assistant", "content": "It is green."}],
|
||||
}
|
||||
```
|
||||
|
||||
@ -128,7 +141,13 @@ Conversational datasets are useful for training chat models, but must be convert
|
||||
A language modeling dataset consists of a column `"text"` (or `"messages"` for conversational datasets) containing a full sequence of text.
|
||||
|
||||
```python
|
||||
# Standard format
|
||||
language_modeling_example = {"text": "The sky is blue."}
|
||||
# Conversational format
|
||||
language_modeling_example = {"messages": [
|
||||
{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."}
|
||||
]}
|
||||
```
|
||||
|
||||
#### Prompt-only
|
||||
@ -136,9 +155,14 @@ language_modeling_example = {"text": "The sky is blue."}
|
||||
In a prompt-only dataset, only the initial prompt (the question or partial sentence) is provided under the key `"prompt"`. The training typically involves generating the completion based on this prompt, where the model learns to continue or complete the given input.
|
||||
|
||||
```python
|
||||
# Standard format
|
||||
prompt_only_example = {"prompt": "The sky is"}
|
||||
# Conversational format
|
||||
prompt_only_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}]}
|
||||
```
|
||||
|
||||
For examples of prompt-only datasets, refer to the [Prompt-only datasets collection](https://huggingface.co/collections/trl-lib/prompt-only-datasets-677ea25245d20252cea00368).
|
||||
|
||||
<Tip>
|
||||
|
||||
While both the prompt-only and language modeling types are similar, they differ in how the input is handled. In the prompt-only type, the prompt represents a partial input that expects the model to complete or continue, while in the language modeling type, the input is treated as a complete sentence or sequence. These two types are processed differently by TRL. Below is an example showing the difference in the output of the `apply_chat_template` function for each type:
|
||||
@ -170,21 +194,41 @@ apply_chat_template(lm_example, tokenizer)
|
||||
A prompt-completion dataset includes a `"prompt"` and a `"completion"`.
|
||||
|
||||
```python
|
||||
# Standard format
|
||||
prompt_completion_example = {"prompt": "The sky is", "completion": " blue."}
|
||||
# Conversational format
|
||||
prompt_completion_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"completion": [{"role": "assistant", "content": "It is blue."}]}
|
||||
```
|
||||
|
||||
For examples of prompt-completion datasets, refer to the [Prompt-completion datasets collection](https://huggingface.co/collections/trl-lib/prompt-completion-datasets-677ea2bb20bbb6bdccada216).
|
||||
|
||||
#### Preference
|
||||
|
||||
A preference dataset is used for tasks where the model is trained to choose between two or more possible completions to the same prompt. This dataset includes a `"prompt"`, a `"chosen"` completion, and a `"rejected"` completion. The model is trained to select the `"chosen"` response over the `"rejected"` response.
|
||||
Some dataset may not include the `"prompt"` column, in which case the prompt is implicit and directly included in the `"chosen"` and `"rejected"` completions. We recommend using explicit prompts whenever possible.
|
||||
|
||||
```python
|
||||
# explicit prompt
|
||||
preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."} # recommended
|
||||
# implicit prompt
|
||||
# Standard format
|
||||
## Explicit prompt (recommended)
|
||||
preference_example = {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}
|
||||
# Implicit prompt
|
||||
preference_example = {"chosen": "The sky is blue.", "rejected": "The sky is green."}
|
||||
|
||||
# Conversational format
|
||||
## Explicit prompt (recommended)
|
||||
preference_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"chosen": [{"role": "assistant", "content": "It is blue."}],
|
||||
"rejected": [{"role": "assistant", "content": "It is green."}]}
|
||||
## Implicit prompt
|
||||
preference_example = {"chosen": [{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is blue."}],
|
||||
"rejected": [{"role": "user", "content": "What color is the sky?"},
|
||||
{"role": "assistant", "content": "It is green."}]}
|
||||
```
|
||||
|
||||
For examples of preference datasets, refer to the [Preference datasets collection](https://huggingface.co/collections/trl-lib/preference-datasets-677e99b581018fcad9abd82c).
|
||||
|
||||
Some preference datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo). You can also explore the [librarian-bots' DPO Collections](https://huggingface.co/collections/librarian-bots/direct-preference-optimization-datasets-66964b12835f46289b6ef2fc) to identify preference datasets.
|
||||
|
||||
#### Unpaired preference
|
||||
@ -192,9 +236,30 @@ Some preference datasets can be found with [the tag `dpo` on Hugging Face Hub](h
|
||||
An unpaired preference dataset is similar to a preference dataset but instead of having `"chosen"` and `"rejected"` completions for the same prompt, it includes a single `"completion"` and a `"label"` indicating whether the completion is preferred or not.
|
||||
|
||||
```python
|
||||
# Standard format
|
||||
unpaired_preference_example = {"prompt": "The sky is", "completion": " blue.", "label": True}
|
||||
# Conversational format
|
||||
unpaired_preference_example = {"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
"completion": [{"role": "assistant", "content": "It is blue."}],
|
||||
"label": True}
|
||||
```
|
||||
|
||||
For examples of unpaired preference datasets, refer to the [Unpaired preference datasets collection](https://huggingface.co/collections/trl-lib/unpaired-preference-datasets-677ea22bf5f528c125b0bcdf).
|
||||
|
||||
#### Stepwise supervision
|
||||
|
||||
A stepwise (or process) supervision dataset is similar to an [unpaired preference](#unpaired-preference) dataset but includes multiple steps of completions, each with its own label. This structure is useful for tasks that need detailed, step-by-step labeling, such as reasoning tasks. By evaluating each step separately and providing targeted labels, this approach helps identify precisely where the reasoning is correct and where errors occur, allowing for targeted feedback on each part of the reasoning process.
|
||||
|
||||
```python
|
||||
stepwise_example = {
|
||||
"prompt": "Which number is larger, 9.8 or 9.11?",
|
||||
"completions": ["The fractional part of 9.8 is 0.8, while the fractional part of 9.11 is 0.11.", "Since 0.11 is greater than 0.8, the number 9.11 is larger than 9.8."],
|
||||
"labels": [True, False]
|
||||
}
|
||||
```
|
||||
|
||||
For examples of stepwise supervision datasets, refer to the [Stepwise supervision datasets collection](https://huggingface.co/collections/trl-lib/stepwise-supervision-datasets-677ea27fd4c5941beed7a96e).
|
||||
|
||||
## Which dataset type to use?
|
||||
|
||||
Choosing the right dataset type depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset types supported by each TRL trainer.
|
||||
@ -205,14 +270,16 @@ Choosing the right dataset type depends on the task you are working on and the s
|
||||
| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) |
|
||||
| [`GRPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`IterativeSFTTrainer`] | [Unpaired preference](#unpaired-preference) |
|
||||
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`PPOTrainer`] | Tokenized language modeling |
|
||||
| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
|
||||
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
|
||||
| [`SFTTrainer`] | [Language modeling](#language-modeling) |
|
||||
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
|
||||
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
|
||||
<Tip>
|
||||
@ -224,12 +291,12 @@ For more information on how to work with conversational datasets, refer to the [
|
||||
|
||||
## Working with conversational datasets in TRL
|
||||
|
||||
Conversational datasets are increasingly common, especially for training chat models. However, TRL trainers (except [`SFTTrainer`]) don't support conversational datasets in their raw format. These datasets must first be converted into a standard format.
|
||||
Conversational datasets are increasingly common, especially for training chat models. However, some TRL trainers don't support conversational datasets in their raw format. (For more information, see [issue #2071](https://github.com/huggingface/trl/issues/2071).) These datasets must first be converted into a standard format.
|
||||
Fortunately, TRL offers tools to easily handle this conversion, which are detailed below.
|
||||
|
||||
### Converting a conversational dataset into a standard dataset
|
||||
|
||||
TRL trainers do not support conversational datasets in their raw format. To use them, you need to convert them into a standard dataset format using a chat template. This template is provided by the tokenizer of the model you use.
|
||||
To convert a conversational dataset into a standard dataset, you need to _apply a chat template_ to the dataset. A chat template is a predefined structure that typically includes placeholders for user and assistant messages. This template is provided by the tokenizer of the model you use.
|
||||
|
||||
For detailed instructions on using chat templating, refer to the [Chat templating section in the `transformers` documentation](https://huggingface.co/docs/transformers/en/chat_templating).
|
||||
|
||||
@ -274,7 +341,7 @@ dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle conversation.
|
||||
We recommend using the [`apply_chat_template`] function instead of calling `tokenizer.apply_chat_template` directly. Handling chat templates for non-language modeling datasets can be tricky and may result in errors, such as mistakenly placing a system prompt in the middle of a conversation.
|
||||
For additional examples, see [#1930 (comment)](https://github.com/huggingface/trl/pull/1930#issuecomment-2292908614). The [`apply_chat_template`] is designed to handle these intricacies and ensure the correct application of chat templates for various tasks.
|
||||
|
||||
</Tip>
|
||||
@ -338,14 +405,15 @@ This section provides example code to help you convert between different dataset
|
||||
|
||||
For simplicity, some of the examples below do not follow this recommendation and use the standard format. However, the conversions can be applied directly to the conversational format without modification.
|
||||
|
||||
| From \ To | Language modeling | Prompt-completion | Prompt-only | Preference with implicit prompt | Preference | Unpaired preference |
|
||||
| ------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------- | --------------------------------------------------------- | --------------------------------------------------------- | ------------------------------------------------------------------------- |
|
||||
| Language modeling | N/A | N/A | N/A | N/A | N/A | N/A |
|
||||
| Prompt-completion | [🔗](#from-prompt-completion-to-language-modeling-dataset) | N/A | [🔗](#from-prompt-completion-to-prompt-only-dataset) | N/A | N/A | N/A |
|
||||
| Prompt-only | N/A | N/A | N/A | N/A | N/A | N/A |
|
||||
| Preference with implicit prompt | [🔗](#from-preference-with-implicit-prompt-to-language-modeling-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-completion-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-only-dataset) | N/A | [🔗](#from-implicit-to-explicit-prompt-preference-dataset) | [🔗](#from-preference-with-implicit-prompt-to-unpaired-preference-dataset) |
|
||||
| Preference | [🔗](#from-preference-to-language-modeling-dataset) | [🔗](#from-preference-to-prompt-completion-dataset) | [🔗](#from-preference-to-prompt-only-dataset) | [🔗](#from-explicit-to-implicit-prompt-preference-dataset) | N/A | [🔗](#from-preference-to-unpaired-preference-dataset) |
|
||||
| Unpaired preference | [🔗](#from-unpaired-preference-to-language-modeling-dataset) | [🔗](#from-unpaired-preference-to-prompt-completion-dataset) | [🔗](#from-unpaired-preference-to-prompt-only-dataset) | N/A | N/A | N/A |
|
||||
| From \ To | Language modeling | Prompt-completion | Prompt-only | Preference with implicit prompt | Preference | Unpaired preference | Stepwise supervision |
|
||||
| ------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------- | --------------------------------------------------------- | --------------------------------------------------------- | ------------------------------------------------------------------------- | -------------------- |
|
||||
| Language modeling | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
|
||||
| Prompt-completion | [🔗](#from-prompt-completion-to-language-modeling-dataset) | N/A | [🔗](#from-prompt-completion-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
|
||||
| Prompt-only | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
|
||||
| Preference with implicit prompt | [🔗](#from-preference-with-implicit-prompt-to-language-modeling-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-completion-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-only-dataset) | N/A | [🔗](#from-implicit-to-explicit-prompt-preference-dataset) | [🔗](#from-preference-with-implicit-prompt-to-unpaired-preference-dataset) | N/A |
|
||||
| Preference | [🔗](#from-preference-to-language-modeling-dataset) | [🔗](#from-preference-to-prompt-completion-dataset) | [🔗](#from-preference-to-prompt-only-dataset) | [🔗](#from-explicit-to-implicit-prompt-preference-dataset) | N/A | [🔗](#from-preference-to-unpaired-preference-dataset) | N/A |
|
||||
| Unpaired preference | [🔗](#from-unpaired-preference-to-language-modeling-dataset) | [🔗](#from-unpaired-preference-to-prompt-completion-dataset) | [🔗](#from-unpaired-preference-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
|
||||
| Stepwise supervision | [🔗](#from-stepwise-supervision-to-language-modeling-dataset) | [🔗](#from-stepwise-supervision-to-prompt-completion-dataset) | [🔗](#from-stepwise-supervision-to-prompt-only-dataset) | N/A | N/A | [🔗](#from-stepwise-supervision-to-unpaired-preference-dataset) | N/A |
|
||||
|
||||
### From prompt-completion to language modeling dataset
|
||||
|
||||
@ -521,6 +589,14 @@ dataset = unpair_preference_dataset(dataset)
|
||||
'label': True}
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
|
||||
Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
|
||||
This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
|
||||
|
||||
</Tip>
|
||||
|
||||
### From preference to language modeling dataset
|
||||
|
||||
To convert a preference dataset into a language modeling dataset, remove the rejected, concatenate the prompt and the chosen into the `"text"` column.
|
||||
@ -654,9 +730,17 @@ dataset = unpair_preference_dataset(dataset)
|
||||
'label': True}
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Keep in mind that the `"chosen"` and `"rejected"` completions in a preference dataset can be both good or bad.
|
||||
Before applying [`unpair_preference_dataset`], please ensure that all `"chosen"` completions can be labeled as good and all `"rejected"` completions as bad.
|
||||
This can be ensured by checking absolute rating of each completion, e.g. from a reward model.
|
||||
|
||||
</Tip>
|
||||
|
||||
### From unpaired preference to language modeling dataset
|
||||
|
||||
To convert an unpaired preference dataset into a language modeling dataset, concatenate the prompt and the completion into the `"text"` column, and remove the prompt, completion and label columns.
|
||||
To convert an unpaired preference dataset into a language modeling dataset, concatenate prompts with good completions into the `"text"` column, and remove the prompt, completion and label columns.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
@ -670,7 +754,7 @@ dataset = Dataset.from_dict({
|
||||
def concatenate_prompt_completion(example):
|
||||
return {"text": example["prompt"] + example["completion"]}
|
||||
|
||||
dataset = dataset.map(concatenate_prompt_completion).remove_columns(["prompt", "completion", "label"])
|
||||
dataset = dataset.filter(lambda x: x["label"]).map(concatenate_prompt_completion).remove_columns(["prompt", "completion", "label"])
|
||||
```
|
||||
|
||||
```python
|
||||
@ -680,7 +764,7 @@ dataset = dataset.map(concatenate_prompt_completion).remove_columns(["prompt", "
|
||||
|
||||
### From unpaired preference to prompt-completion dataset
|
||||
|
||||
To convert an unpaired preference dataset into a prompt-completion dataset, remove the label columns.
|
||||
To convert an unpaired preference dataset into a prompt-completion dataset, filter for good labels, then remove the label columns.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
@ -691,7 +775,7 @@ dataset = Dataset.from_dict({
|
||||
"label": [True, True, False, False],
|
||||
})
|
||||
|
||||
dataset = dataset.remove_columns(["label"])
|
||||
dataset = dataset.filter(lambda x: x["label"]).remove_columns(["label"])
|
||||
```
|
||||
|
||||
```python
|
||||
@ -720,6 +804,107 @@ dataset = dataset.remove_columns(["completion", "label"])
|
||||
{'prompt': 'The sky is'}
|
||||
```
|
||||
|
||||
### From stepwise supervision to language modeling dataset
|
||||
|
||||
To convert a stepwise supervision dataset into a language modeling dataset, concatenate prompts with good completions into the `"text"` column.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["Blue light", "Water"],
|
||||
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
|
||||
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
|
||||
"labels": [[True, False], [True, True]],
|
||||
})
|
||||
|
||||
def concatenate_prompt_completions(example):
|
||||
completion = "".join(example["completions"])
|
||||
return {"text": example["prompt"] + completion}
|
||||
|
||||
dataset = dataset.filter(lambda x: all(x["labels"])).map(concatenate_prompt_completions, remove_columns=["prompt", "completions", "labels"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'text': 'Blue light scatters more in the atmosphere, so the sky is green.'}
|
||||
```
|
||||
|
||||
### From stepwise supervision to prompt completion dataset
|
||||
|
||||
To convert a stepwise supervision dataset into a prompt-completion dataset, join the good completions and remove the labels.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["Blue light", "Water"],
|
||||
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
|
||||
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
|
||||
"labels": [[True, False], [True, True]],
|
||||
})
|
||||
|
||||
def join_completions(example):
|
||||
completion = "".join(example["completions"])
|
||||
return {"completion": completion}
|
||||
|
||||
dataset = dataset.filter(lambda x: all(x["labels"])).map(join_completions, remove_columns=["completions", "labels"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'Blue light', 'completion': ' scatters more in the atmosphere, so the sky is green.'}
|
||||
```
|
||||
|
||||
### From stepwise supervision to prompt only dataset
|
||||
|
||||
To convert a stepwise supervision dataset into a prompt-only dataset, remove the completions and the labels.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["Blue light", "Water"],
|
||||
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
|
||||
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
|
||||
"labels": [[True, False], [True, True]],
|
||||
})
|
||||
|
||||
dataset = dataset.remove_columns(["completions", "labels"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'Blue light'}
|
||||
```
|
||||
|
||||
### From stepwise supervision to unpaired preference dataset
|
||||
|
||||
To convert a stepwise supervision dataset into an unpaired preference dataset, join the completions and merge the labels.
|
||||
|
||||
The method for merging the labels depends on the specific task. In this example, we use the logical AND operation. This means that if the step labels indicate the correctness of individual steps, the resulting label will reflect the correctness of the entire sequence.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
|
||||
dataset = Dataset.from_dict({
|
||||
"prompt": ["Blue light", "Water"],
|
||||
"completions": [[" scatters more in the atmosphere,", " so the sky is green."],
|
||||
[" forms a less dense structure in ice,", " which causes it to expand when it freezes."]],
|
||||
"labels": [[True, False], [True, True]],
|
||||
})
|
||||
|
||||
def merge_completions_and_labels(example):
|
||||
return {"prompt": example["prompt"], "completion": "".join(example["completions"]), "label": all(example["labels"])}
|
||||
|
||||
dataset = dataset.map(merge_completions_and_labels, remove_columns=["completions", "labels"])
|
||||
```
|
||||
|
||||
```python
|
||||
>>> dataset[0]
|
||||
{'prompt': 'Blue light', 'completion': ' scatters more in the atmosphere, so the sky is green.', 'label': False}
|
||||
```
|
||||
|
||||
## Vision datasets
|
||||
|
||||
Some trainers also support fine-tuning vision-language models (VLMs) using image-text pairs. In this scenario, it's recommended to use a conversational format, as each model handles image placeholders in text differently.
|
@ -6,16 +6,16 @@
|
||||
|
||||
| Before | After DDPO finetuning |
|
||||
| --- | --- |
|
||||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_squirrel.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_squirrel.png"/></div> |
|
||||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_crab.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_crab.png"/></div> |
|
||||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_starfish.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_starfish.png"/></div> |
|
||||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/pre_squirrel.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/post_squirrel.png"/></div> |
|
||||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/pre_crab.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/post_crab.png"/></div> |
|
||||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/pre_starfish.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/post_starfish.png"/></div> |
|
||||
|
||||
|
||||
## Getting started with Stable Diffusion finetuning with reinforcement learning
|
||||
|
||||
The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers`
|
||||
library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers.
|
||||
Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to made.
|
||||
library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers.
|
||||
Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to be made.
|
||||
|
||||
There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.**
|
||||
There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide.
|
||||
@ -26,7 +26,7 @@ For a more detailed look into the interface and the associated default implement
|
||||
|
||||
Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training.
|
||||
|
||||
Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images.
|
||||
Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images.
|
||||
|
||||
## Getting started with `examples/scripts/ddpo.py`
|
||||
|
39
docs/source/deepspeed_integration.md
Normal file
39
docs/source/deepspeed_integration.md
Normal file
@ -0,0 +1,39 @@
|
||||
# DeepSpeed Integration
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Section under construction. Feel free to contribute!
|
||||
|
||||
</Tip>
|
||||
|
||||
TRL supports training with DeepSpeed, a library that implements advanced training optimization techniques. These include optimizer state partitioning, offloading, gradient partitioning, and more.
|
||||
|
||||
DeepSpeed integrates the [Zero Redundancy Optimizer (ZeRO)](https://huggingface.co/papers/1910.02054), which allows to scale the model size proportional to the number of devices with sustained high efficiency.
|
||||
|
||||

|
||||
|
||||
## Installation
|
||||
|
||||
To use DeepSpeed with TRL, install it using the following command:
|
||||
|
||||
```bash
|
||||
pip install deepspeed
|
||||
```
|
||||
|
||||
## Running Training Scripts with DeepSpeed
|
||||
|
||||
No modifications to your training script are required. Simply run it with the DeepSpeed configuration file:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file <ACCELERATE_WITH_DEEPSPEED_CONFIG_FILE.yaml> train.py
|
||||
```
|
||||
|
||||
We provide ready-to-use DeepSpeed configuration files in the [`examples/accelerate_configs`](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) directory. For example, to run training with ZeRO Stage 2, use the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml train.py
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin.
|
@ -30,7 +30,7 @@ We selected the following models for our experiments to show that TRL can be eas
|
||||
* [`EleutherAI/gpt-neo-2.7B`](https://huggingface.co/EleutherAI/gpt-neo-2.7B) (2.7 billion parameters)
|
||||
* [`EleutherAI/gpt-j-6B`](https://huggingface.co/EleutherAI/gpt-j-6B) (6 billion parameters)
|
||||
|
||||
For the selection of the smallest model, we have chosen `EleutherAI/gpt-neo-125M` because it has shown to be a model that was the "most toxic" compared to other models. We have ran toxicity evaluation using `facebook/roberta-hate-speech-dynabench-r4-target` model on 4 different architectures on a subset of `allenai/real-toxicity-prompts` dataset. Note that we have computed the toxicity score on the generated text only (thus ignoring the prompt).
|
||||
For the selection of the smallest model, we have chosen `EleutherAI/gpt-neo-125M` because it has shown to be a model that was the "most toxic" compared to other models. We have run toxicity evaluation using `facebook/roberta-hate-speech-dynabench-r4-target` model on 4 different architectures on a subset of `allenai/real-toxicity-prompts` dataset. Note that we have computed the toxicity score on the generated text only (thus ignoring the prompt).
|
||||
|
||||
| Model | Mean toxicity score |
|
||||
|---|---|
|
||||
@ -45,7 +45,7 @@ When doing PPO, it is very important to design the problem efficiently so that t
|
||||
|
||||
### Pre-processing the dataset
|
||||
|
||||
The dataset consist of prompts and their continuations, and each of them has an associated `toxicity` score.
|
||||
The dataset consists of prompts and their continuations, and each of them has an associated `toxicity` score.
|
||||
|
||||
A `prompt` example:
|
||||
```
|
||||
@ -83,12 +83,12 @@ As a compromise between the two we took for a context window of 10 to 15 tokens
|
||||
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-long-vs-short-context.png">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-long-vs-short-context.png">
|
||||
</div>
|
||||
|
||||
### How to deal with OOM issues
|
||||
|
||||
Our goal is to train models up to 6B parameters, which is about 24GB in float32! Here two tricks we use to be able to train a 6B model on a single 40GB-RAM GPU:
|
||||
Our goal is to train models up to 6B parameters, which is about 24GB in float32! Here are two tricks we use to be able to train a 6B model on a single 40GB-RAM GPU:
|
||||
|
||||
- Use `bfloat16` precision: Simply load your model in `bfloat16` when calling `from_pretrained` and you can reduce the size of the model by 2:
|
||||
|
||||
@ -101,15 +101,15 @@ and the optimizer will take care of computing the gradients in `bfloat16` precis
|
||||
- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by specifying `num_shared_layers` argument when calling the `create_reference_model()` function. For example, if you want to share the first 6 layers of the model, you can do it like this:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-shared-layers.png">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-shared-layers.png">
|
||||
</div>
|
||||
|
||||
```python
|
||||
ref_policy = create_reference_model(model, num_shared_layers=6)
|
||||
trainer = PPOTrainer(..., ref_policy=ref_policy)
|
||||
ref_model = create_reference_model(model, num_shared_layers=6)
|
||||
trainer = PPOTrainer(..., ref_model=ref_model)
|
||||
```
|
||||
|
||||
In the example above this means that the model have the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model).
|
||||
In the example above this means that the model has the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model).
|
||||
|
||||
- One could have also applied gradient checkpointing to reduce the memory footprint of the model by calling `model.pretrained_model.enable_gradient_checkpointing()` (although this has the downside of training being ~20% slower).
|
||||
|
||||
@ -124,13 +124,13 @@ We have decided to keep 3 models in total that correspond to our best models:
|
||||
We have used different learning rates for each model, and have found out that the largest models were quite hard to train and can easily lead to collapse mode if the learning rate is not chosen correctly (i.e. if the learning rate is too high):
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-collapse-mode.png">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-collapse-mode.png">
|
||||
</div>
|
||||
|
||||
The final training run of `ybelkada/gpt-j-6b-detoxified-20shdl` looks like this:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-gpt-j-final-run-2.png">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-gpt-j-final-run-2.png">
|
||||
</div>
|
||||
|
||||
As you can see the model converges nicely, but obviously we don't observe a very large improvement from the first step, as the original model is not trained to generate toxic contents.
|
||||
@ -138,7 +138,7 @@ As you can see the model converges nicely, but obviously we don't observe a very
|
||||
Also we have observed that training with larger `mini_batch_size` leads to smoother convergence and better results on the test set:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-gpt-j-mbs-run.png">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-gpt-j-mbs-run.png">
|
||||
</div>
|
||||
|
||||
## Results
|
||||
@ -159,7 +159,7 @@ We report the toxicity score of 400 sampled examples, compute its mean and stand
|
||||
|
||||
<div class="column" style="text-align:center">
|
||||
<figure>
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-final-barplot.png" style="width:80%">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-final-barplot.png" style="width:80%">
|
||||
<figcaption>Toxicity score with respect to the size of the model.</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
@ -167,7 +167,7 @@ We report the toxicity score of 400 sampled examples, compute its mean and stand
|
||||
Below are few generation examples of `gpt-j-6b-detox` model:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-toxicity-examples.png">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-toxicity-examples.png">
|
||||
</div>
|
||||
|
||||
The evaluation script can be found [here](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py).
|
||||
@ -176,7 +176,7 @@ The evaluation script can be found [here](https://github.com/huggingface/trl/blo
|
||||
|
||||
The results are quite promising, as we can see that the models are able to reduce the toxicity score of the generated text by an interesting margin. The gap is clear for `gpt-neo-2B` model but we less so for the `gpt-j-6B` model. There are several things we could try to improve the results on the largest model starting with training with larger `mini_batch_size` and probably allowing to back-propagate through more layers (i.e. use less shared layers).
|
||||
|
||||
To sum up, in addition to human feedback this could be a useful additional signal when training large language models to ensure there outputs are less toxic as well as useful.
|
||||
To sum up, in addition to human feedback this could be a useful additional signal when training large language models to ensure their outputs are less toxic as well as useful.
|
||||
|
||||
### Limitations
|
||||
|
60
docs/source/distributing_training.md
Normal file
60
docs/source/distributing_training.md
Normal file
@ -0,0 +1,60 @@
|
||||
# Distributing Training
|
||||
|
||||
<Tip warning={true}>
|
||||
Section under construction. Feel free to contribute!
|
||||
</Tip>
|
||||
|
||||
## Multi-GPU Training with TRL
|
||||
|
||||
The trainers in TRL use [🤗 Accelerate](https://github.com/huggingface/accelerate) to enable distributed training across multiple GPUs or nodes. To do so, first create an [🤗 Accelerate](https://github.com/huggingface/accelerate) config file by running
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
and answering the questions according to your multi-GPU / multi-node setup. You can then launch distributed training by running:
|
||||
|
||||
```bash
|
||||
accelerate launch train.py
|
||||
```
|
||||
|
||||
We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.:
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml train.py <SCRIPT_ARGS>
|
||||
```
|
||||
|
||||
This automatically distributes the workload across all available GPUs.
|
||||
|
||||
Under the hood, [🤗 Accelerate](https://github.com/huggingface/accelerate) creates one model per GPU. Each process:
|
||||
- Processes its own batch of data
|
||||
- Computes the loss and gradients for that batch
|
||||
- Shares gradient updates across all GPUs
|
||||
|
||||

|
||||
|
||||
The effective batch size is calculated as:
|
||||
|
||||
$$
|
||||
\text{Batch Size} = \text{per\_device\_train\_batch\_size} \times \text{num\_devices} \times \text{gradient\_accumulation\_steps}
|
||||
$$
|
||||
|
||||
To maintain a consistent batch size when scaling to multiple GPUs, make sure to update `per_device_train_batch_size` and `gradient_accumulation_steps` accordingly.
|
||||
|
||||
Example, these configurations are equivalent, and should yield the same results:
|
||||
|
||||
| Number of GPUs | Per device batch size | Gradient accumulation steps | Comments |
|
||||
| --- | --- | --- | --- |
|
||||
| 1 | 32 | 1 | Possibly high memory usage, but faster training |
|
||||
| 1 | 4 | 8 | Lower memory usage, slower training |
|
||||
| 8 | 4 | 1 | Multi-GPU to get the best of both worlds |
|
||||
|
||||
<Tip>
|
||||
|
||||
Having one model per GPU can lead to high memory usage, which may not be feasible for large models or low-memory GPUs. In such cases, you can leverage [DeepSpeed](https://github.com/deepspeedai/DeepSpeed), which provides optimizations like model sharding, Zero Redundancy Optimizer, mixed precision training, and offloading to CPU or NVMe. Check out our [DeepSpeed Integration](deepspeed_integration.md) guide for more details.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Multi-Nodes Training
|
||||
|
||||
We're working on a guide for multi-node training. Stay tuned! 🚀
|
@ -1,6 +1,6 @@
|
||||
# DPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=dpo,trl)
|
||||
[](https://huggingface.co/models?other=dpo,trl) [](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment)
|
||||
|
||||
## Overview
|
||||
|
||||
@ -59,11 +59,11 @@ accelerate launch train_dpo.py
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 3 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
|
||||
|
||||

|
||||

|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-DPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-DPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-DPO
|
||||
<pre><code>$ transformers-cli chat --model_name_or_path trl-lib/Qwen2-0.5B-DPO
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
@ -81,7 +81,7 @@ The best programming language based on these factors is subjective and depends o
|
||||
|
||||
## Expected dataset type
|
||||
|
||||
DPO requires a [preference dataset](dataset_formats#preference). The [`DPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
DPO requires a [preference dataset](dataset_formats#preference). The [`DPOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
|
||||
Although the [`DPOTrainer`] supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the `"chosen"` and `"rejected"` columns. For more information, refer to the [preference style](dataset_formats#preference) section.
|
||||
|
||||
@ -112,12 +112,12 @@ For a complete example of fine-tuning a vision-language model, refer to the scri
|
||||
|
||||
## Example script
|
||||
|
||||
We provide an example script to train a model using the DPO method. The script is available in [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py)
|
||||
We provide an example script to train a model using the DPO method. The script is available in [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py)
|
||||
|
||||
To test the DPO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), run the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch examples/scripts/dpo.py \
|
||||
accelerate launch trl/scripts/dpo.py \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
||||
--dataset_name trl-lib/ultrafeedback_binarized \
|
||||
--num_train_epochs 1 \
|
||||
@ -150,6 +150,7 @@ The DPO algorithm supports several loss functions. The loss function can be set
|
||||
| `"sppo_hard"` | The [SPPO](https://huggingface.co/papers/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. The implementation approximates this algorithm by employing hard label probabilities, assigning 1 to the winner and 0 to the loser. |
|
||||
| `"aot"` or `loss_type="aot_pair"` | The [AOT](https://huggingface.co/papers/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_pair"` is for unpaired datasets. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_pair"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size. |
|
||||
| `"apo_zero"` or `loss_type="apo_down"` | The [APO](https://huggingface.co/papers/2408.06266) method introduces an "anchored" version of the alignment objective. There are two variants: `apo_zero` and `apo_down`. The `apo_zero` loss increases the likelihood of winning outputs while decreasing the likelihood of losing outputs, making it suitable when the model is less performant than the winning outputs. On the other hand, `apo_down` decreases the likelihood of both winning and losing outputs, but with a stronger emphasis on reducing the likelihood of losing outputs. This variant is more effective when the model is better than the winning outputs. |
|
||||
| `"discopop"` | The [DiscoPOP](https://huggingface.co/papers/2406.08414) paper uses LLMs to discover more efficient offline preference optimization losses. In the paper the proposed DiscoPOP loss (which is a log-ratio modulated loss) outperformed other optimization losses on different tasks (IMDb positive text generation, Reddit TLDR summarization, and Alpaca Eval 2.0). |
|
||||
|
||||
### Label smoothing
|
||||
|
||||
@ -277,6 +278,6 @@ dpo_trainer = DPOTrainer(
|
||||
|
||||
[[autodoc]] DPOConfig
|
||||
|
||||
## PreferenceCollator
|
||||
## DataCollatorForPreference
|
||||
|
||||
[[autodoc]] trainer.dpo_trainer.PreferenceCollator
|
||||
[[autodoc]] trainer.dpo_trainer.DataCollatorForPreference
|
@ -31,32 +31,39 @@ Then, it is encouraged to launch jobs with `accelerate launch`!
|
||||
|
||||
# Maintained Examples
|
||||
|
||||
Scripts can be used as examples of how to use TRL trainers. They are located in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) directory. Additionally, we provide examples in the [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directory. These examples are maintained and tested regularly.
|
||||
|
||||
|
||||
| File | Description |
|
||||
| ----------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| [`examples/scripts/alignprop.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/alignprop.py) | This script shows how to use the [`AlignPropTrainer`] to fine-tune a diffusion model. |
|
||||
| [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. |
|
||||
| [`examples/scripts/chat.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/chat.py) | This script allows you to load and use a model as a chatbot. |
|
||||
| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`examples/scripts/ddpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ddpo.py) | This script shows how to use the [`DDPOTrainer`] to fine-tune a stable diffusion model using reinforcement learning. |
|
||||
| [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. |
|
||||
| [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a stable to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`KTOTrainer`] to fine-tune a model. |
|
||||
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language |
|
||||
| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
|
||||
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a reward model on your own dataset. |
|
||||
| [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a model or adapters into a target dataset. |
|
||||
| [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models so users may see unexpected behaviour in other model architectures. |
|
||||
| File | Description |
|
||||
| --- | --- |
|
||||
| [`examples/scripts/alignprop.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/alignprop.py) | This script shows how to use the [`AlignPropTrainer`] to fine-tune a diffusion model. |
|
||||
| [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. |
|
||||
| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`examples/scripts/ddpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ddpo.py) | This script shows how to use the [`DDPOTrainer`] to fine-tune a stable diffusion model using reinforcement learning. |
|
||||
| [`examples/scripts/dpo_online.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_online.py) | This script shows how to use the [`OnlineDPOTrainer`] to fine-tune a model. |
|
||||
| [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. |
|
||||
| [`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py) | This script shows how to use the [`GKDTrainer`] to fine-tune a model. |
|
||||
| [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`NashMDTrainer`] to fine-tune a model. |
|
||||
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language |
|
||||
| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
|
||||
| [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py) | This script shows how to use the [`PRMTrainer`] to fine-tune a Process-supervised Reward Model (PRM). |
|
||||
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a Outcome Reward Model (ORM) on your own dataset. |
|
||||
| [`examples/scripts/rloo/rloo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo/rloo.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language |
|
||||
| [`examples/scripts/rloo/rloo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/rloo/rloo_tldr.py) | This script shows how to use the [`RLOOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
|
||||
| [`examples/scripts/sft_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model. |
|
||||
| [`examples/scripts/sft_video_llm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_video_llm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Video Language Model. |
|
||||
| [`examples/scripts/sft_vlm_gemma3.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Gemma 3 model on vision to text tasks. |
|
||||
| [`examples/scripts/sft_vlm_smol_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_smol_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a SmolVLM model. |
|
||||
| [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models so users may see unexpected behaviour in other model architectures. |
|
||||
| [`examples/scripts/xpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/xpo.py) | This script shows how to use the [`XPOTrainer`] to fine-tune a model. |
|
||||
|
||||
Here are also some easier-to-run colab notebooks that you can use to get started with TRL:
|
||||
|
||||
| File | Description |
|
||||
| --------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------- |
|
||||
| [`examples/notebooks/best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb) | This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. |
|
||||
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
|
||||
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
|
||||
| File | Description |
|
||||
| --- | --- |
|
||||
| [`examples/notebooks/best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb) | This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. |
|
||||
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
|
||||
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
|
||||
|
||||
|
||||
We also have some other examples that are less maintained but can be used as a reference:
|
||||
|
511
docs/source/grpo_trainer.md
Normal file
511
docs/source/grpo_trainer.md
Normal file
@ -0,0 +1,511 @@
|
||||
# GRPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=grpo,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
TRL supports the GRPO Trainer for training language models, as described in the paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300) by [Zhihong Shao](https://huggingface.co/syhia), [Peiyi Wang](https://huggingface.co/peiyiwang89), [Qihao Zhu](https://huggingface.co/zqh11), Runxin Xu, [Junxiao Song](https://huggingface.co/haha-point), Mingchuan Zhang, Y. K. Li, Y. Wu, [Daya Guo](https://huggingface.co/guoday).
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Mathematical reasoning poses a significant challenge for language models due to its complex and structured nature. In this paper, we introduce DeepSeekMath 7B, which continues pre-training DeepSeek-Coder-Base-v1.5 7B with 120B math-related tokens sourced from Common Crawl, together with natural language and code data. DeepSeekMath 7B has achieved an impressive score of 51.7% on the competition-level MATH benchmark without relying on external toolkits and voting techniques, approaching the performance level of Gemini-Ultra and GPT-4. Self-consistency over 64 samples from DeepSeekMath 7B achieves 60.9% on MATH. The mathematical reasoning capability of DeepSeekMath is attributed to two key factors: First, we harness the significant potential of publicly available web data through a meticulously engineered data selection pipeline. Second, we introduce Group Relative Policy Optimization (GRPO), a variant of Proximal Policy Optimization (PPO), that enhances mathematical reasoning abilities while concurrently optimizing the memory usage of PPO.
|
||||
|
||||
This post-training method was contributed by [Quentin Gallouédec](https://huggingface.co/qgallouedec).
|
||||
|
||||
## Quick start
|
||||
|
||||
This example demonstrates how to train a model using the GRPO method. We train a [Qwen 0.5B Instruct model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) with the prompts from the [TLDR dataset](https://huggingface.co/datasets/trl-lib/tldr) (completion column is ignored!). You can view the data in the dataset here:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/tldr/embed/viewer/default/train?row=0"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
Below is the script to train the model.
|
||||
|
||||
```python
|
||||
# train_grpo.py
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
# Define the reward function, which rewards completions that are close to 20 characters
|
||||
def reward_len(completions, **kwargs):
|
||||
return [-abs(20 - len(completion)) for completion in completions]
|
||||
|
||||
training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=10)
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=reward_len,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Execute the script using the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_grpo.py
|
||||
```
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 1 day.
|
||||
|
||||

|
||||
|
||||
## Looking deeper into the GRPO method
|
||||
|
||||
GRPO is an online learning algorithm, meaning it improves iteratively by using the data generated by the trained model itself during training. The intuition behind GRPO objective is to maximize the advantage of the generated completions, while ensuring that the model remains close to the reference policy. To understand how GRPO works, it can be broken down into four main steps: **Generating completions**, **computing the advantage**, **estimating the KL divergence**, and **computing the loss**.
|
||||
|
||||

|
||||
|
||||
### Generating completions
|
||||
|
||||
At each training step, we sample a batch of prompts and generate a set of \\( G \\) completions for each prompt (denoted as \\( o_i \\)).
|
||||
|
||||
### Computing the advantage
|
||||
|
||||
For each of the \\( G \\) sequences, we compute the reward using a reward model. To align with the comparative nature of reward models—typically trained on datasets of comparisons between outputs for the same question—the advantage is calculated to reflect these relative comparisons. It is normalized as follows:
|
||||
|
||||
$$\hat{A}_{i,t} = \frac{r_i - \text{mean}(\mathbf{r})}{\text{std}(\mathbf{r})}$$
|
||||
|
||||
This approach gives the method its name: **Group Relative Policy Optimization (GRPO)**.
|
||||
|
||||
<Tip>
|
||||
|
||||
It was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that scaling by \\( \text{std}(\mathbf{r}) \\) may cause a question-level difficulty bias. You can disable this scaling by setting `scale_rewards=False` in [`GRPOConfig`].
|
||||
|
||||
</Tip>
|
||||
|
||||
### Estimating the KL divergence
|
||||
|
||||
KL divergence is estimated using the approximator introduced by [Schulman et al. (2020)](http://joschu.net/blog/kl-approx.html). The approximator is defined as follows:
|
||||
|
||||
$$\mathbb{D}_{\text{KL}}\left[\pi_\theta \|\pi_{\text{ref}}\right] = \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - \log \frac{\pi_{\text{ref}}(o_{i,t} \mid q, o_{i,<t})}{\pi_\theta(o_{i,t} \mid q, o_{i,<t})} - 1,
|
||||
$$
|
||||
|
||||
### Computing the loss
|
||||
|
||||
The objective is to maximize the advantage while ensuring that the model remains close to the reference policy. Consequently, the loss is defined as follows:
|
||||
|
||||
$$
|
||||
\mathcal{L}_{\text{GRPO}}(\theta) = -\frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
|
||||
$$
|
||||
|
||||
where the first term represents the scaled advantage and the second term penalizes deviations from the reference policy through KL divergence.
|
||||
|
||||
<Tip>
|
||||
|
||||
Note that compared to the original formulation in [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300), we don't scale by \\( \frac{1}{|o_i|} \\) because it was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that this introduces a response-level length bias. More details in [loss types](#loss-types).
|
||||
|
||||
</Tip>
|
||||
|
||||
In the original paper, this formulation is generalized to account for multiple updates after each generation (denoted \\( \mu \\), can be set with `num_iterations` in [`GRPOConfig`]) by leveraging the **clipped surrogate objective**:
|
||||
|
||||
$$
|
||||
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} \left[ \min \left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})} \hat{A}_{i,t}, \, \text{clip}\left( \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}, 1 - \epsilon, 1 + \epsilon \right) \hat{A}_{i,t} \right) - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right] \right],
|
||||
$$
|
||||
|
||||
where \\(\text{clip}(\cdot, 1 - \epsilon, 1 + \epsilon) \\) ensures that updates do not deviate excessively from the reference policy by bounding the policy ratio between \\( 1 - \epsilon \\) and \\( 1 + \epsilon \\).
|
||||
When \\( \mu = 1 \\) (default in TRL), the clipped surrogate objective simplifies to the original objective.
|
||||
|
||||
#### Loss Types
|
||||
|
||||
Several formulations of the objective have been proposed in the literature. Initially, the objective of GRPO was defined as follows:
|
||||
|
||||
$$
|
||||
\mathcal{L}_{\text{GRPO}}(\theta) = - \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} l_{i,t},
|
||||
$$
|
||||
|
||||
where
|
||||
|
||||
$$
|
||||
l_{i,t} = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\left[\pi_\theta(o_{i,t} \mid q, o_{i,< t})\right]_{\text{no grad}}} \hat{A}_{i,t} - \beta \mathbb{D}_{\text{KL}}\left[\pi_\theta \| \pi_{\text{ref}}\right].
|
||||
$$
|
||||
|
||||
The DAPO paper highlights the limitations of the GRPO algorithm’s sample-level loss in long-CoT scenarios, where longer responses are under-penalized, leading to poorer quality outputs. The proposed solution is a token-level normalization, which better handles longer sequences by assigning more balanced rewards to individual tokens, regardless of response length:
|
||||
|
||||
$$
|
||||
\mathcal{L}_{\text{DAPO}}(\theta) = - \frac{1}{\sum_{i=1}^G |o_i|} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},
|
||||
$$
|
||||
|
||||
|
||||
Furthermore, it was demonstrated in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that the initial GRPO formulation introduces a response length bias. They show that while the DAPO formulation reduces this bias, it does not eliminate it completely. To fully remove this bias, they propose dividing by a constant instead of the sequence length, resulting in the following formulation:
|
||||
|
||||
$$
|
||||
\mathcal{L}_{\text{Dr. GRPO}}(\theta) = - \frac{1}{LG} \sum_{i=1}^G \sum_{t=1}^{|o_i|} l_{i,t},
|
||||
$$
|
||||
|
||||
This constant is recommended to be the maximum completion length. To use this formulation, set `loss_type="dr_grpo"` in the [`GRPOConfig`].
|
||||
|
||||
## Logged metrics
|
||||
|
||||
- `num_tokens`: The total number of tokens processed so far, including both prompts and completions.
|
||||
- `completions/mean_length`: The average length of generated completions.
|
||||
- `completions/min_length`: The minimun length of generated completions.
|
||||
- `completions/max_length`: The maximum length of generated completions.
|
||||
- `completions/mean_terminated_length`: The average length of generated completions that terminate with EOS.
|
||||
- `completions/min_terminated_length`: The minimun length of generated completions that terminate with EOS.
|
||||
- `completions/max_terminated_length`: The maximum length of generated completions that terminate with EOS.
|
||||
- `completions/clipped_ratio` : The ratio of truncated (clipped) completions.
|
||||
- `reward/{reward_func_name}/mean`: The average reward from a specific reward function.
|
||||
- `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function.
|
||||
- `reward`: The overall average reward after applying reward weights.
|
||||
- `reward_std`: The standard deviation of the overall reward within each batch after applying reward weights.
|
||||
- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero.
|
||||
- `clip_ratio/region_mean`: The ratio of token probabilities where the GRPO objective is clipped to stay within the trust region:
|
||||
$$
|
||||
\text{clip}\left( r_{i,t}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \qquad r_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}\,.
|
||||
$$
|
||||
A higher value means more tokens are clipped, which constrains how much the policy $\pi_\theta$ can change.
|
||||
- `clip_ratio/low_mean`: The average ratio of token probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
|
||||
- `clip_ratio/low_min`: The minimum ratio of token probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
|
||||
- `clip_ratio/high_mean`: The average ratio of token probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\)
|
||||
- `clip_ratio/high_max`: The maximum ratio of token probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\).
|
||||
|
||||
## Customization
|
||||
|
||||
### Speed up training with vLLM-powered generation
|
||||
|
||||
Generation is often the main bottleneck when training with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a high-throughput, low-latency inference engine for LLMs. To enable it, first install the package with
|
||||
```shell
|
||||
pip install trl[vllm]
|
||||
```
|
||||
|
||||
We support two ways of using vLLM during training: **server mode** and **colocate mode**.
|
||||
|
||||
#### 🔌 Option 1: Server mode
|
||||
|
||||
In this mode, vLLM runs in a separate process (and using separate GPUs) and communicates with the trainer via HTTP. This is ideal if you have dedicated GPUs for inference.
|
||||
|
||||
1. **Start the vLLM server**:
|
||||
```bash
|
||||
trl vllm-serve --model <model_name>
|
||||
```
|
||||
|
||||
2. **Enable server mode in your training script**:
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
...,
|
||||
use_vllm=True,
|
||||
vllm_mode="server", # default value, can be omitted
|
||||
)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Make sure that the server is using different GPUs than the trainer, otherwise you may run into NCCL errors. You can specify the GPUs to use with the `CUDA_VISIBLE_DEVICES` environment variable.
|
||||
|
||||
</Tip>
|
||||
|
||||
#### 🧩 Option 2: Colocate mode
|
||||
|
||||
In this mode, vLLM runs inside the trainer process and shares GPU memory with the training model. This avoids launching a separate server and can improve GPU utilization, but may lead to memory contention on the training GPUs.
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(
|
||||
...,
|
||||
use_vllm=True,
|
||||
vllm_mode="colocate",
|
||||
)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
Depending on the model size and the overall GPU memory requirements for training, you may need to adjust the `vllm_gpu_memory_utilization` parameter in [`GRPOConfig`] to avoid underutilization or out-of-memory errors.
|
||||
|
||||
</Tip>
|
||||
|
||||
For more information, see [Speeding up training with vLLM](speeding_up_training#vllm-for-fast-generation-in-online-methods).
|
||||
|
||||
### GRPO at scale: train a 70B+ Model on multiple nodes
|
||||
|
||||
When training large models like **Qwen2.5-72B**, you need several key optimizations to make the training efficient and scalable across multiple GPUs and nodes. These include:
|
||||
|
||||
- **DeepSpeed ZeRO Stage 3**: ZeRO leverages data parallelism to distribute model states (weights, gradients, optimizer states) across multiple GPUs and CPUs, reducing memory and compute requirements on each device. Since large models cannot fit on a single GPU, using ZeRO Stage 3 is required for training such model. For more details, see [DeepSpeed Integration](deepspeed_integration).
|
||||
- **Accelerate**: Accelerate is a library that simplifies distributed training across multiple GPUs and nodes. It provides a simple API to launch distributed training and handles the complexities of distributed training, such as data parallelism, gradient accumulation, and distributed data loading. For more details, see [Distributing Training](distributing_training).
|
||||
- **vLLM**: See the previous section on how to use vLLM to speed up generation.
|
||||
|
||||
Below is an example SLURM script to train a 70B model with GRPO on multiple nodes. This script trains a model on 4 nodes and uses the 5th node for vLLM-powered generation.
|
||||
|
||||
```sh
|
||||
#!/bin/bash
|
||||
#SBATCH --nodes=5
|
||||
#SBATCH --gres=gpu:8
|
||||
|
||||
# Get the list of allocated nodes
|
||||
NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
|
||||
|
||||
# Assign the first 4 nodes for training and the 5th node for vLLM
|
||||
TRAIN_NODES="${NODELIST[@]:0:4}" # Nodes 0, 1, 2, 3 for training
|
||||
VLLM_NODE="${NODELIST[4]}" # Node 4 for vLLM
|
||||
|
||||
# Run training on the first 4 nodes (Group 1)
|
||||
srun --nodes=4 --ntasks=4 --nodelist="${NODELIST[@]:0:4}" accelerate launch \
|
||||
--config_file examples/accelerate_configs/deepspeed_zero3.yaml \
|
||||
--num_processes 32 \
|
||||
--num_machines 4 \
|
||||
--main_process_ip ${NODELIST[0]} \
|
||||
--machine_rank $SLURM_PROCID \
|
||||
--rdzv_backend c10d \
|
||||
train_grpo.py \
|
||||
--server_ip $VLLM_NODE &
|
||||
|
||||
# Run vLLM server on the 5th node (Group 2)
|
||||
srun --nodes=1 --ntasks=1 --nodelist="${NODELIST[4]}" trl vllm-serve --model Qwen/Qwen2.5-72B --tensor_parallel_size 8 &
|
||||
|
||||
wait
|
||||
```
|
||||
|
||||
```python
|
||||
import argparse
|
||||
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--vllm_server_host", type=str, default="", help="The server IP")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Example dataset from TLDR
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
# Dummy reward function: count the number of unique characters in the completions
|
||||
def reward_num_unique_chars(completions, **kwargs):
|
||||
return [len(set(c)) for c in completions]
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir="Qwen2.5-72B-GRPO",
|
||||
per_device_train_batch_size=4,
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
logging_steps=10,
|
||||
use_vllm=True,
|
||||
vllm_server_host=args.vllm_server_host.replace("ip-", "").replace("-", "."), # from ip-X-X-X-X to X.X.X.X
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(model="Qwen/Qwen2.5-72B", args=training_args, reward_funcs=reward_num_unique_chars, train_dataset=dataset)
|
||||
trainer.train()
|
||||
|
||||
if __name__=="__main__":
|
||||
main()
|
||||
```
|
||||
|
||||
### Using a custom reward function
|
||||
|
||||
The [`GRPOTrainer`] supports using custom reward functions instead of dense reward models. To ensure compatibility, your reward function must satisfy the following requirements:
|
||||
|
||||
1. **Input arguments**:
|
||||
- The function must accept the following as keyword arguments:
|
||||
- `prompts` (contains the prompts),
|
||||
- `completions` (contains the generated completions),
|
||||
- `completions_ids` (contains the tokenized completions),
|
||||
- All columns names (but `prompt`) that the dataset may have. For example, if the dataset contains a column named `ground_truth`, the function will be called with `ground_truth` as a keyword argument.
|
||||
|
||||
The easiest way to comply with this requirement is to use `**kwargs` in the function signature.
|
||||
- Depending on the dataset format, the input will vary:
|
||||
- For [standard format](dataset_formats#standard), `prompts` and `completions` will be lists of strings.
|
||||
- For [conversational format](dataset_formats#conversational), `prompts` and `completions` will be lists of message dictionaries.
|
||||
|
||||
2. **Return value**: The function must return a list of floats. Each float represents the reward corresponding to a single completion.
|
||||
|
||||
#### Example 1: Reward longer completions
|
||||
|
||||
Below is an example of a reward function for a standard format that rewards longer completions:
|
||||
|
||||
```python
|
||||
def reward_func(completions_ids, **kwargs):
|
||||
"""Reward function that assigns higher scores to longer completions (in terms of token count)."""
|
||||
return [float(len(ids)) for ids in completions_ids]
|
||||
```
|
||||
|
||||
You can test it as follows:
|
||||
|
||||
```python
|
||||
>>> prompts = ["The sky is", "The sun is"] # not used in the reward function, but the trainer will pass it
|
||||
>>> completions = [" blue.", " in the sky."] # not used in the reward function, but the trainer will pass it
|
||||
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]]
|
||||
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
|
||||
[2.0, 4.0]
|
||||
```
|
||||
|
||||
#### Example 1.1: Reward longer completions (based in the number of characters)
|
||||
|
||||
Same as the previous example, but this time the reward function is based on the number of characters instead of tokens.
|
||||
|
||||
```python
|
||||
def reward_func(completions, **kwargs):
|
||||
"""Reward function that assigns higher scores to longer completions (in terms of character count)."""
|
||||
return [float(len(completion)) for completion in completions]
|
||||
```
|
||||
|
||||
You can test it as follows:
|
||||
|
||||
```python
|
||||
>>> prompts = ["The sky is", "The sun is"]
|
||||
>>> completions = [" blue.", " in the sky."]
|
||||
>>> completions_ids = [[6303, 13], [304, 279, 12884, 13]] # not used in the reward function, but the trainer will pass it
|
||||
>>> reward_func(prompts=prompts, completions=completions, completions_ids=completions_ids)
|
||||
[6.0, 12.0]
|
||||
```
|
||||
|
||||
#### Example 2: Reward completions with specific format
|
||||
|
||||
Below is an example of a reward function that checks if the completion has a specific format. This example is inspired by the _format reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948).
|
||||
It is designed for conversational format, where prompts and completions consist of structured messages.
|
||||
|
||||
```python
|
||||
import re
|
||||
|
||||
def format_reward_func(completions, **kwargs):
|
||||
"""Reward function that checks if the completion has a specific format."""
|
||||
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
|
||||
completion_contents = [completion[0]["content"] for completion in completions]
|
||||
matches = [re.match(pattern, content) for content in completion_contents]
|
||||
return [1.0 if match else 0.0 for match in matches]
|
||||
```
|
||||
|
||||
You can test this function as follows:
|
||||
|
||||
```python
|
||||
>>> prompts = [
|
||||
... [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}],
|
||||
... [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}],
|
||||
... ]
|
||||
>>> completions = [
|
||||
... [{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],
|
||||
... [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],
|
||||
... ]
|
||||
>>> format_reward_func(prompts=prompts, completions=completions)
|
||||
[1.0, 0.0]
|
||||
```
|
||||
|
||||
#### Example 3: Reward completions based on a reference
|
||||
|
||||
Below is an example of a reward function that checks if the completion is correct. This example is inspired by the _accuracy reward_ function used in the paper [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning](https://huggingface.co/papers/2501.12948).
|
||||
This example is designed for [standard format](dataset_formats#standard), where the dataset contains a column named `ground_truth`.
|
||||
|
||||
```python
|
||||
import re
|
||||
|
||||
def reward_func(completions, ground_truth, **kwargs):
|
||||
# Regular expression to capture content inside \boxed{}
|
||||
matches = [re.search(r"\\boxed\{(.*?)\}", completion) for completion in completions]
|
||||
contents = [match.group(1) if match else "" for match in matches]
|
||||
# Reward 1 if the content is the same as the ground truth, 0 otherwise
|
||||
return [1.0 if c == gt else 0.0 for c, gt in zip(contents, ground_truth)]
|
||||
```
|
||||
|
||||
You can test this function as follows:
|
||||
|
||||
```python
|
||||
>>> prompts = ["Problem: Solve the equation $2x + 3 = 7$. Solution:", "Problem: Solve the equation $3x - 5 = 10$."]
|
||||
>>> completions = [r" The solution is \boxed{2}.", r" The solution is \boxed{6}."]
|
||||
>>> ground_truth = ["2", "5"]
|
||||
>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth)
|
||||
[1.0, 0.0]
|
||||
```
|
||||
#### Example 4: Multi-task reward functions
|
||||
|
||||
Below is an example of using multiple reward functions in the [`GRPOTrainer`]. In this example, we define two task-specific reward functions: `math_reward_func` and `coding_reward_func`. The `math_reward_func` rewards math problems based on their correctness, while the `coding_reward_func` rewards coding problems based on whether the solution works.
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import GRPOTrainer
|
||||
|
||||
# Define a dataset that contains both math and coding problems
|
||||
dataset = Dataset.from_list(
|
||||
[
|
||||
{"prompt": "What is 2+2?", "task": "math"},
|
||||
{"prompt": "Write a function that returns the sum of two numbers.", "task": "code"},
|
||||
{"prompt": "What is 3*4?", "task": "math"},
|
||||
{"prompt": "Write a function that returns the product of two numbers.", "task": "code"},
|
||||
]
|
||||
)
|
||||
|
||||
# Math-specific reward function
|
||||
def math_reward_func(prompts, completions, task, **kwargs):
|
||||
rewards = []
|
||||
for prompt, completion, t in zip(prompts, completions, task):
|
||||
if t == "math":
|
||||
# Calculate math-specific reward
|
||||
correct = check_math_solution(prompt, completion)
|
||||
reward = 1.0 if correct else -1.0
|
||||
rewards.append(reward)
|
||||
else:
|
||||
# Return None for non-math tasks
|
||||
rewards.append(None)
|
||||
return rewards
|
||||
|
||||
# Coding-specific reward function
|
||||
def coding_reward_func(prompts, completions, task, **kwargs):
|
||||
rewards = []
|
||||
for prompt, completion, t in zip(prompts, completions, task):
|
||||
if t == "coding":
|
||||
# Calculate coding-specific reward
|
||||
works = test_code_solution(prompt, completion)
|
||||
reward = 1.0 if works else -1.0
|
||||
rewards.append(reward)
|
||||
else:
|
||||
# Return None for non-coding tasks
|
||||
rewards.append(None)
|
||||
return rewards
|
||||
|
||||
# Use both task-specific reward functions
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=[math_reward_func, coding_reward_func],
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
In this example, the `math_reward_func` and `coding_reward_func` are designed to work with a mixed dataset that contains both math and coding problems. The `task` column in the dataset is used to determine which reward function to apply to each problem. If there is no relevant reward function for a sample in the dataset, the reward function will return `None` and the [`GRPOTrainer`] will continue with the valid functions and tasks. This allows the [`GRPOTrainer`] to handle multiple reward functions with different applicability.
|
||||
|
||||
Note that the [`GRPOTrainer`] will ignore the `None` rewards returned by the reward functions and only consider the rewards returned by the relevant functions. This ensures that the model is trained on the relevant tasks and ignores the tasks for which there is no relevant reward function.
|
||||
|
||||
|
||||
|
||||
#### Passing the reward function to the trainer
|
||||
|
||||
To use your custom reward function, pass it to the [`GRPOTrainer`] as follows:
|
||||
|
||||
```python
|
||||
from trl import GRPOTrainer
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
reward_funcs=reward_func,
|
||||
...,
|
||||
)
|
||||
```
|
||||
|
||||
If you have multiple reward functions, you can pass them as a list:
|
||||
|
||||
```python
|
||||
from trl import GRPOTrainer
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
reward_funcs=[reward_func1, reward_func2],
|
||||
...,
|
||||
)
|
||||
```
|
||||
and the reward will be computed as the sum of the rewards from each function, or the weighted sum if `reward_weights` is provided in the config.
|
||||
|
||||
Note that [`GRPOTrainer`] supports multiple reward functions of different types. See the parameters documentation for more details.
|
||||
|
||||
## GRPOTrainer
|
||||
|
||||
[[autodoc]] GRPOTrainer
|
||||
|
||||
## GRPOConfig
|
||||
|
||||
[[autodoc]] GRPOConfig
|
@ -18,7 +18,7 @@ When training RL models, optimizing solely for reward may lead to unexpected beh
|
||||
However, the RL model being optimized against the reward model may learn patterns that yield high reward but do not represent good language. This can result in extreme cases where the model generates texts with excessive exclamation marks or emojis to maximize the reward. In some worst-case scenarios, the model may generate patterns completely unrelated to natural language yet receive high rewards, similar to adversarial attacks.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/kl-example.png">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/kl-example.png">
|
||||
<p style="text-align: center;"> <b>Figure:</b> Samples without a KL penalty from <a href="https://huggingface.co/papers/1909.08593">https://huggingface.co/papers/1909.08593</a>. </p>
|
||||
</div>
|
||||
|
||||
|
72
docs/source/index.md
Normal file
72
docs/source/index.md
Normal file
@ -0,0 +1,72 @@
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png">
|
||||
</div>
|
||||
|
||||
# TRL - Transformer Reinforcement Learning
|
||||
|
||||
TRL is a full stack library where we provide a set of tools to train transformer language models with methods like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO), Reward Modeling, and more.
|
||||
The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers).
|
||||
|
||||
You can also explore TRL-related models, datasets, and demos in the [TRL Hugging Face organization](https://huggingface.co/trl-lib).
|
||||
|
||||
## Learn
|
||||
|
||||
Learn post-training with TRL and other libraries in 🤗 [smol course](https://github.com/huggingface/smol-course).
|
||||
|
||||
## Contents
|
||||
|
||||
The documentation is organized into the following sections:
|
||||
|
||||
- **Getting Started**: installation and quickstart guide.
|
||||
- **Conceptual Guides**: dataset formats, training FAQ, and understanding logs.
|
||||
- **How-to Guides**: reducing memory usage, speeding up training, distributing training, etc.
|
||||
- **Integrations**: DeepSpeed, Liger Kernel, PEFT, etc.
|
||||
- **Examples**: example overview, community tutorials, etc.
|
||||
- **API**: trainers, utils, etc.
|
||||
|
||||
## Blog posts
|
||||
|
||||
<div class="mt-10">
|
||||
<div class="w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5">
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/open-r1">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/open-r1/thumbnails.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on January 28, 2025</p>
|
||||
<p class="text-gray-700">Open-R1: a fully open reproduction of DeepSeek-R1</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo_vlm">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/dpo_vlm/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on July 10, 2024</p>
|
||||
<p class="text-gray-700">Preference Optimization for Vision Language Models with TRL</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/putting_rl_back_in_rlhf_with_rloo">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/putting_rl_back_in_rlhf_with_rloo/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on June 12, 2024</p>
|
||||
<p class="text-gray-700">Putting RL back in RLHF</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-ddpo">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/166_trl_ddpo/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on September 29, 2023</p>
|
||||
<p class="text-gray-700">Finetune Stable Diffusion Models with DDPO via TRL</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo-trl">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/157_dpo_trl/dpo_thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on August 8, 2023</p>
|
||||
<p class="text-gray-700">Fine-tune Llama 2 with DPO</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/stackllama">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/138_stackllama/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on April 5, 2023</p>
|
||||
<p class="text-gray-700">StackLLaMA: A hands-on guide to train LLaMA with RLHF</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-peft">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/133_trl_peft/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on March 9, 2023</p>
|
||||
<p class="text-gray-700">Fine-tuning 20B LLMs with RLHF on a 24GB consumer GPU</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/rlhf">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/120_rlhf/thumbnail.png" alt="thumbnail" class="mt-0">
|
||||
<p class="text-gray-500 text-sm">Published on December 9, 2022</p>
|
||||
<p class="text-gray-700">Illustrating Reinforcement Learning from Human Feedback</p>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
@ -1,65 +0,0 @@
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_banner_dark.png">
|
||||
</div>
|
||||
|
||||
# TRL - Transformer Reinforcement Learning
|
||||
|
||||
TRL is a full stack library where we provide a set of tools to train transformer language models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step.
|
||||
The library is integrated with 🤗 [transformers](https://github.com/huggingface/transformers).
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/TRL-readme.png">
|
||||
</div>
|
||||
|
||||
Check the appropriate sections of the documentation depending on your needs:
|
||||
|
||||
## API documentation
|
||||
|
||||
- [Model Classes](models): *A brief overview of what each public model class does.*
|
||||
- [`SFTTrainer`](sft_trainer): *Supervise Fine-tune your model easily with `SFTTrainer`*
|
||||
- [`RewardTrainer`](reward_trainer): *Train easily your reward model using `RewardTrainer`.*
|
||||
- [`PPOTrainer`](ppo_trainer): *Further fine-tune the supervised fine-tuned model using PPO algorithm*
|
||||
- [Best-of-N Sampling](best-of-n): *Use best of n sampling as an alternative way to sample predictions from your active model*
|
||||
- [`DPOTrainer`](dpo_trainer): *Direct Preference Optimization training using `DPOTrainer`.*
|
||||
- [`TextEnvironment`](text_environments): *Text environment to train your model using tools with RL.*
|
||||
|
||||
## Examples
|
||||
|
||||
- [Sentiment Tuning](sentiment_tuning): *Fine tune your model to generate positive movie contents*
|
||||
- [Training with PEFT](lora_tuning_peft): *Memory efficient RLHF training using adapters with PEFT*
|
||||
- [Detoxifying LLMs](detoxifying_a_lm): *Detoxify your language model through RLHF*
|
||||
- [StackLlama](using_llama_models): *End-to-end RLHF training of a Llama model on Stack exchange dataset*
|
||||
- [Learning with Tools](learning_tools): *Walkthrough of using `TextEnvironments`*
|
||||
- [Multi-Adapter Training](multi_adapter_rl): *Use a single base model and multiple adapters for memory efficient end-to-end training*
|
||||
|
||||
|
||||
## Blog posts
|
||||
|
||||
<div class="mt-10">
|
||||
<div class="w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5">
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo_vlm">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/dpo_vlm/thumbnail.png" alt="thumbnail">
|
||||
<p class="text-gray-700">Preference Optimization for Vision Language Models with TRL</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/rlhf">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/120_rlhf/thumbnail.png" alt="thumbnail">
|
||||
<p class="text-gray-700">Illustrating Reinforcement Learning from Human Feedback</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-peft">
|
||||
<img src="https://github.com/huggingface/blog/blob/main/assets/133_trl_peft/thumbnail.png?raw=true" alt="thumbnail">
|
||||
<p class="text-gray-700">Fine-tuning 20B LLMs with RLHF on a 24GB consumer GPU</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/stackllama">
|
||||
<img src="https://github.com/huggingface/blog/blob/main/assets/138_stackllama/thumbnail.png?raw=true" alt="thumbnail">
|
||||
<p class="text-gray-700">StackLLaMA: A hands-on guide to train LLaMA with RLHF</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo-trl">
|
||||
<img src="https://github.com/huggingface/blog/blob/main/assets/157_dpo_trl/dpo_thumbnail.png?raw=true" alt="thumbnail">
|
||||
<p class="text-gray-700">Fine-tune Llama 2 with DPO</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-ddpo">
|
||||
<img src="https://github.com/huggingface/blog/blob/main/assets/166_trl_ddpo/thumbnail.png?raw=true" alt="thumbnail">
|
||||
<p class="text-gray-700">Finetune Stable Diffusion Models with DDPO via TRL</p>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
39
docs/source/installation.md
Normal file
39
docs/source/installation.md
Normal file
@ -0,0 +1,39 @@
|
||||
# Installation
|
||||
You can install TRL either from PyPI or from source:
|
||||
|
||||
## PyPI
|
||||
Install the library with pip or [uv](https://docs.astral.sh/uv/):
|
||||
|
||||
<hfoptions id="install">
|
||||
<hfoption id="uv">
|
||||
|
||||
uv is a fast Rust-based Python package and project manager. Refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions), .
|
||||
|
||||
```bash
|
||||
uv pip install trl
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="pip">
|
||||
|
||||
```bash
|
||||
pip install trl
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Source
|
||||
You can also install the latest version from source. First clone the repo and then run the installation with `pip`:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
cd trl/
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
If you want the development install you can replace the pip install with the following:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
```
|
@ -1,24 +0,0 @@
|
||||
# Installation
|
||||
You can install TRL either from pypi or from source:
|
||||
|
||||
## pypi
|
||||
Install the library with pip:
|
||||
|
||||
```bash
|
||||
pip install trl
|
||||
```
|
||||
|
||||
### Source
|
||||
You can also install the latest version from source. First clone the repo and then run the installation with `pip`:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
cd trl/
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
If you want the development install you can replace the pip install with the following:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
```
|
139
docs/source/iterative_sft_trainer.md
Normal file
139
docs/source/iterative_sft_trainer.md
Normal file
@ -0,0 +1,139 @@
|
||||
# Iterative Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=iterative-sft,trl)
|
||||
|
||||
Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.
|
||||
|
||||
## Quickstart
|
||||
|
||||
To get started quickly, you can either pass a model identifier or a pre-instantiated model to the trainer:
|
||||
|
||||
```python
|
||||
from trl import IterativeSFTConfig, IterativeSFTTrainer
|
||||
|
||||
# Using a model identifier
|
||||
trainer = IterativeSFTTrainer(
|
||||
"facebook/opt-350m",
|
||||
args=IterativeSFTConfig(
|
||||
max_length=512,
|
||||
output_dir="./output",
|
||||
),
|
||||
)
|
||||
|
||||
# Or using a pre-instantiated model
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
|
||||
trainer = IterativeSFTTrainer(
|
||||
model,
|
||||
args=IterativeSFTConfig(
|
||||
max_length=512,
|
||||
output_dir="./output",
|
||||
),
|
||||
processing_class=tokenizer,
|
||||
)
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
The [`IterativeSFTTrainer`] supports two ways of providing input data to the `step` function:
|
||||
|
||||
### Using a list of tensors as input:
|
||||
|
||||
```python
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
trainer.step(**inputs)
|
||||
```
|
||||
|
||||
### Using a list of strings as input:
|
||||
|
||||
```python
|
||||
inputs = {
|
||||
"texts": texts,
|
||||
"texts_labels": texts_labels, # Optional, defaults to texts
|
||||
}
|
||||
|
||||
trainer.step(**inputs)
|
||||
```
|
||||
|
||||
For causal language models, labels will automatically be created from `input_ids` or from `texts`. When using sequence to sequence models you will have to provide your own labels or `text_labels`.
|
||||
|
||||
## Configuration
|
||||
|
||||
The [`IterativeSFTConfig`] class provides several parameters to customize the training:
|
||||
|
||||
```python
|
||||
from trl import IterativeSFTConfig
|
||||
|
||||
config = IterativeSFTConfig(
|
||||
# Model initialization parameters
|
||||
model_init_kwargs={"torch_dtype": "bfloat16"},
|
||||
|
||||
# Data preprocessing parameters
|
||||
max_length=512,
|
||||
truncation_mode="keep_end",
|
||||
|
||||
# Training parameters
|
||||
output_dir="./output",
|
||||
learning_rate=2e-5,
|
||||
per_device_train_batch_size=4,
|
||||
gradient_accumulation_steps=4,
|
||||
max_steps=1000,
|
||||
logging_steps=10,
|
||||
save_steps=100,
|
||||
optim="adamw_torch",
|
||||
report_to="wandb",
|
||||
)
|
||||
```
|
||||
|
||||
### Model Initialization
|
||||
|
||||
You can control how the model is initialized by passing keyword arguments to `model_init_kwargs`:
|
||||
|
||||
```python
|
||||
config = IterativeSFTConfig(
|
||||
model_init_kwargs={
|
||||
"torch_dtype": "bfloat16",
|
||||
"device_map": "auto",
|
||||
"trust_remote_code": True,
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### Data Preprocessing
|
||||
|
||||
The trainer supports two truncation modes:
|
||||
|
||||
- `keep_end`: Truncates from the start of the sequence
|
||||
- `keep_start`: Truncates from the end of the sequence
|
||||
|
||||
```python
|
||||
config = IterativeSFTConfig(
|
||||
max_length=512,
|
||||
truncation_mode="keep_end", # or "keep_start"
|
||||
)
|
||||
```
|
||||
|
||||
### Training Optimization
|
||||
|
||||
You can optimize CUDA cache usage for more memory-efficient training:
|
||||
|
||||
```python
|
||||
config = IterativeSFTConfig(
|
||||
optimize_device_cache=True,
|
||||
)
|
||||
```
|
||||
|
||||
## IterativeSFTTrainer
|
||||
|
||||
[[autodoc]] IterativeSFTTrainer
|
||||
|
||||
## IterativeSFTConfig
|
||||
|
||||
[[autodoc]] IterativeSFTConfig
|
@ -1,57 +0,0 @@
|
||||
# Iterative Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=iterative-sft,trl)
|
||||
|
||||
|
||||
Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.
|
||||
|
||||
## Usage
|
||||
|
||||
To get started quickly, instantiate an instance a model, and a tokenizer.
|
||||
|
||||
```python
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
trainer = IterativeSFTTrainer(
|
||||
model,
|
||||
tokenizer
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
You have the choice to either provide a list of strings or a list of tensors to the step function.
|
||||
|
||||
#### Using a list of tensors as input:
|
||||
|
||||
```python
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask
|
||||
}
|
||||
|
||||
trainer.step(**inputs)
|
||||
|
||||
```
|
||||
|
||||
#### Using a list of strings as input:
|
||||
|
||||
```python
|
||||
|
||||
inputs = {
|
||||
"texts": texts
|
||||
}
|
||||
|
||||
trainer.step(**inputs)
|
||||
|
||||
```
|
||||
|
||||
For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels.
|
||||
|
||||
## IterativeTrainer
|
||||
|
||||
[[autodoc]] IterativeSFTTrainer
|
@ -1,11 +1,17 @@
|
||||
# Judges
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
TRL Judges is an experimental API which is subject to change at any time.
|
||||
|
||||
</Tip>
|
||||
|
||||
TRL provides judges to easily compare two completions.
|
||||
|
||||
Make sure to have installed the required dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install trl[llm_judge]
|
||||
pip install trl[judges]
|
||||
```
|
||||
|
||||
## Using the provided judges
|
||||
@ -46,34 +52,38 @@ judge.judge(
|
||||
) # Outputs: [0, 1]
|
||||
```
|
||||
|
||||
## BaseJudge
|
||||
## Provided judges
|
||||
|
||||
[[autodoc]] BaseJudge
|
||||
|
||||
## BaseRankJudge
|
||||
|
||||
[[autodoc]] BaseRankJudge
|
||||
|
||||
## BasePairwiseJudge
|
||||
|
||||
[[autodoc]] BasePairwiseJudge
|
||||
|
||||
## RandomRankJudge
|
||||
|
||||
[[autodoc]] RandomRankJudge
|
||||
|
||||
## RandomPairwiseJudge
|
||||
|
||||
[[autodoc]] RandomPairwiseJudge
|
||||
|
||||
## PairRMJudge
|
||||
### PairRMJudge
|
||||
|
||||
[[autodoc]] PairRMJudge
|
||||
|
||||
## HfPairwiseJudge
|
||||
### HfPairwiseJudge
|
||||
|
||||
[[autodoc]] HfPairwiseJudge
|
||||
|
||||
## OpenAIPairwiseJudge
|
||||
### OpenAIPairwiseJudge
|
||||
|
||||
[[autodoc]] OpenAIPairwiseJudge
|
||||
|
||||
### AllTrueJudge
|
||||
|
||||
[[autodoc]] AllTrueJudge
|
||||
|
||||
## Base classes
|
||||
|
||||
### BaseJudge
|
||||
|
||||
[[autodoc]] BaseJudge
|
||||
|
||||
### BaseBinaryJudge
|
||||
|
||||
[[autodoc]] BaseBinaryJudge
|
||||
|
||||
### BaseRankJudge
|
||||
|
||||
[[autodoc]] BaseRankJudge
|
||||
|
||||
### BasePairwiseJudge
|
||||
|
||||
[[autodoc]] BasePairwiseJudge
|
@ -51,11 +51,11 @@ accelerate launch train_kto.py
|
||||
|
||||
Distributed across 8 x H100 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
|
||||
|
||||

|
||||

|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-KTO
|
||||
<pre><code>$ transformers-cli chat --model_name_or_path trl-lib/Qwen2-0.5B-KTO
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
@ -80,12 +80,12 @@ In theory, the dataset should contain at least one chosen and one rejected compl
|
||||
|
||||
## Example script
|
||||
|
||||
We provide an example script to train a model using the KTO method. The script is available in [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py)
|
||||
We provide an example script to train a model using the KTO method. The script is available in [`trl/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/kto.py)
|
||||
|
||||
To test the KTO script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/kto-mix-14k), run the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch examples/scripts/kto.py \
|
||||
accelerate launch trl/scripts/kto.py \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
||||
--dataset_name trl-lib/kto-mix-14k \
|
||||
--num_train_epochs 1 \
|
||||
@ -115,20 +115,20 @@ Each choice of `beta` has a maximum learning rate it can tolerate before learnin
|
||||
### Imbalanced data
|
||||
|
||||
The `desirable_weight` and `undesirable_weight` of the [`KTOConfig`] refer to the weights placed on the losses for desirable/positive and undesirable/negative examples.
|
||||
By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` \\(\times\\) number of positives) to (`undesirable_weight` \\(\times\\) number of negatives) is in the range 1:1 to 4:3.
|
||||
By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` \\(\times\\) number of positives) to (`undesirable_weight` \\(\times\\) number of negatives) is in the range 1:1 to 4:3.
|
||||
|
||||
## Logged metrics
|
||||
|
||||
While training and evaluating we record the following reward metrics:
|
||||
|
||||
- `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
|
||||
- `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
|
||||
- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
|
||||
- `logps/chosen`: the mean log probabilities of the chosen completions
|
||||
- `logps/rejected`: the mean log probabilities of the rejected completions
|
||||
- `logits/chosen`: the mean logits of the chosen completions
|
||||
- `logits/rejected`: the mean logits of the rejected completions
|
||||
- `kl`: the KL divergence between the policy model and the reference model
|
||||
- `rewards/chosen_sum`: the sum of log probabilities of the policy model for the chosen responses scaled by beta
|
||||
- `rewards/rejected_sum`: the sum of log probabilities of the policy model for the rejected responses scaled by beta
|
||||
- `logps/chosen_sum`: the sum of log probabilities of the chosen completions
|
||||
- `logps/rejected_sum`: the sum of log probabilities of the rejected completions
|
||||
- `logits/chosen_sum`: the sum of logits of the chosen completions
|
||||
- `logits/rejected_sum`: the sum of logits of the rejected completions
|
||||
- `count/chosen`: the count of chosen samples in a batch
|
||||
- `count/rejected`: the count of rejected samples in a batch
|
||||
|
||||
## KTOTrainer
|
||||
|
@ -1,233 +0,0 @@
|
||||
# Learning Tools (Experimental 🧪)
|
||||
|
||||
Using Large Language Models (LLMs) with tools has been a popular topic recently with awesome works such as [ToolFormer](https://huggingface.co/papers/2302.04761) and [ToolBench](https://huggingface.co/papers/2305.16504). In TRL, we provide a simple example of how to teach LLM to use tools with reinforcement learning.
|
||||
|
||||
|
||||
Here's an overview of the scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples/research_projects/tools):
|
||||
|
||||
| File | Description |
|
||||
|---|---|
|
||||
| [`calculator.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/calculator.py) | Script to train LLM to use a calculator with reinforcement learning. |
|
||||
| [`triviaqa.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/triviaqa.py) | Script to train LLM to use a wiki tool to answer questions. |
|
||||
| [`python_interpreter.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/python_interpreter.py) | Script to train LLM to use python interpreter to solve math puzzles. |
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Note that the scripts above rely heavily on the `TextEnvironment` API which is still under active development. The API may change in the future. Please see [`TextEnvironment`](text_environment) for the related docs.
|
||||
</Tip>
|
||||
|
||||
|
||||
## Learning to Use a Calculator
|
||||
|
||||
|
||||
The rough idea is as follows:
|
||||
|
||||
1. Load a tool such as [ybelkada/simple-calculator](https://huggingface.co/spaces/ybelkada/simple-calculator) that parse a text calculation like `"14 + 34"` and return the calulated number:
|
||||
```python
|
||||
from transformers import AutoTokenizer, load_tool
|
||||
tool = load_tool("ybelkada/simple-calculator")
|
||||
tool_fn = lambda text: str(round(float(tool(text)), 2)) # rounding to 2 decimal places
|
||||
```
|
||||
1. Define a reward function that returns a positive reward if the tool returns the correct answer. In the script we create a dummy reward function like `reward_fn = lambda x: 1`, but we override the rewards directly later.
|
||||
1. Create a prompt on how to use the tools
|
||||
```python
|
||||
# system prompt
|
||||
prompt = """\
|
||||
What is 13.1-3?
|
||||
|
||||
<request><SimpleCalculatorTool>13.1-3<call>10.1<response>
|
||||
|
||||
Result=10.1<submit>
|
||||
|
||||
What is 4*3?
|
||||
|
||||
<request><SimpleCalculatorTool>4*3<call>12<response>
|
||||
|
||||
Result=12<submit>
|
||||
|
||||
What is 12.1+1?
|
||||
|
||||
<request><SimpleCalculatorTool>12.1+1<call>13.1<response>
|
||||
|
||||
Result=13.1<submit>
|
||||
|
||||
What is 12.1-20?
|
||||
|
||||
<request><SimpleCalculatorTool>12.1-20<call>-7.9<response>
|
||||
|
||||
Result=-7.9<submit>"""
|
||||
```
|
||||
3. Create a `trl.TextEnvironment` with the model
|
||||
```python
|
||||
env = TextEnvironment(
|
||||
model,
|
||||
tokenizer,
|
||||
{"SimpleCalculatorTool": tool_fn},
|
||||
reward_fn,
|
||||
prompt,
|
||||
generation_kwargs=generation_kwargs,
|
||||
)
|
||||
```
|
||||
4. Then generate some data such as `tasks = ["\n\nWhat is 13.1-3?", "\n\nWhat is 4*3?"]` and run the environment with `queries, responses, masks, rewards, histories = env.run(tasks)`. The environment will look for the `<call>` token in the prompt and append the tool output to the response; it will also return the mask associated with the response. You can further use the `histories` to visualize the interaction between the model and the tool; `histories[0].show_text()` will show the text with color-coded tool output and `histories[0].show_tokens(tokenizer)` will show visualize the tokens.
|
||||

|
||||
1. Finally, we can train the model with `train_stats = ppo_trainer.step(queries, responses, rewards, masks)`. The trainer will use the mask to ignore the tool output when computing the loss, make sure to pass that argument to `step`.
|
||||
|
||||
## Experiment results
|
||||
|
||||
We trained a model with the above script for 10 random seeds. You can reproduce the run with the following command. Feel free to remove the `--slurm-*` arguments if you don't have access to a slurm cluster.
|
||||
|
||||
```
|
||||
WANDB_TAGS="calculator_final" python benchmark/benchmark.py \
|
||||
--command "python examples/research_projects/tools/calculator.py" \
|
||||
--num-seeds 10 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 8 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
```
|
||||
|
||||
We can then use [`openrlbenchmark`](https://github.com/openrlbenchmark/openrlbenchmark) which generates the following plot.
|
||||
```
|
||||
# pip install openrlbenchmark==0.2.1a5
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=openrlbenchmark&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.tracker_project_name&cen=trl_ppo_trainer_config.value.log_with&metrics=env/reward_mean&metrics=objective/kl' \
|
||||
'wandb?tag=calculator_final&cl=calculator_mask' \
|
||||
--env-ids trl \
|
||||
--check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename static/0compare \
|
||||
--scan-history
|
||||
```
|
||||
|
||||

|
||||
|
||||
As we can see, while 1-2 experiments crashed for some reason, most of the runs obtained near perfect proficiency in the calculator task.
|
||||
|
||||
|
||||
## (Early Experiments 🧪): learning to use a wiki tool for question answering
|
||||
|
||||
In the [ToolFormer](https://huggingface.co/papers/2302.04761) paper, it shows an interesting use case that utilizes a Wikipedia Search tool to help answer questions. In this section, we attempt to perform similar experiments but uses RL instead to teach the model to use a wiki tool on the [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) dataset.
|
||||
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
**Note that many settings are different so the results are not directly comparable.**
|
||||
</Tip>
|
||||
|
||||
|
||||
|
||||
|
||||
### Building a search index
|
||||
|
||||
Since [ToolFormer](https://huggingface.co/papers/2302.04761) did not open source, we needed to first replicate the search index. It is mentioned in their paper that the authors built the search index using a BM25 retriever that indexes the Wikipedia dump from [KILT](https://github.com/facebookresearch/KILT)
|
||||
|
||||
Fortunately, [`pyserini`](https://github.com/castorini/pyserini) already implements the BM25 retriever and provides a prebuilt index for the KILT Wikipedia dump. We can use the following code to search the index.
|
||||
|
||||
```python
|
||||
from pyserini.search.lucene import LuceneSearcher
|
||||
import json
|
||||
searcher = LuceneSearcher.from_prebuilt_index('wikipedia-kilt-doc')
|
||||
def search(query):
|
||||
hits = searcher.search(query, k=1)
|
||||
hit = hits[0]
|
||||
contents = json.loads(hit.raw)['contents']
|
||||
return contents
|
||||
print(search("tennis racket"))
|
||||
```
|
||||
```
|
||||
Racket (sports equipment)
|
||||
A racket or racquet is a sports implement consisting of a handled frame with an open hoop across which a network of strings or catgut is stretched tightly. It is used for striking a ball or shuttlecock in games such as squash, tennis, racquetball, and badminton. Collectively, these games are known as racket sports. Racket design and manufacturing has changed considerably over the centuries.
|
||||
|
||||
The frame of rackets for all sports was traditionally made of solid wood (later laminated wood) and the strings of animal intestine known as catgut. The traditional racket size was limited by the strength and weight of the wooden frame which had to be strong enough to hold the strings and stiff enough to hit the ball or shuttle. Manufacturers started adding non-wood laminates to wood rackets to improve stiffness. Non-wood rackets were made first of steel, then of aluminum, and then carbon fiber composites. Wood is still used for real tennis, rackets, and xare. Most rackets are now made of composite materials including carbon fiber or fiberglass, metals such as titanium alloys, or ceramics.
|
||||
...
|
||||
```
|
||||
|
||||
We then basically deployed this snippet as a Hugging Face space [here](https://huggingface.co/spaces/vwxyzjn/pyserini-wikipedia-kilt-doc), so that we can use the space as a `transformers.Tool` later.
|
||||
|
||||

|
||||
|
||||
### Experiment settings
|
||||
|
||||
We use the following settings:
|
||||
|
||||
* use the `bigcode/starcoderbase` model as the base model
|
||||
* use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragrahs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool.
|
||||
* test if the response contain the answer string, if so, give a reward of 1, otherwise, give a reward of 0.
|
||||
* notice this is a simplified evaluation criteria. In [ToolFormer](https://huggingface.co/papers/2302.04761), the authors checks if the first 20 words of the response contain the correct answer.
|
||||
* used the following prompt that demonstrates the usage of the wiki tool.
|
||||
```python
|
||||
prompt = """\
|
||||
Answer the following question:
|
||||
|
||||
Q: In which branch of the arts is Patricia Neary famous?
|
||||
A: Ballets
|
||||
A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
|
||||
Result=Ballets<submit>
|
||||
|
||||
Q: Who won Super Bowl XX?
|
||||
A: Chicago Bears
|
||||
A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
|
||||
Result=Chicago Bears<submit>
|
||||
|
||||
Q: """
|
||||
```
|
||||
|
||||
|
||||
### Result and Discussion
|
||||
|
||||
|
||||
Our experiments show that the agent can learn to use the wiki tool to answer questions. The learning curves would go up mostly, but one of the experiment did crash.
|
||||
|
||||

|
||||
|
||||
Wandb report is [here](https://wandb.ai/costa-huang/cleanRL/reports/TriviaQA-Final-Experiments--Vmlldzo1MjY0ODk5) for further inspection.
|
||||
|
||||
|
||||
Note that the correct rate of the trained model is on the low end, which could be due to the following reasons:
|
||||
|
||||
* **incorrect searches:** When given the question `"What is Bruce Willis' real first name?"` if the model searches for `Bruce Willis`, our wiki tool returns "Patrick Poivey (born 18 February 1948) is a French actor. He is especially known for his voice: he is the French dub voice of Bruce Willis since 1988.` But a correct search should be `Walter Bruce Willis (born March 19, 1955) is an American former actor. He achieved fame with a leading role on the comedy-drama series Moonlighting (1985–1989) and appeared in over a hundred films, gaining recognition as an action hero after his portrayal of John McClane in the Die Hard franchise (1988–2013) and other roles.[1][2]"
|
||||
|
||||
|
||||

|
||||
|
||||
* **unnecessarily long response**: The wiki tool by default sometimes output very long sequences. E.g., when the wiki tool searches for "Brown Act"
|
||||
* Our wiki tool returns "The Ralph M. Brown Act, located at California Government Code 54950 "et seq.", is an act of the California State Legislature, authored by Assemblymember Ralph M. Brown and passed in 1953, that guarantees the public's right to attend and participate in meetings of local legislative bodies."
|
||||
* [ToolFormer](https://huggingface.co/papers/2302.04761)'s wiki tool returns "The Ralph M. Brown Act is an act of the California State Legislature that guarantees the public's right to attend and participate in meetings of local legislative bodies." which is more succinct.
|
||||
|
||||

|
||||
|
||||
|
||||
## (Early Experiments 🧪): solving math puzzles with python interpreter
|
||||
|
||||
In this section, we attempt to teach the model to use a python interpreter to solve math puzzles. The rough idea is to give the agent a prompt like the following:
|
||||
|
||||
```python
|
||||
prompt = """\
|
||||
Example of using a Python API to solve math questions.
|
||||
|
||||
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
|
||||
|
||||
<request><PythonInterpreter>
|
||||
def solution():
|
||||
money_initial = 23
|
||||
bagels = 5
|
||||
bagel_cost = 3
|
||||
money_spent = bagels * bagel_cost
|
||||
money_left = money_initial - money_spent
|
||||
result = money_left
|
||||
return result
|
||||
print(solution())
|
||||
<call>72<response>
|
||||
|
||||
Result = 72 <submit>
|
||||
|
||||
Q: """
|
||||
```
|
||||
|
||||
|
||||
Training experiment can be found at https://wandb.ai/lvwerra/trl-gsm8k/runs/a5odv01y
|
||||
|
||||

|
7
docs/source/liger_kernel_integration.md
Normal file
7
docs/source/liger_kernel_integration.md
Normal file
@ -0,0 +1,7 @@
|
||||
# Liger Kernel Integration
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Section under construction. Feel free to contribute!
|
||||
|
||||
</Tip>
|
@ -16,8 +16,8 @@ If you want to log with tensorboard, add the kwarg `project_kwargs={"logging_dir
|
||||
Here's a brief explanation for the logged metrics provided in the data:
|
||||
|
||||
Key metrics to monitor. We want to maximize the reward, maintain a low KL divergence, and maximize entropy:
|
||||
1. `env/reward_mean`: The average reward obtained from the environment. Alias `ppo/mean_scores`, which is sed to specifically monitor the reward model.
|
||||
1. `env/reward_std`: The standard deviation of the reward obtained from the environment. Alias ``ppo/std_scores`, which is sed to specifically monitor the reward model.
|
||||
1. `env/reward_mean`: The average reward obtained from the environment. Alias `ppo/mean_scores`, which is used to specifically monitor the reward model.
|
||||
1. `env/reward_std`: The standard deviation of the reward obtained from the environment. Alias ``ppo/std_scores`, which is used to specifically monitor the reward model.
|
||||
1. `env/reward_dist`: The histogram distribution of the reward obtained from the environment.
|
||||
1. `objective/kl`: The mean Kullback-Leibler (KL) divergence between the old and new policies. It measures how much the new policy deviates from the old policy. The KL divergence is used to compute the KL penalty in the objective function.
|
||||
1. `objective/kl_dist`: The histogram distribution of the `objective/kl`.
|
||||
@ -71,4 +71,4 @@ Here are some parameters that are useful to monitor for stability (when these di
|
||||
1. `ppo/policy/ratio`: `ratio` being 1 is a baseline value, meaning that the probability of sampling a token is the same under the new and old policy. If the ratio is too high like 200, it means the probability of sampling a token is 200 times higher under the new policy than the old policy. This is a sign that the new policy is too different from the old policy, which will likely cause overoptimization and collapse training later on.
|
||||
1. `ppo/policy/clipfrac` and `ppo/policy/approxkl`: if `ratio` is too high, the `ratio` is going to get clipped, resulting in high `clipfrac` and high `approxkl` as well.
|
||||
1. `objective/kl`: it should stay positive so that the policy is not too far away from the reference policy.
|
||||
1. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities.
|
||||
1. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities.
|
5
docs/source/model_utils.md
Normal file
5
docs/source/model_utils.md
Normal file
@ -0,0 +1,5 @@
|
||||
# Model Utilities
|
||||
|
||||
## get_act_offloading_ctx_manager
|
||||
|
||||
[[autodoc]] models.get_act_offloading_ctx_manager
|
@ -51,9 +51,9 @@ accelerate launch train_nash_md.py
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 3 hours.
|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-NashMD) performs, you can use the [TRL Chat CLI](clis#chat-interface).
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-NashMD) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-NashMD
|
||||
<pre><code>$ transformers-cli chat --model_name_or_path trl-lib/Qwen2-0.5B-NashMD
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
@ -111,7 +111,7 @@ trainer.add_callback(completions_callback)
|
||||
|
||||
This callback logs the model's generated completions directly to Weights & Biases.
|
||||
|
||||

|
||||

|
||||
|
||||
## Example script
|
||||
|
||||
|
@ -51,11 +51,11 @@ accelerate launch train_online_dpo.py
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 1 hour. You can verify the training progress by checking the reward graph. An increasing trend in both the reward for rejected and chosen completions indicates that the model is improving and generating better responses over time.
|
||||
|
||||

|
||||

|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-OnlineDPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-OnlineDPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-OnlineDPO
|
||||
<pre><code>$ transformers-cli chat --model_name_or_path trl-lib/Qwen2-0.5B-OnlineDPO
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
@ -110,7 +110,7 @@ trainer.add_callback(completions_callback)
|
||||
|
||||
This callback logs the model's generated completions directly to Weights & Biases.
|
||||
|
||||

|
||||

|
||||
|
||||
|
||||
## Example script
|
||||
@ -265,7 +265,7 @@ plt.tight_layout()
|
||||
plt.show()
|
||||
```
|
||||
|
||||

|
||||

|
||||
|
||||
The online DPO checkpoint gets increasingly more win rate as we scale up the model sizes. This is a good sign that the online DPO implementation is working as intended.
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# ORPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=orpo,trl)
|
||||
[](https://huggingface.co/models?other=orpo,trl) [](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment)
|
||||
|
||||
## Overview
|
||||
|
||||
@ -54,11 +54,11 @@ accelerate launch train_orpo.py
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
|
||||
|
||||

|
||||

|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-ORPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-ORPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-ORPO
|
||||
<pre><code>$ transformers-cli chat --model_name_or_path trl-lib/Qwen2-0.5B-ORPO
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
|
9
docs/source/others.md
Normal file
9
docs/source/others.md
Normal file
@ -0,0 +1,9 @@
|
||||
# Other
|
||||
|
||||
## profiling_decorator
|
||||
|
||||
[[autodoc]] extras.profiling.profiling_decorator
|
||||
|
||||
## profiling_context
|
||||
|
||||
[[autodoc]] extras.profiling.profiling_context
|
@ -118,7 +118,7 @@ The `trl` library also supports naive pipeline parallelism (NPP) for large model
|
||||
This paradigm, termed as "Naive Pipeline Parallelism" (NPP) is a simple way to parallelize the model across multiple GPUs. We load the model and the adapters across multiple GPUs and the activations and gradients will be naively communicated across the GPUs. This supports `int8` models as well as other `dtype` models.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-npp.png">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-npp.png">
|
||||
</div>
|
||||
|
||||
### How to use NPP?
|
||||
@ -140,5 +140,5 @@ python PATH_TO_SCRIPT
|
||||
You can easily fine-tune Llama2 model using `SFTTrainer` and the official script! For example to fine-tune llama2-7b on the Guanaco dataset, run (tested on a single NVIDIA T4-16GB):
|
||||
|
||||
```bash
|
||||
python examples/scripts/sft.py --output_dir sft_openassistant-guanaco --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --per_device_train_batch_size 4 --gradient_accumulation_steps 2
|
||||
python trl/scripts/sft.py --output_dir sft_openassistant-guanaco --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --per_device_train_batch_size 4 --gradient_accumulation_steps 2
|
||||
```
|
@ -26,6 +26,8 @@ python examples/scripts/ppo/ppo.py \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--total_episodes 10000 \
|
||||
--model_name_or_path EleutherAI/pythia-1b-deduped \
|
||||
--sft_model_path EleutherAI/pythia-1b-deduped \
|
||||
--reward_model_path EleutherAI/pythia-1b-deduped \
|
||||
--missing_eos_penalty 1.0
|
||||
```
|
||||
|
||||
@ -50,13 +52,13 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an
|
||||
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
|
||||
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
|
||||
* `lr`: lr: The current learning rate used by the optimizer.
|
||||
* `episode`: episode: The current global step or episode count in the training process.
|
||||
* `episode`: episode: The current episode count in the training process.
|
||||
|
||||
|
||||
## Cookbook
|
||||
|
||||
* Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
|
||||
* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it.
|
||||
* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try understand why this is happening and try to fix it.
|
||||
* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
|
||||
* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
|
||||
* Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions.
|
||||
@ -66,7 +68,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an
|
||||
|
||||
To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations.
|
||||
|
||||

|
||||

|
||||
|
||||
|
||||
In the logs the sampled generations look like
|
||||
@ -210,7 +212,7 @@ The PPO checkpoint gets a 64.7% preferred rate vs the 33.0% preference rate of t
|
||||
|
||||
Metrics:
|
||||
|
||||

|
||||

|
||||
|
||||
|
||||
```bash
|
||||
@ -234,4 +236,4 @@ python -m openrlbenchmark.rlops_multi_metrics \
|
||||
|
||||
## PPOConfig
|
||||
|
||||
[[autodoc]] PPOConfig
|
||||
[[autodoc]] PPOConfig
|
||||
|
125
docs/source/prm_trainer.md
Normal file
125
docs/source/prm_trainer.md
Normal file
@ -0,0 +1,125 @@
|
||||
# PRM Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=prm,trl)
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
PRM Trainer is an experimental API which is subject to change at any time.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Overview
|
||||
|
||||
Process-supervised Reward Models (PRM) were proposed in [Solving math word problems with process- and outcome-based feedback](https://huggingface.co/papers/2211.14275) by Jonathan Uesato, Nate Kushman, Ramana Kumar, Francis Song, Noah Siegel, Lisa Wang, Antonia Creswell, Geoffrey Irving, and Irina Higgins.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Recent work has shown that asking language models to generate reasoning steps improves performance on many reasoning tasks. When moving beyond prompting, this raises the question of how we should supervise such models: outcome-based approaches which supervise the final result, or process-based approaches which supervise the reasoning process itself? Differences between these approaches might naturally be expected not just in final-answer errors but also in reasoning errors, which can be difficult to detect and are problematic in many real-world domains such as education. We run the first comprehensive comparison between process- and outcome-based approaches trained on a natural language task, GSM8K. We find that pure outcome-based supervision produces similar final-answer error rates with less label supervision. However, for correct reasoning steps we find it necessary to use processbased supervision or supervision from learned reward models that emulate process-based feedback. In total, we improve the previous best results from 16.8% → 12.7% final-answer error and 14.0% → 3.4% reasoning error among final-answer-correct solutions.
|
||||
|
||||
This post-training method was contributed by [Gaetan Lopez](https://github.com/gaetanlop), [Lewis Tunstall](https://huggingface.co/lewtun), [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Agustín Piqueres](https://huggingface.co/plaguss).
|
||||
|
||||
|
||||
## Quick start
|
||||
|
||||
This example demonstrates how to train a model using the PRM method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) as the base model. We use the stepwise supervision data from the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd). You can view the data in the dataset here:
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/trl-lib/math_shepherd/embed/viewer/default/train?row=0"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
Below is the script to train the model:
|
||||
|
||||
```python
|
||||
# train_prm.py
|
||||
from datasets import load_dataset
|
||||
from trl import PRMConfig, PRMTrainer
|
||||
from transformers import AutoModelForTokenClassification, AutoTokenizer
|
||||
|
||||
model = AutoModelForTokenClassification.from_pretrained("Qwen/Qwen2-0.5B", num_labels=2)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
|
||||
train_dataset = load_dataset("trl-lib/math_shepherd", split="train[:10%]")
|
||||
|
||||
training_args = PRMConfig(output_dir="Qwen2-0.5B-Reward-Math-Sheperd", logging_steps=10)
|
||||
trainer = PRMTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Execute the script using the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_prm.py
|
||||
```
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 1 hour.
|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward-Math-Sheperd) performs, you can use the following script.
|
||||
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import pipeline
|
||||
|
||||
pipe = pipeline("token-classification", model="trl-lib/Qwen2-0.5B-Reward-Math-Sheperd")
|
||||
dataset = load_dataset("trl-lib/math_shepherd")
|
||||
example = {
|
||||
"prompt": "Musa is the class teacher of a class of 45 students. He wants to split them into three groups by age. If a third of the class is under 11 years, and two-fifths are above 11 but under 13, how many students will be in the third group (13 years and above)?",
|
||||
"completions": [
|
||||
"Step 1: A third of the class is under 11 years because 11 - 1/3 = <<11-1/3=7>>7.",
|
||||
"Step 2: Two-fifths of the class are above 11 but under 13 because 2/5 * 11 = <<2/5*11=8>>8.",
|
||||
"Step 3: There are 45 students, so the third group will have 45 - 7 - 8 = <<45-7-8=20>>20 students. The answer is: 20",
|
||||
],
|
||||
"labels": [True, False, False],
|
||||
}
|
||||
|
||||
|
||||
separator = "\n" # It's important to use the same separator as the one used during training
|
||||
|
||||
for idx in range(1, len(example["completions"]) + 1):
|
||||
steps = example["completions"][0:idx]
|
||||
text = separator.join((example["prompt"], *steps)) + separator # Add a separator between the prompt and each steps
|
||||
pred_entity = pipe(text)[-1]["entity"]
|
||||
pred = {"LABEL_0": False, "LABEL_1": True}[pred_entity]
|
||||
label = example["labels"][idx - 1]
|
||||
print(f"Step {idx}\tPredicted: {pred} \tLabel: {label}")
|
||||
```
|
||||
|
||||
```text
|
||||
Step 1 Predicted: True Label: True
|
||||
Step 2 Predicted: False Label: False
|
||||
Step 3 Predicted: False Label: False
|
||||
```
|
||||
|
||||
It's a win!
|
||||
|
||||
## Expected dataset type
|
||||
|
||||
PRM requires a [stepwise supervision](dataset_formats#stepwise-supervision).
|
||||
The dataset should contain the following columns: `prompt`, `completions` and `labels`, where `completions` contains a list of reasoning steps and `labels` a list of booleans or floats indicating the correctness of each step.
|
||||
|
||||
The [`PRMTrainer`] only supports [standard](dataset_formats#standard) dataset format.
|
||||
|
||||
## Example script
|
||||
|
||||
We provide an example script to train a model using the PRM method. The script is available in [`examples/scripts/prm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/prm.py)
|
||||
|
||||
To use the PRM script with the [Qwen2 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) on the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd), run the following command:
|
||||
|
||||
```bash
|
||||
accelerate launch examples/scripts/prm.py \
|
||||
--model_name_or_path Qwen/Qwen2-0.5B \
|
||||
--dataset_name trl-lib/math_shepherd \
|
||||
--num_train_epochs 1 \
|
||||
--logging_steps 25 \
|
||||
--output_dir Qwen2-0.5B-Reward-Math-Sheperd
|
||||
```
|
||||
|
||||
## PRMTrainer
|
||||
|
||||
[[autodoc]] PRMTrainer
|
||||
|
||||
## PRMConfig
|
||||
|
||||
[[autodoc]] PRMConfig
|
@ -9,7 +9,7 @@ Fine-tuning a language model via PPO consists of roughly three steps:
|
||||
3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.
|
||||
|
||||
The full process is illustrated in the following figure:
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png"/>
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_overview.png"/>
|
||||
|
||||
## Minimal example
|
||||
|
213
docs/source/reducing_memory_usage.md
Normal file
213
docs/source/reducing_memory_usage.md
Normal file
@ -0,0 +1,213 @@
|
||||
# Reducing Memory Usage
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Section under construction. Feel free to contribute!
|
||||
|
||||
</Tip>
|
||||
|
||||
## Truncation
|
||||
|
||||
Sequence lengths in the dataset can vary widely. When data is batched, sequences are padded to match the longest one in the batch, which can cause high memory usage, even if most sequences are relatively short.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/why_you_should_truncate.png" alt="Truncation prompt completion" width="600"/>
|
||||
</div>
|
||||
|
||||
To reduce memory usage, it's important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case.
|
||||
|
||||
<hfoptions id="dpo">
|
||||
<hfoption id="DPO">
|
||||
|
||||
DPO truncation is applied first to the prompt and to the completion via the `max_prompt_length` and `max_completion_length` parameters. The `max_length` parameter is then used to truncate the resulting sequence.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/truncation_prompt_completion.png" alt="Truncation prompt completion" width="600"/>
|
||||
</div>
|
||||
|
||||
To set the truncation parameters, use the following code snippet:
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
training_args = DPOConfig(..., max_prompt_length=..., max_length=...)
|
||||
```
|
||||
|
||||
You can also use the `max_completion_length` parameter to truncate the completion, though this is less common since the goal is typically to preserve the completion's full length whenever possible.
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
training_args = DPOConfig(..., max_completion_length=...)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="SFT">
|
||||
|
||||
SFT truncation is applied to the input sequence via the `max_length` parameter.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/truncation_input_ids.png" alt="Truncation input ids" width="600"/>
|
||||
</div>
|
||||
|
||||
To set the truncation parameter, use the following code snippet:
|
||||
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(..., max_length=...)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Packing
|
||||
|
||||
<Tip>
|
||||
|
||||
This technique applies only to SFT.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
[Truncation](#truncation) has several drawbacks:
|
||||
1. **Loss of information**: Key data at the end of a sequence may be discarded.
|
||||
2. **Choosing truncation length**: Too short loses data; too long undermines efficiency.
|
||||
|
||||
Packing, introduced in [Raffel et al., 2020](https://huggingface.co/papers/1910.10683), addresses these issues by grouping sequences instead of truncating. It concatenates and splits dataset sequences into the desired lengths.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/packing.png" alt="Packing" width="600"/>
|
||||
</div>
|
||||
|
||||
Packing eliminates padding, preserves all sequence information, and allows for flexible sequence lengths, making it a more efficient alternative to truncation. To enable packing, use `packing=True` in the [`SFTConfig`]:
|
||||
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(..., packing=True, max_length=512)
|
||||
```
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Packing may cause batch contamination, where adjacent sequences influence one another. This can be problematic for some applications. For more details, see [#1230](https://github.com/huggingface/trl/issues/1230).
|
||||
|
||||
</Tip>
|
||||
|
||||
## Padding-free
|
||||
|
||||
Padding-free batching is an alternative approach for reducing memory usage. In this method, a batch is first sampled and then flattened into a single sequence, avoiding padding. Unlike packing, which can result in incomplete sequences by combining parts of different samples, padding-free batching ensures that all sequences remain complete and intact.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/padding-free.png" alt="Padding-free batching" width="600"/>
|
||||
</div>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
It's highly recommended to use padding-free batching with **Flash Attention 2**. Otherwise, you may encounter batch contamination issues.
|
||||
|
||||
</Tip>
|
||||
|
||||
<hfoptions id="padding-free">
|
||||
<hfoption id="DPO">
|
||||
|
||||
```python
|
||||
from trl import DPOConfig
|
||||
|
||||
training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"})
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="SFT">
|
||||
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"})
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Activation offloading
|
||||
|
||||
Activation offloading is a memory efficiency technique that reduces GPU VRAM usage by temporarily moving activation tensors to CPU RAM during the forward pass and bringing them back only when needed for the backward pass. This significantly reduces peak memory usage at the cost of slightly increased training time.
|
||||
|
||||
To enable activation offloading in your SFT training configuration:
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="SFT">
|
||||
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(..., activation_offloading=True)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
When using activation offloading with models that use Liger kernels, you must disable Liger cross entropy due to compatibility issues. The issue occurs specifically with `use_liger_kernel=True` because Liger cross entropy performs in-place operations which conflict with activation offloading. The default setting (`use_liger_kernel=False`) works:
|
||||
|
||||
```python
|
||||
# When using activation offloading with a model that uses Liger kernels:
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(
|
||||
activation_offloading=True,
|
||||
use_liger_kernel=False, # Disable Liger cross entropy
|
||||
# Other parameters...
|
||||
)
|
||||
```
|
||||
</Tip>
|
||||
|
||||
Under the hood, activation offloading implements PyTorch's [`saved_tensors_hooks`](https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html#hooks-for-autograd-saved-tensors) to intercept activations during the forward pass. It intelligently manages which tensors to offload based on size and context, avoiding offloading output tensors which would be inefficient. For performance optimization, it can optionally use CUDA streams to overlap computation with CPU-GPU transfers.
|
||||
|
||||
## Disabling model gathering for generation in online methods
|
||||
|
||||
When using DeepSpeed ZeRO-3, model weights are sharded across multiple GPUs. Online methods involve generating completions from the model as part of the training process. During this step, the model weights are temporarily gathered on a single GPU for generation. For very large models, this gathering can lead to out-of-memory (OOM) errors, as described in this issue: [#2250](https://github.com/huggingface/trl/issues/2250#issue-2598304204).
|
||||
|
||||
If you encounter this issue, you can disable the gathering of model weights for generation by setting the following parameter:
|
||||
|
||||
<hfoptions id="ds3_gather_for_generation">
|
||||
<hfoption id="GRPO">
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(..., ds3_gather_for_generation=False)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Online DPO">
|
||||
|
||||
```python
|
||||
from trl import OnlineDPOConfig
|
||||
|
||||
training_args = OnlineDPOConfig(..., ds3_gather_for_generation=False)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="PPO">
|
||||
|
||||
```python
|
||||
from trl import PPOConfig
|
||||
|
||||
training_args = PPOConfig(..., ds3_gather_for_generation=False)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="RLOO">
|
||||
|
||||
```python
|
||||
from trl import RLOOConfig
|
||||
|
||||
training_args = RLOOConfig(..., ds3_gather_for_generation=False)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
This adjustment prevents model weights from being gathered, avoiding OOM errors, but it may result in slower generation speeds.
|
9
docs/source/rewards.md
Normal file
9
docs/source/rewards.md
Normal file
@ -0,0 +1,9 @@
|
||||
# Reward Functions
|
||||
|
||||
This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`].
|
||||
|
||||
## Format rewards
|
||||
|
||||
### think_format_reward
|
||||
|
||||
[[autodoc]] rewards.think_format_reward
|
@ -2,7 +2,7 @@
|
||||
|
||||
[](https://huggingface.co/models?other=rloo,trl)
|
||||
|
||||
TRL supports training LLMs with REINFORCE Leave-One-Out (RLOO). The idea is that instead of using a value function, RLOO generates K completions for each prompt. For each completion, RLOO uses the mean scores from the other K-1 completions as a baseline to calculate the advantage. RLOO also models the entire completion as a single action, where as PPO models each token as an action. Note that REINFORCE / A2C is a special case of PPO, when the number of PPO epochs is 1 and the number of mini-batches is 1, which is how we implement RLOO in TRL.
|
||||
TRL supports training LLMs with REINFORCE Leave-One-Out (RLOO). The idea is that instead of using a value function, RLOO generates K completions for each prompt. For each completion, RLOO uses the mean scores from the other K-1 completions as a baseline to calculate the advantage. RLOO also models the entire completion as a single action, whereas PPO models each token as an action. Note that REINFORCE / A2C is a special case of PPO, when the number of PPO epochs is 1 and the number of mini-batches is 1, which is how we implement RLOO in TRL.
|
||||
|
||||
References:
|
||||
- [Back to Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs](https://huggingface.co/papers/2402.14740)
|
||||
@ -58,7 +58,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an
|
||||
## Cookbook
|
||||
|
||||
* Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
|
||||
* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it.
|
||||
* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try understand why this is happening and try to fix it.
|
||||
* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
|
||||
* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
|
||||
* Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions.
|
||||
@ -68,7 +68,7 @@ The logged metrics are as follows. Here is an example [tracked run at Weights an
|
||||
|
||||
To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/u2sqci34), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations.
|
||||
|
||||

|
||||

|
||||
|
||||
|
||||
In the logs the sampled generations look like
|
||||
@ -218,8 +218,8 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml
|
||||
--num_ppo_epochs 2 \
|
||||
--num_mini_batches 2 \
|
||||
--learning_rate 3e-6 \
|
||||
--per_device_train_batch_size 8 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--per_device_train_batch_size 16 \
|
||||
--gradient_accumulation_steps 16 \
|
||||
--total_episodes 1000000 \
|
||||
--model_name_or_path EleutherAI/pythia-1b-deduped \
|
||||
--sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
|
||||
@ -251,7 +251,7 @@ The RLOO checkpoint gets a 51.2% preferred rate vs the 33.0% preference rate of
|
||||
|
||||
Metrics:
|
||||
|
||||

|
||||

|
||||
|
||||
|
||||
```bash
|
||||
@ -269,6 +269,17 @@ python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--scan-history
|
||||
```
|
||||
|
||||
## Reinforce++
|
||||
|
||||
The [Reinforce++](https://hijkzzz.notion.site/reinforce-plus-plus) report by Jian Hu suggests several optimization tricks to enhance performance and stability of RLHF. They include:
|
||||
|
||||
- Clipping rewards: limiting reward values within a specific range to mitigate the impact of extreme rewards on model updates, thus preventing gradient explosion
|
||||
- Normalizing rewards: scaling rewards to have a mean of 0 and a standard deviation of 1, which helps in stabilizing the training process
|
||||
- Normalizing advantages: scaling advantages to have a mean of 0 and a standard deviation of 1, which helps in stabilizing the training process
|
||||
- Using token-level KL penalty that is defined as equation (1) of the report vs. sequence-level KL penalty (default)
|
||||
|
||||
These options are available via the appropriate arguments in the [`RLOOConfig`] class.
|
||||
|
||||
|
||||
## RLOOTrainer
|
||||
|
||||
@ -276,4 +287,4 @@ python -m openrlbenchmark.rlops_multi_metrics \
|
||||
|
||||
## RLOOConfig
|
||||
|
||||
[[autodoc]] RLOOConfig
|
||||
[[autodoc]] RLOOConfig
|
||||
|
12
docs/source/script_utils.md
Normal file
12
docs/source/script_utils.md
Normal file
@ -0,0 +1,12 @@
|
||||
# Scripts Utilities
|
||||
|
||||
## ScriptArguments
|
||||
|
||||
[[autodoc]] ScriptArguments
|
||||
|
||||
## TrlParser
|
||||
|
||||
[[autodoc]] TrlParser
|
||||
- parse_args_and_config
|
||||
- parse_args_into_dataclasses
|
||||
- set_defaults_with_config
|
@ -1,11 +1,8 @@
|
||||
# Supervised Fine-tuning Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=sft,trl)
|
||||
[](https://huggingface.co/models?other=sft,trl) [](https://github.com/huggingface/smol-course/tree/main/1_instruction_tuning)
|
||||
|
||||
Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset.
|
||||
|
||||
Check out a complete flexible example at [`examples/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft.py).
|
||||
Experimental support for Vision Language Models is also included in the example [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft_vlm.py).
|
||||
Supervised fine-tuning (SFT) is the most common step in post-training foundation models, and also one of the most effective. In TRL, we provide a simple API to train models with SFT in a few lines of code; for a complete training script, check out [`trl/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/trl/scripts/sft.py). Experimental support for Vision Language Models is also included in [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft_vlm.py).
|
||||
|
||||
## Quickstart
|
||||
|
||||
@ -19,7 +16,7 @@ from trl import SFTConfig, SFTTrainer
|
||||
dataset = load_dataset("stanfordnlp/imdb", split="train")
|
||||
|
||||
training_args = SFTConfig(
|
||||
max_seq_length=512,
|
||||
max_length=512,
|
||||
output_dir="/tmp",
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
@ -29,7 +26,7 @@ trainer = SFTTrainer(
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
Make sure to pass the correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
|
||||
Make sure to pass the correct value for `max_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
|
||||
|
||||
You can also construct a model outside of the trainer and pass it as follows:
|
||||
|
||||
@ -59,108 +56,9 @@ The above snippets will use the default training arguments from the [`SFTConfig`
|
||||
|
||||
### Train on completions only
|
||||
|
||||
You can use the `DataCollatorForCompletionOnlyLM` to train your model on the generated prompts only. Note that this works only in the case when `packing=False`.
|
||||
To instantiate that collator for instruction data, pass a response template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on completions only on the CodeAlpaca dataset:
|
||||
To train on completions only, simply use a [prompt-completion](#prompt-completion) dataset. In this mode, loss is computed solely on the completion part.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from datasets import load_dataset
|
||||
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
|
||||
|
||||
dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
|
||||
def formatting_prompts_func(example):
|
||||
output_texts = []
|
||||
for i in range(len(example['instruction'])):
|
||||
text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
|
||||
output_texts.append(text)
|
||||
return output_texts
|
||||
|
||||
response_template = " ### Answer:"
|
||||
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
train_dataset=dataset,
|
||||
args=SFTConfig(output_dir="/tmp"),
|
||||
formatting_func=formatting_prompts_func,
|
||||
data_collator=collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
To instantiate that collator for assistant style conversation data, pass a response template, an instruction template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on assistant completions only on the Open Assistant Guanaco dataset:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from datasets import load_dataset
|
||||
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
|
||||
|
||||
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
|
||||
instruction_template = "### Human:"
|
||||
response_template = "### Assistant:"
|
||||
collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template, response_template=response_template, tokenizer=tokenizer, mlm=False)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
args=SFTConfig(output_dir="/tmp"),
|
||||
train_dataset=dataset,
|
||||
data_collator=collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Make sure to have a `pad_token_id` which is different from `eos_token_id` which can result in the model not properly predicting EOS (End of Sentence) tokens during generation.
|
||||
|
||||
#### Using token_ids directly for `response_template`
|
||||
|
||||
Some tokenizers like Llama 2 (`meta-llama/Llama-2-XXb-hf`) tokenize sequences differently depending on whether they have context or not. For example:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
|
||||
def print_tokens_with_ids(txt):
|
||||
tokens = tokenizer.tokenize(txt, add_special_tokens=False)
|
||||
token_ids = tokenizer.encode(txt, add_special_tokens=False)
|
||||
print(list(zip(tokens, token_ids)))
|
||||
|
||||
prompt = """### User: Hello\n\n### Assistant: Hi, how can I help you?"""
|
||||
print_tokens_with_ids(prompt) # [..., ('▁Hello', 15043), ('<0x0A>', 13), ('<0x0A>', 13), ('##', 2277), ('#', 29937), ('▁Ass', 4007), ('istant', 22137), (':', 29901), ...]
|
||||
|
||||
response_template = "### Assistant:"
|
||||
print_tokens_with_ids(response_template) # [('▁###', 835), ('▁Ass', 4007), ('istant', 22137), (':', 29901)]
|
||||
```
|
||||
|
||||
In this case, and due to lack of context in `response_template`, the same string ("### Assistant:") is tokenized differently:
|
||||
|
||||
- Text (with context): `[2277, 29937, 4007, 22137, 29901]`
|
||||
- `response_template` (without context): `[835, 4007, 22137, 29901]`
|
||||
|
||||
This will lead to an error when the `DataCollatorForCompletionOnlyLM` does not find the `response_template` in the dataset example text:
|
||||
|
||||
```
|
||||
RuntimeError: Could not find response key [835, 4007, 22137, 29901] in token IDs tensor([ 1, 835, ...])
|
||||
```
|
||||
|
||||
|
||||
To solve this, you can tokenize the `response_template` with the same context as in the dataset, truncate it as needed and pass the `token_ids` directly to the `response_template` argument of the `DataCollatorForCompletionOnlyLM` class. For example:
|
||||
|
||||
```python
|
||||
response_template_with_context = "\n### Assistant:" # We added context here: "\n". This is enough for this tokenizer
|
||||
response_template_ids = tokenizer.encode(response_template_with_context, add_special_tokens=False)[2:] # Now we have it like in the dataset texts: `[2277, 29937, 4007, 22137, 29901]`
|
||||
|
||||
data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)
|
||||
```
|
||||
If you’d like to compute loss on both the prompt **and** the completion while still using a prompt-completion dataset, set `completion_only_loss=False` in the [`SFTConfig`]. This is equivalent to [converting the dataset to a language modeling](#from-prompt-completion-to-language-modeling-dataset) format.
|
||||
|
||||
### Add Special Tokens for Chat Format
|
||||
|
||||
@ -181,9 +79,11 @@ tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
|
||||
# Set up the chat format with default 'chatml' format
|
||||
model, tokenizer = setup_chat_format(model, tokenizer)
|
||||
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Some base models, like those from Qwen, have a predefined chat template in the model's tokenizer. In these cases it is not necessary to apply `setup_chat_format()`, as the tokenizer already handles the formatting. However, it is necessary to align the EOS token with the chat template to ensure the model's responses terminate correctly. In these cases, specify `eos_token` in `SFTConfig`; for example, for `Qwen/Qwen2.5-1.5B` one should set `eos_token="<|im_end|>"`.
|
||||
|
||||
With our model and tokenizer set up, we can now fine-tune our model on a conversational dataset. Below is an example of how a dataset can be formatted for fine-tuning.
|
||||
|
||||
### Dataset format support
|
||||
@ -246,26 +146,23 @@ Let us assume your dataset has two fields, `question` and `answer`. Therefore yo
|
||||
```python
|
||||
...
|
||||
def formatting_prompts_func(example):
|
||||
output_texts = []
|
||||
for i in range(len(example['question'])):
|
||||
text = f"### Question: {example['question'][i]}\n ### Answer: {example['answer'][i]}"
|
||||
output_texts.append(text)
|
||||
return output_texts
|
||||
return f"### Question: {example['question']}\n ### Answer: {example['answer']}"
|
||||
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
formatting_func=formatting_prompts_func,
|
||||
formatting_func=formatting_prompt_func,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
To properly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example of how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
|
||||
|
||||
### Packing dataset ([`ConstantLengthDataset`])
|
||||
### Packing dataset
|
||||
|
||||
[`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. This is done with the [`ConstantLengthDataset`] utility class that returns constant length chunks of tokens from a stream of examples. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTConfig`] constructor.
|
||||
[`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTConfig`] constructor.
|
||||
|
||||
```python
|
||||
...
|
||||
@ -302,7 +199,6 @@ trainer = SFTTrainer(
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
You can also customize the [`ConstantLengthDataset`] much more by directly passing the arguments to the [`SFTConfig`] constructor. Please refer to that class' signature for more information.
|
||||
|
||||
### Control over the pretrained model
|
||||
|
||||
@ -331,33 +227,38 @@ Note that all keyword arguments of `from_pretrained()` are supported.
|
||||
|
||||
### Training adapters
|
||||
|
||||
We also support tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model
|
||||
We also support tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model.
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
from peft import LoraConfig
|
||||
|
||||
dataset = load_dataset("stanfordnlp/imdb", split="train")
|
||||
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
target_modules="all-linear",
|
||||
modules_to_save=["lm_head", "embed_token"],
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
"EleutherAI/gpt-neo-125m",
|
||||
"Qwen/Qwen2.5-0.5B",
|
||||
train_dataset=dataset,
|
||||
args=SFTConfig(output_dir="/tmp"),
|
||||
args=SFTConfig(output_dir="Qwen2.5-0.5B-SFT"),
|
||||
peft_config=peft_config
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> If the chat template contains special tokens like `<|im_start|>` (ChatML) or `<|eot_id|>` (Llama), the embedding layer and LM head must be included in the trainable parameters via the `modules_to_save` argument. Without this, the fine-tuned model will produce unbounded or nonsense generations. If the chat template doesn't contain special tokens (e.g. Alpaca), then the `modules_to_save` argument can be ignored or set to `None`.
|
||||
|
||||
|
||||
You can also continue training your `PeftModel`. For that, first load a `PeftModel` outside `SFTTrainer` and pass it directly to the trainer without the `peft_config` argument being passed.
|
||||
|
||||
### Training adapters with base 8 bit models
|
||||
@ -426,9 +327,9 @@ Below are some numbers you can get in terms of speedup and memory efficiency, us
|
||||
|
||||
| use_flash_attn_1 | model_name | max_seq_len | batch_size | time per training step |
|
||||
| ---------------- | ----------------- | ----------- | ---------- | ---------------------- |
|
||||
| x | facebook/opt-350m | 2048 | 8 | ~59.1s |
|
||||
| ✓ | facebook/opt-350m | 2048 | 8 | ~59.1s |
|
||||
| | facebook/opt-350m | 2048 | 8 | **OOM** |
|
||||
| x | facebook/opt-350m | 2048 | 4 | ~30.3s |
|
||||
| ✓ | facebook/opt-350m | 2048 | 4 | ~30.3s |
|
||||
| | facebook/opt-350m | 2048 | 4 | ~148.9s |
|
||||
|
||||
### Using Flash Attention-2
|
||||
@ -463,30 +364,30 @@ We included a utility function to create your model.
|
||||
|
||||
```python
|
||||
from trl import ModelConfig, SFTTrainer, get_kbit_device_map, get_peft_config, get_quantization_config
|
||||
model_config = ModelConfig(
|
||||
model_args = ModelConfig(
|
||||
model_name_or_path="facebook/opt-350m"
|
||||
attn_implementation=None, # or "flash_attention_2"
|
||||
)
|
||||
torch_dtype = (
|
||||
model_config.torch_dtype
|
||||
if model_config.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_config.torch_dtype)
|
||||
model_args.torch_dtype
|
||||
if model_args.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_config)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
model_kwargs = dict(
|
||||
revision=model_config.model_revision,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
attn_implementation=model_config.attn_implementation,
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
|
||||
trainer = SFTTrainer(
|
||||
...,
|
||||
model=model_config.model_name_or_path,
|
||||
peft_config=get_peft_config(model_config),
|
||||
model=model_args.model_name_or_path,
|
||||
peft_config=get_peft_config(model_args),
|
||||
)
|
||||
```
|
||||
|
||||
@ -497,7 +398,7 @@ NEFTune is a technique to boost the performance of chat models and was introduce
|
||||
> Standard finetuning of LLaMA-2-7B using Alpaca achieves 29.79% on AlpacaEval, which rises to 64.69% using noisy embeddings. NEFTune also improves over strong baselines on modern instruction datasets. Models trained with Evol-Instruct see a 10% improvement, with ShareGPT an 8% improvement, and with OpenPlatypus an 8% improvement. Even powerful models further refined with RLHF such as LLaMA-2-Chat benefit from additional training with NEFTune.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/neft-screenshot.png">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/neft-screenshot.png">
|
||||
</div>
|
||||
|
||||
To use it in `SFTTrainer` simply pass `neftune_noise_alpha` when creating your `SFTConfig` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to retrieve back the original behaviour of the embedding layer.
|
||||
@ -522,7 +423,7 @@ trainer.train()
|
||||
We have tested NEFTune by training `mistralai/Mistral-7B-v0.1` on the [OpenAssistant dataset](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) and validated that using NEFTune led to a performance boost of ~25% on MT Bench.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-neftune-mistral-7b.png">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-neftune-mistral-7b.png">
|
||||
</div>
|
||||
|
||||
Note however, that the amount of performance gain is _dataset dependent_ and in particular, applying NEFTune on synthetic datasets like [UltraChat](https://huggingface.co/datasets/stingning/ultrachat) typically produces smaller gains.
|
||||
@ -545,12 +446,12 @@ import torch
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number
|
||||
max_length = 2048 # Supports automatic RoPE Scaling, so choose any number
|
||||
|
||||
# Load model
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="unsloth/mistral-7b",
|
||||
max_seq_length=max_seq_length,
|
||||
max_seq_length=max_length,
|
||||
dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
||||
load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False
|
||||
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
|
||||
@ -576,7 +477,7 @@ model = FastLanguageModel.get_peft_model(
|
||||
random_state=3407,
|
||||
)
|
||||
|
||||
training_args = SFTConfig(output_dir="./output", max_seq_length=max_seq_length)
|
||||
training_args = SFTConfig(output_dir="./output", max_length=max_length)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
@ -599,17 +500,17 @@ With great memory reduction, you can potentially turn off cpu_offloading or grad
|
||||
|  |  |
|
||||
|
||||
|
||||
1. To use Liger-Kernel in `SFTTrainer`, first install by
|
||||
1. To use Liger-Kernel in [`SFTTrainer`], first install by
|
||||
|
||||
```bash
|
||||
pip install liger-kernel
|
||||
```
|
||||
|
||||
2. Once installed, set `use_liger` in [`SFTConfig`]. No other changes are needed!
|
||||
2. Once installed, set `use_liger_kernel` in [`SFTConfig`]. No other changes are needed!
|
||||
|
||||
```python
|
||||
training_args = SFTConfig(
|
||||
use_liger=True
|
||||
use_liger_kernel=True
|
||||
)
|
||||
```
|
||||
|
||||
@ -619,7 +520,7 @@ To learn more about Liger-Kernel, visit their [official repository](https://gith
|
||||
|
||||
Pay attention to the following best practices when training a model with that trainer:
|
||||
|
||||
- [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training.
|
||||
- [`SFTTrainer`] always truncates by default the sequences to the `max_length` argument of the [`SFTConfig`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 1024 and that value. Make sure to check it before training.
|
||||
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
|
||||
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
|
||||
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
|
||||
@ -766,7 +667,3 @@ A full example of training LLaVa 1.5 on the [HuggingFaceH4/llava-instruct-mix-vs
|
||||
In the SFTTrainer we smartly support `datasets.IterableDataset` in addition to other style datasets. This is useful if you are using large corpora that you do not want to save all to disk. The data will be tokenized and processed on the fly, even when packing is enabled.
|
||||
|
||||
Additionally, in the SFTTrainer, we support pre-tokenized datasets if they are `datasets.Dataset` or `datasets.IterableDataset`. In other words, if such a dataset has a column of `input_ids`, no further processing (tokenization or packing) will be done, and the dataset will be used as-is. This can be useful if you have pretokenized your dataset outside of this script and want to re-use it directly.
|
||||
|
||||
### ConstantLengthDataset
|
||||
|
||||
[[autodoc]] trainer.ConstantLengthDataset
|
73
docs/source/speeding_up_training.md
Normal file
73
docs/source/speeding_up_training.md
Normal file
@ -0,0 +1,73 @@
|
||||
# Speeding Up Training
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Section under construction. Feel free to contribute!
|
||||
|
||||
</Tip>
|
||||
|
||||
## vLLM for fast generation in online methods
|
||||
|
||||
Online methods such as GRPO or Online DPO require the model to generate completions, which is often a slow process and can significantly impact training time.
|
||||
To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through, among other things, PagedAttention. TRL's online trainers support vLLM, greatly improving training speed.
|
||||
|
||||
To use [vLLM](https://github.com/vllm-project/vllm), first install it using:
|
||||
|
||||
```bash
|
||||
pip install vllm
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```bash
|
||||
pip install "trl[vllm]"
|
||||
```
|
||||
|
||||
<hfoptions id="vllm examples">
|
||||
<hfoption id="Online DPO">
|
||||
|
||||
Then, enable it by passing `use_vllm=True` in the training arguments.
|
||||
|
||||
```python
|
||||
from trl import OnlineDPOConfig
|
||||
|
||||
training_args = OnlineDPOConfig(..., use_vllm=True)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="GRPO">
|
||||
|
||||
First, start a vLLM server by running:
|
||||
|
||||
```bash
|
||||
trl vllm-serve --model <model_name>
|
||||
```
|
||||
|
||||
Then, run the training script and pass `use_vllm=True` in the training arguments.
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(..., use_vllm=True)
|
||||
```
|
||||
|
||||
You can customize the server configuration by passing additional arguments. For more information, see [vLLM integration](vllm_integration).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
When using vLLM, ensure that the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation using `CUDA_VISIBLE_DEVICES`.
|
||||
|
||||
Set GPUs **0-3** for vLLM generation:
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model <model_name>
|
||||
```
|
||||
|
||||
And GPUs **4-7** for training:
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
|
||||
```
|
||||
|
||||
</Tip>
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
@ -1,197 +0,0 @@
|
||||
# Text Environments
|
||||
|
||||
Text environments provide a learning ground for language agents. It allows a language model to use tools to accomplish a task such as using a Python interpreter to answer math questions or using a search index for trivia questions. Having access to tools allows language models to solve tasks that would be very hard for the models itself but can be trivial for the appropriate tools. A good example is arithmetics of large numbers that become a simple copy-paste task once you have access to a calculator.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/textenv.png">
|
||||
</div>
|
||||
|
||||
Let's dive into how text environments work and start with tools!
|
||||
|
||||
## Tools
|
||||
|
||||
One of the core building blocks of text environments are tools that the model can use to solve tasks. In general tools can be any Python function that takes a string as input and returns string. The `TextEnvironment` offers two options for tools: either go with predefined tools from `transformers.Tool` or define your own function or class with `__call__` method. Let's have a look at both!
|
||||
|
||||
### `transformers.Tool`
|
||||
|
||||
Text environments fully support tools of the class `transformers.Tool`. The advantage of building tools in that framework is that they can easily be shared
|
||||
|
||||
```Python
|
||||
from transformers import load_tool
|
||||
|
||||
# simple calculator tool that runs +-/* operations
|
||||
calc_tool = load_tool("ybelkada/simple-calculator")
|
||||
|
||||
# python interpreter that executes program and returns outputs
|
||||
py_tool = load_tool("lvwerra/python-interpreter")
|
||||
|
||||
# wikipedia search index that returns best search match
|
||||
wiki_tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
|
||||
```
|
||||
|
||||
These tools are either loaded from the hub or from a local folder. Using the tool is as simple as calling them with a text query:
|
||||
|
||||
```Python
|
||||
calc_tool("1/2")
|
||||
>>> "0.5"
|
||||
```
|
||||
|
||||
Note that both input and return values are strings to enable easy usage with a language model.
|
||||
|
||||
### Custom Tools
|
||||
|
||||
The following is an example of a tool that adds two integers:
|
||||
|
||||
```Python
|
||||
def add(text):
|
||||
int_1, int_2 = text.split("+")
|
||||
result = int(int_1) + int(int_2)
|
||||
return str(result)
|
||||
|
||||
print(add("1+1"))
|
||||
>>> "2"
|
||||
```
|
||||
|
||||
We looked at basic examples such as a calculator but the principle holds for more complex tools as well such as a web search tool where you input the query and get the search results in return. Now let's look at how the model can use the tools with the call syntax.
|
||||
|
||||
### Call syntax
|
||||
|
||||
In order to have a unified way for the model to call a tool we created a simple syntax that looks as follows:
|
||||
|
||||
```python
|
||||
"<request><TOOL_NAME>QUERY<call>TOOL_RESPONSE<response>"
|
||||
```
|
||||
|
||||
There are a few special tokens involved so let's decompose it: First the model can signal that it wants to use a tool by emitting the `<request>` token. After that we want to know the name of the tool to call which is done by enclosing the tool name with `<>` brackets. Once we know which tool to call the tool query follows which is in free text form. The `<call>` tokens signifies the end of the query and stops the model generation. At this point the model output is parsed and the query sent to the tool. The environment appends the tool response to the string followed by the `<response>` token to show the end the tool output.
|
||||
|
||||
Let's look at the concrete example of the calculator and assume its name is `Calculator` (more on how the name of a tool is inferred later):
|
||||
|
||||
```python
|
||||
"<request><Calculator>1/2<call>0.5<response>"
|
||||
```
|
||||
|
||||
Finally, the episode is ended and generation stops when the model generates `<submit>` which marks the interaction as completed.
|
||||
|
||||
Now let's have a look how we can create a new text environment!
|
||||
|
||||
## Create a `TextEnvironment`
|
||||
|
||||
|
||||
```python
|
||||
prompt = """\
|
||||
What is 13-3?
|
||||
<request><SimpleCalculatorTool>13-3<call>10.0<response>
|
||||
Result=10<submit>
|
||||
"""
|
||||
|
||||
def reward_fn(result, answer):
|
||||
"""Simplified reward function returning 1 if result matches answer and 0 otherwise."""
|
||||
result_parsed = result.split("=")[1].split("<")[0]
|
||||
return int(result_parsed==answer)
|
||||
|
||||
text_env = TextEnvironemnt(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
tools= {"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")},
|
||||
reward_fn=exact_match_reward,
|
||||
prompt=prompt,
|
||||
max_turns=1
|
||||
max_tool_response=100
|
||||
generation_kwargs={"do_sample": "true"}
|
||||
)
|
||||
```
|
||||
|
||||
Let's decompose the settings:
|
||||
|
||||
| Argument | Description |
|
||||
|:-------------------|:----------------|
|
||||
| `model` | Language model to interact with the environment and generate requests. |
|
||||
| `tokenizer` | Tokenizer of language model handling tokenization of strings. |
|
||||
| `tools` | `list` of `dict` of tools. If former the name of the tool is inferred from class name and otherwise it's the keys of the dictionary.|
|
||||
| `reward_fn` | A function that takes a string as input and returns. Can have extra arguments that are passed to `.run()` such as ground truth.|
|
||||
| `prompt` | Prompt to prepend to every task. Usually a few examples to demonstrate to the model how to use the tools in a few-shot fashion. |
|
||||
| `max_turns` | Maximum number of interactions between model and tools before episode ends.|
|
||||
| `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.|
|
||||
| `max_length` | The maximum number of tokens to allow in an episode. |
|
||||
| `generation_kwargs`| Generation settings used by the language model. |
|
||||
|
||||
You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools!
|
||||
|
||||
|
||||
## Run an Episode
|
||||
|
||||
To run a set of queries through the text environment one can simply use the `run` method.
|
||||
|
||||
```python
|
||||
queries = ["What is 1/2?"]
|
||||
answers = ["0.5"]
|
||||
|
||||
queries, responses, masks, rewards, histories = text_env.run(queries, answers=answers)
|
||||
```
|
||||
|
||||
This will execute the model/tool feedback loop for each query until either no tool is called anymore, the maximum number of turns is reached or to maximum number of tokens in an episode is exceeded. The extra `kwargs` (e.g. `answers=answers` above) passed to `run` will be passed on to the reward function.
|
||||
|
||||
There are five objects that are returned by `run`:
|
||||
|
||||
- `queries`: a list of the tokenized queries
|
||||
- `responses`: all tokens that have been generated withing the environment including model and tool tokens
|
||||
- `masks`: mask that indicates which tokens have been generated by the model and which tokens are generated by the tool
|
||||
- `rewards`: a list of reward for each query/response
|
||||
- `histories`: list of `TextHistory` objects, which are useful objects containing all the above and also the text equivalents
|
||||
|
||||
The masks are crucial for training as we don't want to optimize tokens that the model has not generated which are tokens produced by the tools.
|
||||
|
||||
Next, we'll train a PPO step with the generated responses!
|
||||
|
||||
|
||||
### Train
|
||||
Training on episodes from the `TextEnvironment` is straight forward and simply requires forwarding all the returned variables except the `TextHistory` objects to the `step` method:
|
||||
|
||||
```python
|
||||
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
|
||||
```
|
||||
|
||||
## `TextHistory`
|
||||
|
||||
The `TextHistory` object stores the interactions between the model and the text environment. It stores tokens and text generated in each turn and their source in each turn (model or system) as well as rewards. Let's go through the class attributes and methods.
|
||||
|
||||
### Attributes
|
||||
|
||||
The following table summarises the available attributes of the `TextEnvironment` class:
|
||||
|
||||
| Attribute | Description |
|
||||
|:-------------------|:----------------|
|
||||
| `text` | The full string of the text generated in the text environment with both model and system generated text. |
|
||||
| `text_spans` | A list of tuples with the spans for each model or system generated text segment. |
|
||||
| `system_spans` | A list of boolean values indicating if the segment is model or system generated. |
|
||||
| `tokens` | All tokens generated in text environment with both model and system generated tokens. |
|
||||
| `token_spans` | Similar to `text_spans` the `token_spans` indicate the boundaries of model andsystem generated tokens. |
|
||||
| `token_masks` | The token masks can be used to ignore system generated tokens by masking them. |
|
||||
| `completed` | Indicates if the interaction with the environment has completed. |
|
||||
| `truncated` | Indicates if the interaction with the environment has completed because max length was reached. |
|
||||
|
||||
With these attributes you can reconstruct every interaction of the model with the `TextEnvironment`. The `TextHistory` also lets you visualize the text history. Let's have a look!
|
||||
|
||||
### Visualization
|
||||
|
||||
When the model interacts inside the `TextEnvironment` it can be useful to visualize and separate which parts of the text outputs were generated by the model and which parts come from the system and tools. For that purpose there are the two methods [`TextHistory.show_text`] and [`TextHistory.show_tokens`]. They print the text and tokens respectively and highlight the various segments using the [`rich` libray](https://github.com/Textualize/rich) (make sure to install it before using these methods).
|
||||
|
||||
You can see that the prompt is highlighted in gray, whereas system segments such as query and tool responses are highlighted in green. All segments generated by the model are highlighted in blue and in addition to the pure text output the reward is displayed as additional text in plum. Here an example of `show_text`:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/textenv_show_text.png" width=600>
|
||||
</div>
|
||||
|
||||
Sometimes there can be tricky tokenization related issues that are hidden when showing the decoded text. Thus `TextHistory` also offers an option to display the same highlighting on the tokens directly with `show_tokens`:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/textenv_show_tokens.png" width=800>
|
||||
</div>
|
||||
|
||||
Note that you can turn on the colour legend by passing `show_legend=True`.
|
||||
|
||||
## API Documentation
|
||||
|
||||
[[autodoc]] TextEnvironment
|
||||
|
||||
[[autodoc]] TextHistory
|
381
docs/source/training_vlm_sft.md
Normal file
381
docs/source/training_vlm_sft.md
Normal file
@ -0,0 +1,381 @@
|
||||
# Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset)
|
||||
|
||||

|
||||
|
||||
## Overview
|
||||
|
||||
This guide walks you through the process of fine-tuning a multimodal language model (e.g., **Gemma 3**) using **Supervised Fine-Tuning (SFT)**. We cover two cases:
|
||||
|
||||
- **Single Image + Text**
|
||||
- **Multi-Image + Text**
|
||||
|
||||
This guide serves as a **detailed walkthrough** and complements the existing [VLM SFT script](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm_gemma3.py). If you're already familiar with the concepts, you can use the script directly.
|
||||
|
||||
We demonstrate the fine-tuning process using two datasets, but these principles extend to other **Vision-Language Models (VLMs)** and datasets.
|
||||
|
||||
## Understanding the Datasets
|
||||
|
||||
To address both **Single Image + Text** and **Multi-Image + Text** scenarios, we use two datasets that are well-suited for this task.
|
||||
|
||||
### HuggingFaceH4/llava-instruct-mix-vsft Dataset (Image + Text)
|
||||
|
||||
This dataset is a reformatted version of [LLaVA Instruct Mix](https://huggingface.co/datasets/theblackcat102/llava-instruct-mix). It consists of conversations where a user provides both **text** and a **single image** as input.
|
||||
|
||||
The model (referred to as the **"assistant"**) responds based on both the **visual and textual information** shared by the user. This dataset is particularly useful for training multimodal models to **understand and generate responses based on images and text**.
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft/embed/viewer/default/train"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
### FanqingM/MMIU-Benchmark Dataset (Multi-Image + Text)
|
||||
|
||||
The **FanqingM/MMIU-Benchmark** dataset consists of:
|
||||
|
||||
- **Context:** Included in the system prompt.
|
||||
- **Question:** Provided as part of the user's input.
|
||||
- **Series of Images:** Multiple images related to the question.
|
||||
- **Answer:** The model's expected response.
|
||||
|
||||
This dataset is designed for tasks where the model must reason over multiple images to generate an informed response based on both visual and textual inputs.
|
||||
|
||||
<iframe
|
||||
src="https://huggingface.co/datasets/FanqingM/MMIU-Benchmark/embed/viewer/default/test"
|
||||
frameborder="0"
|
||||
width="100%"
|
||||
height="560px"
|
||||
></iframe>
|
||||
|
||||
## Developing a Fine-Tuning Script for Multimodal SFT
|
||||
|
||||
In this section, we build the script needed to fine-tune a multimodal model for both **Single Image + Text** and **Multi-Image + Text** use cases.
|
||||
|
||||
### Setting Up the Environment
|
||||
|
||||
Before fine-tuning, we need to install the required dependencies. Let's start by setting up the environment:
|
||||
|
||||
```bash
|
||||
# Install the required libraries. Futher details: https://huggingface.co/docs/trl/installation
|
||||
pip install -U -q trl bitsandbytes peft hf_xet tensorboard
|
||||
```
|
||||
|
||||
Once all dependencies are installed, we need to log in to the **Hugging Face Hub**. Since **Gemma 3** is a gated model, access permissions are required.
|
||||
|
||||
If you haven’t requested access yet, visit the [Model Card](https://huggingface.co/google/gemma-3-4b-it) and request it.
|
||||
|
||||
To log in, you’ll need to generate an [access token](https://huggingface.co/settings/tokens) from your Hugging Face account.
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
### **Loading the Data**
|
||||
|
||||
As mentioned earlier, we will cover two possible use cases. While the specific procedure may vary based on the dataset, the core principles remain consistent.
|
||||
|
||||
This guide supports both use cases, so refer to the **Single Image + Text** or **Multi-Image + Text** sections depending on your specific scenario.
|
||||
|
||||
#### **Single Image + Text**
|
||||
|
||||

|
||||
|
||||
In this case, each sample in a batch consists of a **single image paired with text**. Since the dataset is already formatted for supervised fine-tuning (SFT), we can directly load it using `load_dataset`.
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset_name = "HuggingFaceH4/llava-instruct-mix-vsft"
|
||||
|
||||
# Load Dataset
|
||||
dataset = load_dataset(dataset_name)
|
||||
```
|
||||
|
||||
#### **Multi-Image + Text (or Interleaving)**
|
||||
|
||||

|
||||
|
||||
Gemma 3 also supports **Multi-Image + Text** scenarios, where:
|
||||
|
||||
- The model receives a **list of images** alongside a user message.
|
||||
- The model processes **interleaved images and text** within a conversation.
|
||||
|
||||
For this dataset, some preprocessing is required before training.
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset_name = "FanqingM/MMIU-Benchmark"
|
||||
|
||||
# Load Dataset
|
||||
dataset = load_dataset(dataset_name)
|
||||
```
|
||||
|
||||
After loading the dataset, we need to preprocess and format it into a conversational structure. Here’s an example of how the data might look:
|
||||
|
||||
```python
|
||||
{"role": "system", "content": [{"type": "text", "text": "You are a judge in a photography competition, and now you are given the four images. Please examine the details and tell which one of them is most likely to be a real photograph.\nSelect from the following choices.\nA: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]},
|
||||
{"role": "user", "content": images_list + [{"type": "text", "text": "Which image is most likely to be a real photograph?"}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "A: the first image\nB: the second image\nC: the third image\nD: the fourth image"}]},
|
||||
```
|
||||
|
||||
Here, `images_list` is a list of images:
|
||||
|
||||
```python
|
||||
images_list = [
|
||||
{"type": "image", "image": <class 'PIL.Image.Image'>},
|
||||
{"type": "image", "image": <class 'PIL.Image.Image'>},
|
||||
{"type": "image", "image": <class 'PIL.Image.Image'>},
|
||||
{"type": "image", "image": <class 'PIL.Image.Image'>},
|
||||
{"type": "image", "image": <class 'PIL.Image.Image'>},
|
||||
]
|
||||
```
|
||||
|
||||
This structure can be translated into code like this:
|
||||
|
||||
```python
|
||||
import os
|
||||
import zipfile
|
||||
import io
|
||||
from datasets import DatasetDict
|
||||
from huggingface_hub import hf_hub_download, list_repo_files
|
||||
from PIL import Image
|
||||
|
||||
dataset_train_split = "test"
|
||||
|
||||
def format_data(samples: dict[str, any]) -> dict[str, list]:
|
||||
formatted_samples = {"messages": []}
|
||||
for cont in range(len(samples["question"])):
|
||||
images = []
|
||||
for img_path in samples["input_image_path"][cont]:
|
||||
try:
|
||||
with open(img_path, "rb") as f:
|
||||
img_bytes = f.read()
|
||||
image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
||||
images.append({"type": "image", "image": image})
|
||||
except Exception as e:
|
||||
print(f"Error processing image {img_path}: {e}")
|
||||
continue
|
||||
|
||||
formatted_samples["messages"].append(
|
||||
[
|
||||
{"role": "system", "content": [{"type": "text", "text": samples["context"][cont]}]},
|
||||
{"role": "user", "content": images + [{"type": "text", "text": samples["question"][cont]}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": samples["output"][cont]}]},
|
||||
]
|
||||
)
|
||||
return formatted_samples
|
||||
|
||||
# For multi-image example
|
||||
def prepare_dataset(dataset: DatasetDict, dataset_name: str, dataset_train_split: str) -> DatasetDict:
|
||||
all_files = list_repo_files(dataset_name, repo_type="dataset")
|
||||
zip_files = [f for f in all_files if f.endswith(".zip")]
|
||||
|
||||
for zip_filename in zip_files:
|
||||
zip_path = hf_hub_download(repo_id=dataset_name, filename=zip_filename, repo_type="dataset")
|
||||
extract_folder = zip_filename.replace(".zip", "")
|
||||
os.makedirs(extract_folder, exist_ok=True)
|
||||
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(extract_folder)
|
||||
|
||||
dataset = dataset.map(format_data, batched=True, batch_size=4, num_proc=16)
|
||||
return dataset
|
||||
|
||||
dataset = prepare_dataset(dataset, dataset_name, dataset_train_split)
|
||||
```
|
||||
|
||||
With this, your **Multi-Image + Text** dataset is now prepared for training.
|
||||
|
||||
### **Preparing for Training**
|
||||
|
||||
We start by loading the model and processor. In this example, we use `google/gemma-3-4b-it`, but the same process applies to its other variants and similar models.
|
||||
|
||||
To optimize memory usage, we configure `BitsAndBytes` to load the quantized version of the model.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig
|
||||
|
||||
model_id = "google/gemma-3-4b-it"
|
||||
|
||||
# BitsAndBytesConfig int-4 config
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_quant_storage=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Load model and tokenizer
|
||||
model = AutoModelForImageTextToText.from_pretrained(
|
||||
model_id,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="eager", # Important (Ref: https://github.com/huggingface/transformers/blob/c15a7adb283fa984a40558c7fe7bed30ae975cdd/src/transformers/models/gemma3/modeling_gemma3.py#L934)
|
||||
quantization_config=bnb_config
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
processor.tokenizer.padding_side = "right"
|
||||
```
|
||||
|
||||
Next, we set up [Quantized Low-Rank Adaptation (QLoRA)](https://huggingface.co/papers/2305.14314), an efficient fine-tuning technique for Large Language Models (LLMs) and Vision-Language Models (VLMs).
|
||||
|
||||
```python
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
# Configure QLoRA
|
||||
peft_config = LoraConfig(
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.05,
|
||||
r=16,
|
||||
bias="none",
|
||||
target_modules="all-linear",
|
||||
task_type="CAUSAL_LM",
|
||||
modules_to_save=[
|
||||
"lm_head",
|
||||
"embed_tokens",
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
With QLoRA now set up, we need to define the training arguments for SFT. The [`SFTConfig`] class simplifies this process, providing an easy way to adjust parameters based on our specific needs.
|
||||
|
||||
```python
|
||||
from trl import SFTConfig
|
||||
|
||||
training_args = SFTConfig(
|
||||
output_dir="gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft", # Directory to save the model and push to the Hub. Use a specific repository id (e.g., gemma-3-4b-it-trl-sft-MMIU-Benchmark for multi-image datasets).
|
||||
num_train_epochs=1, # Set the number of epochs to train the model.
|
||||
per_device_train_batch_size=8, # Batch size for each device (e.g., GPU) during training. multi-image -> per_device_train_batch_size=1
|
||||
gradient_accumulation_steps=4, # Number of steps before performing a backward/update pass to accumulate gradients. multi-image -> gradient_accumulation_steps=1
|
||||
gradient_checkpointing=True, # Enable gradient checkpointing to reduce memory usage during training.
|
||||
optim="adamw_torch_fused", # Use the fused AdamW optimizer for better performance.
|
||||
logging_steps=10, # Frequency of logging training progress (log every 10 steps).
|
||||
save_strategy="epoch", # Save checkpoints at the end of each epoch.
|
||||
learning_rate=2e-05, # Learning rate for training.
|
||||
bf16=True, # Enable bfloat16 precision for training to save memory and speed up computations.
|
||||
push_to_hub=True, # Automatically push the fine-tuned model to Hugging Face Hub after training.
|
||||
report_to="tensorboard", # Automatically report metrics to tensorboard.
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}, # Set gradient checkpointing to non-reentrant to avoid issues.
|
||||
dataset_kwargs={"skip_prepare_dataset": True}, # Skip dataset preparation to handle preprocessing manually.
|
||||
remove_unused_columns=False, # Ensure unused columns are not removed in the collator (important for batch processing).
|
||||
)
|
||||
```
|
||||
|
||||
The `collate_fn` is responsible for processing and preparing individual examples to form a batch.
|
||||
|
||||
Each example in the batch undergoes the following steps:
|
||||
|
||||
1. The **chat template** is applied to the text.
|
||||
2. The **processor tokenizes** both `texts` and `images`, encoding them into tensors.
|
||||
3. The **labels** for training are set as the `input_ids` of the example.
|
||||
4. Certain **special tokens** are **masked (ignored)** during loss computation:
|
||||
- `pad_token_id`
|
||||
- `<image_token_id>`
|
||||
- `<image_soft_token>` (corresponding to ID `262144`)
|
||||
|
||||
This process is similar across different dataset types, with a minor variation in how images are handled:
|
||||
|
||||
- **Single Image + Text** → A **list of images** is directly processed.
|
||||
- **Multi-Image + Text** → A **list of lists of images** is used, where each batch element contains multiple images.
|
||||
|
||||
```python
|
||||
from PIL import Image
|
||||
|
||||
# For multi-image cases
|
||||
def process_vision_info(messages: list[dict]) -> list[Image.Image]:
|
||||
image_inputs = []
|
||||
for msg in messages:
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
content = [content]
|
||||
|
||||
for element in content:
|
||||
if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
|
||||
if "image" in element:
|
||||
image = element["image"]
|
||||
else:
|
||||
image = element
|
||||
if image is not None:
|
||||
image = Image.open(io.BytesIO(image["bytes"]))
|
||||
image_inputs.append(image.convert("RGB"))
|
||||
return image_inputs
|
||||
|
||||
def collate_fn(examples):
|
||||
texts = [processor.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=False).strip() for example in examples]
|
||||
if "images" in examples[0]: # single-image
|
||||
images = [
|
||||
[img.convert("RGB") for img in example["images"]]
|
||||
for example in examples
|
||||
]
|
||||
else: # multi-image
|
||||
images = [process_vision_info(example["messages"]) for example in examples]
|
||||
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(
|
||||
text=texts, images=images, return_tensors="pt", padding=True
|
||||
) # Encode texts and images into tensors
|
||||
|
||||
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||
labels = batch["input_ids"].clone() # Clone input IDs for labels
|
||||
# Mask image tokens
|
||||
image_token_id = [
|
||||
processor.tokenizer.convert_tokens_to_ids(processor.tokenizer.special_tokens_map["boi_token"])
|
||||
]
|
||||
# Mask tokens for not being used in the loss computation
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
labels[labels == image_token_id] = -100
|
||||
labels[labels == 262144] = -100
|
||||
|
||||
batch["labels"] = labels
|
||||
return batch # Return the prepared batch
|
||||
```
|
||||
|
||||
### **Training the Model**
|
||||
|
||||
With all the components set up, we now configure the `SFTTrainer` using the previously defined settings and start the training process.
|
||||
|
||||
``` python
|
||||
# Training
|
||||
from trl import SFTTrainer
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
data_collator=collate_fn,
|
||||
train_dataset=dataset["train"], # multi-image -> train_dataset=dataset["test"],
|
||||
processing_class=processor,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
# Save the final model
|
||||
trainer.save_model()
|
||||
```
|
||||
|
||||
We save the fine-tuned model to the Hub, making it easily accessible for future use. Additionally, TRL automatically logs the training results to **Weights & Biases (Wandb)** or **TensorBoard**, depending on the chosen configuration.
|
||||
|
||||
<!-- Add Wandb training results -->
|
||||
### Results
|
||||
|
||||
During and after trainig, we can inspect the results using **Weights & Biases (Wandb)** or **TensorBoard**. For example:
|
||||
|
||||
* [**gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft (Single Image+Text)**](https://huggingface.co/sergiopaniego/gemma-3-4b-it-trl-sft-llava-instruct-mix-vsft)
|
||||
|
||||
* [**gemma-3-4b-it-trl-sft-MMIU-Benchmark (Multi-Images+Text or Interleaving)**](https://huggingface.co/sergiopaniego/gemma-3-4b-it-trl-sft-MMIU-Benchmark)
|
||||
|
||||
## Limitations
|
||||
|
||||
Currently, fine-tuning Gemma has some [known limitations](https://github.com/huggingface/trl/issues/3121). We recommend following the procedure outlined in this guide to ensure the best results.
|
||||
|
||||
## References
|
||||
|
||||
For further reading and complementary resources, check out the following:
|
||||
|
||||
- [Fine-Tuning Vision-Language Models with QLoRA](https://ai.google.dev/gemma/docs/core/huggingface_vision_finetune_qlora)
|
||||
- [Fine-Tuning a Vision Language Model (Qwen2-VL-7B) with the Hugging Face Ecosystem (TRL)](https://huggingface.co/learn/cookbook/fine_tuning_vlm_trl)
|
||||
|
7
docs/source/unsloth_integration.md
Normal file
7
docs/source/unsloth_integration.md
Normal file
@ -0,0 +1,7 @@
|
||||
# Unsloth Integration
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Section under construction. Feel free to contribute!
|
||||
|
||||
</Tip>
|
@ -19,7 +19,7 @@ Now we can fit very large models into a single GPU, but the training might still
|
||||
The simplest strategy in this scenario is data parallelism: we replicate the same training setup into separate GPUs and pass different batches to each GPU.
|
||||
With this, you can parallelize the forward/backward passes of the model and scale with the number of GPUs.
|
||||
|
||||

|
||||

|
||||
|
||||
We use either the `transformers.Trainer` or `accelerate`, which both support data parallelism without any code changes, by simply passing arguments when calling the scripts with `torchrun` or `accelerate launch`. The following runs a training script with 8 GPUs on a single machine with `accelerate` and `torchrun`, respectively.
|
||||
|
||||
@ -38,12 +38,11 @@ The [StackExchange dataset](https://huggingface.co/datasets/HuggingFaceH4/stack-
|
||||
There is nothing special about fine-tuning the model before doing RLHF - it’s just the causal language modeling objective from pretraining that we apply here.
|
||||
To use the data efficiently, we use a technique called packing: instead of having one text per sample in the batch and then padding to either the longest text or the maximal context of the model, we concatenate a lot of texts with a EOS token in between and cut chunks of the context size to fill the batch without any padding.
|
||||
|
||||

|
||||

|
||||
|
||||
With this approach the training is much more efficient as each token that is passed through the model is also trained in contrast to padding tokens which are usually masked from the loss.
|
||||
If you don't have much data and are more concerned about occasionally cutting off some tokens that are overflowing the context you can also use a classical data loader.
|
||||
|
||||
The packing is handled by the `ConstantLengthDataset` and we can then use the `Trainer` after loading the model with `peft`. First, we load the model in int8, prepare it for training, and then add the LoRA adapters.
|
||||
|
||||
```python
|
||||
# load model in 8bit
|
185
docs/source/vllm_integration.md
Normal file
185
docs/source/vllm_integration.md
Normal file
@ -0,0 +1,185 @@
|
||||
# vLLM Integration
|
||||
|
||||
This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood. Let's go! 🔥
|
||||
|
||||
## 🚀 How can I use vLLM with TRL to speed up training?
|
||||
|
||||
💡 **Note**: Resources required for this specific example: a single node with 8 GPUs.
|
||||
|
||||
First, install vLLM using the following command:
|
||||
|
||||
```bash
|
||||
pip install "trl[vllm]"
|
||||
```
|
||||
|
||||
Then run the server:
|
||||
|
||||
```sh
|
||||
trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 2 --data-parallel-size 2
|
||||
```
|
||||
|
||||
Once the server is running, you can use it to generate completions for training. In the example below, we are using the `GRPOTrainer` to train a model using the vLLM server for generation. The `--tensor-parallel-size` and `--data-parallel-size` arguments control how the model and data are sharded across GPUs.
|
||||
|
||||
In this example, we are sharding two copies of the model across 4 GPUs. Increasing data parallelism increases throughput, while increasing tensor parallelism allows for serving larger models. Then, run the training script by passing `use_vllm=True` in the training arguments as follows:
|
||||
|
||||
Sample of a simple `train.py` script:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import GRPOTrainer, GRPOConfig
|
||||
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
# Dummy reward function: count the number of unique characters in the completions
|
||||
def reward_num_unique_chars(completions, **kwargs):
|
||||
return [len(set(c)) for c in completions]
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir="my_test",
|
||||
use_vllm=True,
|
||||
bf16=True,
|
||||
gradient_checkpointing=True,
|
||||
logging_steps=10,
|
||||
)
|
||||
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2.5-7B",
|
||||
args=training_args,
|
||||
reward_funcs=reward_num_unique_chars,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
And the train command:
|
||||
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
|
||||
```
|
||||
|
||||
## 🎬 Flashback: Why do we need to use vLLM in online methods?
|
||||
|
||||
Online methods like GRPO or Online DPO require the model to generate completions during training, which are then used to compute reward signals. However, generation can be extremely time-consuming, especially with large or reasoning models. In the default setup (without vLLM), completions are generated using the [(unwrapped) model's `generate` method](https://github.com/huggingface/trl/blob/f3e8c2304428ef16e9ae5de9e5741ed84d533b7b/trl/trainer/grpo_trainer.py#L965C39-L965C66). This approach quickly becomes a major bottleneck — generation is slow and inefficient, particularly for large batches or models. As a result, training times increase significantly, and overall efficiency drops. To address this, we turn to vLLM, which enables much faster and more scalable generation, helping eliminate this bottleneck in online methods.
|
||||
|
||||
## 🤔 How does vLLM solve the slow generation issue?
|
||||
|
||||
If you've ever done autoregressive decoder training, you know all the input tokens to the LLM produce their attention key and value tensors, and these tensors are kept in GPU memory to later generate subsequent tokens based on them. These cached key and value tensors are often referred to as the KV cache. However, storing the KV cache occupies a lot of memory, so vLLM uses a technique called **PagedAttention** to solve this problem. PagedAttention, which is inspired by the OS’s virtual memory concept, stores continuous keys and values in **non-contiguous memory space**, which is much more efficient. The details of this are beyond the scope of this document, but in short, it allows the model to store the keys and values in a more efficient way, reducing the memory footprint and speeding up the generation process. If you are interested, make sure to check out the [vLLM PagedAttention](https://blog.vllm.ai/2023/06/20/vllm.html) for more details.
|
||||
|
||||
## 🤔 What exactly happens when you run `trl vllm-serve --model <model_name>`?
|
||||
|
||||
When you run for example
|
||||
|
||||
```sh
|
||||
trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 1 --data-parallel-size 4
|
||||
```
|
||||
|
||||
the following happens:
|
||||
|
||||

|
||||
|
||||
1. vLLM first spawns multiple workers to handle incoming requests in parallel. The number of workers is determined by multiplying the `--tensor-parallel-size` and `--data-parallel-size` values. In this example, it spawns 4 workers (1 × 4).
|
||||
Each worker operates independently and processes a chunk of the incoming requests — which are basically the prompts sent to the server for generation. A key point to understand is that these 4 workers are running in parallel, and each one is responsible for handling a subset of the total incoming load.
|
||||
|
||||
2. Once the incoming requests (prompts) are distributed across the workers, the model starts generating completions. Internally, the model’s weights are split across multiple GPUs based on the `--tensor-parallel-size` argument — this is how tensor parallelism is handled. Meanwhile, data parallelism (controlled by `--data-parallel-size`) ensures that different sets of requests are processed independently across the workers. In short: tensor parallelism splits the model across GPUs, and data parallelism splits the batch of requests across different model replicas.
|
||||
|
||||
3. Although the GPUs process requests independently and in parallel, they still need to communicate with each other. Remember that each GPU handles only a slice of the incoming prompts (for example, with 4 GPUs and 8 prompts using `--data-parallel-size=4`, each GPU processes 2 prompts).
|
||||
This GPU-to-GPU communication is managed efficiently by NVIDIA’s NCCL library. The communication mainly ensures that each GPU gets its correct portion of the incoming requests — it’s lightweight and doesn’t interfere with generation itself.
|
||||
Separately, the number of completions to generate per prompt is controlled by the `num_generations` setting in the GRPO config. For instance, if you set `num_generations=2` (like in the picture above), each prompt will have 2 completions. So, with 8 prompts and `num_generations=2`, you would end up with 16 completions total — regardless of the number of GPUs or parallelism settings.
|
||||
|
||||
## 🥸 More detail on what happens under the hood when running the server
|
||||
|
||||
* The vLLM server starts by running the command: `trl vllm-serve --model Qwen/Qwen2.5-7B`.
|
||||
* Once the server is running, it generates completions based on requests from the client (trainer) using `vllm_client.generate` [here](https://github.com/huggingface/trl/blob/cc044e35b285be7dc062764b3364e1e684db4c7c/trl/trainer/grpo_trainer.py#L1025-L1035).
|
||||
* The client (trainer) then requests these completions from the server.
|
||||
* These completions are used to compute the reward signal.
|
||||
* Based on the reward signal and the model’s output, the loss is computed, and the backward pass is performed to update the model’s weights.
|
||||
* **Note**: The server only handles completion generation — it doesn’t train the model. Therefore, the model’s weights aren’t updated on the server. Once the backward pass is complete, the client sends the updated weights to the server using `vllm_client.update_named_param(name, param.data)`.
|
||||
|
||||
When using vLLM, ensure the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation for training using `CUDA_VISIBLE_DEVICES`. See the example below:
|
||||
|
||||
* **Set GPUs *0–3* for vLLM generation:** Assume `CUDA_VISIBLE_DEVICES=0,1,2,3` are allocated for vLLM generation.
|
||||
|
||||
```sh
|
||||
trl vllm-serve --model <model_name> --tensor-parallel-size 1 --data-parallel-size 4
|
||||
```
|
||||
|
||||
* **And GPUs *4–7* for training:** If you do not set the `CUDA_VISIBLE_DEVICES` environment variable, the training script will use all available GPUs by default, which may lead to resource conflicts. To avoid this, you can specify which GPUs to use for training. For example, if you want to use GPUs 4–7 for training, set the environment variable as follows:
|
||||
|
||||
```sh
|
||||
CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
|
||||
```
|
||||
|
||||
## 🍷 More customization options with vLLM?
|
||||
|
||||
You can customize the server configuration by passing additional arguments.
|
||||
|
||||
```
|
||||
$ trl vllm-serve --help
|
||||
usage: trl vllm-serve [-h] --model MODEL [--revision REVISION] [--tensor_parallel_size TENSOR_PARALLEL_SIZE]
|
||||
[--data_parallel_size DATA_PARALLEL_SIZE] [--host HOST] [--port PORT]
|
||||
[--gpu_memory_utilization GPU_MEMORY_UTILIZATION] [--dtype DTYPE] [--max_model_len MAX_MODEL_LEN]
|
||||
[--enable_prefix_caching ENABLE_PREFIX_CACHING] [--enforce_eager ENFORCE_EAGER] [--log_level LOG_LEVEL]
|
||||
|
||||
options:
|
||||
-h, --help Show this help message and exit
|
||||
--model MODEL Model name or path to load the model from. (default: None)
|
||||
--revision REVISION Revision to use for the model. If not specified, the default branch will be used. (default: None)
|
||||
--tensor_parallel_size TENSOR_PARALLEL_SIZE, --tensor-parallel-size TENSOR_PARALLEL_SIZE
|
||||
Number of tensor parallel workers to use. (default: 1)
|
||||
--data_parallel_size DATA_PARALLEL_SIZE, --data-parallel-size DATA_PARALLEL_SIZE
|
||||
Number of data parallel workers to use. (default: 1)
|
||||
--host HOST Host address to run the server on. (default: 0.0.0.0)
|
||||
--port PORT Port to run the server on. (default: 8000)
|
||||
--gpu_memory_utilization GPU_MEMORY_UTILIZATION, --gpu-memory-utilization GPU_MEMORY_UTILIZATION
|
||||
Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the device
|
||||
dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus improve the
|
||||
model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors during
|
||||
initialization. (default: 0.9)
|
||||
--dtype DTYPE Data type to use for vLLM generation. If set to 'auto', the data type will be automatically determined based on
|
||||
the model configuration. Find the supported values in the vLLM documentation. (default: auto)
|
||||
--max_model_len MAX_MODEL_LEN, --max-model-len MAX_MODEL_LEN
|
||||
If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced
|
||||
`vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model context
|
||||
size, which might be much larger than the KV cache, leading to inefficiencies. (default: None)
|
||||
--enable_prefix_caching ENABLE_PREFIX_CACHING, --enable-prefix-caching ENABLE_PREFIX_CACHING
|
||||
Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support this
|
||||
feature. (default: None)
|
||||
--enforce_eager ENFORCE_EAGER, --enforce-eager ENFORCE_EAGER
|
||||
Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the model
|
||||
in eager mode. If `False` (default behavior), we will use CUDA graph and eager execution in hybrid. (default:
|
||||
None)
|
||||
--log_level LOG_LEVEL, --log-level LOG_LEVEL
|
||||
Log level for uvicorn. Possible choices: 'critical', 'error', 'warning', 'info', 'debug', 'trace'. (default:
|
||||
info)
|
||||
```
|
||||
|
||||
## 🥳 Okay, now that we have the server running, how can we use it to generate completions?
|
||||
|
||||
Run the training script and pass `use_vllm=True` in the training arguments:
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
training_args = GRPOConfig(..., use_vllm=True)
|
||||
```
|
||||
|
||||
## 💆🏻♀️ What's the best distributed setup?
|
||||
|
||||

|
||||

|
||||
|
||||
First and foremost, always remember that the optimal setup depends on:
|
||||
|
||||
* The model size
|
||||
* The number of GPUs you have
|
||||
* The GPU memory size
|
||||
* The batch size you are using
|
||||
* The number of requests you are sending to the server (prompts)
|
||||
* The `max_model_len` you are using (this is the max length of the input sequence that the model can process, a.k.a. the context window size)
|
||||
* The number of completions you are generating for each request (`num_generations`)
|
||||
|
||||
Given these factors, our experiments on the Qwen model family (3B, 7B, 14B, 32B) using 8 H100 GPUs show that:
|
||||
|
||||
* For reasonable-sized models (3B–14B) and a moderate context window (`max_len < 8k`), using full capacity for data parallelism gives better throughput. The setup `(tp=1, dp=8)` yields the best results.
|
||||
* For larger models (32B) and longer context windows (`max_len > 8k`), a smaller DP size combined with some model-side parallelism performs better. For example, `(tp=2, dp=4)` is a good setup for 32B models with a larger context window.
|
@ -4,7 +4,7 @@
|
||||
|
||||
## Overview
|
||||
|
||||
Exploratory Preference Optimization (XPO) was proposed in the paper [Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF](https://huggingface.co/papers/2405.21046) by Tengyang Xie, Dylan J. Foster, Akshay Krishnamurthy, [Corby Rosset](https://huggingface.co/corbyrosset), [Ahmed Awadallah](https://huggingface.co/AhmedAwadallah), and Alexander Rakhlin. It is a simple online preference tuning method based on the DPO loss together with a reward model (RM). XPO augments the DPO objective with an exploration bonus allowing the method to explore outside the support of the intitial model and human feedback data.
|
||||
Exploratory Preference Optimization (XPO) was proposed in the paper [Exploratory Preference Optimization: Harnessing Implicit Q*-Approximation for Sample-Efficient RLHF](https://huggingface.co/papers/2405.21046) by Tengyang Xie, Dylan J. Foster, Akshay Krishnamurthy, [Corby Rosset](https://huggingface.co/corbyrosset), [Ahmed Awadallah](https://huggingface.co/AhmedAwadallah), and Alexander Rakhlin. It is a simple online preference tuning method based on the DPO loss together with a reward model (RM). XPO augments the DPO objective with an exploration bonus allowing the method to explore outside the support of the initial model and human feedback data.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
@ -50,9 +50,9 @@ accelerate launch train_xpo.py
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 1 hour.
|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-XPO) performs, you can use the [TRL Chat CLI](clis#chat-interface).
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-XPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
<pre><code>$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-XPO
|
||||
<pre><code>$ transformers-cli chat --model_name_or_path trl-lib/Qwen2-0.5B-XPO
|
||||
<strong><span style="color: red;"><quentin_gallouedec>:</span></strong>
|
||||
What is the best programming language?
|
||||
|
||||
@ -110,7 +110,7 @@ trainer.add_callback(completions_callback)
|
||||
|
||||
This callback logs the model's generated completions directly to Weights & Biases.
|
||||
|
||||

|
||||

|
||||
|
||||
## Example script
|
||||
|
28
examples/accelerate_configs/fsdp1.yaml
Normal file
28
examples/accelerate_configs/fsdp1.yaml
Normal file
@ -0,0 +1,28 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
fsdp_config:
|
||||
fsdp_activation_checkpointing: false
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_forward_prefetch: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: FULL_SHARD
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_use_orig_params: true
|
||||
fsdp_version: 1
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
25
examples/accelerate_configs/fsdp2.yaml
Normal file
25
examples/accelerate_configs/fsdp2.yaml
Normal file
@ -0,0 +1,25 @@
|
||||
# Requires accelerate 1.7.0 or higher
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
fsdp_config:
|
||||
fsdp_activation_checkpointing: false
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_offload_params: false
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_version: 2
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
@ -1,25 +0,0 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_forward_prefetch: false
|
||||
fsdp_offload_params: true
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_use_orig_params: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
@ -7,7 +7,7 @@
|
||||
# CUDA_VISIBLE_DEVICES: 0
|
||||
|
||||
model_name_or_path:
|
||||
trl-internal-testing/tiny-random-LlamaForCausalLM
|
||||
Qwen/Qwen2.5-0.5B
|
||||
dataset_name:
|
||||
stanfordnlp/imdb
|
||||
report_to:
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -13,10 +13,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import ModelCard
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
@ -30,13 +31,20 @@ class ScriptArguments:
|
||||
Whether to push the dataset to the Hugging Face Hub.
|
||||
repo_id (`str`, *optional*, defaults to `"trl-lib/hh-rlhf-helpful-base"`):
|
||||
Hugging Face repository ID to push the dataset to.
|
||||
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
Number of workers to use for dataset processing.
|
||||
"""
|
||||
|
||||
push_to_hub: bool = False
|
||||
repo_id: str = "trl-lib/hh-rlhf-helpful-base"
|
||||
dataset_num_proc: Optional[int] = None
|
||||
push_to_hub: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
|
||||
)
|
||||
repo_id: str = field(
|
||||
default="trl-lib/hh-rlhf-helpful-base", metadata={"help": "Hugging Face repository ID to push the dataset to."}
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
default=None, metadata={"help": "Number of workers to use for dataset processing."}
|
||||
)
|
||||
|
||||
|
||||
def common_start(str1: str, str2: str) -> str:
|
||||
@ -51,7 +59,7 @@ def common_start(str1: str, str2: str) -> str:
|
||||
return "".join(common_chars)
|
||||
|
||||
|
||||
def extract_dialogue(example: str) -> List[Dict[str, str]]:
|
||||
def extract_dialogue(example: str) -> list[dict[str, str]]:
|
||||
# Extract the prompt, which corresponds to the common start of the chosen and rejected dialogues
|
||||
prompt_text = common_start(example["chosen"], example["rejected"])
|
||||
|
||||
@ -79,12 +87,40 @@ def extract_dialogue(example: str) -> List[Dict[str, str]]:
|
||||
prompt.append({"role": role, "content": content})
|
||||
|
||||
# Remove the prompt from the chosen and rejected dialogues
|
||||
chosen = [{"role": "assitant", "content": chosen_line}]
|
||||
chosen = [{"role": "assistant", "content": chosen_line}]
|
||||
rejected = [{"role": "assistant", "content": rejected_line}]
|
||||
|
||||
return {"prompt": prompt, "chosen": chosen, "rejected": rejected}
|
||||
|
||||
|
||||
model_card = ModelCard("""
|
||||
---
|
||||
tags: [trl]
|
||||
---
|
||||
|
||||
# HH-RLHF-Helpful-Base Dataset
|
||||
|
||||
## Summary
|
||||
|
||||
The HH-RLHF-Helpful-Base dataset is a processed version of [Anthropic's HH-RLHF](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset, specifically curated to train models using the [TRL library](https://github.com/huggingface/trl) for preference learning and alignment tasks. It contains pairs of text samples, each labeled as either "chosen" or "rejected," based on human preferences regarding the helpfulness of the responses. This dataset enables models to learn human preferences in generating helpful responses, enhancing their ability to assist users effectively.
|
||||
|
||||
## Data Structure
|
||||
|
||||
- **Format**: [Conversational](https://huggingface.co/docs/trl/main/dataset_formats#conversational)
|
||||
- **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference)
|
||||
|
||||
Columns:
|
||||
- `"prompt"`: The user query.
|
||||
- `"chosen"`: A response deemed helpful by human evaluators.
|
||||
- `"rejected"`: A response considered less helpful or unhelpful.
|
||||
|
||||
This structure allows models to learn to prefer the _chosen_ response over the _rejected_ one, thereby aligning with human preferences in helpfulness.
|
||||
|
||||
## Generation script
|
||||
|
||||
The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/hh-rlhf-helpful-base.py).
|
||||
""")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
@ -94,3 +130,4 @@ if __name__ == "__main__":
|
||||
|
||||
if script_args.push_to_hub:
|
||||
dataset.push_to_hub(script_args.repo_id)
|
||||
model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,10 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import ModelCard
|
||||
from transformers import AutoTokenizer, HfArgumentParser
|
||||
|
||||
|
||||
@ -29,13 +30,22 @@ class ScriptArguments:
|
||||
Whether to push the dataset to the Hugging Face Hub.
|
||||
repo_id (`str`, *optional*, defaults to `"trl-lib/lm-human-preferences-descriptiveness"`):
|
||||
Hugging Face repository ID to push the dataset to.
|
||||
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
Number of workers to use for dataset processing.
|
||||
"""
|
||||
|
||||
push_to_hub: bool = False
|
||||
repo_id: str = "trl-lib/lm-human-preferences-descriptiveness"
|
||||
dataset_num_proc: Optional[int] = None
|
||||
push_to_hub: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
|
||||
)
|
||||
repo_id: str = field(
|
||||
default="trl-lib/lm-human-preferences-descriptiveness",
|
||||
metadata={"help": "Hugging Face repository ID to push the dataset to."},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of workers to use for dataset processing."},
|
||||
)
|
||||
|
||||
|
||||
# Edge cases handling: remove the cases where all samples are the same
|
||||
@ -55,6 +65,34 @@ def to_prompt_completion(example, tokenizer):
|
||||
return {"prompt": prompt, "chosen": chosen, "rejected": rejected}
|
||||
|
||||
|
||||
model_card = ModelCard("""
|
||||
---
|
||||
tags: [trl]
|
||||
---
|
||||
|
||||
# LM-Human-Preferences-Descriptiveness Dataset
|
||||
|
||||
## Summary
|
||||
|
||||
The LM-Human-Preferences-Descriptiveness dataset is a processed subset of [OpenAI's LM-Human-Preferences](https://github.com/openai/lm-human-preferences), focusing specifically on enhancing the descriptiveness of generated text. It contains pairs of text samples, each labeled as either "chosen" or "rejected," based on human preferences regarding the level of detail and vividness in the descriptions. This dataset enables models to learn human preferences in descriptive language, improving their ability to generate rich and engaging narratives.
|
||||
|
||||
## Data Structure
|
||||
|
||||
- **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard)
|
||||
- **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference)
|
||||
|
||||
Columns:
|
||||
- `"prompt"`: The text sample.
|
||||
- `"chosen"`: A version of the text with enhanced descriptiveness.
|
||||
- `"rejected"`: A version of the text with less descriptiveness.
|
||||
|
||||
This structure allows models to learn to prefer the _chosen_ response over the _rejected_ one, thereby aligning with human preferences in descriptive language.
|
||||
|
||||
## Generation script
|
||||
|
||||
The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/lm-human-preferences-descriptiveness.py).
|
||||
""")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
@ -79,3 +117,4 @@ if __name__ == "__main__":
|
||||
|
||||
if script_args.push_to_hub:
|
||||
dataset.push_to_hub(script_args.repo_id)
|
||||
model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,10 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import ModelCard
|
||||
from transformers import AutoTokenizer, HfArgumentParser
|
||||
|
||||
|
||||
@ -29,13 +30,22 @@ class ScriptArguments:
|
||||
Whether to push the dataset to the Hugging Face Hub.
|
||||
repo_id (`str`, *optional*, defaults to `"trl-lib/lm-human-preferences-sentiment"`):
|
||||
Hugging Face repository ID to push the dataset to.
|
||||
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
Number of workers to use for dataset processing.
|
||||
"""
|
||||
|
||||
push_to_hub: bool = False
|
||||
repo_id: str = "trl-lib/lm-human-preferences-sentiment"
|
||||
dataset_num_proc: Optional[int] = None
|
||||
push_to_hub: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
|
||||
)
|
||||
repo_id: str = field(
|
||||
default="trl-lib/lm-human-preferences-sentiment",
|
||||
metadata={"help": "Hugging Face repository ID to push the dataset to."},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of workers to use for dataset processing."},
|
||||
)
|
||||
|
||||
|
||||
def to_prompt_completion(example, tokenizer):
|
||||
@ -50,6 +60,34 @@ def to_prompt_completion(example, tokenizer):
|
||||
return {"prompt": prompt, "chosen": chosen, "rejected": rejected}
|
||||
|
||||
|
||||
model_card = ModelCard("""
|
||||
---
|
||||
tags: [trl]
|
||||
---
|
||||
|
||||
# LM-Human-Preferences-Sentiment Dataset
|
||||
|
||||
## Summary
|
||||
|
||||
The LM-Human-Preferences-Sentiment dataset is a processed subset of [OpenAI's LM-Human-Preferences](https://github.com/openai/lm-human-preferences), focusing specifically on sentiment analysis tasks. It contains pairs of text samples, each labeled as either "chosen" or "rejected," based on human preferences regarding the sentiment conveyed in the text. This dataset enables models to learn human preferences in sentiment expression, enhancing their ability to generate and evaluate text with desired emotional tones.
|
||||
|
||||
## Data Structure
|
||||
|
||||
- **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard)
|
||||
- **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference)
|
||||
|
||||
Columns:
|
||||
- `"prompt"`: The text sample.
|
||||
- `"chosen"`: A version of the text that conveys the desired sentiment.
|
||||
- `"rejected"`: A version of the text that does not convey the desired sentiment.
|
||||
|
||||
This structure allows models to learn to prefer the _chosen_ response over the _rejected_ one, thereby aligning with human preferences in sentiment expression.
|
||||
|
||||
## Generation script
|
||||
|
||||
The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/lm-human-preferences-sentiment.py).
|
||||
""")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
@ -72,3 +110,4 @@ if __name__ == "__main__":
|
||||
|
||||
if script_args.push_to_hub:
|
||||
dataset.push_to_hub(script_args.repo_id)
|
||||
model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
|
||||
|
170
examples/datasets/math_shepherd.py
Normal file
170
examples/datasets/math_shepherd.py
Normal file
@ -0,0 +1,170 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import ModelCard
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
r"""
|
||||
Arguments for the script.
|
||||
|
||||
Args:
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether to push the dataset to the Hugging Face Hub.
|
||||
repo_id (`str`, *optional*, defaults to `"trl-lib/math_shepherd"`):
|
||||
Hugging Face repository ID to push the dataset to.
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
Number of workers to use for dataset processing.
|
||||
"""
|
||||
|
||||
push_to_hub: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
|
||||
)
|
||||
repo_id: str = field(
|
||||
default="trl-lib/math_shepherd",
|
||||
metadata={"help": "Hugging Face repository ID to push the dataset to."},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of workers to use for dataset processing."},
|
||||
)
|
||||
|
||||
|
||||
def process_example(example):
|
||||
# Replace "ки" with "ⶻ" so that the size of the "input" matches the size of the "label"
|
||||
inputs = example["input"].replace("ки", "ⶻ")
|
||||
|
||||
# Find the indices of the "ⶻ" characters (that should match with the indexes of the "+" or "-" in the label)
|
||||
indexes = [m.start() for m in re.finditer("ⶻ", inputs)]
|
||||
|
||||
# Sanity that all indexes are either "+" or "-"
|
||||
assert all(example["label"][idx] in ["+", "-"] for idx in indexes)
|
||||
|
||||
# Get the labels
|
||||
labels = [example["label"][idx] == "+" for idx in indexes]
|
||||
|
||||
# Split the inputs into steps (caution, the first step is missing here, it is the prompt)
|
||||
steps = [inputs[i:j] for i, j in zip(chain([0], indexes), chain(indexes, [None]))]
|
||||
|
||||
# Remove the last step (single ⶻ)
|
||||
steps = steps[:-1]
|
||||
|
||||
# Get the prompt (first part) and completions (rest)
|
||||
prompt = steps[0]
|
||||
completions = steps[1:]
|
||||
|
||||
# Remove the heading "ⶻ" and the final whitespace from the completions
|
||||
assert all(completion.startswith("ⶻ") for completion in completions)
|
||||
completions = [completion[1:].strip() for completion in completions]
|
||||
|
||||
# At this point, we need to retrieve the first step from the prompt.
|
||||
# First, we handle particular cases (annotation error) where we have a first label before the end of the prompt.
|
||||
if prompt.startswith(
|
||||
(
|
||||
"Mr. Rocky",
|
||||
"Parker",
|
||||
"What is the smallest positive",
|
||||
" The Myth",
|
||||
"Let $\\mathbf{a}$",
|
||||
"Find the arithmetic",
|
||||
"Determine an ordered pair",
|
||||
"Determine the ordered pair",
|
||||
"At the Quill and Scroll stationery",
|
||||
"Round to the nearest",
|
||||
r"Calculate $\sqrt{10p}",
|
||||
r"Simplify $\sqrt{28x}",
|
||||
)
|
||||
):
|
||||
# Some spotted datasets errors where there is an annotation in the prompt: we remove it
|
||||
labels = labels[1:]
|
||||
|
||||
# Then we handle the general case: we get the first step from the prompt by looking for "Step 1:" or "step 1:" or
|
||||
# (less common) "?".
|
||||
elif "Step 1:" in prompt:
|
||||
prompt, first_step = prompt.split("Step 1:")
|
||||
first_step = "Step 1:" + first_step
|
||||
completions = [first_step.strip()] + completions
|
||||
elif "step 1:" in prompt:
|
||||
prompt, first_step = prompt.split("step 1:")
|
||||
first_step = "step 1:" + first_step
|
||||
completions = [first_step.strip()] + completions
|
||||
elif "?" in prompt:
|
||||
prompt, first_step = prompt.split("?")
|
||||
prompt = prompt + "?"
|
||||
completions = [first_step.strip()] + completions
|
||||
else:
|
||||
raise ValueError(f"Prompt can't be processed: {prompt}")
|
||||
|
||||
# Strip the prompt
|
||||
prompt = prompt.strip()
|
||||
|
||||
# Sanity check that the length of the completions is the same as the length of the labels
|
||||
assert len(completions) == len(labels)
|
||||
|
||||
return {"prompt": prompt, "completions": completions, "labels": labels}
|
||||
|
||||
|
||||
model_card = ModelCard("""
|
||||
---
|
||||
tags: [trl]
|
||||
---
|
||||
|
||||
# Math-Shepherd Dataset
|
||||
|
||||
## Summary
|
||||
|
||||
The Math-Shepherd dataset is a processed version of [Math-Shepherd dataset](peiyi9979/Math-Shepherd), designed to train models using the [TRL library](https://github.com/huggingface/trl) for stepwise supervision tasks. It provides step-by-step solutions to mathematical problems, enabling models to learn and verify each step of a solution, thereby enhancing their reasoning capabilities.
|
||||
|
||||
## Data Structure
|
||||
|
||||
- **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard)
|
||||
- **Type**: [Stepwise supervision](https://huggingface.co/docs/trl/main/dataset_formats#stepwise-supervision)
|
||||
|
||||
Columns:
|
||||
- `"prompt"`: The problem statement.
|
||||
- `"completions"`: A list of reasoning steps generated to solve the problem.
|
||||
- `"labels"`: A list of booleans or floats indicating the correctness of each corresponding reasoning step.
|
||||
|
||||
This structure allows models to learn the correctness of each step in a solution, facilitating improved reasoning and problem-solving abilities.
|
||||
|
||||
## Generation script
|
||||
|
||||
The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/math_shepherd.py).
|
||||
""")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
dataset = load_dataset("peiyi9979/Math-Shepherd", split="train")
|
||||
|
||||
dataset = dataset.map(
|
||||
process_example,
|
||||
remove_columns=["input", "label", "task"],
|
||||
num_proc=script_args.dataset_num_proc,
|
||||
)
|
||||
dataset = dataset.train_test_split(test_size=0.05, seed=42)
|
||||
|
||||
if script_args.push_to_hub:
|
||||
dataset.push_to_hub(script_args.repo_id)
|
||||
model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
|
157
examples/datasets/prm800k.py
Normal file
157
examples/datasets/prm800k.py
Normal file
@ -0,0 +1,157 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import ModelCard
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
r"""
|
||||
Arguments for the script.
|
||||
|
||||
Args:
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether to push the dataset to the Hugging Face Hub.
|
||||
repo_id (`str`, *optional*, defaults to `"trl-lib/prm800k"`):
|
||||
Hugging Face repository ID to push the dataset to.
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
Number of workers to use for dataset processing.
|
||||
"""
|
||||
|
||||
push_to_hub: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
|
||||
)
|
||||
repo_id: str = field(
|
||||
default="trl-lib/prm800k",
|
||||
metadata={"help": "Hugging Face repository ID to push the dataset to."},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of workers to use for dataset processing."},
|
||||
)
|
||||
|
||||
|
||||
def process_example(example):
|
||||
outputs = []
|
||||
prompt = example["question"]["problem"]
|
||||
|
||||
# Iterate through each step
|
||||
previous_completions = []
|
||||
previous_labels = []
|
||||
for step in example["label"]["steps"]:
|
||||
if step["completions"] is None and step["human_completion"] is None and step["chosen_completion"] is None:
|
||||
# happens sometimes
|
||||
break
|
||||
# Loop through completions
|
||||
for completion_idx, completion in enumerate(step["completions"]):
|
||||
# For every completion that are not chosen, we are in a terminal state, so we can add it to the list of outputs.
|
||||
if completion_idx != step["chosen_completion"]:
|
||||
content = completion["text"]
|
||||
completions = previous_completions[:] + [content]
|
||||
label = completion["rating"] == 1
|
||||
labels = previous_labels[:] + [label]
|
||||
outputs.append({"prompt": prompt, "completions": completions, "labels": labels})
|
||||
|
||||
# Now, exapand the previous completions and labels
|
||||
if step["chosen_completion"] is not None:
|
||||
chosen_completion = step["completions"][step["chosen_completion"]]
|
||||
label = chosen_completion["rating"] == 1
|
||||
elif step["human_completion"] is not None:
|
||||
chosen_completion = step["human_completion"]
|
||||
label = True
|
||||
else:
|
||||
break
|
||||
content = chosen_completion["text"]
|
||||
previous_completions.append(content)
|
||||
previous_labels.append(label)
|
||||
|
||||
# Last step: we are in a terminal state, so we can add it to the list of outputs
|
||||
outputs.append({"prompt": prompt, "completions": previous_completions, "labels": previous_labels})
|
||||
return outputs
|
||||
|
||||
|
||||
def process_batch(examples):
|
||||
outputs = []
|
||||
batch_size = len(examples["label"])
|
||||
for idx in range(batch_size):
|
||||
example = {k: v[idx] for k, v in examples.items()}
|
||||
outputs.extend(process_example(example))
|
||||
# list of dict to dict of list
|
||||
outputs = {k: [v[k] for v in outputs] for k in outputs[0]}
|
||||
return outputs
|
||||
|
||||
|
||||
model_card = ModelCard("""
|
||||
---
|
||||
tags: [trl]
|
||||
---
|
||||
|
||||
# PRM800K Dataset
|
||||
|
||||
## Summary
|
||||
|
||||
The PRM800K dataset is a processed version of [OpenAI's PRM800K](https://github.com/openai/prm800k), designed to train models using the [TRL library](https://github.com/huggingface/trl) for stepwise supervision tasks. It contains 800,000 step-level correctness labels for model-generated solutions to problems from the MATH dataset. This dataset enables models to learn and verify each step of a solution, enhancing their reasoning capabilities.
|
||||
|
||||
## Data Structure
|
||||
|
||||
- **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard)
|
||||
- **Type**: [Stepwise supervision](https://huggingface.co/docs/trl/main/dataset_formats#stepwise-supervision)
|
||||
|
||||
Columns:
|
||||
- `"prompt"`: The problem statement.
|
||||
- `"completions"`: A list of reasoning steps generated to solve the problem.
|
||||
- `"labels"`: A list of booleans or floats indicating the correctness of each corresponding reasoning step.
|
||||
|
||||
This structure allows models to learn the correctness of each step in a solution, facilitating improved reasoning and problem-solving abilities.
|
||||
|
||||
## Generation script
|
||||
|
||||
The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/prm800k.py).
|
||||
""")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
data_files = {
|
||||
"train": "https://github.com/openai/prm800k/raw/refs/heads/main/prm800k/data/phase1_train.jsonl",
|
||||
"test": "https://github.com/openai/prm800k/raw/refs/heads/main/prm800k/data/phase1_test.jsonl",
|
||||
}
|
||||
dataset = load_dataset("json", data_files=data_files)
|
||||
|
||||
dataset = dataset.map(
|
||||
process_batch,
|
||||
batched=True,
|
||||
batch_size=10,
|
||||
remove_columns=[
|
||||
"labeler",
|
||||
"timestamp",
|
||||
"generation",
|
||||
"is_quality_control_question",
|
||||
"is_initial_screening_question",
|
||||
"question",
|
||||
"label",
|
||||
],
|
||||
num_proc=script_args.dataset_num_proc,
|
||||
)
|
||||
|
||||
if script_args.push_to_hub:
|
||||
dataset.push_to_hub(script_args.repo_id)
|
||||
model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,10 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import features, load_dataset
|
||||
from huggingface_hub import ModelCard
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
@ -29,13 +30,22 @@ class ScriptArguments:
|
||||
Whether to push the dataset to the Hugging Face Hub.
|
||||
repo_id (`str`, *optional*, defaults to `"trl-lib/rlaif-v"`):
|
||||
Hugging Face repository ID to push the dataset to.
|
||||
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
Number of workers to use for dataset processing.
|
||||
"""
|
||||
|
||||
push_to_hub: bool = False
|
||||
repo_id: str = "trl-lib/rlaif-v"
|
||||
dataset_num_proc: Optional[int] = None
|
||||
push_to_hub: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
|
||||
)
|
||||
repo_id: str = field(
|
||||
default="trl-lib/rlaif-v",
|
||||
metadata={"help": "Hugging Face repository ID to push the dataset to."},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of workers to use for dataset processing."},
|
||||
)
|
||||
|
||||
|
||||
def to_conversational(example):
|
||||
@ -50,6 +60,35 @@ def to_conversational(example):
|
||||
return {"prompt": prompt, "images": [example["image"]], "chosen": chosen, "rejected": rejected}
|
||||
|
||||
|
||||
model_card = ModelCard("""
|
||||
---
|
||||
tags: [trl]
|
||||
---
|
||||
|
||||
# RLAIF-V Dataset
|
||||
|
||||
## Summary
|
||||
|
||||
The RLAIF-V dataset is a processed version of the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset#dataset-card-for-rlaif-v-dataset), specifically curated to train vision-language models using the [TRL library](https://github.com/huggingface/trl) for preference learning tasks. It contains 83,132 high-quality comparison pairs, each comprising an image and two textual descriptions: one preferred and one rejected. This dataset enables models to learn human preferences in visual contexts, enhancing their ability to generate and evaluate image captions.
|
||||
|
||||
## Data Structure
|
||||
|
||||
- **Format**: [Conversational](https://huggingface.co/docs/trl/main/dataset_formats#conversational)
|
||||
- **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference)
|
||||
|
||||
Columns:
|
||||
- `"prompt"`: The task related to the image.
|
||||
- `"images"`: The image.
|
||||
- `"chosen"`: The preferred answer.
|
||||
- `"rejected"`: An alternative answer that was not preferred.
|
||||
|
||||
This structure allows models to learn to prefer the _chosen_ response over the _rejected_ one, thereby aligning with human preferences in visual tasks.
|
||||
|
||||
## Generation script
|
||||
|
||||
The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/rlaif-v.py).
|
||||
""")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
@ -71,3 +110,4 @@ if __name__ == "__main__":
|
||||
|
||||
if script_args.push_to_hub:
|
||||
dataset.push_to_hub(script_args.repo_id)
|
||||
model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,10 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import ModelCard
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
@ -29,13 +30,22 @@ class ScriptArguments:
|
||||
Whether to push the dataset to the Hugging Face Hub.
|
||||
repo_id (`str`, *optional*, defaults to `"trl-lib/tldr"`):
|
||||
Hugging Face repository ID to push the dataset to.
|
||||
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
Number of workers to use for dataset processing.
|
||||
"""
|
||||
|
||||
push_to_hub: bool = False
|
||||
repo_id: str = "trl-lib/tldr"
|
||||
dataset_num_proc: Optional[int] = None
|
||||
push_to_hub: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
|
||||
)
|
||||
repo_id: str = field(
|
||||
default="trl-lib/tldr",
|
||||
metadata={"help": "Hugging Face repository ID to push the dataset to."},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of workers to use for dataset processing."},
|
||||
)
|
||||
|
||||
|
||||
def to_prompt_completion(example):
|
||||
@ -45,6 +55,33 @@ def to_prompt_completion(example):
|
||||
return {"prompt": prompt, "completion": completion}
|
||||
|
||||
|
||||
model_card = ModelCard("""
|
||||
---
|
||||
tags: [trl]
|
||||
---
|
||||
|
||||
# TL;DR Dataset
|
||||
|
||||
## Summary
|
||||
|
||||
The TL;DR dataset is a processed version of Reddit posts, specifically curated to train models using the [TRL library](https://github.com/huggingface/trl) for summarization tasks. It leverages the common practice on Reddit where users append "TL;DR" (Too Long; Didn't Read) summaries to lengthy posts, providing a rich source of paired text data for training summarization models.
|
||||
|
||||
## Data Structure
|
||||
|
||||
- **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard)
|
||||
- **Type**: [Prompt-completion](https://huggingface.co/docs/trl/main/dataset_formats#prompt-completion)
|
||||
|
||||
Columns:
|
||||
- `"prompt"`: The unabridged Reddit post.
|
||||
- `"completion"`: The concise "TL;DR" summary appended by the author.
|
||||
|
||||
This structure enables models to learn the relationship between detailed content and its abbreviated form, enhancing their summarization capabilities.
|
||||
|
||||
## Generation script
|
||||
|
||||
The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/tldr.py).
|
||||
""")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
@ -65,3 +102,4 @@ if __name__ == "__main__":
|
||||
|
||||
if script_args.push_to_hub:
|
||||
dataset.push_to_hub(script_args.repo_id)
|
||||
model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,10 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import ModelCard
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
@ -29,13 +30,22 @@ class ScriptArguments:
|
||||
Whether to push the dataset to the Hugging Face Hub.
|
||||
repo_id (`str`, *optional*, defaults to `"trl-lib/tldr-preference"`):
|
||||
Hugging Face repository ID to push the dataset to.
|
||||
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
Number of workers to use for dataset processing.
|
||||
"""
|
||||
|
||||
push_to_hub: bool = False
|
||||
repo_id: str = "trl-lib/tldr-preference"
|
||||
dataset_num_proc: Optional[int] = None
|
||||
push_to_hub: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
|
||||
)
|
||||
repo_id: str = field(
|
||||
default="trl-lib/tldr-preference",
|
||||
metadata={"help": "Hugging Face repository ID to push the dataset to."},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of workers to use for dataset processing."},
|
||||
)
|
||||
|
||||
|
||||
def to_preference(example):
|
||||
@ -56,6 +66,34 @@ def to_preference(example):
|
||||
return {"prompt": prompt, "chosen": chosen, "rejected": rejected}
|
||||
|
||||
|
||||
model_card = ModelCard("""
|
||||
---
|
||||
tags: [trl]
|
||||
---
|
||||
|
||||
# TL;DR Dataset for Preference Learning
|
||||
|
||||
## Summary
|
||||
|
||||
The TL;DR dataset is a processed version of Reddit posts, specifically curated to train models using the [TRL library](https://github.com/huggingface/trl) for preference learning and Reinforcement Learning from Human Feedback (RLHF) tasks. It leverages the common practice on Reddit where users append "TL;DR" (Too Long; Didn't Read) summaries to lengthy posts, providing a rich source of paired text data for training models to understand and generate concise summaries.
|
||||
|
||||
## Data Structure
|
||||
|
||||
- **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard)
|
||||
- **Type**: [Preference](https://huggingface.co/docs/trl/main/dataset_formats#preference)
|
||||
|
||||
Columns:
|
||||
- `"prompt"`: The unabridged Reddit post.
|
||||
- `"chosen"`: The concise "TL;DR" summary appended by the author.
|
||||
- `"rejected"`: An alternative summary or response that was not selected.
|
||||
|
||||
This structure enables models to learn the relationship between detailed content and its abbreviated form, enhancing their summarization capabilities.
|
||||
|
||||
## Generation script
|
||||
|
||||
The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/tldr_preference.py).
|
||||
""")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
@ -70,3 +108,4 @@ if __name__ == "__main__":
|
||||
|
||||
if script_args.push_to_hub:
|
||||
dataset.push_to_hub(script_args.repo_id)
|
||||
model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
|
||||
|
@ -1,54 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer, HfArgumentParser
|
||||
|
||||
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
|
||||
|
||||
|
||||
"""
|
||||
python -i examples/datasets/tokenize_ds.py --model HuggingFaceH4/zephyr-7b-beta
|
||||
python -i examples/datasets/tokenize_ds.py --model gpt2
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
dataset_name: str = field(
|
||||
default="trl-internal-testing/hh-rlhf-helpful-base-trl-style", metadata={"help": "The dataset to load"}
|
||||
)
|
||||
model: str = field(default="gpt2", metadata={"help": "The model to use for tokenization"})
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
default=None, metadata={"help": "The number of workers to use to tokenize the data"}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
script_args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
|
||||
dataset = load_dataset(script_args.dataset_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.model)
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
|
||||
|
||||
def process(row):
|
||||
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
|
||||
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
|
||||
return row
|
||||
|
||||
dataset = dataset.map(process, num_proc=script_args.dataset_num_proc)
|
||||
print(dataset["train"][0]["chosen"])
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,10 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import ModelCard
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
@ -29,13 +30,22 @@ class ScriptArguments:
|
||||
Whether to push the dataset to the Hugging Face Hub.
|
||||
repo_id (`str`, *optional*, defaults to `"trl-lib/ultrafeedback-prompt"`):
|
||||
Hugging Face repository ID to push the dataset to.
|
||||
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
Number of workers to use for dataset processing.
|
||||
"""
|
||||
|
||||
push_to_hub: bool = False
|
||||
repo_id: str = "trl-lib/ultrafeedback-prompt"
|
||||
dataset_num_proc: Optional[int] = None
|
||||
push_to_hub: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
|
||||
)
|
||||
repo_id: str = field(
|
||||
default="trl-lib/ultrafeedback-prompt",
|
||||
metadata={"help": "Hugging Face repository ID to push the dataset to."},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of workers to use for dataset processing."},
|
||||
)
|
||||
|
||||
|
||||
def to_unpaired_preference(example):
|
||||
@ -50,6 +60,30 @@ def drop_long_prompt(example):
|
||||
return True
|
||||
|
||||
|
||||
model_card = ModelCard("""
|
||||
---
|
||||
tags: [trl]
|
||||
---
|
||||
|
||||
# UltraFeedback - Prompts Dataset
|
||||
|
||||
## Summary
|
||||
|
||||
The UltraFeedback - Prompts dataset is a processed version of the [UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset for model evaluation on specific aspects like helpfulness, honesty, and instruction-following.
|
||||
|
||||
## Data Structure
|
||||
|
||||
- **Format**: [Conversational](https://huggingface.co/docs/trl/main/dataset_formats#conversational)
|
||||
- **Type**: [Prompt-only](https://huggingface.co/docs/trl/main/dataset_formats#prompt-only)
|
||||
|
||||
Column:
|
||||
- `"prompt"`: The input question or instruction provided to the model.
|
||||
|
||||
## Generation script
|
||||
|
||||
The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/ultrafeedback-prompt.py).
|
||||
""")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
@ -66,3 +100,4 @@ if __name__ == "__main__":
|
||||
|
||||
if script_args.push_to_hub:
|
||||
dataset.push_to_hub(script_args.repo_id)
|
||||
model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,10 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import ModelCard
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
@ -27,46 +28,61 @@ class ScriptArguments:
|
||||
Args:
|
||||
model_name (`str`, *optional*, defaults to `"gpt-3.5-turbo"`):
|
||||
Language model to target. Possible values are:
|
||||
|
||||
- `"alpaca-7b"`
|
||||
- `"bard"`
|
||||
- `"falcon-40b-instruct"`
|
||||
- `"gpt-3.5-turbo"` (default)
|
||||
- `"gpt-4"`
|
||||
- `"llama-2-13b-chat"`
|
||||
- `"llama-2-70b-chat"`
|
||||
- `"llama-2-7b-chat"`
|
||||
- `"mpt-30b-chat"`
|
||||
- `"pythia-12b"`
|
||||
- `"starchat"`
|
||||
- `"ultralm-13b"`
|
||||
- `"ultralm-65b"`
|
||||
- `"vicuna-33b"`
|
||||
- `"wizardlm-13b"`
|
||||
- `"wizardlm-70b"`
|
||||
- `"wizardlm-7b"`
|
||||
|
||||
aspect (`str`, *optional*, defaults to `"helpfulness"`):
|
||||
Aspect to target. Possible values are:
|
||||
|
||||
- `"helpfulness"` (default)
|
||||
- `"honesty"`
|
||||
- `"instruction-following"`
|
||||
- `"truthfulness"`
|
||||
|
||||
Aspect to target.
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether to push the dataset to the Hugging Face Hub.
|
||||
repo_id (`str`, *optional*, defaults to `"trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness"`):
|
||||
Hugging Face repository ID to push the dataset to.
|
||||
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
|
||||
dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
|
||||
Number of workers to use for dataset processing.
|
||||
"""
|
||||
|
||||
model_name: str = "gpt-3.5-turbo"
|
||||
aspect: str = "helpfulness"
|
||||
push_to_hub: bool = False
|
||||
repo_id: str = "trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness"
|
||||
dataset_num_proc: Optional[int] = None
|
||||
model_name: str = field(
|
||||
default="gpt-3.5-turbo",
|
||||
metadata={
|
||||
"help": "Language model to target.",
|
||||
"choices": [
|
||||
"alpaca-7b",
|
||||
"bard",
|
||||
"falcon-40b-instruct",
|
||||
"gpt-3.5-turbo",
|
||||
"gpt-4",
|
||||
"llama-2-13b-chat",
|
||||
"llama-2-70b-chat",
|
||||
"llama-2-7b-chat",
|
||||
"mpt-30b-chat",
|
||||
"pythia-12b",
|
||||
"starchat",
|
||||
"ultralm-13b",
|
||||
"ultralm-65b",
|
||||
"vicuna-33b",
|
||||
"wizardlm-13b",
|
||||
"wizardlm-70b",
|
||||
"wizardlm-7b",
|
||||
],
|
||||
},
|
||||
)
|
||||
aspect: str = field(
|
||||
default="helpfulness",
|
||||
metadata={
|
||||
"help": "Aspect to target. Possible values are: 'helpfulness' (default), 'honesty', "
|
||||
"'instruction-following', 'truthfulness'.",
|
||||
"choices": ["helpfulness", "honesty", "instruction-following", "truthfulness"],
|
||||
},
|
||||
)
|
||||
push_to_hub: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
|
||||
)
|
||||
repo_id: str = field(
|
||||
default="trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness",
|
||||
metadata={"help": "Hugging Face repository ID to push the dataset to."},
|
||||
)
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "Number of workers to use for dataset processing."},
|
||||
)
|
||||
|
||||
|
||||
def to_unpaired_preference(example, model_name, aspect):
|
||||
@ -79,6 +95,32 @@ def to_unpaired_preference(example, model_name, aspect):
|
||||
return {"prompt": prompt, "completion": completion, "label": label}
|
||||
|
||||
|
||||
model_card = ModelCard("""
|
||||
---
|
||||
tags: [trl]
|
||||
---
|
||||
|
||||
# UltraFeedback GPT-3.5-Turbo Helpfulness Dataset
|
||||
|
||||
## Summary
|
||||
|
||||
The UltraFeedback GPT-3.5-Turbo Helpfulness dataset contains processed user-assistant interactions filtered for helpfulness, derived from the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. It is designed for fine-tuning and evaluating models in alignment tasks.
|
||||
|
||||
## Data Structure
|
||||
|
||||
- **Format**: [Conversational](https://huggingface.co/docs/trl/main/dataset_formats#conversational)
|
||||
- **Type**: [Unpaired preference](https://huggingface.co/docs/trl/main/dataset_formats#unpaired-preference)
|
||||
|
||||
Column:
|
||||
- `"prompt"`: The input question or instruction provided to the model.
|
||||
- `"completion"`: The model's response to the prompt.
|
||||
- `"label"`: A binary value indicating whether the response is sufficiently helpful.
|
||||
|
||||
## Generation script
|
||||
|
||||
The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/ultrafeedback.py).
|
||||
""")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
@ -100,3 +142,4 @@ if __name__ == "__main__":
|
||||
|
||||
if script_args.push_to_hub:
|
||||
dataset.push_to_hub(script_args.repo_id)
|
||||
model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
|
||||
|
@ -13,7 +13,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<div style=\"text-align: center\">\n",
|
||||
"<img src='https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/gpt2_bert_training.png' width='600'>\n",
|
||||
"<img src='https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/gpt2_bert_training.png' width='600'>\n",
|
||||
"<p style=\"text-align: center;\"> <b>Figure:</b> Experiment setup to tune GPT2. The yellow arrows are outside the scope of this notebook, but the trained models are available through Hugging Face. </p>\n",
|
||||
"</div>\n",
|
||||
"\n",
|
||||
|
15
examples/research_projects/layer_skip/README.md
Normal file
15
examples/research_projects/layer_skip/README.md
Normal file
@ -0,0 +1,15 @@
|
||||
# LayerSkip Training Recipe
|
||||
|
||||
Implements the training recipe as described in the [LayerSkip paper](https://huggingface.co/papers/2404.16710).
|
||||
|
||||
## Run training
|
||||
```
|
||||
cd scripts
|
||||
python layer_skip_sft.py
|
||||
```
|
||||
|
||||
## Run benchmark
|
||||
```
|
||||
cd scripts
|
||||
python benchmark_layer_skip.py
|
||||
```
|
@ -0,0 +1,77 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import config
|
||||
import torch
|
||||
from torch.utils import benchmark
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
def generate_tokens(model, inputs):
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
do_sample=False,
|
||||
max_new_tokens=64,
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
def generate_tokens_with_assistance(model, inputs, assistant_early_exit):
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
assistant_early_exit=assistant_early_exit,
|
||||
do_sample=False,
|
||||
max_new_tokens=64,
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ckpt = config.hub_model_id
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(ckpt, device_map="auto", torch_dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt)
|
||||
|
||||
prompt = "### Instruction: What are my alarms for the rest of the day?\n ### Response: "
|
||||
|
||||
results = []
|
||||
label = "Generation Times"
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="generate_tokens(model, inputs)",
|
||||
setup="from __main__ import generate_tokens",
|
||||
globals={"model": model, "inputs": inputs},
|
||||
num_threads=torch.get_num_threads(),
|
||||
label=label,
|
||||
sub_label="no layer skip",
|
||||
description="generation",
|
||||
).blocked_autorange()
|
||||
)
|
||||
|
||||
for i in range(1, model.config.num_hidden_layers):
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="generate_tokens_with_assistance(model, inputs, assistant_early_exit)",
|
||||
setup="from __main__ import generate_assistant_tokens",
|
||||
globals={"model": model, "assistant_early_exit": i, "inputs": inputs},
|
||||
num_threads=torch.get_num_threads(),
|
||||
label=label,
|
||||
sub_label=f"layer skip {i}",
|
||||
description="generation",
|
||||
).blocked_autorange()
|
||||
)
|
||||
|
||||
benchmark.Compare(results).print()
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -12,17 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
|
||||
from .ppo_config import PPOConfig
|
||||
from huggingface_hub import whoami
|
||||
|
||||
|
||||
# Define an alias for PPOv2Config that raises a warning
|
||||
class PPOv2Config(PPOConfig):
|
||||
def __init__(self, *args, **kwargs):
|
||||
warnings.warn(
|
||||
"`PPOv2Config` is deprecated and has been renamed to `PPOConfig`. Please use `PPOConfig` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
model_name = "unsloth/Llama-3.2-3B"
|
||||
tokenizer_name = "unsloth/Llama-3.2-3B"
|
||||
dataset_name = "WillHeld/top_v2"
|
||||
|
||||
output_root_dir = "./checkpoints/"
|
||||
hub_model_id = f"{whoami()['name']}/layerskip-{model_name.split('/')[1]}-{dataset_name.split('/')[1]}"
|
||||
output_dir = f"{output_root_dir}/{hub_model_id}"
|
||||
|
||||
per_device_train_batch_size = 8
|
||||
gradient_accumulation_steps = 1
|
||||
learning_rate = 2e-5
|
@ -0,0 +1,48 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from trl import SFTTrainer
|
||||
|
||||
|
||||
class LayerSkipSFTTrainer(SFTTrainer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.early_exit_layer = 0 # initialize with 0
|
||||
self.always_last_layer = True
|
||||
self.early_exit_loss_scale = 1.0
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
self.early_exit_layer = (
|
||||
self.early_exit_layer % (model.config.num_hidden_layers - 1)
|
||||
) + 1 # rotates between [1, num_hidden_layers-1]
|
||||
bs, seqlen = inputs.input_ids.shape
|
||||
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs, output_hidden_states=True)
|
||||
|
||||
hidden_state = outputs["hidden_states"][self.early_exit_layer].to(model.dtype)
|
||||
if self.early_exit_layer != model.config.num_hidden_layers:
|
||||
hidden_state = model.model.norm(hidden_state)
|
||||
logits = model.lm_head(hidden_state)
|
||||
loss_early = model.loss_function(logits=logits, labels=labels, vocab_size=model.vocab_size)
|
||||
|
||||
if self.always_last_layer:
|
||||
loss_last = model.loss_function(logits=outputs["logits"], labels=labels, vocab_size=model.vocab_size)
|
||||
loss = self.early_exit_loss_scale * loss_early.to(loss_last.device) + 1.0 * loss_last
|
||||
# normalize loss scales
|
||||
loss = loss / (1.0 + self.early_exit_loss_scale)
|
||||
else:
|
||||
loss = loss_early
|
||||
|
||||
return loss
|
@ -0,0 +1,91 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import config
|
||||
import torch
|
||||
from custom_trainer import LayerSkipSFTTrainer
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from trl import DataCollatorForCompletionOnlyLM, SFTConfig
|
||||
|
||||
|
||||
def formatting_prompts_func(example):
|
||||
text = f"### Instruction: {example['utterance']}\n ### Response: {example['semantic_parse']}"
|
||||
|
||||
# Inject eos_token as a string before tokenization, because they are not always added
|
||||
# See: https://github.com/huggingface/transformers/issues/22794 and
|
||||
# https://github.com/huggingface/trl/issues/1623
|
||||
if tokenizer.eos_token: # usually something like "</s>" for GPT2 or "<|endoftext|>"
|
||||
text += f"{tokenizer.eos_token}"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# load the dataset
|
||||
print("[INFO] loading the dataset...")
|
||||
train_dataset = load_dataset(config.dataset_name, split="train")
|
||||
|
||||
print(f"output_root_dir: {config.output_root_dir}")
|
||||
print(f"hub_model_id: {config.hub_model_id}")
|
||||
|
||||
# load the model and tokenizer
|
||||
print("[INFO] loading the model and tokenizer...")
|
||||
model = AutoModelForCausalLM.from_pretrained(config.model_name, device_map="auto", torch_dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, add_eos_token=True)
|
||||
|
||||
# adding pad and eos tokens if not provided in the tokenizer
|
||||
if tokenizer.pad_token is None:
|
||||
# Add '[PAD]' token if it doesn't exist
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
model.config.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
if tokenizer.eos_token is None or tokenizer.eos_token == tokenizer.bos_token:
|
||||
# Add '[EOS]' token if it doesn't exist
|
||||
tokenizer.add_special_tokens({"eos_token": "[EOS]"})
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
response_template = " ### Response:"
|
||||
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
|
||||
|
||||
args = SFTConfig(
|
||||
do_train=True,
|
||||
bf16=True,
|
||||
max_seq_length=None,
|
||||
per_device_train_batch_size=config.per_device_train_batch_size,
|
||||
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
||||
learning_rate=config.learning_rate,
|
||||
packing=False,
|
||||
num_train_epochs=1.0,
|
||||
report_to="none",
|
||||
push_to_hub=True,
|
||||
hub_model_id=config.hub_model_id,
|
||||
output_dir=config.output_dir,
|
||||
logging_steps=500,
|
||||
save_steps=1000,
|
||||
save_total_limit=2,
|
||||
)
|
||||
|
||||
trainer = LayerSkipSFTTrainer(
|
||||
model,
|
||||
train_dataset=train_dataset,
|
||||
args=args,
|
||||
formatting_func=formatting_prompts_func,
|
||||
data_collator=collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
@ -236,7 +236,7 @@ class RewardDataCollatorWithPadding:
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
return_tensors: str = "pt"
|
||||
|
||||
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
features_j = []
|
||||
features_k = []
|
||||
for feature in features:
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user