mirror of
https://github.com/huggingface/trl.git
synced 2025-10-21 02:53:59 +08:00
Compare commits
547 Commits
Author | SHA1 | Date | |
---|---|---|---|
a879e6ad5a | |||
4dd0dc2988 | |||
4f59e923ac | |||
10f70fa333 | |||
47ab034ca9 | |||
e755eee660 | |||
ac31d1205e | |||
c44ab6d1e9 | |||
a15a80e0d5 | |||
264f1279fd | |||
0cda2f2f01 | |||
e0ff66103e | |||
3a3ed88f28 | |||
b65657f41d | |||
de024ece28 | |||
2fbc0f4fc2 | |||
cf5168ea7c | |||
1e4fb80cbc | |||
fe41acd6ae | |||
c71262c9c6 | |||
dcee683d96 | |||
4788e5cda5 | |||
6cea2ef964 | |||
64d9816eac | |||
67564fdbbe | |||
e529579232 | |||
dc4cfab700 | |||
66d3a82dd2 | |||
3eda856371 | |||
616a273ac2 | |||
9955583829 | |||
bed205a2d2 | |||
42933fa647 | |||
bbdef00961 | |||
0956dc17cc | |||
a7dc892717 | |||
b0372e66a5 | |||
c1b272f4a6 | |||
f05f63c1ea | |||
54f806b6ff | |||
a9a756553f | |||
96bb3deb32 | |||
dbea3da917 | |||
150a93101b | |||
cbcaa46cd3 | |||
e3fe28ee1a | |||
fb0b9edc24 | |||
fc76fe8d11 | |||
b60ce797d8 | |||
6faf4c0d81 | |||
29bd0046a9 | |||
4867c2a3db | |||
332062372d | |||
b580e45c94 | |||
2004d62c5c | |||
ac7c8b1284 | |||
df12913602 | |||
ddf4c8dc3e | |||
890232fa28 | |||
9929370dee | |||
6171cddee5 | |||
33d2151f4f | |||
8bd2ab82f4 | |||
82b07d6b01 | |||
72bf6c21be | |||
74e54b5946 | |||
393097356c | |||
db8e09e346 | |||
1dae55f90f | |||
c8cef79e6c | |||
7dcf437a19 | |||
4e85bd75a9 | |||
c9d56366ed | |||
4dce042a38 | |||
98ad01ddfd | |||
fef8240c23 | |||
915ffc7c61 | |||
5828a666bf | |||
052a8e14b5 | |||
a2adfb836a | |||
4ebfc5de28 | |||
9e9dc96e67 | |||
7ddef5c158 | |||
a9cddf8c55 | |||
2860ce5091 | |||
30e33bd92d | |||
d5a0d2d345 | |||
314e8eb367 | |||
e10792032b | |||
78045dedc8 | |||
747612f9d3 | |||
9e3a35bd3d | |||
4402b36dcf | |||
78f8228874 | |||
b6af2edc93 | |||
cd85b14fbb | |||
a57544f47a | |||
b68ff96f0c | |||
c8c01cc055 | |||
3479606c8c | |||
7965b78340 | |||
56bd1bba26 | |||
94d53e6617 | |||
b5be100ae0 | |||
6e1652bc5e | |||
65374c6a71 | |||
9956091112 | |||
34d273f227 | |||
3bf94492a8 | |||
ba6abee37f | |||
a57e75967c | |||
ae23d40f3b | |||
83b367b11a | |||
d1ed730ab8 | |||
8f8e95e25d | |||
4e23d958f2 | |||
50c46205b6 | |||
6105d03f92 | |||
e247bbd7d5 | |||
3d04496196 | |||
2d244f8acb | |||
f5168fdbaf | |||
79686e1ac7 | |||
34ebc4ccaf | |||
1d84e2b888 | |||
2f71b8b1e2 | |||
5bcb8ad0d6 | |||
b8b972fde1 | |||
3eb9ccb104 | |||
974b0d380f | |||
39a7d1c121 | |||
0bdc63839f | |||
275d33b3ef | |||
c0819ee99f | |||
a03e7cc4e4 | |||
a13cb8952c | |||
84156f179f | |||
4eb0b905e2 | |||
6c203f9fef | |||
f18253bf2d | |||
151a452d14 | |||
488b502d31 | |||
3c0a10b1ae | |||
b031adfdb8 | |||
e7cb597230 | |||
bc8dfbf4e2 | |||
e4ed7a3a5a | |||
9a7efbd051 | |||
b344bcea2c | |||
35e12dc595 | |||
1da6be18e0 | |||
e249cd802f | |||
a02513c3b7 | |||
13454d2f4b | |||
99f2c94b22 | |||
6401d080c9 | |||
d632a5b289 | |||
5aeb752053 | |||
b8b89783ca | |||
8799952876 | |||
3b4c24946b | |||
0347f583e3 | |||
75de236c09 | |||
7075cec94d | |||
adf17a5a26 | |||
0d40e186ee | |||
683bc5af6f | |||
5f0913122b | |||
d1aa0b6b2c | |||
d88ec14602 | |||
6c18e40e97 | |||
1d0a7ea17b | |||
9f68ead8cf | |||
f30daa4225 | |||
24fd8dd513 | |||
c050ebc073 | |||
abc0584736 | |||
6d1cb85e73 | |||
e90e8d91d2 | |||
113aaae033 | |||
0865572748 | |||
a6532a11c2 | |||
3595eb00e0 | |||
9afd901d0f | |||
e04432d5e3 | |||
75c1c47fcc | |||
a5788ac99b | |||
3bbe7e0407 | |||
edf60e826b | |||
5d1deb1445 | |||
476c4b8dc0 | |||
e823458a6a | |||
1c0d8bca15 | |||
363369a717 | |||
aba4df02c1 | |||
98226473e4 | |||
87f4c70e60 | |||
995f1174da | |||
143e11123d | |||
346c99d222 | |||
087fe544b0 | |||
ebbd37ba99 | |||
e667550a5a | |||
57aebe9c36 | |||
85f5fd220d | |||
4dca169404 | |||
f35b68a301 | |||
5cf863576a | |||
9a28b3fd05 | |||
4f8057ad23 | |||
ab0d11d815 | |||
c674c66a45 | |||
45da5df53e | |||
04fd8d9400 | |||
bf2aed3876 | |||
0ee349dcd4 | |||
7ff6206510 | |||
e4b20ecbc4 | |||
6c2f829bb7 | |||
c4f0f41935 | |||
dc6a934269 | |||
9ce7ac6925 | |||
99553c19ae | |||
2ce8e45bb2 | |||
d1df79f83c | |||
d10f7663b0 | |||
423991c204 | |||
988d4c4e1a | |||
8534f0edf8 | |||
5095e7f948 | |||
9fcf61d706 | |||
66b043a910 | |||
f2c71771cc | |||
631c33cbb3 | |||
3f7ff60528 | |||
1705aebeba | |||
4e622a9033 | |||
eb2d5b2972 | |||
f976c6d234 | |||
abc7301bab | |||
6cfa5cfc81 | |||
a2aa0f0b09 | |||
304e208f77 | |||
4fe8b027f6 | |||
fb6ebb1e11 | |||
66078c7c01 | |||
58c0888996 | |||
486e7a4071 | |||
7630f877f9 | |||
4d862da181 | |||
22b4f548f4 | |||
4219cbfedc | |||
3bd02380c7 | |||
067db7553a | |||
93e85ed808 | |||
14e0d78807 | |||
b32656f726 | |||
9399bc113b | |||
11f122ad49 | |||
009c9a610b | |||
7712d42f8c | |||
7c2213b9e5 | |||
ddeebce176 | |||
cf68d871cf | |||
2a2676e7ec | |||
ca90cba351 | |||
4f97fb4a74 | |||
a46cd84a64 | |||
1f56bffdf8 | |||
1bfe0b8fcb | |||
0f13e51efa | |||
1e77d8aeb2 | |||
3b1911c2a9 | |||
851e7fe556 | |||
31b02d0cd0 | |||
9bc478ecbb | |||
29f162b86c | |||
6852097169 | |||
f12a1da74b | |||
ae87b3aefa | |||
3f7cee7643 | |||
ae8431bd50 | |||
66a976c6bd | |||
814930377c | |||
88685f2cd4 | |||
6f40f20233 | |||
036213bd85 | |||
6042596705 | |||
070c75ec54 | |||
b415224a4a | |||
9186710671 | |||
aa35fec099 | |||
737d771941 | |||
ef441ea028 | |||
af623aeba6 | |||
3843cfc32f | |||
9a71e67be9 | |||
09ca565b24 | |||
4edc688311 | |||
29d439a204 | |||
5760e5d3db | |||
a3c5b7178a | |||
222d275b8a | |||
09ca7607d5 | |||
1e68753216 | |||
1f59eeb9bb | |||
928d14445e | |||
3319993bd1 | |||
4fb3d0c860 | |||
bcccdeb6f9 | |||
ef209e311f | |||
341f6a6787 | |||
97b9fa212a | |||
a7d796c9a2 | |||
fa074e6a15 | |||
776939dcc4 | |||
163ca9f059 | |||
2eeb7b04cf | |||
9f8d0e48ad | |||
c9b7145c75 | |||
baf3c1c293 | |||
b181e401a7 | |||
26da9e80cb | |||
d6cc88ab2c | |||
7a95cc8696 | |||
d1715514de | |||
d116887ed4 | |||
a236c5750f | |||
4ae35afdd6 | |||
b21ed0ddbc | |||
384b868fe6 | |||
3267be0fcd | |||
dbcb2f0021 | |||
d5910b0ff5 | |||
104a02d207 | |||
ad597dbcb3 | |||
d57d0f9ca4 | |||
ec3d41b879 | |||
be32d304db | |||
dc53b8c6b0 | |||
20428c48ba | |||
6614b8aa6b | |||
df7b770da8 | |||
18a33ffcd3 | |||
911d3658e2 | |||
95ec8577df | |||
3539f3e3cd | |||
e451298b50 | |||
3efb484694 | |||
8f5b4923c8 | |||
e0dec27272 | |||
6ef785a6fb | |||
950ee2187d | |||
c1bb1f39f6 | |||
54babd9508 | |||
0c4edb750e | |||
17ec68d980 | |||
9be5680039 | |||
f11e213fd8 | |||
814fe396d4 | |||
06b7959b72 | |||
b07935f867 | |||
2aff709144 | |||
830cadfc4c | |||
f2acd821e0 | |||
f100ca34cc | |||
d708ec272f | |||
8140129595 | |||
48b3ef0b7b | |||
c0ce52ab26 | |||
393dbf6749 | |||
94fa4b022b | |||
cb7819e627 | |||
8f0fc4c8f7 | |||
d275cb431e | |||
7d0a8eea4e | |||
5a233546ee | |||
9fb00cf007 | |||
ee44946814 | |||
7f2401bd6e | |||
23bf9d4b58 | |||
501c347083 | |||
f06f357e9c | |||
4cdc03ab5c | |||
a60ceefa69 | |||
baa8f09cb3 | |||
c859f5fa5f | |||
481ef96293 | |||
6d9ea38ae1 | |||
c203e47fbf | |||
c84e5918a6 | |||
4b67af37b6 | |||
55d7c952c7 | |||
3719f7a929 | |||
e7961e45f1 | |||
b307faf07b | |||
aea1da8e2b | |||
e5eb4db8b5 | |||
28bdb6a373 | |||
e140d22881 | |||
e23a541af9 | |||
be3faa768e | |||
13679aa97e | |||
9e9f024399 | |||
c2884b5096 | |||
2f726ce4e8 | |||
a78a05d7b7 | |||
1b258247cd | |||
9c93dec05e | |||
d1dad6ebda | |||
8ce810250e | |||
8e9cae8072 | |||
654543a8cf | |||
c273b18c1c | |||
6c6ff24926 | |||
6ff0fac2c1 | |||
951ca1841f | |||
cc1de9820a | |||
a64a522fcc | |||
5b32372b71 | |||
d759004e52 | |||
cbc6c9bb3e | |||
f3cd86578b | |||
b763432eaf | |||
2bbd594ec5 | |||
b89b712dbf | |||
ec9e76623e | |||
d192244f54 | |||
051d5a1f61 | |||
2068fdcd93 | |||
02f5c1d8ce | |||
7de7db6765 | |||
4e7d5b5abe | |||
a90e13321b | |||
5b2aeca6c0 | |||
1f3314fd2f | |||
304ee70eef | |||
0a5aee7d99 | |||
db592a2eb6 | |||
122edc8f5d | |||
f91fb2bda2 | |||
01e4ad0009 | |||
1e56ff0f16 | |||
c4ed3274be | |||
14b6bc6691 | |||
eb4d2f381a | |||
78e08bd658 | |||
96d4854455 | |||
3ef21a24e7 | |||
f7707fd4c6 | |||
dd9b8f4189 | |||
ddd318865b | |||
8aa12d3c95 | |||
95aea7c072 | |||
eda1f36c57 | |||
ac0d5b726d | |||
6826d592ae | |||
c058ee6f05 | |||
fbeb146eea | |||
98845b9282 | |||
9f6326e65a | |||
7dcc71b1a6 | |||
6b73adc900 | |||
249d3e3259 | |||
ad8d50e30d | |||
d608fea0d1 | |||
92b03f5fdc | |||
7877e92991 | |||
1d7e3c2ae2 | |||
eb6aa20401 | |||
b8f0c4cf12 | |||
e11a45c5d8 | |||
08cfc4179b | |||
d603e7c527 | |||
5d30cd4d30 | |||
46975236be | |||
9a8d52cc5a | |||
0a6c42c12c | |||
221be13d26 | |||
a922af6927 | |||
42e7a0a824 | |||
15d52e759b | |||
24e914a0ab | |||
637612d95f | |||
35694baef2 | |||
d2f27df50a | |||
5cee9a0478 | |||
3f7710aed7 | |||
ca0af3944d | |||
e4f9a483d9 | |||
80890b17be | |||
cf9d2a7133 | |||
c02ce6d3f5 | |||
9141aa42ba | |||
05723c0b88 | |||
b87ec2d5a0 | |||
27df071ad8 | |||
67452ef213 | |||
22a90198e5 | |||
4f81e7736d | |||
14292b08af | |||
453c4eca14 | |||
decc832d3e | |||
1111295776 | |||
c04074e248 | |||
d484dc2a93 | |||
34e6948d45 | |||
9f69f06a1c | |||
5bb46687c5 | |||
25d6700c5e | |||
4d31d0c4f8 | |||
0ff39d2a87 | |||
b4899b29d2 | |||
6aae9e75f3 | |||
79b90e19ba | |||
7f636c9ed7 | |||
98d8cc509d | |||
9d09b3e107 | |||
336d63eb80 | |||
7fc970983c | |||
d3bbee3ab8 | |||
eb5465df7e | |||
1c272240ac | |||
b095245830 | |||
c115453fba | |||
16f214c58d | |||
e9a437992e | |||
c837fbe5b9 | |||
01c4a35928 | |||
1aca98fbcf | |||
029f961b7c | |||
8ec912ffa6 | |||
f360c37466 | |||
217313014b | |||
b946e875b1 | |||
6dd50b45d8 | |||
98120d6aeb | |||
3b2c820db6 | |||
25fd6f2313 | |||
3f1477cdc0 | |||
2cff1e4385 | |||
d7d7902938 | |||
77b0cc1707 | |||
17f22c1c20 | |||
e448bb69f0 | |||
9aa4e3ce2b | |||
ca8a508913 |
67
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
Normal file
67
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
Normal file
@ -0,0 +1,67 @@
|
||||
name: "\U0001F41B Bug Report"
|
||||
description: Submit a bug report to help us improve TRL
|
||||
labels: [ "bug" ]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Thanks for taking the time to fill out this bug report! 🤗
|
||||
|
||||
Before you submit your bug report:
|
||||
|
||||
- If it is your first time submitting, be sure to check our [bug report guidelines](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#did-you-find-a-bug)
|
||||
|
||||
- type: textarea
|
||||
id: system-info
|
||||
attributes:
|
||||
label: System Info
|
||||
description: Please share your system info with us. You can run the command `transformers-cli env` and copy-paste its output below.
|
||||
placeholder: trl version, transformers version, platform, python version, ...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: checkboxes
|
||||
id: information-scripts-examples
|
||||
attributes:
|
||||
label: Information
|
||||
description: 'The problem arises when using:'
|
||||
options:
|
||||
- label: "The official example scripts"
|
||||
- label: "My own modified scripts"
|
||||
|
||||
- type: checkboxes
|
||||
id: information-tasks
|
||||
attributes:
|
||||
label: Tasks
|
||||
description: "The tasks I am working on are:"
|
||||
options:
|
||||
- label: "An officially supported task in the `examples` folder"
|
||||
- label: "My own task or dataset (give details below)"
|
||||
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Reproduction
|
||||
description: |
|
||||
Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
|
||||
If you have code snippets, error messages, stack traces please provide them here as well.
|
||||
Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
|
||||
Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
|
||||
|
||||
placeholder: |
|
||||
Steps to reproduce the behavior:
|
||||
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Expected behavior
|
||||
description: "A clear and concise description of what you would expect to happen."
|
31
.github/ISSUE_TEMPLATE/feature-request.yml
vendored
Normal file
31
.github/ISSUE_TEMPLATE/feature-request.yml
vendored
Normal file
@ -0,0 +1,31 @@
|
||||
name: "\U0001F680 Feature request"
|
||||
description: Submit a proposal/request for a new TRL feature
|
||||
labels: [ "Feature request" ]
|
||||
body:
|
||||
- type: textarea
|
||||
id: feature-request
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Feature request
|
||||
description: |
|
||||
A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist.
|
||||
|
||||
- type: textarea
|
||||
id: motivation
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Motivation
|
||||
description: |
|
||||
Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too.
|
||||
|
||||
|
||||
- type: textarea
|
||||
id: contribution
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Your contribution
|
||||
description: |
|
||||
Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md)
|
32
.github/ISSUE_TEMPLATE/new-trainer-addition.yml
vendored
Normal file
32
.github/ISSUE_TEMPLATE/new-trainer-addition.yml
vendored
Normal file
@ -0,0 +1,32 @@
|
||||
name: "\U0001F31F New trainer addition"
|
||||
description: Submit a proposal/request to implement a new trainer for a post-training method
|
||||
labels: [ "New trainer" ]
|
||||
|
||||
body:
|
||||
- type: textarea
|
||||
id: description-request
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Method description
|
||||
description: |
|
||||
Put any and all important information relative to the method
|
||||
|
||||
- type: checkboxes
|
||||
id: information-tasks
|
||||
attributes:
|
||||
label: Open source status
|
||||
description: |
|
||||
Please note that if the method implementation isn't available or model weights with training datasets aren't available, we are less likely to implement it in `trl`.
|
||||
options:
|
||||
- label: "The method implementation is available"
|
||||
- label: "The model weights are available"
|
||||
- label: "The training datasets are available"
|
||||
|
||||
- type: textarea
|
||||
id: additional-info
|
||||
attributes:
|
||||
label: Provide useful links for the implementation
|
||||
description: |
|
||||
Please provide information regarding the implementation, the weights, and the authors.
|
||||
Please mention the authors by @gh-username if you're aware of their usernames.
|
32
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
32
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
@ -0,0 +1,32 @@
|
||||
# What does this PR do?
|
||||
|
||||
<!--
|
||||
Congratulations! You've made it this far! You're not quite done yet though.
|
||||
|
||||
Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution.
|
||||
|
||||
Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change.
|
||||
|
||||
Once you're done, someone will review your PR shortly. They may suggest changes to make the code even better.
|
||||
-->
|
||||
|
||||
<!-- Remove if not applicable -->
|
||||
|
||||
Fixes # (issue)
|
||||
|
||||
|
||||
## Before submitting
|
||||
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
|
||||
- [ ] Did you read the [contributor guideline](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#create-a-pull-request),
|
||||
Pull Request section?
|
||||
- [ ] Was this discussed/approved via a GitHub issue? Please add a link
|
||||
to it if that's the case.
|
||||
- [ ] Did you make sure to update the documentation with your changes? Here are the
|
||||
[documentation guidelines](https://github.com/huggingface/trl/tree/main/docs).
|
||||
- [ ] Did you write any new necessary tests?
|
||||
|
||||
|
||||
## Who can review?
|
||||
|
||||
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
|
||||
members/contributors who may be interested in your PR.
|
5
.github/workflows/build_documentation.yml
vendored
5
.github/workflows/build_documentation.yml
vendored
@ -13,8 +13,7 @@ jobs:
|
||||
with:
|
||||
commit_sha: ${{ github.sha }}
|
||||
package: trl
|
||||
repo_owner: lvwerra
|
||||
version_tag_suffix: ""
|
||||
custom_container: huggingface/transformers-doc-builder
|
||||
secrets:
|
||||
token: ${{ secrets.HUGGINGFACE_PUSH }}
|
||||
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
||||
hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
|
||||
|
4
.github/workflows/build_pr_documentation.yml
vendored
4
.github/workflows/build_pr_documentation.yml
vendored
@ -14,5 +14,5 @@ jobs:
|
||||
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||
pr_number: ${{ github.event.number }}
|
||||
package: trl
|
||||
repo_owner: lvwerra
|
||||
version_tag_suffix: ""
|
||||
version_tag_suffix: ""
|
||||
custom_container: huggingface/transformers-doc-builder
|
||||
|
2
.github/workflows/clear_cache.yml
vendored
2
.github/workflows/clear_cache.yml
vendored
@ -10,7 +10,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Cleanup
|
||||
run: |
|
||||
|
13
.github/workflows/delete_doc_comment.yml
vendored
13
.github/workflows/delete_doc_comment.yml
vendored
@ -1,13 +0,0 @@
|
||||
name: Delete doc comment
|
||||
|
||||
on:
|
||||
workflow_run:
|
||||
workflows: ["Delete doc comment trigger"]
|
||||
types:
|
||||
- completed
|
||||
|
||||
jobs:
|
||||
delete:
|
||||
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment.yml@main
|
||||
secrets:
|
||||
comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
|
12
.github/workflows/delete_doc_comment_trigger.yml
vendored
12
.github/workflows/delete_doc_comment_trigger.yml
vendored
@ -1,12 +0,0 @@
|
||||
name: Delete doc comment trigger
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [ closed ]
|
||||
|
||||
|
||||
jobs:
|
||||
delete:
|
||||
uses: huggingface/doc-builder/.github/workflows/delete_doc_comment_trigger.yml@main
|
||||
with:
|
||||
pr_number: ${{ github.event.number }}
|
95
.github/workflows/docker-build.yml
vendored
Normal file
95
.github/workflows/docker-build.yml
vendored
Normal file
@ -0,0 +1,95 @@
|
||||
name: Build Docker images (scheduled)
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
workflow_call:
|
||||
schedule:
|
||||
- cron: "0 1 * * *"
|
||||
|
||||
concurrency:
|
||||
group: docker-image-builds
|
||||
cancel-in-progress: false
|
||||
|
||||
env:
|
||||
CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
|
||||
|
||||
jobs:
|
||||
trl-latest:
|
||||
name: "Latest TRL GPU"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cleanup disk
|
||||
run: |
|
||||
sudo ls -l /usr/local/lib/
|
||||
sudo ls -l /usr/share/
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push GPU
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: ./docker/trl-latest-gpu
|
||||
push: true
|
||||
tags: huggingface/trl-latest-gpu
|
||||
|
||||
- name: Post to Slack
|
||||
if: always()
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: 🤗 Results of the trl-latest-gpu Docker Image build
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
trl-source:
|
||||
name: "Latest TRL + HF ecosystem from source"
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cleanup disk
|
||||
run: |
|
||||
sudo ls -l /usr/local/lib/
|
||||
sudo ls -l /usr/share/
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo du -sh /usr/local/lib/
|
||||
sudo du -sh /usr/share/
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push GPU
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: ./docker/trl-source-gpu
|
||||
push: true
|
||||
tags: huggingface/trl-source-gpu
|
||||
|
||||
- name: Post to Slack
|
||||
if: always()
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: 🤗 Results of the trl-source-gpu Docker Image build
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
96
.github/workflows/slow-tests.yml
vendored
Normal file
96
.github/workflows/slow-tests.yml
vendored
Normal file
@ -0,0 +1,96 @@
|
||||
name: Slow tests (on push)
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
# Run only when python files are modified
|
||||
- "trl/**.py"
|
||||
- "examples/**.py"
|
||||
env:
|
||||
RUN_SLOW: "yes"
|
||||
IS_GITHUB_CI: "1"
|
||||
SLACK_API_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
|
||||
jobs:
|
||||
run_all_tests_single_gpu:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"]
|
||||
runs-on: [self-hosted, single-gpu, nvidia-gpu, t4, ci]
|
||||
env:
|
||||
CUDA_VISIBLE_DEVICES: "0"
|
||||
TEST_TYPE: "single_gpu_${{ matrix.docker-image-name }}"
|
||||
container:
|
||||
image: ${{ matrix.docker-image-name }}
|
||||
options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Pip install
|
||||
run: |
|
||||
source activate trl
|
||||
pip install -e ".[test]" --no-deps
|
||||
pip install pytest-reportlog parameterized
|
||||
|
||||
- name: Run slow SFT tests on single GPU
|
||||
if: always()
|
||||
run: |
|
||||
source activate trl
|
||||
make slow_tests
|
||||
|
||||
- name: Generate Report
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
|
||||
run_all_tests_multi_gpu:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"]
|
||||
runs-on: [self-hosted, multi-gpu, nvidia-gpu, t4, ci]
|
||||
env:
|
||||
CUDA_VISIBLE_DEVICES: "0,1"
|
||||
TEST_TYPE: "multi_gpu_${{ matrix.docker-image-name }}"
|
||||
container:
|
||||
image: ${{ matrix.docker-image-name }}
|
||||
options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Pip install
|
||||
run: |
|
||||
source activate trl
|
||||
pip install -e ".[test]" --no-deps
|
||||
pip install pytest-reportlog parameterized
|
||||
|
||||
- name: Run slow SFT tests on Multi GPU
|
||||
if: always()
|
||||
run: |
|
||||
source activate trl
|
||||
make slow_tests
|
||||
|
||||
- name: Run end-to-end examples tests on multi GPU
|
||||
if: always()
|
||||
run: |
|
||||
source activate trl
|
||||
pip install deepspeed
|
||||
make test_examples
|
||||
|
||||
- name: Generate Reports
|
||||
if: always()
|
||||
run: |
|
||||
pip install slack_sdk tabulate
|
||||
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
python scripts/log_example_reports.py --text_file_name temp_results_sft_tests.txt >> $GITHUB_STEP_SUMMARY
|
||||
python scripts/log_example_reports.py --text_file_name temp_results_dpo_tests.txt >> $GITHUB_STEP_SUMMARY
|
||||
rm *.txt
|
6
.github/workflows/stale.yml
vendored
6
.github/workflows/stale.yml
vendored
@ -7,15 +7,15 @@ on:
|
||||
jobs:
|
||||
close_stale_issues:
|
||||
name: Close Stale Issues
|
||||
if: github.repository == 'lvwerra/trl'
|
||||
if: github.repository == 'huggingface/trl'
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.8
|
||||
|
||||
|
46
.github/workflows/tests-main.yml
vendored
Normal file
46
.github/workflows/tests-main.yml
vendored
Normal file
@ -0,0 +1,46 @@
|
||||
name: tests on transformers PEFT main
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
|
||||
env:
|
||||
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.9', '3.10', '3.11']
|
||||
os: ['ubuntu-latest', 'windows-latest']
|
||||
fail-fast: false
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
setup.py
|
||||
requirements.txt
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
# install PEFT & transformers from source
|
||||
pip install -U git+https://github.com/huggingface/peft.git
|
||||
pip install -U git+https://github.com/huggingface/transformers.git
|
||||
# cpu version of pytorch
|
||||
pip install ".[test, diffusers]"
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
- name: Post to Slack
|
||||
if: always()
|
||||
uses: huggingface/hf-workflows/.github/actions/post-slack@main
|
||||
with:
|
||||
slack_channel: ${{ env.CI_SLACK_CHANNEL }}
|
||||
title: 🤗 Results of the TRL CI on transformers/PEFT main
|
||||
status: ${{ job.status }}
|
||||
slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
49
.github/workflows/tests.yml
vendored
49
.github/workflows/tests.yml
vendored
@ -5,6 +5,16 @@ on:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
paths:
|
||||
# Run only when relevant files are modified
|
||||
- "trl/**.py"
|
||||
- "examples/**.py"
|
||||
- "scripts/**.py"
|
||||
- ".github/**.yml"
|
||||
- "tests/**.py"
|
||||
|
||||
env:
|
||||
TQDM_DISABLE: 1
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
@ -14,15 +24,15 @@ jobs:
|
||||
python-version: [3.9]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: recursive
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- uses: pre-commit/action@v2.0.3
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
with:
|
||||
extra_args: --all-files
|
||||
|
||||
@ -30,19 +40,44 @@ jobs:
|
||||
needs: check_code_quality
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.8', '3.9', '3.10']
|
||||
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
|
||||
python-version: ['3.9', '3.10', '3.11']
|
||||
os: ['ubuntu-latest', 'windows-latest']
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
setup.py
|
||||
requirements.txt
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
# install PEFT & transformers from source
|
||||
pip install -U git+https://github.com/huggingface/peft.git
|
||||
pip install -U git+https://github.com/huggingface/transformers.git
|
||||
# cpu version of pytorch
|
||||
pip install ".[test, diffusers]"
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
|
||||
tests_no_optional_dep:
|
||||
needs: check_code_quality
|
||||
runs-on: 'ubuntu-latest'
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python 3.9
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.9'
|
||||
cache: "pip"
|
||||
cache-dependency-path: |
|
||||
setup.py
|
||||
requirements.txt
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
15
.github/workflows/trufflehog.yml
vendored
Normal file
15
.github/workflows/trufflehog.yml
vendored
Normal file
@ -0,0 +1,15 @@
|
||||
on:
|
||||
push:
|
||||
|
||||
name: Secret Leaks
|
||||
|
||||
jobs:
|
||||
trufflehog:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Secret Scanning
|
||||
uses: trufflesecurity/trufflehog@main
|
6
.gitignore
vendored
6
.gitignore
vendored
@ -1,3 +1,4 @@
|
||||
benchmark/trl
|
||||
*.bak
|
||||
.gitattributes
|
||||
.last_checked
|
||||
@ -142,4 +143,7 @@ checklink/cookies.txt
|
||||
# wandb files
|
||||
nbs/wandb/
|
||||
examples/notebooks/wandb/
|
||||
wandb/
|
||||
wandb/
|
||||
|
||||
# cli scripts that are symlinked from `examples/scripts`
|
||||
trl/commands/scripts/
|
@ -1,37 +1,10 @@
|
||||
repos:
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 5.12.0
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.2.0
|
||||
hooks:
|
||||
- id: isort
|
||||
args:
|
||||
- --profile=black
|
||||
- --skip-glob=wandb/**/*
|
||||
- --thirdparty=wandb
|
||||
- repo: https://github.com/myint/autoflake
|
||||
rev: v1.4
|
||||
hooks:
|
||||
- id: autoflake
|
||||
args:
|
||||
- -r
|
||||
- --exclude=wandb,__init__.py
|
||||
- --in-place
|
||||
- --remove-unused-variables
|
||||
- --remove-all-unused-imports
|
||||
- repo: https://github.com/python/black
|
||||
rev: 22.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
args:
|
||||
- --line-length=119
|
||||
- --target-version=py38
|
||||
- --exclude=wandb
|
||||
- repo: https://github.com/pycqa/flake8
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
args:
|
||||
- --ignore=E203,E501,W503,E128
|
||||
- --max-line-length=119
|
||||
- id: ruff
|
||||
args: [ --fix ]
|
||||
- id: ruff-format
|
||||
|
||||
# - repo: https://github.com/codespell-project/codespell
|
||||
# rev: v2.1.0
|
||||
|
@ -17,7 +17,7 @@ authors:
|
||||
family-names: Thrush
|
||||
- given-names: Nathan
|
||||
family-names: Lambert
|
||||
repository-code: 'https://github.com/lvwerra/trl'
|
||||
repository-code: 'https://github.com/huggingface/trl'
|
||||
abstract: "With trl you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the transformers library by \U0001F917 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point, most decoder and encoder-decoder architectures are supported."
|
||||
keywords:
|
||||
- rlhf
|
||||
|
133
CODE_OF_CONDUCT.md
Normal file
133
CODE_OF_CONDUCT.md
Normal file
@ -0,0 +1,133 @@
|
||||
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to make participation in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
nationality, personal appearance, race, caste, color, religion, or sexual
|
||||
identity and orientation.
|
||||
|
||||
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||
diverse, inclusive, and healthy community.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to a positive environment for our
|
||||
community include:
|
||||
|
||||
* Demonstrating empathy and kindness toward other people
|
||||
* Being respectful of differing opinions, viewpoints, and experiences
|
||||
* Giving and gracefully accepting constructive feedback
|
||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||
and learning from the experience
|
||||
* Focusing on what is best not just for us as individuals, but for the overall
|
||||
community
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
* The use of sexualized language or imagery, and sexual attention or advances of
|
||||
any kind
|
||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or email address,
|
||||
without their explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for clarifying and enforcing our standards of
|
||||
acceptable behavior and will take appropriate and fair corrective action in
|
||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||
or harmful.
|
||||
|
||||
Community leaders have the right and responsibility to remove, edit, or reject
|
||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||
decisions when appropriate.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces, and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official e-mail address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement at
|
||||
feedback@huggingface.co.
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
reporter of any incident.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining
|
||||
the consequences for any action they deem in violation of this Code of Conduct:
|
||||
|
||||
### 1. Correction
|
||||
|
||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||
unprofessional or unwelcome in the community.
|
||||
|
||||
**Consequence**: A private, written warning from community leaders, providing
|
||||
clarity around the nature of the violation and an explanation of why the
|
||||
behavior was inappropriate. A public apology may be requested.
|
||||
|
||||
### 2. Warning
|
||||
|
||||
**Community Impact**: A violation through a single incident or series of
|
||||
actions.
|
||||
|
||||
**Consequence**: A warning with consequences for continued behavior. No
|
||||
interaction with the people involved, including unsolicited interaction with
|
||||
those enforcing the Code of Conduct, for a specified period of time. This
|
||||
includes avoiding interactions in community spaces as well as external channels
|
||||
like social media. Violating these terms may lead to a temporary or permanent
|
||||
ban.
|
||||
|
||||
### 3. Temporary Ban
|
||||
|
||||
**Community Impact**: A serious violation of community standards, including
|
||||
sustained inappropriate behavior.
|
||||
|
||||
**Consequence**: A temporary ban from any sort of interaction or public
|
||||
communication with the community for a specified period of time. No public or
|
||||
private interaction with the people involved, including unsolicited interaction
|
||||
with those enforcing the Code of Conduct, is allowed during this period.
|
||||
Violating these terms may lead to a permanent ban.
|
||||
|
||||
### 4. Permanent Ban
|
||||
|
||||
**Community Impact**: Demonstrating a pattern of violation of community
|
||||
standards, including sustained inappropriate behavior, harassment of an
|
||||
individual, or aggression toward or disparagement of classes of individuals.
|
||||
|
||||
**Consequence**: A permanent ban from any sort of public interaction within the
|
||||
community.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||
version 2.1, available at
|
||||
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
|
||||
|
||||
Community Impact Guidelines were inspired by
|
||||
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
|
||||
|
||||
For answers to common questions about this code of conduct, see the FAQ at
|
||||
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
|
||||
[https://www.contributor-covenant.org/translations][translations].
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
|
||||
[Mozilla CoC]: https://github.com/mozilla/diversity
|
||||
[FAQ]: https://www.contributor-covenant.org/faq
|
||||
[translations]: https://www.contributor-covenant.org/translations
|
267
CONTRIBUTING.md
267
CONTRIBUTING.md
@ -1,53 +1,258 @@
|
||||
# How to contribute
|
||||
# How to contribute to TRL?
|
||||
|
||||
## How to get started
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code
|
||||
contributions are not the only way to help the community. Answering questions, helping
|
||||
others, and improving the documentation are also immensely valuable.
|
||||
|
||||
Before you start contributing make sure you installed all the dev tools:
|
||||
It also helps us if you spread the word! Reference the library in blog posts
|
||||
about the awesome projects it made possible, shout out on Twitter every time it has
|
||||
helped you, or simply ⭐️ the repository to say thank you.
|
||||
|
||||
However you choose to contribute, please be mindful and respect our
|
||||
[code of conduct](https://github.com/huggingface/trl/blob/main/CODE_OF_CONDUCT.md).
|
||||
|
||||
**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).**
|
||||
|
||||
## Ways to contribute
|
||||
|
||||
There are several ways you can contribute to TRL:
|
||||
|
||||
* Fix outstanding issues with the existing code.
|
||||
* Submit issues related to bugs or desired new features.
|
||||
* Implement trainers for new post-training algorithms.
|
||||
* Contribute to the examples or to the documentation.
|
||||
|
||||
If you don't know where to start, there is a special [Good First
|
||||
Issue](https://github.com/huggingface/trl/contribute) listing. It will give you a list of
|
||||
open issues that are beginner-friendly and help you start contributing to open-source. The best way to do that is to open a Pull Request and link it to the issue that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over.
|
||||
|
||||
For something slightly more challenging, you can also take a look at the [Good Second Issue](https://github.com/huggingface/trl/labels/Good%20Second%20Issue) list. In general though, if you feel like you know what you're doing, go for it and we'll help you get there! 🚀
|
||||
|
||||
> All contributions are equally valuable to the community. 🥰
|
||||
|
||||
Before you start contributing make sure you have installed all the dev tools:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
make dev
|
||||
```
|
||||
|
||||
## Did you find a bug?
|
||||
## Fixing outstanding issues
|
||||
|
||||
* Ensure the bug was not already reported by searching on GitHub under Issues.
|
||||
* If you're unable to find an open issue addressing the problem, open a new one. Be sure to include a title and clear description, as much relevant information as possible, and a code sample or an executable test case demonstrating the expected behavior that is not occurring.
|
||||
* Be sure to add the complete error messages.
|
||||
If you notice an issue with the existing code and have a fix in mind, feel free to [start contributing](#create-a-pull-request) and open a Pull Request!
|
||||
|
||||
#### Did you write a patch that fixes a bug?
|
||||
## Submitting a bug-related issue or feature request
|
||||
|
||||
* Open a new GitHub pull request with the patch.
|
||||
* Ensure that your PR includes a test that fails without your patch, and pass with it.
|
||||
* Ensure the PR description clearly describes the problem and solution. Include the relevant issue number if applicable.
|
||||
Do your best to follow these guidelines when submitting a bug-related issue or a feature request. It will make it easier for us to come back to you quickly and with good feedback.
|
||||
|
||||
## PR submission guidelines
|
||||
### Did you find a bug?
|
||||
|
||||
* Keep each PR focused. While it's more convenient, do not combine several unrelated fixes together. Create as many branches as needing to keep each PR focused.
|
||||
* Do not mix style changes/fixes with "functional" changes. It's very difficult to review such PRs and it most likely get rejected.
|
||||
* Do not add/remove vertical whitespace. Preserve the original style of the file you edit as much as you can.
|
||||
* Do not turn an already submitted PR into your development playground. If after you submitted PR, you discovered that more work is needed - close the PR, do the required work and then submit a new PR. Otherwise each of your commits requires attention from maintainers of the project.
|
||||
* If, however, you submitted a PR and received a request for changes, you should proceed with commits inside that PR, so that the maintainer can see the incremental fixes and won't need to review the whole PR again. In the exception case where you realize it'll take many many commits to complete the requests, then it's probably best to close the PR, do the work and then submit it again. Use common sense where you'd choose one way over another.
|
||||
The TRL library is robust and reliable thanks to users who report the problems they encounter.
|
||||
|
||||
### Before you submit a PR
|
||||
Before you report an issue, we would really appreciate it if you could **make sure the bug was not
|
||||
already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code.
|
||||
|
||||
First you want to make sure that all the tests pass:
|
||||
Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so we can quickly resolve it:
|
||||
|
||||
* Your **OS type and version**, **Python**, **PyTorch**, **TRL** and **Transformers** versions.
|
||||
* A short, self-contained, code snippet that allows us to reproduce the bug in
|
||||
less than 30s.
|
||||
* The *full* traceback if an exception is raised.
|
||||
* Attach any other additional information, like screenshots, you think may help.
|
||||
|
||||
To get the OS and software versions automatically, run the following command:
|
||||
|
||||
```bash
|
||||
make test
|
||||
transformers-cli env
|
||||
```
|
||||
|
||||
Then before submitting your PR make sure the code quality follows the standards. You can run the following command to format:
|
||||
### Do you want a new feature?
|
||||
|
||||
If there is a new feature you'd like to see in TRL, please open an issue and describe:
|
||||
|
||||
1. What is the *motivation* behind this feature? Is it related to a problem or frustration with the library? Is it a feature related to something you need for a project? Is it something you worked on and think it could benefit the community?
|
||||
|
||||
Whatever it is, we'd love to hear about it!
|
||||
|
||||
2. Describe your requested feature in as much detail as possible. The more you can tell us about it, the better we'll be able to help you.
|
||||
3. Provide a *code snippet* that demonstrates the features usage.
|
||||
4. If the feature is related to a paper, please include a link.
|
||||
|
||||
If your issue is well written we're already 80% of the way there by the time you create it.
|
||||
|
||||
## Do you want to implement a new trainer?
|
||||
|
||||
New post-training methods are published on a frequent basis and those which satisfy the following criteria are good candidates to be integrated in TRL:
|
||||
|
||||
* **Simplicity:** does the new method achieve similar performance as prior methods, but with less complexity? A good example is [Direct Preference Optimization](https://arxiv.org/abs/2305.18290) (DPO), which provided a simpler and compelling alternative to RLHF methods.
|
||||
* **Efficiency:** does the new method provide a significant improvement in training efficiency? A good example is [Odds Ratio Preference Optimization](https://arxiv.org/abs/2403.07691v2), which utilises a similar objective as DPO, but requires half the GPU VRAM.
|
||||
|
||||
Methods which only provide incremental improvements at the expense of added complexity or compute costs are unlikely to be included in TRL.
|
||||
|
||||
If you want to implement a trainer for a new post-training method, first open an issue and provide the following information:
|
||||
|
||||
* A short description of the method and a link to the paper.
|
||||
* Link to the implementation if it is open-sourced.
|
||||
* Link to model weights trained with the method if they are available.
|
||||
|
||||
Based on the community and maintainer feedback, the next step will be to implement the trainer and config classes. See the following examples for inspiration:
|
||||
|
||||
* Paired preference optimisation: [`dpo_trainer.py`](./trl/trainer/dpo_trainer.py) and [`dpo_config.py`](./trl/trainer/dpo_config.py)
|
||||
* RL-based optimisation: [`rloo_trainer.py](./trl/trainer/rloo_trainer.py) and [`rloo_config.py](./trl/trainer/rloo_config.py)
|
||||
* Online optimisation: [`online_dpo_trainer.py`](./trl/trainer/online_dpo_trainer.py) and [`online_dpo_config.py`](./trl/trainer/online_dpo_config.py)
|
||||
|
||||
## Do you want to add documentation?
|
||||
|
||||
We're always looking for improvements to the documentation that make it more clear and accurate. Please let us know how the documentation can be improved, such as typos, dead links and any missing, unclear or inaccurate content.. We'll be happy to make the changes or help you make a contribution if you're interested!
|
||||
|
||||
## Submitting a pull request (PR)
|
||||
|
||||
Before writing code, we strongly advise you to search through the existing PRs or
|
||||
issues to make sure that nobody is already working on the same thing. If you are
|
||||
unsure, it is always a good idea to open an issue to get some feedback.
|
||||
|
||||
You will need basic `git` proficiency to be able to contribute to
|
||||
TRL. `git` is not the easiest tool to use but it has the greatest
|
||||
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
|
||||
Git](https://git-scm.com/book/en/v2) is a very good reference.
|
||||
|
||||
Follow these steps to start contributing:
|
||||
|
||||
1. Fork the [repository](https://github.com/huggingface/trl) by
|
||||
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
|
||||
under your GitHub user account.
|
||||
|
||||
2. Clone your fork to your local disk, and add the base repository as a remote. The following command
|
||||
assumes you have your public SSH key uploaded to GitHub. See the following guide for more
|
||||
[information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
|
||||
|
||||
```bash
|
||||
$ git clone git@github.com:<your Github handle>/trl.git
|
||||
$ cd trl
|
||||
$ git remote add upstream https://github.com/huggingface/trl.git
|
||||
```
|
||||
|
||||
3. Create a new branch to hold your development changes, and do this for every new PR you work on.
|
||||
|
||||
Start by synchronizing your `main` branch with the `upstream/main` branch (ore details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):
|
||||
|
||||
```bash
|
||||
$ git checkout main
|
||||
$ git fetch upstream
|
||||
$ git merge upstream/main
|
||||
```
|
||||
|
||||
Once your `main` branch is synchronized, create a new branch from it:
|
||||
|
||||
```bash
|
||||
$ git checkout -b a-descriptive-name-for-my-changes
|
||||
```
|
||||
|
||||
**Do not** work on the `main` branch.
|
||||
|
||||
4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library:
|
||||
|
||||
```bash
|
||||
$ make dev
|
||||
```
|
||||
|
||||
(If TRL was already installed in the virtual environment, remove
|
||||
it with `pip uninstall trl` before reinstalling it.)
|
||||
|
||||
Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using
|
||||
the provided Dev Container. Documentation on how to get started with dev containers is available [here](https://code.visualstudio.com/docs/remote/containers).
|
||||
|
||||
5. Develop the features on your branch.
|
||||
|
||||
As you work on the features, you should make sure that the test suite
|
||||
passes. You should run the tests impacted by your changes like this (see
|
||||
below an explanation regarding the environment variable):
|
||||
|
||||
```bash
|
||||
$ pytest tests/<TEST_TO_RUN>.py
|
||||
```
|
||||
|
||||
> For the following commands leveraging the `make` utility, we recommend using the WSL system when running on
|
||||
> Windows. More information [here](https://docs.microsoft.com/en-us/windows/wsl/about).
|
||||
|
||||
You can also run the full suite with the following command.
|
||||
|
||||
```bash
|
||||
$ make test
|
||||
```
|
||||
|
||||
TRL relies on `ruff` to format its source code
|
||||
consistently. After you make changes, apply automatic style corrections and code verifications
|
||||
that can't be automated in one go with:
|
||||
|
||||
This target is also optimized to only work with files modified by the PR you're working on.
|
||||
|
||||
If you prefer to run the checks one after the other, the following command apply the
|
||||
style corrections:
|
||||
|
||||
```bash
|
||||
$ make precommit
|
||||
```
|
||||
|
||||
Once you're happy with your changes, add changed files using `git add` and
|
||||
make a commit with `git commit` to record your changes locally:
|
||||
|
||||
```bash
|
||||
$ git add modified_file.py
|
||||
$ git commit
|
||||
```
|
||||
|
||||
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
|
||||
|
||||
It is a good idea to sync your copy of the code with the original
|
||||
repository regularly. This way you can quickly account for changes:
|
||||
|
||||
```bash
|
||||
$ git fetch upstream
|
||||
$ git rebase upstream/main
|
||||
```
|
||||
|
||||
Push the changes to your account using:
|
||||
|
||||
```bash
|
||||
$ git push -u origin a-descriptive-name-for-my-changes
|
||||
```
|
||||
|
||||
6. Once you are satisfied (**and the checklist below is happy too**), go to the
|
||||
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
|
||||
to the project maintainers for review.
|
||||
|
||||
7. It's ok if maintainers ask you for changes. It happens to core contributors
|
||||
too! So everyone can see the changes in the Pull request, work in your local
|
||||
branch and push the changes to your fork. They will automatically appear in
|
||||
the pull request.
|
||||
|
||||
|
||||
### Checklist
|
||||
|
||||
1. The title of your pull request should be a summary of its contribution;
|
||||
2. If your pull request addresses an issue, please mention the issue number in
|
||||
the pull request description to make sure they are linked (and people
|
||||
consulting the issue know you are working on it);
|
||||
3. To indicate a work in progress please prefix the title with `[WIP]`, or mark
|
||||
the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate
|
||||
it from PRs ready to be merged;
|
||||
4. Make sure existing tests pass;
|
||||
5. Add high-coverage tests. No quality testing = no merge.
|
||||
|
||||
|
||||
### Tests
|
||||
|
||||
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in
|
||||
the [tests folder](https://github.com/huggingface/trl/tree/main/tests).
|
||||
|
||||
We use `pytest` in order to run the tests. From the root of the
|
||||
repository, here's how to run tests with `pytest` for the library:
|
||||
|
||||
```bash
|
||||
make precommit
|
||||
$ python -m pytest -sv ./tests
|
||||
```
|
||||
|
||||
Make sure to install `pre-commit` before running the command:
|
||||
```bash
|
||||
pip install pre-commit
|
||||
```
|
||||
|
||||
## Do you want to contribute to the documentation?
|
||||
|
||||
* Docs are in the `docs/` folder and can be updated there.
|
||||
In fact, that's how `make test` is implemented (sans the `pip install` line)!
|
||||
|
||||
You can specify a smaller set of tests in order to test only the feature
|
||||
you're working on.
|
||||
|
@ -2,4 +2,4 @@ include settings.ini
|
||||
include LICENSE
|
||||
include CONTRIBUTING.md
|
||||
include README.md
|
||||
recursive-exclude * __pycache__
|
||||
recursive-exclude * __pycache__
|
38
Makefile
38
Makefile
@ -1,9 +1,43 @@
|
||||
.PHONY: test precommit
|
||||
.PHONY: test precommit benchmark_core benchmark_aux common_tests slow_tests test_examples tests_gpu
|
||||
|
||||
check_dirs := examples tests trl
|
||||
|
||||
ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
|
||||
COMMAND_FILES_PATH = `pwd`/commands
|
||||
|
||||
|
||||
dev:
|
||||
[ -L "$(pwd)/trl/commands/scripts" ] && unlink "$(pwd)/trl/commands/scripts" || true
|
||||
pip install -e ".[dev]"
|
||||
ln -s `pwd`/examples/scripts/ `pwd`/trl/commands
|
||||
|
||||
test:
|
||||
python -m pytest -n auto --dist=loadfile -s -v ./tests/
|
||||
python -m pytest -n auto --dist=loadfile -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' ./tests/
|
||||
|
||||
precommit:
|
||||
pre-commit run --all-files
|
||||
|
||||
benchmark_core:
|
||||
bash ./benchmark/benchmark_core.sh
|
||||
|
||||
benchmark_aux:
|
||||
bash ./benchmark/benchmark_aux.sh
|
||||
|
||||
tests_gpu:
|
||||
python -m pytest tests/test_* $(if $(IS_GITHUB_CI),--report-log "common_tests.log",)
|
||||
|
||||
slow_tests:
|
||||
python -m pytest tests/slow/test_* $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
|
||||
|
||||
test_examples:
|
||||
touch temp_results_sft_tests.txt
|
||||
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
|
||||
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_sft.sh; \
|
||||
echo $$?','$${file} >> temp_results_sft_tests.txt; \
|
||||
done
|
||||
|
||||
touch temp_results_dpo_tests.txt
|
||||
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
|
||||
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_dpo.sh; \
|
||||
echo $$?','$${file} >> temp_results_dpo_tests.txt; \
|
||||
done
|
||||
|
154
README.md
154
README.md
@ -3,78 +3,90 @@
|
||||
</div>
|
||||
|
||||
# TRL - Transformer Reinforcement Learning
|
||||
> Full stack transformer language models with reinforcement learning.
|
||||
> Full stack library to fine-tune and align large language models.
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/lvwerra/trl/blob/main/LICENSE">
|
||||
<img alt="License" src="https://img.shields.io/github/license/lvwerra/trl.svg?color=blue">
|
||||
<a href="https://github.com/huggingface/trl/blob/main/LICENSE">
|
||||
<img alt="License" src="https://img.shields.io/github/license/huggingface/trl.svg?color=blue">
|
||||
</a>
|
||||
<a href="https://huggingface.co/docs/trl/index">
|
||||
<img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/trl/index.svg?down_color=red&down_message=offline&up_message=online">
|
||||
</a>
|
||||
<a href="https://github.com/lvwerra/trl/releases">
|
||||
<img alt="GitHub release" src="https://img.shields.io/github/release/lvwerra/trl.svg">
|
||||
<a href="https://github.com/huggingface/trl/releases">
|
||||
<img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/trl.svg">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
|
||||
## What is it?
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/TRL-readme.png">
|
||||
</div>
|
||||
The `trl` library is a full stack tool to fine-tune and align transformer language and diffusion models using methods such as Supervised Fine-tuning step (SFT), Reward Modeling (RM) and the Proximal Policy Optimization (PPO) as well as Direct Preference Optimization (DPO).
|
||||
|
||||
`trl` is a full stack library where we provide a set of tools to train transformer language models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step. The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point most of decoder architectures and encoder-decoder architectures are supported. Refer to the documentation or the `examples/` folder for example code snippets and how to run these tools.
|
||||
|
||||
**Highlights:**
|
||||
|
||||
- [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer): A light and friendly wrapper around `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
|
||||
- [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer): A light wrapper around `transformers` Trainer to easily fine-tune language models for human preferences (Reward Modeling).
|
||||
- [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer): A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.
|
||||
- [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead): A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
|
||||
- [Examples](https://github.com/lvwerra/trl/tree/main/examples): Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [Stack-Llama example](https://huggingface.co/blog/stackllama), etc.
|
||||
|
||||
## How PPO works
|
||||
Fine-tuning a language model via PPO consists of roughly three steps:
|
||||
|
||||
1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence.
|
||||
2. **Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
|
||||
3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate to far from the reference language model. The active language model is then trained with PPO.
|
||||
|
||||
This process is illustrated in the sketch below:
|
||||
The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library and thus allows to use any model architecture available there.
|
||||
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png" width="800">
|
||||
<p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
|
||||
</div>
|
||||
## Highlights
|
||||
|
||||
- **`Efficient and scalable`**:
|
||||
- [`accelerate`](https://github.com/huggingface/accelerate) is the backbone of `trl` which allows to scale model training from a single GPU to a large scale multi-node cluster with methods such as DDP and DeepSpeed.
|
||||
- [`PEFT`](https://github.com/huggingface/peft) is fully integrated and allows to train even the largest models on modest hardware with quantisation and methods such as LoRA or QLoRA.
|
||||
- [`unsloth`](https://github.com/unslothai/unsloth) is also integrated and allows to significantly speed up training with dedicated kernels.
|
||||
- **`CLI`**: With the [CLI](https://huggingface.co/docs/trl/clis) you can fine-tune and chat with LLMs without writing any code using a single command and a flexible config system.
|
||||
- **`Trainers`**: The Trainer classes are an abstraction to apply many fine-tuning methods with ease such as the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.DPOTrainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer), [`CPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.CPOTrainer), and [`ORPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.ORPOTrainer).
|
||||
- **`AutoModels`**: The [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead) classes add an additional value head to the model which allows to train them with RL algorithms such as PPO.
|
||||
- **`Examples`**: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [StackLlama example](https://huggingface.co/blog/stackllama), etc. following the [examples](https://github.com/huggingface/trl/tree/main/examples).
|
||||
|
||||
## Installation
|
||||
|
||||
### Python package
|
||||
Install the library with pip:
|
||||
Install the library with `pip`:
|
||||
```bash
|
||||
pip install trl
|
||||
```
|
||||
|
||||
### From source
|
||||
If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:
|
||||
If you want to use the latest features before an official release you can install from source:
|
||||
```bash
|
||||
git clone https://github.com/lvwerra/trl.git
|
||||
cd trl/
|
||||
pip install .
|
||||
pip install git+https://github.com/huggingface/trl.git
|
||||
```
|
||||
|
||||
If you wish to develop TRL, you should install in editable mode:
|
||||
### Repository
|
||||
If you want to use the examples you can clone the repository with the following command:
|
||||
```bash
|
||||
pip install -e .
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
```
|
||||
|
||||
## Command Line Interface (CLI)
|
||||
|
||||
You can use TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT), Direct Preference Optimization (DPO) and test your aligned model with the chat CLI:
|
||||
|
||||
**SFT:**
|
||||
|
||||
```bash
|
||||
trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir opt-sft-imdb
|
||||
```
|
||||
|
||||
**DPO:**
|
||||
|
||||
```bash
|
||||
trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style --output_dir opt-sft-hh-rlhf
|
||||
```
|
||||
|
||||
**Chat:**
|
||||
|
||||
```bash
|
||||
trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
|
||||
```
|
||||
|
||||
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.
|
||||
|
||||
## How to use
|
||||
|
||||
For more flexibility and control over the training, you can use the dedicated trainer classes to fine-tune the model in Python.
|
||||
|
||||
### `SFTTrainer`
|
||||
|
||||
This is a basic example on how to use the `SFTTrainer` from the library. The `SFTTrainer` is a light wrapper around the `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
|
||||
This is a basic example of how to use the `SFTTrainer` from the library. The `SFTTrainer` is a light wrapper around the `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
|
||||
|
||||
```python
|
||||
# imports
|
||||
@ -98,7 +110,7 @@ trainer.train()
|
||||
|
||||
### `RewardTrainer`
|
||||
|
||||
This is a basic example on how to use the `RewardTrainer` from the library. The `RewardTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.
|
||||
This is a basic example of how to use the `RewardTrainer` from the library. The `RewardTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.
|
||||
|
||||
```python
|
||||
# imports
|
||||
@ -106,7 +118,7 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from trl import RewardTrainer
|
||||
|
||||
# load model and dataset - dataset needs to be in a specific format
|
||||
model = AutoModelForSequenceClassification.from_pretrained("gpt2")
|
||||
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
...
|
||||
@ -124,7 +136,7 @@ trainer.train()
|
||||
|
||||
### `PPOTrainer`
|
||||
|
||||
This is a basic example on how to use the `PPOTrainer` from the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.
|
||||
This is a basic example of how to use the `PPOTrainer` from the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.
|
||||
|
||||
```python
|
||||
# imports
|
||||
@ -135,14 +147,13 @@ from trl.core import respond_to_batch
|
||||
|
||||
# get models
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
model_ref = create_reference_model(model)
|
||||
ref_model = create_reference_model(model)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# initialize trainer
|
||||
ppo_config = PPOConfig(
|
||||
batch_size=1,
|
||||
)
|
||||
ppo_config = PPOConfig(batch_size=1, mini_batch_size=1)
|
||||
|
||||
# encode a query
|
||||
query_txt = "This morning I went to the "
|
||||
@ -152,7 +163,7 @@ query_tensor = tokenizer.encode(query_txt, return_tensors="pt")
|
||||
response_tensor = respond_to_batch(model, query_tensor)
|
||||
|
||||
# create a ppo trainer
|
||||
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)
|
||||
ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer)
|
||||
|
||||
# define a reward for response
|
||||
# (this could be any reward such as human feedback or output from another model)
|
||||
@ -162,33 +173,60 @@ reward = [torch.tensor(1.0)]
|
||||
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
|
||||
```
|
||||
|
||||
### Advanced example: IMDB sentiment
|
||||
For a detailed example check out the example python script `examples/scripts/sentiment_tuning.py`, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:
|
||||
### `DPOTrainer`
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/table_imdb_preview.png" width="800">
|
||||
<p style="text-align: center;"> <b>Figure:</b> A few review continuations before and after optimisation. </p>
|
||||
</div>
|
||||
`DPOTrainer` is a trainer that uses [Direct Preference Optimization algorithm](https://huggingface.co/papers/2305.18290). This is a basic example of how to use the `DPOTrainer` from the library. The `DPOTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.
|
||||
|
||||
Have a look at more examples inside [`examples/`](https://github.com/lvwerra/trl/tree/main/examples) folder.
|
||||
```python
|
||||
# imports
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import DPOTrainer
|
||||
|
||||
# load model and dataset - dataset needs to be in a specific format
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
...
|
||||
|
||||
# load trainer
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
# train
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
If you want to contribute to `trl` or customizing it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
cd trl/
|
||||
make dev
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
### Proximal Policy Optimisation
|
||||
The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].
|
||||
The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://huggingface.co/papers/1909.08593), [code](https://github.com/openai/lm-human-preferences)].
|
||||
|
||||
### Direct Preference Optimization
|
||||
DPO is based on the original implementation of **"Direct Preference Optimization: Your Language Model is Secretly a Reward Model"** by E. Mitchell et al. \[[paper](https://huggingface.co/papers/2305.18290), [code](https://github.com/eric-mitchell/direct-preference-optimization)]
|
||||
|
||||
### Language models
|
||||
The language models utilize the `transformers` library by 🤗 Hugging Face.
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@misc{vonwerra2022trl,
|
||||
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert},
|
||||
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang},
|
||||
title = {TRL: Transformer Reinforcement Learning},
|
||||
year = {2020},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/lvwerra/trl}}
|
||||
howpublished = {\url{https://github.com/huggingface/trl}}
|
||||
}
|
||||
```
|
||||
|
@ -6,6 +6,8 @@ import subprocess
|
||||
import uuid
|
||||
from distutils.util import strtobool
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def parse_args():
|
||||
# fmt: off
|
||||
@ -38,14 +40,65 @@ def parse_args():
|
||||
def run_experiment(command: str):
|
||||
command_list = shlex.split(command)
|
||||
print(f"running {command}")
|
||||
fd = subprocess.Popen(command_list)
|
||||
return_code = fd.wait()
|
||||
assert return_code == 0
|
||||
|
||||
# Use subprocess.PIPE to capture the output
|
||||
fd = subprocess.Popen(command_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
output, errors = fd.communicate()
|
||||
|
||||
return_code = fd.returncode
|
||||
assert return_code == 0, f"Command failed with error: {errors.decode('utf-8')}"
|
||||
|
||||
# Convert bytes to string and strip leading/trailing whitespaces
|
||||
return output.decode("utf-8").strip()
|
||||
|
||||
|
||||
def autotag() -> str:
|
||||
wandb_tag = ""
|
||||
print("autotag feature is enabled")
|
||||
git_tag = ""
|
||||
try:
|
||||
git_tag = subprocess.check_output(["git", "describe", "--tags"]).decode("ascii").strip()
|
||||
print(f"identified git tag: {git_tag}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e)
|
||||
if len(git_tag) == 0:
|
||||
try:
|
||||
count = int(subprocess.check_output(["git", "rev-list", "--count", "HEAD"]).decode("ascii").strip())
|
||||
hash = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip()
|
||||
git_tag = f"no-tag-{count}-g{hash}"
|
||||
print(f"identified git tag: {git_tag}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(e)
|
||||
wandb_tag = f"{git_tag}"
|
||||
|
||||
git_commit = subprocess.check_output(["git", "rev-parse", "--verify", "HEAD"]).decode("ascii").strip()
|
||||
try:
|
||||
# try finding the pull request number on github
|
||||
prs = requests.get(f"https://api.github.com/search/issues?q=repo:huggingface/trl+is:pr+{git_commit}")
|
||||
if prs.status_code == 200:
|
||||
prs = prs.json()
|
||||
if len(prs["items"]) > 0:
|
||||
pr = prs["items"][0]
|
||||
pr_number = pr["number"]
|
||||
wandb_tag += f",pr-{pr_number}"
|
||||
print(f"identified github pull request: {pr_number}")
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
return wandb_tag
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
if args.auto_tag:
|
||||
existing_wandb_tag = os.environ.get("WANDB_TAGS", "")
|
||||
wandb_tag = autotag()
|
||||
if len(wandb_tag) > 0:
|
||||
if len(existing_wandb_tag) > 0:
|
||||
os.environ["WANDB_TAGS"] = ",".join([existing_wandb_tag, wandb_tag])
|
||||
else:
|
||||
os.environ["WANDB_TAGS"] = wandb_tag
|
||||
print("WANDB_TAGS: ", os.environ.get("WANDB_TAGS", ""))
|
||||
commands = []
|
||||
for seed in range(0, args.num_seeds):
|
||||
commands += [" ".join([args.command, "--seed", str(args.start_seed + seed)])]
|
||||
@ -93,4 +146,5 @@ if __name__ == "__main__":
|
||||
slurm_path = os.path.join("slurm", f"{filename}.slurm")
|
||||
print(f"saving command in {slurm_path}")
|
||||
if args.workers > 0:
|
||||
run_experiment(f"sbatch {slurm_path}")
|
||||
job_id = run_experiment(f"sbatch --parsable {slurm_path}")
|
||||
print(f"Job ID: {job_id}")
|
||||
|
26
benchmark/benchmark_and_report.sh
Normal file
26
benchmark/benchmark_and_report.sh
Normal file
@ -0,0 +1,26 @@
|
||||
export WANDB_ENTITY=huggingface
|
||||
export WANDB_PROJECT=trl
|
||||
bash $BENCHMARK_SCRIPT > output.txt
|
||||
|
||||
# Extract Job IDs into an array
|
||||
job_ids=($(grep "Job ID:" output.txt | awk '{print $3}'))
|
||||
|
||||
# Extract WANDB_TAGS into an array
|
||||
WANDB_TAGS=($(grep "WANDB_TAGS:" output.txt | awk '{print $2}'))
|
||||
WANDB_TAGS=($(echo $WANDB_TAGS | tr "," "\n"))
|
||||
|
||||
# Print to verify
|
||||
echo "Job IDs: ${job_ids[@]}"
|
||||
echo "WANDB_TAGS: ${WANDB_TAGS[@]}"
|
||||
|
||||
TAGS_STRING="?tag=${WANDB_TAGS[0]}"
|
||||
FOLDER_STRING="${WANDB_TAGS[0]}"
|
||||
for tag in "${WANDB_TAGS[@]:1}"; do
|
||||
TAGS_STRING+="&tag=$tag"
|
||||
FOLDER_STRING+="_$tag"
|
||||
done
|
||||
|
||||
echo "TAGS_STRING: $TAGS_STRING"
|
||||
echo "FOLDER_STRING: $FOLDER_STRING"
|
||||
|
||||
TAGS_STRING=$TAGS_STRING FOLDER_STRING=$FOLDER_STRING BENCHMARK_PLOT_SCRIPT=$BENCHMARK_PLOT_SCRIPT sbatch --dependency=afterany:$job_ids benchmark/post_github_comment.sbatch
|
44
benchmark/benchmark_level1.sh
Normal file
44
benchmark/benchmark_level1.sh
Normal file
@ -0,0 +1,44 @@
|
||||
# hello world experiment
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --log_with wandb" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/dpo.py --model_name_or_path=gpt2 --per_device_train_batch_size 4 --max_steps 1000 --learning_rate 1e-3 --gradient_accumulation_steps 1 --logging_steps 10 --eval_steps 500 --output_dir="dpo_anthropic_hh" --optim adamw_torch --warmup_steps 150 --report_to wandb --bf16 --logging_first_step --no_remove_unused_columns" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/sft.py --model_name_or_path="facebook/opt-350m" --report_to="wandb" --learning_rate=1.41e-5 --per_device_train_batch_size=64 --gradient_accumulation_steps=16 --output_dir="sft_openassistant-guanaco" --logging_steps=1 --num_train_epochs=3 --max_steps=-1 --push_to_hub --gradient_checkpointing" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/reward_modeling.py --model_name_or_path=facebook/opt-350m --output_dir="reward_modeling_anthropic_hh" --per_device_train_batch_size=64 --num_train_epochs=1 --gradient_accumulation_steps=16 --gradient_checkpointing=True --learning_rate=1.41e-5 --report_to="wandb" --remove_unused_columns=False --optim="adamw_torch" --logging_steps=10 --eval_strategy="steps" --max_length=512" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
50
benchmark/benchmark_level1_plot.sh
Normal file
50
benchmark/benchmark_level1_plot.sh
Normal file
@ -0,0 +1,50 @@
|
||||
# pip install openrlbenchmark==0.2.1a5
|
||||
# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation
|
||||
echo "we deal with $TAGS_STRING"
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
|
||||
"ppo$TAGS_STRING" \
|
||||
--env-ids sentiment-analysis:lvwerra/distilbert-imdb \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$FOLDER_STRING/ppo \
|
||||
--scan-history
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=output_dir&cen=_name_or_path&metrics=train/rewards/accuracies&metrics=train/loss' \
|
||||
"gpt2$TAGS_STRING" \
|
||||
--env-ids dpo_anthropic_hh \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$FOLDER_STRING/dpo \
|
||||
--scan-history
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=output_dir&cen=_name_or_path&metrics=train/loss&metrics=eval/accuracy&metrics=eval/loss' \
|
||||
"facebook/opt-350m$TAGS_STRING" \
|
||||
--env-ids reward_modeling_anthropic_hh \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$FOLDER_STRING/reward_modeling \
|
||||
--scan-history
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=output_dir&cen=_name_or_path&metrics=train/loss' \
|
||||
"facebook/opt-350m$TAGS_STRING" \
|
||||
--env-ids sft_openassistant-guanaco \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$FOLDER_STRING/sft \
|
||||
--scan-history
|
||||
|
||||
python benchmark/upload_benchmark.py \
|
||||
--folder_path="benchmark/trl/$FOLDER_STRING" \
|
||||
--path_in_repo="images/benchmark/$FOLDER_STRING" \
|
||||
--repo_id="trl-internal-testing/example-images" \
|
||||
--repo_type="dataset"
|
||||
|
23
benchmark/benchmark_level2.sh
Normal file
23
benchmark/benchmark_level2.sh
Normal file
@ -0,0 +1,23 @@
|
||||
# compound experiments: gpt2xl + grad_accu
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --exp_name ppo_gpt2xl_grad_accu --model_name gpt2-xl --mini_batch_size 16 --gradient_accumulation_steps 8 --log_with wandb" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
|
||||
# compound experiments: Cerebras-GPT-6.7B + deepspeed zero2 + grad_accu
|
||||
python benchmark/benchmark.py \
|
||||
--command "accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml examples/scripts/ppo.py --exp_name ppo_Cerebras-GPT-6.7B_grad_accu_deepspeed_stage2 --batch_size 32 --mini_batch_size 32 --log_with wandb --model_name cerebras/Cerebras-GPT-6.7B --reward_model sentiment-analysis:cerebras/Cerebras-GPT-6.7B" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 8 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 90 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
31
benchmark/benchmark_level2_plot.sh
Normal file
31
benchmark/benchmark_level2_plot.sh
Normal file
@ -0,0 +1,31 @@
|
||||
# pip install openrlbenchmark==0.2.1a5
|
||||
# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation
|
||||
echo "we deal with $TAGS_STRING"
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
|
||||
"ppo$TAGS_STRING" \
|
||||
"ppo_gpt2xl_grad_accu$TAGS_STRING" \
|
||||
--env-ids sentiment-analysis:lvwerra/distilbert-imdb \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$FOLDER_STRING/different_models \
|
||||
--scan-history
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
|
||||
"ppo_Cerebras-GPT-6.7B_grad_accu_deepspeed_stage2$TAGS_STRING" \
|
||||
--env-ids sentiment-analysis:cerebras/Cerebras-GPT-6.7B \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$FOLDER_STRING/deepspeed \
|
||||
--scan-history
|
||||
|
||||
python benchmark/upload_benchmark.py \
|
||||
--folder_path="benchmark/trl/$FOLDER_STRING" \
|
||||
--path_in_repo="images/benchmark/$FOLDER_STRING" \
|
||||
--repo_id="trl-internal-testing/example-images" \
|
||||
--repo_type="dataset"
|
||||
|
46
benchmark/benchmark_level3.sh
Normal file
46
benchmark/benchmark_level3.sh
Normal file
@ -0,0 +1,46 @@
|
||||
## w/ and w/o gradient accumulation
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --exp_name ppo_step_grad_accu --mini_batch_size 1 --gradient_accumulation_steps 128 --log_with wandb" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
|
||||
## w/ different models (gpt2, gpt2-xl, falcon, llama2)
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --exp_name ppo_gpt2 --log_with wandb" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --exp_name ppo_falcon_rw_1b --model_name tiiuae/falcon-rw-1b --log_with wandb" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
|
||||
|
||||
## w/ and w/o PEFT
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --exp_name ppo_peft --use_peft --log_with wandb" \
|
||||
--num-seeds 3 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
56
benchmark/plot.sh
Normal file
56
benchmark/plot.sh
Normal file
@ -0,0 +1,56 @@
|
||||
# pip install openrlbenchmark==0.2.1a5
|
||||
# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation
|
||||
BASELINE_PR_TAG=v0.4.7-55-g110e672
|
||||
BASELINE_PR_NAME=PR-662
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
|
||||
"sentiment_tuning?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb ($BASELINE_PR_NAME)" \
|
||||
--env-ids sentiment-analysis:lvwerra/distilbert-imdb \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$BASELINE_PR_TAG/sentiment \
|
||||
--scan-history
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
|
||||
"sentiment_tuning?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb ($BASELINE_PR_NAME)" \
|
||||
"sentiment_tuning_step_grad_accu?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb gradient accumulation ($BASELINE_PR_NAME)" \
|
||||
--env-ids sentiment-analysis:lvwerra/distilbert-imdb \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$BASELINE_PR_TAG/gradient_accu \
|
||||
--scan-history
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
|
||||
"sentiment_tuning?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb ($BASELINE_PR_NAME)" \
|
||||
"sentiment_tuning_gpt2?tag=$BASELINE_PR_TAG&cl=sentiment gpt2 ($BASELINE_PR_NAME)" \
|
||||
"sentiment_tuning_falcon_rw_1b?tag=$BASELINE_PR_TAG&cl=sentiment tiiuae/falcon-rw-1b ($BASELINE_PR_NAME)" \
|
||||
"sentiment_tuning_gpt2xl_grad_accu?tag=$BASELINE_PR_TAG&cl=sentiment gpt2xl ($BASELINE_PR_NAME)" \
|
||||
--env-ids sentiment-analysis:lvwerra/distilbert-imdb \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$BASELINE_PR_TAG/different_models \
|
||||
--scan-history
|
||||
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.reward_model&cen=trl_ppo_trainer_config.value.exp_name&metrics=env/reward_mean&metrics=objective/kl' \
|
||||
"sentiment_tuning?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb ($BASELINE_PR_NAME)" \
|
||||
"sentiment_tuning_peft?tag=$BASELINE_PR_TAG&cl=sentiment lvwerra/gpt2-imdb w/ peft ($BASELINE_PR_NAME)" \
|
||||
--env-ids sentiment-analysis:lvwerra/distilbert-imdb \
|
||||
--no-check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename benchmark/trl/$BASELINE_PR_TAG/peft \
|
||||
--scan-history
|
||||
|
||||
|
||||
python benchmark/upload_benchmark.py \
|
||||
--folder_path="benchmark/trl/$BASELINE_PR_TAG" \
|
||||
--path_in_repo="images/benchmark/$BASELINE_PR_TAG" \
|
||||
--repo_id="trl-internal-testing/example-images" \
|
||||
--repo_type="dataset"
|
26
benchmark/post_github_comment.py
Normal file
26
benchmark/post_github_comment.py
Normal file
@ -0,0 +1,26 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from ghapi.all import GhApi
|
||||
|
||||
|
||||
FOLDER_STRING = os.environ.get("FOLDER_STRING", "")
|
||||
folder = f"benchmark/trl/{FOLDER_STRING}"
|
||||
host_url = f"https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/benchmark/{FOLDER_STRING}"
|
||||
|
||||
# Create a GitHub API instance
|
||||
github_context = json.loads(os.environ["GITHUB_CONTEXT"])
|
||||
token = os.environ["PERSONAL_ACCESS_TOKEN_GITHUB"] # this needs to refreshed every 12 months
|
||||
status_message = "**[COSTA BENCHMARK BOT]**: Here are the results"
|
||||
body = status_message
|
||||
repo = github_context["repository"]
|
||||
owner, repo = repo.split("/")
|
||||
api = GhApi(owner=owner, repo=repo, token=token)
|
||||
|
||||
# for each `.png` file in the folder, add it to the comment
|
||||
for file in os.listdir(folder):
|
||||
if file.endswith(".png"):
|
||||
body += f"\n"
|
||||
|
||||
# Create a comment on the issue
|
||||
api.issues.create_comment(issue_number=github_context["event"]["issue"]["number"], body=body)
|
9
benchmark/post_github_comment.sbatch
Normal file
9
benchmark/post_github_comment.sbatch
Normal file
@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=trl
|
||||
#SBATCH --partition=hopper-cpu
|
||||
#SBATCH --ntasks=1
|
||||
#SBATCH --output=slurm/logs/%x_%j.out
|
||||
|
||||
sleep 2m
|
||||
bash $BENCHMARK_PLOT_SCRIPT
|
||||
srun python benchmark/post_github_comment.py
|
3
benchmark/regression_test.sh
Normal file
3
benchmark/regression_test.sh
Normal file
@ -0,0 +1,3 @@
|
||||
BENCHMARK_SCRIPT="benchmark/benchmark_level1.sh" \
|
||||
BENCHMARK_PLOT_SCRIPT="benchmark/benchmark_level1_plot.sh" \
|
||||
bash benchmark/benchmark_and_report.sh
|
@ -1,11 +1,14 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --partition=dev-cluster
|
||||
#SBATCH --job-name=trl
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --gpus-per-task={{gpus_per_task}}
|
||||
#SBATCH --cpus-per-gpu={{cpus_per_gpu}}
|
||||
#SBATCH --ntasks={{ntasks}}
|
||||
#SBATCH --mem-per-cpu=11G
|
||||
#SBATCH --output=slurm/logs/%x_%j.out
|
||||
#SBATCH --array={{array}}
|
||||
##SBATCH --exclude=ip-26-0-149-199
|
||||
|
||||
module load cuda/12.1
|
||||
|
||||
{{nodes}}
|
||||
|
||||
|
23
benchmark/upload_benchmark.py
Normal file
23
benchmark/upload_benchmark.py
Normal file
@ -0,0 +1,23 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import tyro
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
folder_path: str = "benchmark/trl"
|
||||
path_in_repo: str = "images/benchmark"
|
||||
repo_id: str = "trl-internal-testing/example-images"
|
||||
repo_type: str = "dataset"
|
||||
|
||||
|
||||
args = tyro.cli(Args)
|
||||
api = HfApi()
|
||||
|
||||
api.upload_folder(
|
||||
folder_path=args.folder_path,
|
||||
path_in_repo=args.path_in_repo,
|
||||
repo_id=args.repo_id,
|
||||
repo_type=args.repo_type,
|
||||
)
|
58
commands/run_dpo.sh
Normal file
58
commands/run_dpo.sh
Normal file
@ -0,0 +1,58 @@
|
||||
#!/bin/bash
|
||||
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
|
||||
# but defaults to QLoRA + PEFT
|
||||
OUTPUT_DIR="test_dpo/"
|
||||
MODEL_NAME="trl-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
DATASET_NAME="trl-internal-testing/hh-rlhf-helpful-base-trl-style"
|
||||
MAX_STEPS=5
|
||||
BATCH_SIZE=2
|
||||
SEQ_LEN=128
|
||||
|
||||
# Handle extra arguments in case one passes accelerate configs.
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
EXTRA_TRAINING_ARGS="""--use_peft \
|
||||
--load_in_4bit
|
||||
"""
|
||||
|
||||
# This is a hack to get the number of available GPUs
|
||||
NUM_GPUS=2
|
||||
|
||||
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
else
|
||||
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
|
||||
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
|
||||
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
|
||||
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
|
||||
EXTRA_TRAINING_ARGS="--fp16"
|
||||
else
|
||||
echo "Keeping QLoRA + PEFT"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
CMD="""
|
||||
accelerate launch $EXTRA_ACCELERATE_ARGS \
|
||||
--num_processes $NUM_GPUS \
|
||||
--mixed_precision 'fp16' \
|
||||
`pwd`/examples/scripts/dpo.py \
|
||||
--model_name_or_path $MODEL_NAME \
|
||||
--dataset_name $DATASET_NAME \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_steps $MAX_STEPS \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--max_length $SEQ_LEN \
|
||||
$EXTRA_TRAINING_ARGS
|
||||
"""
|
||||
|
||||
echo "Starting program..."
|
||||
|
||||
{ # try
|
||||
echo $CMD
|
||||
eval "$CMD"
|
||||
} || { # catch
|
||||
# save log for exception
|
||||
echo "Operation Failed!"
|
||||
exit 1
|
||||
}
|
||||
exit 0
|
60
commands/run_sft.sh
Normal file
60
commands/run_sft.sh
Normal file
@ -0,0 +1,60 @@
|
||||
#!/bin/bash
|
||||
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
|
||||
# but defaults to QLoRA + PEFT
|
||||
OUTPUT_DIR="test_sft/"
|
||||
MODEL_NAME="trl-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
DATASET_NAME="imdb"
|
||||
MAX_STEPS=5
|
||||
BATCH_SIZE=2
|
||||
SEQ_LEN=128
|
||||
|
||||
|
||||
# Handle extra arguments in case one passes accelerate configs.
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
EXTRA_TRAINING_ARGS="""--use_peft \
|
||||
--load_in_4bit
|
||||
"""
|
||||
|
||||
# Set your number of GPUs here
|
||||
NUM_GPUS=2
|
||||
|
||||
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
else
|
||||
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
|
||||
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
|
||||
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
|
||||
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
|
||||
EXTRA_TRAINING_ARGS="--fp16"
|
||||
else
|
||||
echo "Keeping QLoRA + PEFT"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
CMD="""
|
||||
accelerate launch $EXTRA_ACCELERATE_ARGS \
|
||||
--num_processes $NUM_GPUS \
|
||||
--mixed_precision 'fp16' \
|
||||
`pwd`/examples/scripts/sft.py \
|
||||
--model_name $MODEL_NAME \
|
||||
--dataset_name $DATASET_NAME \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_steps $MAX_STEPS \
|
||||
--dataset_text_field 'text' \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--max_seq_length $SEQ_LEN \
|
||||
$EXTRA_TRAINING_ARGS
|
||||
"""
|
||||
|
||||
echo "Starting program..."
|
||||
|
||||
{ # try
|
||||
echo $CMD
|
||||
eval "$CMD"
|
||||
} || { # catch
|
||||
# save log for exception
|
||||
echo "Operation Failed!"
|
||||
exit 1
|
||||
}
|
||||
exit 0
|
66
docker/trl-latest-gpu/Dockerfile
Normal file
66
docker/trl-latest-gpu/Dockerfile
Normal file
@ -0,0 +1,66 @@
|
||||
# Builds GPU docker image of PyTorch
|
||||
# Uses multi-staged approach to reduce size
|
||||
# Stage 1
|
||||
# Use base conda image to reduce time
|
||||
FROM continuumio/miniconda3:latest AS compile-image
|
||||
# Specify py version
|
||||
ENV PYTHON_VERSION=3.10
|
||||
# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl git wget software-properties-common git-lfs && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists*
|
||||
|
||||
# Install audio-related libraries
|
||||
RUN apt-get update && \
|
||||
apt install -y ffmpeg
|
||||
|
||||
RUN apt install -y libsndfile1-dev
|
||||
RUN git lfs install
|
||||
|
||||
# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip
|
||||
|
||||
# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
# We don't install pytorch here yet since CUDA isn't available
|
||||
# instead we use the direct torch wheel
|
||||
ENV PATH /opt/conda/envs/trl/bin:$PATH
|
||||
# Activate our bash shell
|
||||
RUN chsh -s /bin/bash
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
# Stage 2
|
||||
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
|
||||
COPY --from=compile-image /opt/conda /opt/conda
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
||||
RUN chsh -s /bin/bash
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
RUN source activate trl && \
|
||||
python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
|
||||
|
||||
# Install apt libs
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl git wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists*
|
||||
|
||||
# Activate the conda env and install transformers + accelerate from source
|
||||
RUN source activate trl && \
|
||||
python3 -m pip install -U --no-cache-dir \
|
||||
librosa \
|
||||
"soundfile>=0.12.1" \
|
||||
scipy \
|
||||
transformers \
|
||||
accelerate \
|
||||
peft \
|
||||
trl[test]@git+https://github.com/huggingface/trl
|
||||
|
||||
RUN source activate trl && \
|
||||
pip freeze | grep trl
|
||||
|
||||
RUN echo "source activate trl" >> ~/.profile
|
||||
|
||||
# Activate the virtualenv
|
||||
CMD ["/bin/bash"]
|
66
docker/trl-source-gpu/Dockerfile
Normal file
66
docker/trl-source-gpu/Dockerfile
Normal file
@ -0,0 +1,66 @@
|
||||
# Builds GPU docker image of PyTorch
|
||||
# Uses multi-staged approach to reduce size
|
||||
# Stage 1
|
||||
# Use base conda image to reduce time
|
||||
FROM continuumio/miniconda3:latest AS compile-image
|
||||
# Specify py version
|
||||
ENV PYTHON_VERSION=3.10
|
||||
# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl git wget software-properties-common git-lfs && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists*
|
||||
|
||||
# Install audio-related libraries
|
||||
RUN apt-get update && \
|
||||
apt install -y ffmpeg
|
||||
|
||||
RUN apt install -y libsndfile1-dev
|
||||
RUN git lfs install
|
||||
|
||||
# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip
|
||||
RUN python3 -m pip install --no-cache-dir --upgrade pip
|
||||
|
||||
# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile
|
||||
# We don't install pytorch here yet since CUDA isn't available
|
||||
# instead we use the direct torch wheel
|
||||
ENV PATH /opt/conda/envs/trl/bin:$PATH
|
||||
# Activate our bash shell
|
||||
RUN chsh -s /bin/bash
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
# Stage 2
|
||||
FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image
|
||||
COPY --from=compile-image /opt/conda /opt/conda
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
||||
RUN chsh -s /bin/bash
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
RUN source activate trl && \
|
||||
python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq
|
||||
|
||||
# Install apt libs
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl git wget && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists*
|
||||
|
||||
# Activate the conda env and install transformers + accelerate from source
|
||||
RUN source activate trl && \
|
||||
python3 -m pip install -U --no-cache-dir \
|
||||
librosa \
|
||||
"soundfile>=0.12.1" \
|
||||
scipy \
|
||||
git+https://github.com/huggingface/transformers \
|
||||
git+https://github.com/huggingface/accelerate \
|
||||
git+https://github.com/huggingface/peft \
|
||||
trl[test]@git+https://github.com/huggingface/trl
|
||||
|
||||
RUN source activate trl && \
|
||||
pip freeze | grep transformers
|
||||
|
||||
RUN echo "source activate trl" >> ~/.profile
|
||||
|
||||
# Activate the virtualenv
|
||||
CMD ["/bin/bash"]
|
@ -1,12 +1,18 @@
|
||||
- sections:
|
||||
- sections:
|
||||
- local: index
|
||||
title: TRL
|
||||
- local: quickstart
|
||||
title: Quickstart
|
||||
- local: installation
|
||||
title: Installation
|
||||
- local: clis
|
||||
title: Get started with Command Line Interfaces (CLIs)
|
||||
- local: how_to_train
|
||||
title: PPO Training FAQ
|
||||
- local: use_model
|
||||
title: Use Trained Models
|
||||
- local: customization
|
||||
title: Customize your Training
|
||||
title: Customize the Training
|
||||
- local: logging
|
||||
title: Understanding Logs
|
||||
title: Get started
|
||||
@ -19,12 +25,42 @@
|
||||
title: Reward Model Training
|
||||
- local: sft_trainer
|
||||
title: Supervised Fine-Tuning
|
||||
- local: ppo_trainer
|
||||
title: PPO Trainer
|
||||
- local: ppov2_trainer
|
||||
title: PPOv2 Trainer
|
||||
- local: rloo_trainer
|
||||
title: RLOO Trainer
|
||||
- local: best_of_n
|
||||
title: Best of N Sampling
|
||||
- local: dpo_trainer
|
||||
title: DPO Trainer
|
||||
- local: online_dpo_trainer
|
||||
title: Online DPO Trainer
|
||||
- local: kto_trainer
|
||||
title: KTO Trainer
|
||||
- local: bco_trainer
|
||||
title: BCO Trainer
|
||||
- local: cpo_trainer
|
||||
title: CPO Trainer
|
||||
- local: ddpo_trainer
|
||||
title: Denoising Diffusion Policy Optimization
|
||||
- local: alignprop_trainer
|
||||
title: AlignProp Trainer
|
||||
- local: orpo_trainer
|
||||
title: ORPO Trainer
|
||||
- local: iterative_sft_trainer
|
||||
title: Iterative Supervised Fine-Tuning
|
||||
- local: callbacks
|
||||
title: Callback Classes
|
||||
- local: judges
|
||||
title: Judge Classes
|
||||
- local: text_environments
|
||||
title: Text Environments
|
||||
title: API
|
||||
- sections:
|
||||
- sections:
|
||||
- local: example_overview
|
||||
title: Example Overview
|
||||
- local: sentiment_tuning
|
||||
title: Sentiment Tuning
|
||||
- local: lora_tuning_peft
|
||||
@ -33,6 +69,8 @@
|
||||
title: Detoxifying a Language Model
|
||||
- local: using_llama_models
|
||||
title: Training StackLlama
|
||||
- local: learning_tools
|
||||
title: Learning to Use Tools
|
||||
- local: multi_adapter_rl
|
||||
title: Multi Adapter RLHF
|
||||
title: Examples
|
||||
|
91
docs/source/alignprop_trainer.mdx
Normal file
91
docs/source/alignprop_trainer.mdx
Normal file
@ -0,0 +1,91 @@
|
||||
# Aligning Text-to-Image Diffusion Models with Reward Backpropagation
|
||||
|
||||
## The why
|
||||
|
||||
If your reward function is differentiable, directly backpropagating gradients from the reward models to the diffusion model is significantly more sample and compute efficient (25x) than doing policy gradient algorithm like DDPO.
|
||||
AlignProp does full backpropagation through time, which allows updating the earlier steps of denoising via reward backpropagation.
|
||||
|
||||
<div style="text-align: center"><img src="https://align-prop.github.io/reward_tuning.png"/></div>
|
||||
|
||||
|
||||
## Getting started with `examples/scripts/alignprop.py`
|
||||
|
||||
The `alignprop.py` script is a working example of using the `AlignProp` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`AlignPropConfig`).
|
||||
|
||||
**Note:** one A100 GPU is recommended to get this running. For lower memory setting, consider setting truncated_backprop_rand to False. With default settings this will do truncated backpropagation with K=1.
|
||||
|
||||
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running
|
||||
|
||||
```batch
|
||||
python alignprop.py --hf_user_access_token <token>
|
||||
```
|
||||
|
||||
To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`
|
||||
|
||||
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
|
||||
|
||||
- The configurable randomized truncation range (`--alignprop_config.truncated_rand_backprop_minmax=(0,50)`) the first number should be equal and greater to 0, while the second number should equal or less to the number of diffusion timesteps (sample_num_steps)
|
||||
- The configurable truncation backprop absolute step (`--alignprop_config.truncated_backprop_timestep=49`) the number should be less than the number of diffusion timesteps (sample_num_steps), it only matters when truncated_backprop_rand is set to False
|
||||
|
||||
## Setting up the image logging hook function
|
||||
|
||||
Expect the function to be given a dictionary with keys
|
||||
```python
|
||||
['image', 'prompt', 'prompt_metadata', 'rewards']
|
||||
|
||||
```
|
||||
and `image`, `prompt`, `prompt_metadata`, `rewards`are batched.
|
||||
You are free to log however you want the use of `wandb` or `tensorboard` is recommended.
|
||||
|
||||
### Key terms
|
||||
|
||||
- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
|
||||
- `prompt` : The prompt is the text that is used to generate the image
|
||||
- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
|
||||
- `image` : The image generated by the Stable Diffusion model
|
||||
|
||||
Example code for logging sampled images with `wandb` is given below.
|
||||
|
||||
```python
|
||||
# for logging these images to wandb
|
||||
|
||||
def image_outputs_hook(image_data, global_step, accelerate_logger):
|
||||
# For the sake of this example, we only care about the last batch
|
||||
# hence we extract the last element of the list
|
||||
result = {}
|
||||
images, prompts, rewards = [image_data['images'],image_data['prompts'],image_data['rewards']]
|
||||
for i, image in enumerate(images):
|
||||
pil = Image.fromarray(
|
||||
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
|
||||
)
|
||||
pil = pil.resize((256, 256))
|
||||
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
|
||||
accelerate_logger.log_images(
|
||||
result,
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
### Using the finetuned model
|
||||
|
||||
Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
||||
pipeline.to("cuda")
|
||||
|
||||
pipeline.load_lora_weights('mihirpd/alignprop-trl-aesthetics')
|
||||
|
||||
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
|
||||
results = pipeline(prompts)
|
||||
|
||||
for prompt, image in zip(prompts,results.images):
|
||||
image.save(f"dump/{prompt}.png")
|
||||
```
|
||||
|
||||
## Credits
|
||||
|
||||
This work is heavily influenced by the repo [here](https://github.com/mihirp1998/AlignProp/) and the associated paper [Aligning Text-to-Image Diffusion Models with Reward Backpropagation
|
||||
by Mihir Prabhudesai, Anirudh Goyal, Deepak Pathak, Katerina Fragkiadaki](https://huggingface.co/papers/2310.03739).
|
139
docs/source/bco_trainer.mdx
Normal file
139
docs/source/bco_trainer.mdx
Normal file
@ -0,0 +1,139 @@
|
||||
# BCO Trainer
|
||||
|
||||
TRL supports the Binary Classifier Optimization (BCO).
|
||||
The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0.
|
||||
For a full example have a look at [`examples/scripts/bco.py`].
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
The BCO trainer expects a very specific format for the dataset as it does not require pairwise preferences. Since the model will be trained to directly optimize examples that consist of a prompt, model completion, and a label to indicate whether the completion is "good" or "bad", we expect a dataset with the following columns:
|
||||
|
||||
- `prompt`
|
||||
- `completion`
|
||||
- `label`
|
||||
|
||||
for example:
|
||||
|
||||
```
|
||||
bco_dataset_dict = {
|
||||
"prompt": [
|
||||
"Hey, hello",
|
||||
"How are you",
|
||||
"What is your name?",
|
||||
"What is your name?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
],
|
||||
"completion": [
|
||||
"hi nice to meet you",
|
||||
"leave me alone",
|
||||
"I don't have a name",
|
||||
"My name is Mary",
|
||||
"Python",
|
||||
"C++",
|
||||
"Java",
|
||||
],
|
||||
"label": [
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired (`True`) or undesired (`False`).
|
||||
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. It is required that the dataset contains at least one desirable and one undesirable completion.
|
||||
|
||||
|
||||
## Expected model format
|
||||
The BCO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
|
||||
|
||||
## Using the `BCOTrainer`
|
||||
|
||||
For a detailed example have a look at the `examples/scripts/bco.py` script. At a high level we need to initialize the `BCOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
|
||||
|
||||
The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
|
||||
|
||||
|
||||
|
||||
```py
|
||||
training_args = BCOConfig(
|
||||
beta=0.1,
|
||||
)
|
||||
|
||||
bco_trainer = BCOTrainer(
|
||||
model,
|
||||
model_ref,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
```
|
||||
After this one can then call:
|
||||
|
||||
```py
|
||||
bco_trainer.train()
|
||||
```
|
||||
|
||||
## Underlying Distribution matching (UDM)
|
||||
|
||||
In practical scenarios, the thumbs-up and thumbs-down datasets are likely to have divergent underlying distributions of prompts.
|
||||
Consider an LLM deployed for user feedback: if the model excels in writing tasks but underperforms in coding, the thumbs-up dataset will be dominated by writing-related prompts, while the thumbs-down dataset will contain mostly coding-related prompts.
|
||||
If the prompts in your desired and undesired datasets differ a lot, it is useful to enable UDM.
|
||||
|
||||
Choose an embedding model and tokenizer:
|
||||
|
||||
```py
|
||||
embedding_model = AutoModel.from_pretrained(your_model_id)
|
||||
embedding_tokenizer = AutoTokenizer.from_pretrained(your_model_id)
|
||||
|
||||
# customize this function depending on your embedding model
|
||||
def embed_prompt(input_ids, attention_mask, model):
|
||||
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
||||
return outputs.last_hidden_state.mean(dim=1)
|
||||
|
||||
embedding_model = Accelerator().prepare_model(self.embedding_model)
|
||||
embedding_func = partial(embed_prompt, model=embedding_model)
|
||||
```
|
||||
|
||||
Set `prompt_sample_size` to defined how many prompts are selected to train the UDM classifier and start the training with the provided embedding function:
|
||||
|
||||
```py
|
||||
training_args = BCOConfig(
|
||||
beta=0.1,
|
||||
prompt_sample_size=512,
|
||||
)
|
||||
|
||||
bco_trainer = BCOTrainer(
|
||||
model,
|
||||
model_ref,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
embedding_func=embedding_func,
|
||||
embedding_tokenizer=self.embedding_tokenizer,
|
||||
)
|
||||
|
||||
bco_trainer.train()
|
||||
```
|
||||
|
||||
### For Mixture of Experts Models: Enabling the auxiliary loss
|
||||
|
||||
MOEs are the most efficient if the load is about equally distributed between experts.
|
||||
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
|
||||
|
||||
This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
|
||||
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
|
||||
|
||||
## BCOTrainer
|
||||
|
||||
[[autodoc]] BCOTrainer
|
||||
|
||||
## BCOConfig
|
||||
|
||||
[[autodoc]] BCOConfig
|
13
docs/source/callbacks.mdx
Normal file
13
docs/source/callbacks.mdx
Normal file
@ -0,0 +1,13 @@
|
||||
# Callbacks
|
||||
|
||||
## SyncRefModelCallback
|
||||
|
||||
[[autodoc]] SyncRefModelCallback
|
||||
|
||||
## RichProgressCallback
|
||||
|
||||
[[autodoc]] RichProgressCallback
|
||||
|
||||
## WinRateCallback
|
||||
|
||||
[[autodoc]] WinRateCallback
|
119
docs/source/clis.mdx
Normal file
119
docs/source/clis.mdx
Normal file
@ -0,0 +1,119 @@
|
||||
# Command Line Interfaces (CLIs)
|
||||
|
||||
You can use TRL to fine-tune your Language Model with Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) or even chat with your model using the TRL CLIs.
|
||||
|
||||
Currently supported CLIs are:
|
||||
|
||||
- `trl sft`: fine-tune a LLM on a text/instruction dataset
|
||||
- `trl dpo`: fine-tune a LLM with DPO on a preference dataset
|
||||
- `trl chat`: quickly spin up a LLM fine-tuned for chatting
|
||||
|
||||
## Fine-tuning with the CLI
|
||||
|
||||
Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task.
|
||||
|
||||
Before using the `sft` or `dpo` commands make sure to run:
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command.
|
||||
|
||||
We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command.
|
||||
|
||||
```yaml
|
||||
model_name_or_path:
|
||||
trl-internal-testing/tiny-random-LlamaForCausalLM
|
||||
dataset_name:
|
||||
imdb
|
||||
dataset_text_field:
|
||||
text
|
||||
report_to:
|
||||
none
|
||||
learning_rate:
|
||||
0.0001
|
||||
lr_scheduler_type:
|
||||
cosine
|
||||
```
|
||||
|
||||
Save that config in a `.yaml` and get started immediately! An example CLI config is available as `examples/cli_configs/example_config.yaml`. Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g. from the root folder:
|
||||
|
||||
```bash
|
||||
trl sft --config examples/cli_configs/example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts
|
||||
```
|
||||
|
||||
Will force-use `cosine_with_restarts` for `lr_scheduler_type`.
|
||||
|
||||
### Supported Arguments
|
||||
|
||||
We do support all arguments from `transformers.TrainingArguments`, for loading your model, we support all arguments from `~trl.ModelConfig`:
|
||||
|
||||
[[autodoc]] ModelConfig
|
||||
|
||||
You can pass any of these arguments either to the CLI or the YAML file.
|
||||
|
||||
### Supervised Fine-tuning (SFT)
|
||||
|
||||
Follow the basic instructions above and run `trl sft --output_dir <output_dir> <*args>`:
|
||||
|
||||
```bash
|
||||
trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir opt-sft-imdb
|
||||
```
|
||||
|
||||
The SFT CLI is based on the `examples/scripts/sft.py` script.
|
||||
|
||||
### Direct Policy Optimization (DPO)
|
||||
|
||||
To use the DPO CLI, you need to have a dataset in the TRL format such as
|
||||
|
||||
* TRL's Anthropic HH dataset: https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-helpful-base-trl-style
|
||||
* TRL's OpenAI TL;DR summarization dataset: https://huggingface.co/datasets/trl-internal-testing/tldr-preference-trl-style
|
||||
|
||||
These datasets always have at least three columns `prompt, chosen, rejected`:
|
||||
|
||||
* `prompt` is a list of strings.
|
||||
* `chosen` is the chosen response in [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
|
||||
* `rejected` is the rejected response [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
|
||||
|
||||
|
||||
To do a quick start, you can run the following command:
|
||||
|
||||
```bash
|
||||
trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style
|
||||
```
|
||||
|
||||
|
||||
The DPO CLI is based on the `examples/scripts/dpo.py` script.
|
||||
|
||||
|
||||
#### Custom preference dataset
|
||||
|
||||
Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`):
|
||||
|
||||
```bash
|
||||
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
|
||||
```
|
||||
|
||||
## Chat interface
|
||||
|
||||
The chat CLI lets you quickly load the model and talk to it. Simply run the following:
|
||||
|
||||
```bash
|
||||
trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> To use the chat CLI with the developer installation, you must run `make dev`
|
||||
>
|
||||
|
||||
Note that the chat interface relies on the tokenizer's [chat template](https://huggingface.co/docs/transformers/chat_templating) to format the inputs for the model. Make sure your tokenizer has a chat template defined.
|
||||
|
||||
Besides talking to the model there are a few commands you can use:
|
||||
|
||||
- **clear**: clears the current conversation and start a new one
|
||||
- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input
|
||||
- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';').
|
||||
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
|
||||
- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
|
||||
- **exit**: closes the interface
|
||||
|
||||
The default examples are defined in `examples/scripts/config/default_chat_config.yaml` but you can pass your own with `--config CONFIG_FILE` where you can also specify the default generation parameters.
|
113
docs/source/cpo_trainer.mdx
Normal file
113
docs/source/cpo_trainer.mdx
Normal file
@ -0,0 +1,113 @@
|
||||
# CPO Trainer
|
||||
|
||||
Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by Haoran Xu, Amr Sharaf, Yunmo Chen, Weiting Tan, Lingfeng Shen, Benjamin Van Durme, Kenton Murray, and Young Jin Kim. At a high-level, CPO trains models to
|
||||
avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation to the DPO loss and can be applied to other domains like chat.
|
||||
|
||||
CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.
|
||||
|
||||
## SimPO
|
||||
The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the `CPOTrainer`. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, we can use SimPO easily by turning on `loss_type="simpo"` and `cpo_alpha=0` in the `CPOConfig`.
|
||||
|
||||
## CPO-SimPO
|
||||
We also offer the combined use of CPO and SimPO, which enables more stable training and improved performance. Learn more details at [CPO-SimPO Github](https://github.com/fe1ixxu/CPO_SIMPO). To use this method, simply enable SimPO by setting `loss_type="simpo"` and a non-zero `cpo_alpha` in the CPOConfig.
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
The CPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows:
|
||||
|
||||
- `prompt`
|
||||
- `chosen`
|
||||
- `rejected`
|
||||
|
||||
for example:
|
||||
|
||||
```py
|
||||
cpo_dataset_dict = {
|
||||
"prompt": [
|
||||
"hello",
|
||||
"how are you",
|
||||
"What is your name?",
|
||||
"What is your name?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
],
|
||||
"chosen": [
|
||||
"hi nice to meet you",
|
||||
"I am fine",
|
||||
"My name is Mary",
|
||||
"My name is Mary",
|
||||
"Python",
|
||||
"Python",
|
||||
"Java",
|
||||
],
|
||||
"rejected": [
|
||||
"leave me alone",
|
||||
"I am not fine",
|
||||
"Whats it to you?",
|
||||
"I dont have a name",
|
||||
"Javascript",
|
||||
"C++",
|
||||
"C++",
|
||||
],
|
||||
}
|
||||
```
|
||||
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
|
||||
|
||||
## Expected model format
|
||||
The CPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
|
||||
|
||||
## Using the `CPOTrainer`
|
||||
For a detailed example have a look at the `examples/scripts/cpo.py` script. At a high level we need to initialize the `CPOTrainer` with a `model` we wish to train. **Note that CPOTrainer eliminates the need to use the reference model, simplifying the optimization process.** The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above.
|
||||
|
||||
```py
|
||||
cpo_config = CPOConfig(
|
||||
beta=0.1,
|
||||
)
|
||||
|
||||
cpo_trainer = CPOTrainer(
|
||||
model,
|
||||
args=cpo_config,
|
||||
train_dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
```
|
||||
After this one can then call:
|
||||
|
||||
```py
|
||||
cpo_trainer.train()
|
||||
```
|
||||
|
||||
## Loss functions
|
||||
|
||||
Given the preference data, the `CPOTrainer` uses the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression.
|
||||
|
||||
The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. The `CPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin.
|
||||
|
||||
The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the CPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. Note that the `beta` parameter is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike CPO which is summed only).
|
||||
|
||||
### For Mixture of Experts Models: Enabling the auxiliary loss
|
||||
|
||||
MOEs are the most efficient if the load is about equally distributed between experts.
|
||||
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
|
||||
|
||||
This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
|
||||
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
|
||||
|
||||
## Logging
|
||||
|
||||
While training and evaluating we record the following reward metrics:
|
||||
|
||||
* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
|
||||
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
|
||||
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
|
||||
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
|
||||
* `nll_loss`: the mean negative log likelihood loss of the policy model for the chosen responses
|
||||
|
||||
## CPOTrainer
|
||||
|
||||
[[autodoc]] CPOTrainer
|
||||
|
||||
## CPOConfig
|
||||
|
||||
[[autodoc]] CPOConfig
|
@ -1,22 +1,50 @@
|
||||
# Training customization
|
||||
|
||||
At `trl` we provide the possibility to give enough modularity to users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques.
|
||||
TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques.
|
||||
|
||||
## Run on multiple GPUs / nodes
|
||||
## Train on multiple GPUs / nodes
|
||||
|
||||
We leverage `accelerate` to enable users to run their training on multiple GPUs or nodes. You should first create your accelerate config by simply running:
|
||||
The trainers in TRL use 🤗 Accelerate to enable distributed training across multiple GPUs or nodes. To do so, first create an 🤗 Accelerate config file by running
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Then make sure you have selected multi-gpu / multi-node setup. You can then run your training by simply running:
|
||||
and answering the questions according to your multi-gpu / multi-node setup. You can then launch distributed training by running:
|
||||
|
||||
```bash
|
||||
accelerate launch your_script.py
|
||||
```
|
||||
|
||||
Refer to the [examples page](https://github.com/lvwerra/trl/tree/main/examples) for more details
|
||||
We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.:
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
|
||||
```
|
||||
|
||||
Refer to the [examples page](https://github.com/huggingface/trl/tree/main/examples) for more details.
|
||||
|
||||
### Distributed training with DeepSpeed
|
||||
|
||||
All of the trainers in TRL can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run:
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_your_script.py --all_arguments_of_the_script
|
||||
```
|
||||
|
||||
Note that for ZeRO-3, a small tweak is needed to initialize your reward model on the correct device via the `zero3_init_context_manager()` context manager. In particular, this is needed to avoid DeepSpeed hanging after a fixed number of training steps. Here is a snippet of what is involved from the [`sentiment_tuning`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) example:
|
||||
|
||||
```python
|
||||
ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin
|
||||
if ds_plugin is not None and ds_plugin.is_zero3_init_enabled():
|
||||
with ds_plugin.zero3_init_context_manager(enable=False):
|
||||
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
|
||||
else:
|
||||
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
|
||||
```
|
||||
|
||||
Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin.
|
||||
|
||||
|
||||
## Use different optimizers
|
||||
|
||||
@ -28,7 +56,7 @@ from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
|
||||
# 2. define config
|
||||
@ -41,7 +69,7 @@ optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate)
|
||||
|
||||
|
||||
# 3. initialize trainer
|
||||
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer)
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer)
|
||||
```
|
||||
|
||||
For memory efficient fine-tuning, you can also pass `Adam8bit` optimizer from `bitsandbytes`:
|
||||
@ -55,7 +83,7 @@ from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
|
||||
# 2. define config
|
||||
@ -67,17 +95,17 @@ config = PPOConfig(**ppo_config)
|
||||
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=config.learning_rate)
|
||||
|
||||
# 3. initialize trainer
|
||||
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer)
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer)
|
||||
```
|
||||
|
||||
### Use LION optimizer
|
||||
|
||||
You can use the new [LION optimizer from Google](https://arxiv.org/abs/2302.06675) as well, first take the source code of the optimizer definition [here](https://github.com/lucidrains/lion-pytorch/blob/main/lion_pytorch/lion_pytorch.py), and copy it so that you can import the optimizer. Make sure to initialize the optimizer by considering the trainable parameters only for a more memory efficient training:
|
||||
You can use the new [LION optimizer from Google](https://huggingface.co/papers/2302.06675) as well, first take the source code of the optimizer definition [here](https://github.com/lucidrains/lion-pytorch/blob/main/lion_pytorch/lion_pytorch.py), and copy it so that you can import the optimizer. Make sure to initialize the optimizer by considering the trainable parameters only for a more memory efficient training:
|
||||
```python
|
||||
optimizer = Lion(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.config.learning_rate)
|
||||
|
||||
...
|
||||
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer)
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer)
|
||||
```
|
||||
We advise you to use the learning rate that you would use for `Adam` divided by 3 as pointed out [here](https://github.com/lucidrains/lion-pytorch#lion---pytorch). We observed an improvement when using this optimizer compared to classic Adam (check the full logs [here](https://wandb.ai/distill-bloom/trl/runs/lj4bheke?workspace=user-younesbelkada)):
|
||||
|
||||
@ -96,7 +124,7 @@ from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
|
||||
# 2. define config
|
||||
@ -109,7 +137,7 @@ optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate)
|
||||
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
|
||||
|
||||
# 3. initialize trainer
|
||||
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer, optimizer=optimizer, lr_scheduler=lr_scheduler)
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer, lr_scheduler=lr_scheduler)
|
||||
```
|
||||
|
||||
## Memory efficient fine-tuning by sharing layers
|
||||
@ -122,13 +150,13 @@ from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m')
|
||||
model_ref = create_reference_model(model, num_shared_layers=6)
|
||||
ref_model = create_reference_model(model, num_shared_layers=6)
|
||||
tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
|
||||
|
||||
# 2. initialize trainer
|
||||
ppo_config = {'batch_size': 1}
|
||||
config = PPOConfig(**ppo_config)
|
||||
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)
|
||||
```
|
||||
|
||||
## Pass 8-bit reference models
|
||||
@ -150,13 +178,13 @@ from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m')
|
||||
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m', device_map="auto", load_in_8bit=True)
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m', device_map="auto", load_in_8bit=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
|
||||
|
||||
# 2. initialize trainer
|
||||
ppo_config = {'batch_size': 1}
|
||||
config = PPOConfig(**ppo_config)
|
||||
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)
|
||||
```
|
||||
|
||||
## Use the CUDA cache optimizer
|
||||
@ -167,31 +195,22 @@ When training large models, you should better handle the CUDA cache by iterative
|
||||
config = PPOConfig(..., optimize_cuda_cache=True)
|
||||
```
|
||||
|
||||
## Use correctly DeepSpeed stage 3:
|
||||
|
||||
A small tweak need to be added to your training script to use DeepSpeed stage 3 correctly. You need to properly initialize your reward model on the correct device using the `zero3_init_context_manager` context manager. Here is an example adapted for the `gpt2-sentiment` script:
|
||||
|
||||
## Use score scaling/normalization/clipping
|
||||
As suggested by [Secrets of RLHF in Large Language Models Part I: PPO](https://huggingface.co/papers/2307.04964), we support score (aka reward) scaling/normalization/clipping to improve training stability via `PPOConfig`:
|
||||
```python
|
||||
ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin
|
||||
if ds_plugin is not None and ds_plugin.is_zero3_init_enabled():
|
||||
with ds_plugin.zero3_init_context_manager(enable=False):
|
||||
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
|
||||
else:
|
||||
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
|
||||
from trl import PPOConfig
|
||||
|
||||
ppo_config = {
|
||||
use_score_scaling=True,
|
||||
use_score_norm=True,
|
||||
score_clip=0.5,
|
||||
}
|
||||
config = PPOConfig(**ppo_config)
|
||||
```
|
||||
|
||||
## Use torch distributed
|
||||
torch.distributed package provides PyTorch natives method to distribute a network over several machines (mostly useful if there are several GPU nodes). It copies the model on each GPU, runs the forward and backward on each and then applies the mean of gradient of all GPUs for each one. If running torch 1.XX, you can call `torch.distributed.launch`, like
|
||||
|
||||
`python -m torch.distributed.launch --nproc_per_node=1 reward_summarization.py --bf16`
|
||||
|
||||
For torch 2.+ `torch.distributed.launch` is deprecated and one needs to run:
|
||||
`torchrun --nproc_per_node=1 reward_summarization.py --bf16`
|
||||
or
|
||||
`python -m torch.distributed.run --nproc_per_node=1 reward_summarization.py --bf16`
|
||||
|
||||
Note that using `python -m torch.distributed.launch --nproc_per_node=1 reward_summarization.py --bf16` with torch 2.0 ends in
|
||||
To run `ppo.py`, you can use the following command:
|
||||
```
|
||||
ValueError: Some specified arguments are not used by the HfArgumentParser: ['--local-rank=0']
|
||||
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 194889) of binary: /home/ubuntu/miniconda3/envs/trl/bin/python
|
||||
python examples/scripts/ppo.py --log_with wandb --use_score_scaling --use_score_norm --score_clip 0.5
|
||||
```
|
||||
|
119
docs/source/ddpo_trainer.mdx
Normal file
119
docs/source/ddpo_trainer.mdx
Normal file
@ -0,0 +1,119 @@
|
||||
# Denoising Diffusion Policy Optimization
|
||||
## The why
|
||||
|
||||
| Before | After DDPO finetuning |
|
||||
| --- | --- |
|
||||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_squirrel.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_squirrel.png"/></div> |
|
||||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_crab.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_crab.png"/></div> |
|
||||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_starfish.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_starfish.png"/></div> |
|
||||
|
||||
|
||||
## Getting started with Stable Diffusion finetuning with reinforcement learning
|
||||
|
||||
The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers`
|
||||
library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers.
|
||||
Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to made.
|
||||
|
||||
There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.**
|
||||
There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide.
|
||||
|
||||
The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is the only way to constrain the scheduler step output to an output type befitting of the algorithm at hand (DDPO).
|
||||
|
||||
For a more detailed look into the interface and the associated default implementation, go [here](https://github.com/lvwerra/trl/tree/main/trl/models/modeling_sd_base.py)
|
||||
|
||||
Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training.
|
||||
|
||||
Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images.
|
||||
|
||||
## Getting started with `examples/scripts/ddpo.py`
|
||||
|
||||
The `ddpo.py` script is a working example of using the `DDPO` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`DDPOConfig`).
|
||||
|
||||
**Note:** one A100 GPU is recommended to get this running. Anything below a A100 will not be able to run this example script and even if it does via relatively smaller sized parameters, the results will most likely be poor.
|
||||
|
||||
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running
|
||||
|
||||
```batch
|
||||
python ddpo.py --hf_user_access_token <token>
|
||||
```
|
||||
|
||||
To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`
|
||||
|
||||
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
|
||||
|
||||
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) should be greater than or equal to the configurable training batch size (`--ddpo_config.train_batch_size=3`)
|
||||
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by the configurable train batch size (`--ddpo_config.train_batch_size=3`)
|
||||
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by both the configurable gradient accumulation steps (`--ddpo_config.train_gradient_accumulation_steps=1`) and the configurable accelerator processes count
|
||||
|
||||
## Setting up the image logging hook function
|
||||
|
||||
Expect the function to be given a list of lists of the form
|
||||
```python
|
||||
[[image, prompt, prompt_metadata, rewards, reward_metadata], ...]
|
||||
|
||||
```
|
||||
and `image`, `prompt`, `prompt_metadata`, `rewards`, `reward_metadata` are batched.
|
||||
The last list in the lists of lists represents the last sample batch. You are likely to want to log this one
|
||||
While you are free to log however you want the use of `wandb` or `tensorboard` is recommended.
|
||||
|
||||
### Key terms
|
||||
|
||||
- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
|
||||
- `reward_metadata` : The reward metadata is the metadata associated with the reward. Think of this as extra information payload delivered alongside the reward
|
||||
- `prompt` : The prompt is the text that is used to generate the image
|
||||
- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
|
||||
- `image` : The image generated by the Stable Diffusion model
|
||||
|
||||
Example code for logging sampled images with `wandb` is given below.
|
||||
|
||||
```python
|
||||
# for logging these images to wandb
|
||||
|
||||
def image_outputs_hook(image_data, global_step, accelerate_logger):
|
||||
# For the sake of this example, we only care about the last batch
|
||||
# hence we extract the last element of the list
|
||||
result = {}
|
||||
images, prompts, _, rewards, _ = image_data[-1]
|
||||
for i, image in enumerate(images):
|
||||
pil = Image.fromarray(
|
||||
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
|
||||
)
|
||||
pil = pil.resize((256, 256))
|
||||
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
|
||||
accelerate_logger.log_images(
|
||||
result,
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
### Using the finetuned model
|
||||
|
||||
Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
|
||||
|
||||
```python
|
||||
|
||||
import torch
|
||||
from trl import DefaultDDPOStableDiffusionPipeline
|
||||
|
||||
pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/ddpo-finetuned-sd-model")
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
# memory optimization
|
||||
pipeline.vae.to(device, torch.float16)
|
||||
pipeline.text_encoder.to(device, torch.float16)
|
||||
pipeline.unet.to(device, torch.float16)
|
||||
|
||||
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
|
||||
results = pipeline(prompts)
|
||||
|
||||
for prompt, image in zip(prompts,results.images):
|
||||
image.save(f"{prompt}.png")
|
||||
|
||||
```
|
||||
|
||||
## Credits
|
||||
|
||||
This work is heavily influenced by the repo [here](https://github.com/kvablack/ddpo-pytorch) and the associated paper [Training Diffusion Models
|
||||
with Reinforcement Learning by Kevin Black, Michael Janner, Yilan Du, Ilya Kostrikov, Sergey Levine](https://huggingface.co/papers/2305.13301).
|
@ -4,12 +4,12 @@ Language models (LMs) are known to sometimes generate toxic outputs. In this exa
|
||||
|
||||
Read this section to follow our investigation on how we can reduce toxicity in a wide range of LMs, from 125m parameters to 6B parameters!
|
||||
|
||||
Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/lvwerra/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo:
|
||||
Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/huggingface/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo:
|
||||
|
||||
| File | Description | Colab link |
|
||||
|---|---| --- |
|
||||
| [`gpt-j-6b-toxicity.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x |
|
||||
| [`evaluate-toxicity.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x |
|
||||
| [`gpt-j-6b-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x |
|
||||
| [`evaluate-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x |
|
||||
| [Interactive Space](https://huggingface.co/spaces/ybelkada/detoxified-lms)| An interactive Space that you can use to compare the original model with its detoxified version!| x |
|
||||
|
||||
## Context
|
||||
@ -155,7 +155,7 @@ We report the toxicity score of 400 sampled examples, compute its mean and stand
|
||||
| `EleutherAI/gpt-neo-125m` | 0.1627 | 0.2997 |
|
||||
| `ybelkada/gpt-neo-125m-detox` | **0.1148** | **0.2506** |
|
||||
| --- | --- | --- |
|
||||
| `EleutherAI/gpt-neo-2.7B` | 0.1884 | ,0.3178 |
|
||||
| `EleutherAI/gpt-neo-2.7B` | 0.1884 | 0.3178 |
|
||||
| `ybelkada/gpt-neo-2.7B-detox` | **0.0916** | **0.2104** |
|
||||
| --- | --- | --- |
|
||||
| `EleutherAI/gpt-j-6B` | 0.1699 | 0.3033 |
|
||||
@ -174,7 +174,7 @@ Below are few generation examples of `gpt-j-6b-detox` model:
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-toxicity-examples.png">
|
||||
</div>
|
||||
|
||||
The evaluation script can be found [here](https://github.com/lvwerra/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py).
|
||||
The evaluation script can be found [here](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py).
|
||||
|
||||
### Discussions
|
||||
|
||||
|
@ -1,10 +1,25 @@
|
||||
# DPO Trainer
|
||||
|
||||
TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) by Rafailov et al., 2023. For a full example have a look at [`examples/dpo.py`](https://github.com/lvwerra/trl/blob/main/examples/dpo.py).
|
||||
|
||||
TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://huggingface.co/papers/2305.18290) by Rafailov et al., 2023. For a full example have a look at [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py).
|
||||
|
||||
The first step as always is to train your SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.
|
||||
|
||||
## How DPO works
|
||||
|
||||
Fine-tuning a language model via DPO consists of two steps and is easier than PPO:
|
||||
|
||||
1. **Data collection**: Gather a preference dataset with positive and negative selected pairs of generation, given a prompt.
|
||||
2. **Optimization**: Maximize the log-likelihood of the DPO loss directly.
|
||||
|
||||
DPO-compatible datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo). You can also explore the [librarian-bots/direct-preference-optimization-datasets](https://huggingface.co/collections/librarian-bots/direct-preference-optimization-datasets-66964b12835f46289b6ef2fc) Collection to identify datasets that are likely to support DPO training.
|
||||
|
||||
This process is illustrated in the sketch below (from [figure 1 of the original paper](https://huggingface.co/papers/2305.18290)):
|
||||
|
||||
<img width="835" alt="Screenshot 2024-03-19 at 12 39 41" src="https://github.com/huggingface/trl/assets/49240599/9150fac6-3d88-4ca2-8ec6-2a6f3473216d">
|
||||
|
||||
Read more about DPO algorithm in the [original paper](https://huggingface.co/papers/2305.18290).
|
||||
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
The DPO trainer expects a very specific format for the dataset. Since the model will be trained to directly optimize the preference of which sentence is the most relevant, given two sentences. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:
|
||||
@ -13,7 +28,7 @@ The DPO trainer expects a very specific format for the dataset. Since the model
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/rlhf-antropic-example.png", width="50%">
|
||||
</div>
|
||||
|
||||
Therefore the final dataset object should contain these 3 entries if you use the default `DPODataCollatorWithPadding` data collator. The entries should be named:
|
||||
Therefore the final dataset object should contain these 3 entries if you use the default [`DPODataCollatorWithPadding`] data collator. The entries should be named:
|
||||
|
||||
- `prompt`
|
||||
- `chosen`
|
||||
@ -55,20 +70,52 @@ dpo_dataset_dict = {
|
||||
|
||||
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
|
||||
|
||||
## Using the `DPOTrainer`
|
||||
[`DPOTrainer`] can be used to fine-tune visual language models (VLMs). In this case, the dataset must also contain the key `images`, and the trainer's `tokenizer` is the VLM's `processor`. For example, for Idefics2, the processor expects the dataset to have the following format:
|
||||
|
||||
For a detailed example have a look at the `examples/dpo.py` script. At a high level we need to initialize the `DPOTrainer` with a `model` we wish to train, a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response, the `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above:
|
||||
Note: Currently, VLM support is exclusive to Idefics2 and does not extend to other VLMs.
|
||||
|
||||
```py
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
model_ref,
|
||||
args=training_args,
|
||||
dpo_dataset_dict = {
|
||||
'images': [
|
||||
[Image.open('beach.jpg')],
|
||||
[Image.open('street.jpg')],
|
||||
],
|
||||
'prompt': [
|
||||
'The image <image> shows',
|
||||
'<image> The image depicts',
|
||||
],
|
||||
'chosen': [
|
||||
'a sunny beach with palm trees.',
|
||||
'a busy street with several cars and buildings.',
|
||||
],
|
||||
'rejected': [
|
||||
'a snowy mountain with skiers.',
|
||||
'a calm countryside with green fields.',
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
## Expected model format
|
||||
|
||||
The DPO trainer expects a model of `AutoModelForCausalLM` or `AutoModelForVision2Seq`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
|
||||
|
||||
## Using the `DPOTrainer`
|
||||
|
||||
For a detailed example have a look at the `examples/scripts/dpo.py` script. At a high level we need to initialize the [`DPOTrainer`] with a `model` we wish to train, a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response, the `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
|
||||
|
||||
```py
|
||||
training_args = DPOConfig(
|
||||
beta=0.1,
|
||||
)
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
ref_model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer=tokenizer, # for visual language models, use tokenizer=processor instead
|
||||
)
|
||||
```
|
||||
|
||||
After this one can then call:
|
||||
|
||||
```py
|
||||
@ -77,6 +124,174 @@ dpo_trainer.train()
|
||||
|
||||
Note that the `beta` is the temperature parameter for the DPO loss, typically something in the range of `0.1` to `0.5`. We ignore the reference model as `beta` -> 0.
|
||||
|
||||
## Loss functions
|
||||
|
||||
Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. To use this loss, set the `loss_type="sigmoid"` (default) in the [`DPOConfig`].
|
||||
|
||||
The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. To use this loss, set the `loss_type="hinge"` in the [`DPOConfig`]. In this case, the `beta` is the reciprocal of the margin.
|
||||
|
||||
The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. To use the loss set the `loss_type="ipo"` in the [`DPOConfig`]. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only).
|
||||
|
||||
The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0).
|
||||
|
||||
The [EXO](https://huggingface.co/papers/2402.00856) authors propose to minimize the reverse KL instead of the negative log-sigmoid loss of DPO which corresponds to forward KL. To use the loss set the `loss_type="exo_pair"` in the [`DPOConfig`]. Setting non-zero `label_smoothing` (default `1e-3`) leads to a simplified version of EXO on pair-wise preferences (see Eqn. (16) of the [EXO paper](https://huggingface.co/papers/2402.00856)). The full version of EXO uses `K>2` completions generated by the SFT policy, which becomes an unbiased estimator of the PPO objective (up to a constant) when `K` is sufficiently large.
|
||||
|
||||
The [NCA](https://huggingface.co/papers/2402.05369) authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood. To use the loss set the `loss_type="nca_pair"` in the [`DPOConfig`].
|
||||
|
||||
The [Robust DPO](https://huggingface.co/papers/2403.00409) authors propose an unbiased estimate of the DPO loss that is robust to preference noise in the data. Like in cDPO, it assumes that the preference labels are noisy with some probability. In this approach, the `label_smoothing` parameter in the [`DPOConfig`] is used to model the probability of existing label noise. To apply this conservative loss, set `label_smoothing` to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0) and set the `loss_type="robust"` in the [`DPOConfig`].
|
||||
|
||||
The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. To use this loss, set the `loss_type="bco_pair"` in the [`DPOConfig`].
|
||||
|
||||
The [TR-DPO](https://huggingface.co/papers/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model=True` in the [`DPOConfig`].
|
||||
|
||||
The [RPO](https://huggingface.co/papers/2404.19733) paper implements an iterative preference tuning algorithm using a loss related to the RPO loss in this [paper](https://huggingface.co/papers/2405.16436) that essentially consists of a weighted SFT loss on the chosen preferences together with the DPO loss. To use this loss, set the `rpo_alpha` in the [`DPOConfig`] to an appropriate value. The paper suggests setting this weight to 1.0.
|
||||
|
||||
The [SPPO](https://huggingface.co/papers/2405.00675) authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. The implementation approximates this algorithm by employing hard label probabilities, assigning 1 to the winner and 0 to the loser. To use this loss, set the `loss_type="sppo_hard"` in the [`DPOConfig`].
|
||||
|
||||
The [AOT](https://huggingface.co/papers/2406.05882) authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, `loss_type="aot"` is appropriate for paired datasets, where each prompt has both chosen and rejected responses; `loss_type="aot_pair"` is for unpaired datasets. In a nutshell, `loss_type="aot"` ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. `loss_type="aot_pair"` ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size.
|
||||
|
||||
The [APO](https://huggingface.co/papers/2408.06266) method introduces an "anchored" version of the alignment objective. There are two variants: `apo_zero` and `apo_down`. The `apo_zero` loss increases the likelihood of winning outputs while decreasing the likelihood of losing outputs, making it suitable when the model is less performant than the winning outputs. On the other hand, `apo_down` decreases the likelihood of both winning and losing outputs, but with a stronger emphasis on reducing the likelihood of losing outputs. This variant is more effective when the model is better than the winning outputs. To use these losses, set `loss_type="apo_zero"` or `loss_type="apo_down"` in the [`DPOConfig`].
|
||||
|
||||
### For Mixture of Experts Models: Enabling the auxiliary loss
|
||||
|
||||
MOEs are the most efficient if the load is about equally distributed between experts.
|
||||
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
|
||||
|
||||
This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
|
||||
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
|
||||
|
||||
## Logging
|
||||
|
||||
While training and evaluating we record the following reward metrics:
|
||||
|
||||
- `rewards/chosen`: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta
|
||||
- `rewards/rejected`: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta
|
||||
- `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
|
||||
- `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
|
||||
|
||||
## Accelerate DPO fine-tuning using `unsloth`
|
||||
|
||||
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks for DPO listed below:
|
||||
|
||||
| GPU | Model | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
|
||||
| -------- | --------- | ---------- | --- | ---------------------- | ---------- | ------------- |
|
||||
| A100 40G | Zephyr 7b | Ultra Chat | 1x | 1.24x | **1.88x** | -11.6% |
|
||||
| Tesla T4 | Zephyr 7b | Ultra Chat | 1x | 1.09x | **1.55x** | -18.6% |
|
||||
|
||||
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number.
|
||||
|
||||
# Load model
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name = "unsloth/zephyr-sft",
|
||||
max_seq_length = max_seq_length,
|
||||
dtype = None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
||||
load_in_4bit = True, # Use 4bit quantization to reduce memory usage. Can be False.
|
||||
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
|
||||
)
|
||||
|
||||
# Do model patching and add fast LoRA weights
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r = 16,
|
||||
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
|
||||
"gate_proj", "up_proj", "down_proj",],
|
||||
lora_alpha = 16,
|
||||
lora_dropout = 0, # Dropout = 0 is currently optimized
|
||||
bias = "none", # Bias = "none" is currently optimized
|
||||
use_gradient_checkpointing = True,
|
||||
random_state = 3407,
|
||||
)
|
||||
|
||||
training_args = DPOConfig(
|
||||
output_dir="./output",
|
||||
beta=0.1,
|
||||
)
|
||||
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
ref_model=None,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
dpo_trainer.train()
|
||||
```
|
||||
|
||||
The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).
|
||||
|
||||
## Reference model considerations with PEFT
|
||||
|
||||
You have three main options (plus several variants) for how the reference model works when using PEFT, assuming the model that you would like to further enhance with DPO was tuned using (Q)LoRA.
|
||||
|
||||
1. Simply create two instances of the model, each loading your adapter - works fine but is very inefficient.
|
||||
2. Merge the adapter into the base model, create another adapter on top, then leave the `ref_model` param null, in which case DPOTrainer will unload the adapter for reference inference - efficient, but has potential downsides discussed below.
|
||||
3. Load the adapter twice with different names, then use `set_adapter` during training to swap between the adapter being DPO'd and the reference adapter - slightly less efficient compared to 2 (~adapter size VRAM overhead), but avoids the pitfalls.
|
||||
|
||||
### Downsides to merging QLoRA before DPO (approach 2)
|
||||
|
||||
As suggested by [Benjamin Marie](https://medium.com/@bnjmn_marie/dont-merge-your-lora-adapter-into-a-4-bit-llm-65b6da287997), the best option for merging QLoRA adapters is to first dequantize the base model, then merge the adapter. Something similar to [this script](https://github.com/jondurbin/qlora/blob/main/qmerge.py).
|
||||
|
||||
However, after using this approach, you will have an unquantized base model. Therefore, to use QLoRA for DPO, you will need to re-quantize the merged model or use the unquantized merge (resulting in higher memory demand).
|
||||
|
||||
### Using option 3 - load the adapter twice
|
||||
|
||||
To avoid the downsides with option 2, you can load your fine-tuned adapter into the model twice, with different names, and set the model/ref adapter names in [`DPOTrainer`].
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
# Load the base model.
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
llm_int8_threshold=6.0,
|
||||
llm_int8_has_fp16_weight=False,
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"mistralai/mixtral-8x7b-v0.1",
|
||||
load_in_4bit=True,
|
||||
quantization_config=bnb_config,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
# Load the adapter.
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
"/path/to/peft",
|
||||
is_trainable=True,
|
||||
adapter_name="train",
|
||||
)
|
||||
# Load the adapter a second time, with a different name, which will be our reference model.
|
||||
model.load_adapter("/path/to/peft", adapter_name="reference")
|
||||
|
||||
# Initialize the trainer, without a ref_model param.
|
||||
training_args = DPOConfig(
|
||||
model_adapter_name="train",
|
||||
ref_adapter_name="reference",
|
||||
)
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
args=training_args,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
## DPOTrainer
|
||||
|
||||
[[autodoc]] DPOTrainer
|
||||
[[autodoc]] DPOTrainer
|
||||
|
||||
## DPOConfig
|
||||
|
||||
[[autodoc]] DPOConfig
|
||||
|
82
docs/source/example_overview.md
Normal file
82
docs/source/example_overview.md
Normal file
@ -0,0 +1,82 @@
|
||||
# Examples
|
||||
|
||||
|
||||
## Introduction
|
||||
|
||||
The examples should work in any of the following settings (with the same script):
|
||||
- single GPU
|
||||
- multi GPUS (using PyTorch distributed mode)
|
||||
- multi GPUS (using DeepSpeed ZeRO-Offload stages 1, 2, & 3)
|
||||
- fp16 (mixed-precision), fp32 (normal precision), or bf16 (bfloat16 precision)
|
||||
|
||||
To run it in each of these various modes, first initialize the accelerate
|
||||
configuration with `accelerate config`
|
||||
|
||||
**NOTE to train with a 4-bit or 8-bit model**, please run
|
||||
|
||||
```bash
|
||||
pip install --upgrade trl[quantization]
|
||||
```
|
||||
|
||||
|
||||
## Accelerate Config
|
||||
For all the examples, you'll need to generate a 🤗 Accelerate config file with:
|
||||
|
||||
```shell
|
||||
accelerate config # will prompt you to define the training configuration
|
||||
```
|
||||
|
||||
Then, it is encouraged to launch jobs with `accelerate launch`!
|
||||
|
||||
|
||||
# Maintained Examples
|
||||
|
||||
|
||||
|
||||
| File | Description |
|
||||
| ----------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| [`examples/scripts/alignprop.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/alignprop.py) | This script shows how to use the [`AlignPropTrainer`] to fine-tune a diffusion model. |
|
||||
| [`examples/scripts/bco.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/bco.py) | This script shows how to use the [`KTOTrainer`] with the BCO loss to fine-tune a model to increase instruction-following, truthfulness, honesty and helpfulness using the [openbmb/UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset. |
|
||||
| [`examples/scripts/chat.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/chat.py) | This script allows you to load and use a model as a chatbot. |
|
||||
| [`examples/scripts/cpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/cpo.py) | This script shows how to use the [`CPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`examples/scripts/ddpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ddpo.py) | This script shows how to use the [`DDPOTrainer`] to fine-tune a stable diffusion model using reinforcement learning. |
|
||||
| [`examples/scripts/dpo_visual.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_visual.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a Vision Language Model to reduce hallucinations using the [openbmb/RLAIF-V-Dataset](https://huggingface.co/datasets/openbmb/RLAIF-V-Dataset) dataset. |
|
||||
| [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a stable to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`KTOTrainer`] to fine-tune a model. |
|
||||
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`examples/scripts/ppo_multi_adapter.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo_multi_adapter.py) | This script shows how to use the [`PPOTrainer`] to train a single base model with multiple adapters. Requires you to run the example script with the reward model training beforehand. |
|
||||
| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a sentiment analysis model using [IMDB dataset](https://huggingface.co/datasets/stanfordnlp/imdb). |
|
||||
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a reward model on your own dataset. |
|
||||
| [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a model or adapters into a target dataset. |
|
||||
| [`examples/scripts/vsft_llava.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/vsft_llava.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested on a [LLaVA 1.5]([llava-hf/llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf)) model so users may see unexpected behaviour in other model architectures. |
|
||||
|
||||
Here are also some easier-to-run colab notebooks that you can use to get started with TRL:
|
||||
|
||||
| File | Description |
|
||||
| --------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------- |
|
||||
| [`examples/notebooks/best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb) | This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. |
|
||||
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
|
||||
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
|
||||
|
||||
|
||||
We also have some other examples that are less maintained but can be used as a reference:
|
||||
1. **[research_projects](https://github.com/huggingface/trl/tree/main/examples/research_projects)**: Check out this folder to find the scripts used for some research projects that used TRL (LM de-toxification, Stack-Llama, etc.)
|
||||
|
||||
|
||||
## Distributed training
|
||||
|
||||
All of the scripts can be run on multiple GPUs by providing the path of an 🤗 Accelerate config file when calling `accelerate launch`. To launch one of them on one or multiple GPUs, run the following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine and `--all_arguments_of_the_script` with your arguments.)
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file=examples/accelerate_configs/multi_gpu.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
|
||||
```
|
||||
|
||||
You can also adjust the parameters of the 🤗 Accelerate config file to suit your needs (e.g. training in mixed precision).
|
||||
|
||||
### Distributed training with DeepSpeed
|
||||
|
||||
Most of the scripts can be run on multiple GPUs together with DeepSpeed ZeRO-{1,2,3} for efficient sharding of the optimizer states, gradients, and model weights. To do so, run following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine, `--all_arguments_of_the_script` with your arguments, and `--deepspeed_config` with the path to the DeepSpeed config file such as `examples/deepspeed_configs/deepspeed_zero1.yaml`):
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file=examples/accelerate_configs/deepspeed_zero{1,2,3}.yaml --num_processes {NUM_GPUS} path_to_script.py --all_arguments_of_the_script
|
||||
```
|
65
docs/source/how_to_train.md
Normal file
65
docs/source/how_to_train.md
Normal file
@ -0,0 +1,65 @@
|
||||
# Training FAQ
|
||||
|
||||
## What Metrics Should I Look at?
|
||||
|
||||
When performing classical supervised fine-tuning of language models, the loss (especially the validation loss) serves as a good indicator of the training progress. However, in Reinforcement Learning (RL), the loss becomes less informative about the model's performance, and its value may fluctuate while the actual performance improves.
|
||||
|
||||
To address this, we recommend focusing on two key metrics first:
|
||||
|
||||
**Mean Reward**: The primary goal is to maximize the reward achieved by the model during RL training.
|
||||
**Objective KL Divergence**: KL divergence (Kullback-Leibler divergence) measures the dissimilarity between two probability distributions. In the context of RL training, we use it to quantify the difference between the current model and a reference model. Ideally, we want to keep the KL divergence between 0 and 10 to ensure the model's generated text remains close to what the reference model produces.
|
||||
|
||||
However, there are more metrics that can be useful for debugging, checkout the [logging section](logging).
|
||||
|
||||
## Why Do We Use a Reference Model, and What's the Purpose of KL Divergence?
|
||||
|
||||
When training RL models, optimizing solely for reward may lead to unexpected behaviors, where the model exploits the environment in ways that don't align with good language generation. In the case of RLHF, we use a reward model trained to predict whether a generated text is highly ranked by humans.
|
||||
|
||||
However, the RL model being optimized against the reward model may learn patterns that yield high reward but do not represent good language. This can result in extreme cases where the model generates texts with excessive exclamation marks or emojis to maximize the reward. In some worst-case scenarios, the model may generate patterns completely unrelated to natural language yet receive high rewards, similar to adversarial attacks.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/kl-example.png">
|
||||
<p style="text-align: center;"> <b>Figure:</b> Samples without a KL penalty from <a href="https://huggingface.co/papers/1909.08593">https://huggingface.co/papers/1909.08593</a>. </p>
|
||||
</div>
|
||||
|
||||
To address this issue, we add a penalty to the reward function based on the KL divergence between the current model and the reference model. By doing this, we encourage the model to stay close to what the reference model generates.
|
||||
|
||||
## What Is the Concern with Negative KL Divergence?
|
||||
|
||||
If you generate text by purely sampling from the model distribution things work fine in general. But when you use the `generate` method there are a few caveats because it does not always purely sample depending on the settings which can cause KL-divergence to go negative. Essentially when the active model achieves `log_p_token_active < log_p_token_ref` we get negative KL-div. This can happen in a several cases:
|
||||
|
||||
- **top-k sampling**: the model can smooth out the probability distribution causing the top-k tokens having a smaller probability than those of the reference model but they still are selected
|
||||
- **min_length**: this ignores the EOS token until `min_length` is reached. thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached
|
||||
|
||||
These are just a few examples. Why is negative KL an issue? The total reward `R` is computed `R = r - beta * KL` so if the model can learn how to drive KL-divergence negative it effectively gets a positive reward. In many cases it can be much easier to exploit such a bug in the generation than actually learning the reward function. In addition the KL can become arbitrarily small thus the actual reward can be very small compared to it.
|
||||
|
||||
So how should you generate text for PPO training? Let's have a look!
|
||||
|
||||
## How to generate text for training?
|
||||
|
||||
In order to avoid the KL issues described above we recommend to use the following settings:
|
||||
|
||||
```python
|
||||
generation_kwargs = {
|
||||
"min_length": -1, # don't ignore the EOS token (see above)
|
||||
"top_k": 0.0, # no top-k sampling
|
||||
"top_p": 1.0, # no nucleus sampling
|
||||
"do_sample": True, # yes, we want to sample
|
||||
"pad_token_id": tokenizer.eos_token_id, # most decoder models don't have a padding token - use EOS token instead
|
||||
"max_new_tokens": 32, # specify how many tokens you want to generate at most
|
||||
}
|
||||
```
|
||||
|
||||
With these settings we usually don't encounter any issues. You can also experiments with other settings but if you encounter issues with negative KL-divergence try to go back to these and see if they persist.
|
||||
|
||||
## How can debug your own use-case?
|
||||
|
||||
Debugging the RL pipeline can be challenging due to its complexity. Here are some tips and suggestions to make the process easier:
|
||||
|
||||
- **Start from a working example**: Begin with a working example from the trl repository and gradually modify it to fit your specific use-case. Changing everything at once can make it difficult to identify the source of potential issues. For example, you can start by replacing the model in the example and once you figure out the best hyperparameters try to switch to your dataset and reward model. If you change everything at once you won't know where a potential problem comes from.
|
||||
- **Start small, scale later**: Training large models can be very slow and take several hours or days until you see any improvement. For debugging this is not a convenient timescale so try to use small model variants during the development phase and scale up once that works. That being said you sometimes have to be careful as small models might not have the capacity to solve a complicated task either.
|
||||
- **Start simple**: Try to start with a minimal example and build complexity from there. Your use-case might require for example a complicated reward function consisting of many different rewards - try to use one signal first and see if you can optimize that and then add more complexity after that.
|
||||
- **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a bug in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations.
|
||||
- **Inspect the reward model**: If you reward is not improving over time maybe there's an issue with the reward model. You can look at extreme cases to see if it does what it should: e.g. in the sentiment case you can check if simple positive and negative examples really get different rewards. And you can look at the distribution of your dataset. Finally, maybe the reward is dominated by the query which the model can't affect so you might need to normalize this (e.g. reward of query+response minus reward of the query).
|
||||
|
||||
These are just a few tips that we find helpful - if you have more useful tricks feel free to open a PR to add them as well!
|
@ -13,19 +13,53 @@ The library is integrated with 🤗 [transformers](https://github.com/huggingfac
|
||||
|
||||
Check the appropriate sections of the documentation depending on your needs:
|
||||
|
||||
API documentation:
|
||||
## API documentation
|
||||
|
||||
- [Model Classes](models): *A brief overview of what each public model class does.*
|
||||
- [`SFTTrainer`](sft_trainer): *Supervise Fine-tune your model easily with `SFTTrainer`*
|
||||
- [`RewardTrainer`](reward_trainer): *Train easily your reward model using `RewardTrainer`.*
|
||||
- [`PPOTrainer`](trainer): *Further fine-tune the supervised fine-tuned model using PPO algorithm*
|
||||
- [Best-of-N Samppling](best-of-n): *Use best of n sampling as an alternative way to sample predictions from your active model*
|
||||
- [`DPOTrainer`](trainer): *Direct Preference Optimization training using `DPOTrainer`.*
|
||||
- [`PPOTrainer`](ppo_trainer): *Further fine-tune the supervised fine-tuned model using PPO algorithm*
|
||||
- [Best-of-N Sampling](best-of-n): *Use best of n sampling as an alternative way to sample predictions from your active model*
|
||||
- [`DPOTrainer`](dpo_trainer): *Direct Preference Optimization training using `DPOTrainer`.*
|
||||
- [`TextEnvironment`](text_environments): *Text environment to train your model using tools with RL.*
|
||||
|
||||
Examples:
|
||||
## Examples
|
||||
|
||||
- [Sentiment Tuning](sentiment_tuning): *Fine tune your model to generate positive movie contents*
|
||||
- [Training with PEFT](lora_tuning_peft): *Memory efficient RLHF training using adapters with PEFT*
|
||||
- [Detoxifying LLMs](detoxifying_a_lm): *Detoxify your language model through RLHF*
|
||||
- [StackLlama](using_llama_models): *End-to-end RLHF training of a Llama model on Stack exchange dataset*
|
||||
- [Learning with Tools](learning_tools): *Walkthrough of using `TextEnvironments`*
|
||||
- [Multi-Adapter Training](multi_adapter_rl): *Use a single base model and multiple adapters for memory efficient end-to-end training*
|
||||
|
||||
|
||||
## Blog posts
|
||||
|
||||
<div class="mt-10">
|
||||
<div class="w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5">
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo_vlm">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/dpo_vlm/thumbnail.png" alt="thumbnail">
|
||||
<p class="text-gray-700">Preference Optimization for Vision Language Models with TRL</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/rlhf">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/120_rlhf/thumbnail.png" alt="thumbnail">
|
||||
<p class="text-gray-700">Illustrating Reinforcement Learning from Human Feedback</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-peft">
|
||||
<img src="https://github.com/huggingface/blog/blob/main/assets/133_trl_peft/thumbnail.png?raw=true" alt="thumbnail">
|
||||
<p class="text-gray-700">Fine-tuning 20B LLMs with RLHF on a 24GB consumer GPU</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/stackllama">
|
||||
<img src="https://github.com/huggingface/blog/blob/main/assets/138_stackllama/thumbnail.png?raw=true" alt="thumbnail">
|
||||
<p class="text-gray-700">StackLLaMA: A hands-on guide to train LLaMA with RLHF</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo-trl">
|
||||
<img src="https://github.com/huggingface/blog/blob/main/assets/157_dpo_trl/dpo_thumbnail.png?raw=true" alt="thumbnail">
|
||||
<p class="text-gray-700">Fine-tune Llama 2 with DPO</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-ddpo">
|
||||
<img src="https://github.com/huggingface/blog/blob/main/assets/166_trl_ddpo/thumbnail.png?raw=true" alt="thumbnail">
|
||||
<p class="text-gray-700">Finetune Stable Diffusion Models with DDPO via TRL</p>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
|
@ -12,7 +12,7 @@ pip install trl
|
||||
You can also install the latest version from source. First clone the repo and then run the installation with `pip`:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/lvwerra/trl.git
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
cd trl/
|
||||
pip install -e .
|
||||
```
|
||||
|
54
docs/source/iterative_sft_trainer.mdx
Normal file
54
docs/source/iterative_sft_trainer.mdx
Normal file
@ -0,0 +1,54 @@
|
||||
# Iterative Trainer
|
||||
|
||||
Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.
|
||||
|
||||
## Usage
|
||||
|
||||
To get started quickly, instantiate an instance a model, and a tokenizer.
|
||||
|
||||
```python
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
trainer = IterativeSFTTrainer(
|
||||
model,
|
||||
tokenizer
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
You have the choice to either provide a list of strings or a list of tensors to the step function.
|
||||
|
||||
#### Using a list of tensors as input:
|
||||
|
||||
```python
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask
|
||||
}
|
||||
|
||||
trainer.step(**inputs)
|
||||
|
||||
```
|
||||
|
||||
#### Using a list of strings as input:
|
||||
|
||||
```python
|
||||
|
||||
inputs = {
|
||||
"texts": texts
|
||||
}
|
||||
|
||||
trainer.step(**inputs)
|
||||
|
||||
```
|
||||
|
||||
For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels.
|
||||
|
||||
## IterativeTrainer
|
||||
|
||||
[[autodoc]] IterativeSFTTrainer
|
79
docs/source/judges.mdx
Normal file
79
docs/source/judges.mdx
Normal file
@ -0,0 +1,79 @@
|
||||
# Judges
|
||||
|
||||
TRL provides judges to easily compare two completions.
|
||||
|
||||
Make sure to have installed the required dependencies by running:
|
||||
|
||||
```bash
|
||||
pip install trl[llm_judge]
|
||||
```
|
||||
|
||||
## Using the provided judges
|
||||
|
||||
TRL provides several judges out of the box. For example, you can use the `HfPairwiseJudge` to compare two completions using a pre-trained model from the Hugging Face model hub:
|
||||
|
||||
```python
|
||||
from trl import HfPairwiseJudge
|
||||
|
||||
judge = HfPairwiseJudge()
|
||||
judge.judge(
|
||||
prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"],
|
||||
completions=[["Paris", "Lyon"], ["Saturn", "Jupiter"]],
|
||||
) # Outputs: [0, 1]
|
||||
```
|
||||
|
||||
## Define your own judge
|
||||
|
||||
To define your own judge, we provide several base classes that you can subclass. For rank-based judges, you need to subclass [`BaseRankJudge`] and implement the [`BaseRankJudge.judge`] method. For pairwise judges, you need to subclass [`BasePairJudge`] and implement the [`BasePairJudge.judge`] method. If you want to define a judge that doesn't fit into these categories, you need to subclass [`BaseJudge`] and implement the [`BaseJudge.judge`] method.
|
||||
|
||||
As an example, let's define a pairwise judge that prefers shorter completions:
|
||||
|
||||
```python
|
||||
from trl import BasePairwiseJudge
|
||||
|
||||
class PrefersShorterJudge(BasePairwiseJudge):
|
||||
def judge(self, prompts, completions, shuffle_order=False):
|
||||
return [0 if len(completion[0]) > len(completion[1]) else 1 for completion in completions]
|
||||
```
|
||||
|
||||
You can then use this judge as follows:
|
||||
|
||||
```python
|
||||
judge = PrefersShorterJudge()
|
||||
judge.judge(
|
||||
prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"],
|
||||
completions=[["Paris", "The capital of France is Paris."], ["Jupiter is the biggest planet in the solar system.", "Jupiter"]],
|
||||
) # Outputs: [0, 1]
|
||||
```
|
||||
|
||||
## BaseJudge
|
||||
|
||||
[[autodoc]] BaseJudge
|
||||
|
||||
## BaseRankJudge
|
||||
|
||||
[[autodoc]] BaseRankJudge
|
||||
|
||||
## BasePairwiseJudge
|
||||
|
||||
[[autodoc]] BasePairwiseJudge
|
||||
|
||||
## RandomRankJudge
|
||||
|
||||
[[autodoc]] RandomRankJudge
|
||||
|
||||
## RandomPairwiseJudge
|
||||
|
||||
[[autodoc]] RandomPairwiseJudge
|
||||
|
||||
## PairRMJudge
|
||||
|
||||
[[autodoc]] PairRMJudge
|
||||
|
||||
## HfPairwiseJudge
|
||||
|
||||
[[autodoc]] HfPairwiseJudge
|
||||
|
||||
## OpenAIPairwiseJudge
|
||||
|
||||
[[autodoc]] OpenAIPairwiseJudge
|
102
docs/source/kto_trainer.mdx
Normal file
102
docs/source/kto_trainer.mdx
Normal file
@ -0,0 +1,102 @@
|
||||
# KTO Trainer
|
||||
|
||||
TRL supports the Kahneman-Tversky Optimization (KTO) Trainer for aligning language models with binary feedback data (e.g., upvote/downvote), as described in the [paper](https://huggingface.co/papers/2402.01306) by Kawin Ethayarajh, Winnie Xu, Niklas Muennighoff, Dan Jurafsky, and Douwe Kiela.
|
||||
For a full example have a look at [`examples/scripts/kto.py`].
|
||||
|
||||
Depending on how good your base model is, you may or may not need to do SFT before KTO.
|
||||
This is different from standard RLHF and DPO, which always require SFT.
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
The KTO trainer expects a very specific format for the dataset as it does not require pairwise preferences. Since the model will be trained to directly optimize examples that consist of a prompt, model completion, and a label to indicate whether the completion is "good" or "bad", we expect a dataset with the following columns:
|
||||
|
||||
- `prompt`
|
||||
- `completion`
|
||||
- `label`
|
||||
|
||||
for example:
|
||||
|
||||
```
|
||||
kto_dataset_dict = {
|
||||
"prompt": [
|
||||
"Hey, hello",
|
||||
"How are you",
|
||||
"What is your name?",
|
||||
"What is your name?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
],
|
||||
"completion": [
|
||||
"hi nice to meet you",
|
||||
"leave me alone",
|
||||
"I don't have a name",
|
||||
"My name is Mary",
|
||||
"Python",
|
||||
"C++",
|
||||
"Java",
|
||||
],
|
||||
"label": [
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired (`True`) or undesired (`False`).
|
||||
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays. It is required that the dataset contains at least one desirable and one undesirable completion.
|
||||
|
||||
|
||||
## Expected model format
|
||||
The KTO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
|
||||
|
||||
## Using the `KTOTrainer`
|
||||
|
||||
For a detailed example have a look at the `examples/scripts/kto.py` script. At a high level we need to initialize the `KTOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
|
||||
|
||||
The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
|
||||
|
||||
The `desirable_weight` and `undesirable_weight` refer to the weights placed on the losses for desirable/positive and undesirable/negative examples.
|
||||
By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` * number of positives) to (`undesirable_weight` * number of negatives) is in the range 1:1 to 4:3.
|
||||
|
||||
```py
|
||||
training_args = KTOConfig(
|
||||
beta=0.1,
|
||||
desirable_weight=1.0,
|
||||
undesirable_weight=1.0,
|
||||
)
|
||||
|
||||
kto_trainer = KTOTrainer(
|
||||
model,
|
||||
ref_model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
```
|
||||
After this one can then call:
|
||||
|
||||
```py
|
||||
kto_trainer.train()
|
||||
```
|
||||
|
||||
### For Mixture of Experts Models: Enabling the auxiliary loss
|
||||
|
||||
MOEs are the most efficient if the load is about equally distributed between experts.
|
||||
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
|
||||
|
||||
This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
|
||||
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
|
||||
|
||||
## KTOTrainer
|
||||
|
||||
[[autodoc]] KTOTrainer
|
||||
|
||||
## KTOConfig
|
||||
|
||||
[[autodoc]] KTOConfig
|
232
docs/source/learning_tools.mdx
Normal file
232
docs/source/learning_tools.mdx
Normal file
@ -0,0 +1,232 @@
|
||||
# Learning Tools (Experimental 🧪)
|
||||
|
||||
Using Large Language Models (LLMs) with tools has been a popular topic recently with awesome works such as [ToolFormer](https://huggingface.co/papers/2302.04761) and [ToolBench](https://huggingface.co/papers/2305.16504). In TRL, we provide a simple example of how to teach LLM to use tools with reinforcement learning.
|
||||
|
||||
|
||||
Here's an overview of the scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples/research_projects/tools):
|
||||
|
||||
| File | Description |
|
||||
|---|---|
|
||||
| [`calculator.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/calculator.py) | Script to train LLM to use a calculator with reinforcement learning. |
|
||||
| [`triviaqa.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/triviaqa.py) | Script to train LLM to use a wiki tool to answer questions. |
|
||||
| [`python_interpreter.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/python_interpreter.py) | Script to train LLM to use python interpreter to solve math puzzles. |
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Note that the scripts above rely heavily on the `TextEnvironment` API which is still under active development. The API may change in the future. Please see [`TextEnvironment`](text_environment) for the related docs.
|
||||
</Tip>
|
||||
|
||||
|
||||
## Learning to Use a Calculator
|
||||
|
||||
|
||||
The rough idea is as follows:
|
||||
|
||||
1. Load a tool such as [ybelkada/simple-calculator](https://huggingface.co/spaces/ybelkada/simple-calculator) that parse a text calculation like `"14 + 34"` and return the calulated number:
|
||||
```python
|
||||
from transformers import AutoTokenizer, load_tool
|
||||
tool = load_tool("ybelkada/simple-calculator")
|
||||
tool_fn = lambda text: str(round(float(tool(text)), 2)) # rounding to 2 decimal places
|
||||
```
|
||||
1. Define a reward function that returns a positive reward if the tool returns the correct answer. In the script we create a dummy reward function like `reward_fn = lambda x: 1`, but we override the rewards directly later.
|
||||
1. Create a prompt on how to use the tools
|
||||
```python
|
||||
# system prompt
|
||||
prompt = """\
|
||||
What is 13.1-3?
|
||||
|
||||
<request><SimpleCalculatorTool>13.1-3<call>10.1<response>
|
||||
|
||||
Result=10.1<submit>
|
||||
|
||||
What is 4*3?
|
||||
|
||||
<request><SimpleCalculatorTool>4*3<call>12<response>
|
||||
|
||||
Result=12<submit>
|
||||
|
||||
What is 12.1+1?
|
||||
|
||||
<request><SimpleCalculatorTool>12.1+1<call>13.1<response>
|
||||
|
||||
Result=13.1<submit>
|
||||
|
||||
What is 12.1-20?
|
||||
|
||||
<request><SimpleCalculatorTool>12.1-20<call>-7.9<response>
|
||||
|
||||
Result=-7.9<submit>"""
|
||||
```
|
||||
3. Create a `trl.TextEnvironment` with the model
|
||||
```python
|
||||
env = TextEnvironment(
|
||||
model,
|
||||
tokenizer,
|
||||
{"SimpleCalculatorTool": tool_fn},
|
||||
reward_fn,
|
||||
prompt,
|
||||
generation_kwargs=generation_kwargs,
|
||||
)
|
||||
```
|
||||
4. Then generate some data such as `tasks = ["\n\nWhat is 13.1-3?", "\n\nWhat is 4*3?"]` and run the environment with `queries, responses, masks, rewards, histories = env.run(tasks)`. The environment will look for the `<call>` token in the prompt and append the tool output to the response; it will also return the mask associated with the response. You can further use the `histories` to visualize the interaction between the model and the tool; `histories[0].show_text()` will show the text with color-coded tool output and `histories[0].show_tokens(tokenizer)` will show visualize the tokens.
|
||||

|
||||
1. Finally, we can train the model with `train_stats = ppo_trainer.step(queries, responses, rewards, masks)`. The trainer will use the mask to ignore the tool output when computing the loss, make sure to pass that argument to `step`.
|
||||
|
||||
## Experiment results
|
||||
|
||||
We trained a model with the above script for 10 random seeds. You can reproduce the run with the following command. Feel free to remove the `--slurm-*` arguments if you don't have access to a slurm cluster.
|
||||
|
||||
```
|
||||
WANDB_TAGS="calculator_final" python benchmark/benchmark.py \
|
||||
--command "python examples/research_projects/tools/calculator.py" \
|
||||
--num-seeds 10 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 8 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
```
|
||||
|
||||
We can then use [`openrlbenchmark`](https://github.com/openrlbenchmark/openrlbenchmark) which generates the following plot.
|
||||
```
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=openrlbenchmark&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.tracker_project_name&cen=trl_ppo_trainer_config.value.log_with&metrics=env/reward_mean&metrics=objective/kl' \
|
||||
'wandb?tag=calculator_final&cl=calculator_mask' \
|
||||
--env-ids trl \
|
||||
--check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename static/0compare \
|
||||
--scan-history
|
||||
```
|
||||
|
||||

|
||||
|
||||
As we can see, while 1-2 experiments crashed for some reason, most of the runs obtained near perfect proficiency in the calculator task.
|
||||
|
||||
|
||||
## (Early Experiments 🧪): learning to use a wiki tool for question answering
|
||||
|
||||
In the [ToolFormer](https://huggingface.co/papers/2302.04761) paper, it shows an interesting use case that utilizes a Wikipedia Search tool to help answer questions. In this section, we attempt to perform similar experiments but uses RL instead to teach the model to use a wiki tool on the [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) dataset.
|
||||
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
**Note that many settings are different so the results are not directly comparable.**
|
||||
</Tip>
|
||||
|
||||
|
||||
|
||||
|
||||
### Building a search index
|
||||
|
||||
Since [ToolFormer](https://huggingface.co/papers/2302.04761) did not open source, we needed to first replicate the search index. It is mentioned in their paper that the authors built the search index using a BM25 retriever that indexes the Wikipedia dump from [KILT](https://github.com/facebookresearch/KILT)
|
||||
|
||||
Fortunately, [`pyserini`](https://github.com/castorini/pyserini) already implements the BM25 retriever and provides a prebuilt index for the KILT Wikipedia dump. We can use the following code to search the index.
|
||||
|
||||
```python
|
||||
from pyserini.search.lucene import LuceneSearcher
|
||||
import json
|
||||
searcher = LuceneSearcher.from_prebuilt_index('wikipedia-kilt-doc')
|
||||
def search(query):
|
||||
hits = searcher.search(query, k=1)
|
||||
hit = hits[0]
|
||||
contents = json.loads(hit.raw)['contents']
|
||||
return contents
|
||||
print(search("tennis racket"))
|
||||
```
|
||||
```
|
||||
Racket (sports equipment)
|
||||
A racket or racquet is a sports implement consisting of a handled frame with an open hoop across which a network of strings or catgut is stretched tightly. It is used for striking a ball or shuttlecock in games such as squash, tennis, racquetball, and badminton. Collectively, these games are known as racket sports. Racket design and manufacturing has changed considerably over the centuries.
|
||||
|
||||
The frame of rackets for all sports was traditionally made of solid wood (later laminated wood) and the strings of animal intestine known as catgut. The traditional racket size was limited by the strength and weight of the wooden frame which had to be strong enough to hold the strings and stiff enough to hit the ball or shuttle. Manufacturers started adding non-wood laminates to wood rackets to improve stiffness. Non-wood rackets were made first of steel, then of aluminum, and then carbon fiber composites. Wood is still used for real tennis, rackets, and xare. Most rackets are now made of composite materials including carbon fiber or fiberglass, metals such as titanium alloys, or ceramics.
|
||||
...
|
||||
```
|
||||
|
||||
We then basically deployed this snippet as a Hugging Face space [here](https://huggingface.co/spaces/vwxyzjn/pyserini-wikipedia-kilt-doc), so that we can use the space as a `transformers.Tool` later.
|
||||
|
||||

|
||||
|
||||
### Experiment settings
|
||||
|
||||
We use the following settings:
|
||||
|
||||
* use the `bigcode/starcoderbase` model as the base model
|
||||
* use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragrahs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool.
|
||||
* test if the response contain the answer string, if so, give a reward of 1, otherwise, give a reward of 0.
|
||||
* notice this is a simplified evaluation criteria. In [ToolFormer](https://huggingface.co/papers/2302.04761), the authors checks if the first 20 words of the response contain the correct answer.
|
||||
* used the following prompt that demonstrates the usage of the wiki tool.
|
||||
```python
|
||||
prompt = """\
|
||||
Answer the following question:
|
||||
|
||||
Q: In which branch of the arts is Patricia Neary famous?
|
||||
A: Ballets
|
||||
A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
|
||||
Result=Ballets<submit>
|
||||
|
||||
Q: Who won Super Bowl XX?
|
||||
A: Chicago Bears
|
||||
A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
|
||||
Result=Chicago Bears<submit>
|
||||
|
||||
Q: """
|
||||
```
|
||||
|
||||
|
||||
### Result and Discussion
|
||||
|
||||
|
||||
Our experiments show that the agent can learn to use the wiki tool to answer questions. The learning curves would go up mostly, but one of the experiment did crash.
|
||||
|
||||

|
||||
|
||||
Wandb report is [here](https://wandb.ai/costa-huang/cleanRL/reports/TriviaQA-Final-Experiments--Vmlldzo1MjY0ODk5) for further inspection.
|
||||
|
||||
|
||||
Note that the correct rate of the trained model is on the low end, which could be due to the following reasons:
|
||||
|
||||
* **incorrect searches:** When given the question `"What is Bruce Willis' real first name?"` if the model searches for `Bruce Willis`, our wiki tool returns "Patrick Poivey (born 18 February 1948) is a French actor. He is especially known for his voice: he is the French dub voice of Bruce Willis since 1988.` But a correct search should be `Walter Bruce Willis (born March 19, 1955) is an American former actor. He achieved fame with a leading role on the comedy-drama series Moonlighting (1985–1989) and appeared in over a hundred films, gaining recognition as an action hero after his portrayal of John McClane in the Die Hard franchise (1988–2013) and other roles.[1][2]"
|
||||
|
||||
|
||||

|
||||
|
||||
* **unnecessarily long response**: The wiki tool by default sometimes output very long sequences. E.g., when the wiki tool searches for "Brown Act"
|
||||
* Our wiki tool returns "The Ralph M. Brown Act, located at California Government Code 54950 "et seq.", is an act of the California State Legislature, authored by Assemblymember Ralph M. Brown and passed in 1953, that guarantees the public's right to attend and participate in meetings of local legislative bodies."
|
||||
* [ToolFormer](https://huggingface.co/papers/2302.04761)'s wiki tool returns "The Ralph M. Brown Act is an act of the California State Legislature that guarantees the public's right to attend and participate in meetings of local legislative bodies." which is more succinct.
|
||||
|
||||

|
||||
|
||||
|
||||
## (Early Experiments 🧪): solving math puzzles with python interpreter
|
||||
|
||||
In this section, we attempt to teach the model to use a python interpreter to solve math puzzles. The rough idea is to give the agent a prompt like the following:
|
||||
|
||||
```python
|
||||
prompt = """\
|
||||
Example of using a Python API to solve math questions.
|
||||
|
||||
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
|
||||
|
||||
<request><PythonInterpreter>
|
||||
def solution():
|
||||
money_initial = 23
|
||||
bagels = 5
|
||||
bagel_cost = 3
|
||||
money_spent = bagels * bagel_cost
|
||||
money_left = money_initial - money_spent
|
||||
result = money_left
|
||||
return result
|
||||
print(solution())
|
||||
<call>72<response>
|
||||
|
||||
Result = 72 <submit>
|
||||
|
||||
Q: """
|
||||
```
|
||||
|
||||
|
||||
Training experiment can be found at https://wandb.ai/lvwerra/trl-gsm8k/runs/a5odv01y
|
||||
|
||||

|
@ -14,16 +14,62 @@ If you want to log with tensorboard, add the kwarg `project_kwargs={"logging_dir
|
||||
|
||||
## PPO Logging
|
||||
|
||||
Here's a brief explanation for the logged metrics provided in the data:
|
||||
|
||||
Key metrics to monitor. We want to maximize the reward, maintain a low KL divergence, and maximize entropy:
|
||||
1. `env/reward_mean`: The average reward obtained from the environment. Alias `ppo/mean_scores`, which is sed to specifically monitor the reward model.
|
||||
1. `env/reward_std`: The standard deviation of the reward obtained from the environment. Alias ``ppo/std_scores`, which is sed to specifically monitor the reward model.
|
||||
1. `env/reward_dist`: The histogram distribution of the reward obtained from the environment.
|
||||
1. `objective/kl`: The mean Kullback-Leibler (KL) divergence between the old and new policies. It measures how much the new policy deviates from the old policy. The KL divergence is used to compute the KL penalty in the objective function.
|
||||
1. `objective/kl_dist`: The histogram distribution of the `objective/kl`.
|
||||
1. `objective/kl_coef`: The coefficient for Kullback-Leibler (KL) divergence in the objective function.
|
||||
1. `ppo/mean_non_score_reward`: The **KL penalty** calculated by `objective/kl * objective/kl_coef` as the total reward for optimization to prevent the new policy from deviating too far from the old policy.
|
||||
1. `objective/entropy`: The entropy of the model's policy, calculated by `-logprobs.sum(-1).mean()`. High entropy means the model's actions are more random, which can be beneficial for exploration.
|
||||
|
||||
Training stats:
|
||||
1. `ppo/learning_rate`: The learning rate for the PPO algorithm.
|
||||
1. `ppo/policy/entropy`: The entropy of the model's policy, calculated by `pd = torch.nn.functional.softmax(logits, dim=-1); entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)`. It measures the randomness of the policy.
|
||||
1. `ppo/policy/clipfrac`: The fraction of probability ratios (old policy / new policy) that fell outside the clipping range in the PPO objective. This can be used to monitor the optimization process.
|
||||
1. `ppo/policy/approxkl`: The approximate KL divergence between the old and new policies, measured by `0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)`, corresponding to the `k2` estimator in http://joschu.net/blog/kl-approx.html
|
||||
1. `ppo/policy/policykl`: Similar to `ppo/policy/approxkl`, but measured by `masked_mean(old_logprobs - logprobs, mask)`, corresponding to the `k1` estimator in http://joschu.net/blog/kl-approx.html
|
||||
1. `ppo/policy/ratio`: The histogram distribution of the ratio between the new and old policies, used to compute the PPO objective.
|
||||
1. `ppo/policy/advantages_mean`: The average of the GAE (Generalized Advantage Estimation) advantage estimates. The advantage function measures how much better an action is compared to the average action at a state.
|
||||
1. `ppo/policy/advantages`: The histogram distribution of `ppo/policy/advantages_mean`.
|
||||
1. `ppo/returns/mean`: The mean of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance. See https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ for more details.
|
||||
1. `ppo/returns/var`: The variance of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance.
|
||||
1. `ppo/val/mean`: The mean of the values, used to monitor the value function's performance.
|
||||
1. `ppo/val/var` : The variance of the values, used to monitor the value function's performance.
|
||||
1. `ppo/val/var_explained`: The explained variance for the value function, used to monitor the value function's performance.
|
||||
1. `ppo/val/clipfrac`: The fraction of the value function's predicted values that are clipped.
|
||||
1. `ppo/val/vpred`: The predicted values from the value function.
|
||||
1. `ppo/val/error`: The mean squared error between the `ppo/val/vpred` and returns, used to monitor the value function's performance.
|
||||
1. `ppo/loss/policy`: The policy loss for the Proximal Policy Optimization (PPO) algorithm.
|
||||
1. `ppo/loss/value`: The loss for the value function in the PPO algorithm. This value quantifies how well the function estimates the expected future rewards.
|
||||
1. `ppo/loss/total`: The total loss for the PPO algorithm. It is the sum of the policy loss and the value function loss.
|
||||
|
||||
|
||||
Stats on queries, responses, and logprobs:
|
||||
1. `tokens/queries_len_mean`: The average length of the queries tokens.
|
||||
1. `tokens/queries_len_std`: The standard deviation of the length of the queries tokens.
|
||||
1. `tokens/queries_dist`: The histogram distribution of the length of the queries tokens.
|
||||
1. `tokens/responses_len_mean`: The average length of the responses tokens.
|
||||
1. `tokens/responses_len_std`: The standard deviation of the length of the responses tokens.
|
||||
1. `tokens/responses_dist`: The histogram distribution of the length of the responses tokens. (Costa: inconsistent naming, should be `tokens/responses_len_dist`)
|
||||
1. `objective/logprobs`: The histogram distribution of the log probabilities of the actions taken by the model.
|
||||
1. `objective/ref_logprobs`: The histogram distribution of the log probabilities of the actions taken by the reference model.
|
||||
|
||||
|
||||
|
||||
### Crucial values
|
||||
During training, many values are logged, here are the most important ones:
|
||||
|
||||
1. `env/reward_mean`,`env/reward_std`, `env/reward_dist`: the properties of the reward distribution from the "environment".
|
||||
2. `ppo/mean_scores`: The mean scores directly out of the reward model.
|
||||
3. `ppo/mean_non_score_reward`: The mean negated KL penalty during training (shows the delta between the reference model and the new policy over the batch in the step)
|
||||
1. `env/reward_mean`,`env/reward_std`, `env/reward_dist`: the properties of the reward distribution from the "environment" / reward model
|
||||
1. `ppo/mean_non_score_reward`: The mean negated KL penalty during training (shows the delta between the reference model and the new policy over the batch in the step)
|
||||
|
||||
### Training stability parameters:
|
||||
Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables):
|
||||
|
||||
1. `ppo/loss/value`: The value function loss -- will spike / NaN when not going well.
|
||||
2. `ppo/val/clipfrac`: The fraction of clipped values in the value function loss. This is often from 0.3 to 0.6.
|
||||
3. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities.
|
||||
1. `ppo/loss/value`: it will spike / NaN when not going well.
|
||||
1. `ppo/policy/ratio`: `ratio` being 1 is a baseline value, meaning that the probability of sampling a token is the same under the new and old policy. If the ratio is too high like 200, it means the probability of sampling a token is 200 times higher under the new policy than the old policy. This is a sign that the new policy is too different from the old policy, which will likely cause overoptimization and collapse training later on.
|
||||
1. `ppo/policy/clipfrac` and `ppo/policy/approxkl`: if `ratio` is too high, the `ratio` is going to get clipped, resulting in high `clipfrac` and high `approxkl` as well.
|
||||
1. `objective/kl`: it should stay positive so that the policy is not too far away from the reference policy.
|
||||
1. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities.
|
@ -1,15 +1,15 @@
|
||||
# Examples of using peft with trl to finetune 8-bit models with Low Rank Adaption (LoRA)
|
||||
|
||||
The notebooks and scripts in this examples show how to use Low Rank Adaptation (LoRA) to fine-tune models in a memory efficient manner. Most of PEFT methods supported in peft library but note that some PEFT methods such as Prompt tuning are not supported.
|
||||
For more information on LoRA, see the [original paper](https://arxiv.org/abs/2106.09685).
|
||||
For more information on LoRA, see the [original paper](https://huggingface.co/papers/2106.09685).
|
||||
|
||||
Here's an overview of the `peft`-enabled notebooks and scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples):
|
||||
Here's an overview of the `peft`-enabled notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples):
|
||||
|
||||
| File | Task | Description | Colab link |
|
||||
|---|---| --- |
|
||||
| [`stack_llama/rl_training.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py) | RLHF | Distributed fine-tuning of the 7b parameter LLaMA models with a learned reward model and `peft`. | |
|
||||
| [`stack_llama/reward_modeling.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py) | Reward Modeling | Distributed training of the 7b parameter LLaMA reward model with `peft`. | |
|
||||
| [`stack_llama/supervised_finetuning.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py) | SFT | Distributed instruction/supervised fine-tuning of the 7b parameter LLaMA model with `peft`. | |
|
||||
| [`stack_llama/rl_training.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py) | RLHF | Distributed fine-tuning of the 7b parameter LLaMA models with a learned reward model and `peft`. | |
|
||||
| [`stack_llama/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py) | Reward Modeling | Distributed training of the 7b parameter LLaMA reward model with `peft`. | |
|
||||
| [`stack_llama/supervised_finetuning.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py) | SFT | Distributed instruction/supervised fine-tuning of the 7b parameter LLaMA model with `peft`. | |
|
||||
|
||||
## Installation
|
||||
Note: peft is in active development, so we install directly from their Github page.
|
||||
@ -71,7 +71,7 @@ The `trl` library is powered by `accelerate`. As such it is best to configure an
|
||||
|
||||
```bash
|
||||
accelerate config # will prompt you to define the training configuration
|
||||
accelerate launch scripts/gpt2-sentiment_peft.py # launches training
|
||||
accelerate launch examples/scripts/ppo.py --use_peft # launch`es training
|
||||
```
|
||||
|
||||
## Using `trl` + `peft` and Data Parallelism
|
||||
@ -140,5 +140,5 @@ python PATH_TO_SCRIPT
|
||||
You can easily fine-tune Llama2 model using `SFTTrainer` and the official script! For example to fine-tune llama2-7b on the Guanaco dataset, run (tested on a single NVIDIA T4-16GB):
|
||||
|
||||
```bash
|
||||
python examples/scripts/sft_trainer.py --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --batch_size 4 --gradient_accumulation_steps 2
|
||||
python examples/scripts/sft.py --output_dir sft_openassistant-guanaco --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --per_device_train_batch_size 4 --gradient_accumulation_steps 2
|
||||
```
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Multi Adapter RL (MARL) - a single base model for everything
|
||||
|
||||
Here we present an approach that uses a single base model for the entire PPO algorithm - which includes retrieving the reference logits, computing the active logits and the rewards. This feature is experimental as we did not tested the convergence of the approach. We encourage the community to let us know if they potentially face into any issue.
|
||||
Here we present an approach that uses a single base model for the entire PPO algorithm - which includes retrieving the reference logits, computing the active logits and the rewards. This feature is experimental as we did not test the convergence of the approach. We encourage the community to let us know if they potentially face issues.
|
||||
|
||||
## Requirements
|
||||
|
||||
@ -11,10 +11,10 @@ You just need to install `peft` and optionally install `bitsandbytes` as well if
|
||||
You need to address this approach in three stages that we summarize as follows:
|
||||
|
||||
1- Train a base model on the target domain (e.g. `imdb` dataset) - this is the Supervised Fine Tuning stage - it can leverage the `SFTTrainer` from TRL.
|
||||
2- Train a reward model using `peft`. This is required in order to re-use the adapter during the RL optimisation process (step 3 below). We show an example of leveraging the `RewardTrainer` from TRL in [this example](https://github.com/lvwerra/trl/tree/main/examples/scripts/reward_trainer.py)
|
||||
2- Train a reward model using `peft`. This is required in order to re-use the adapter during the RL optimisation process (step 3 below). We show an example of leveraging the `RewardTrainer` from TRL in [this example](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py)
|
||||
3- Fine tune new adapters on the base model using PPO and the reward adapter. ("0 abstraction RL")
|
||||
|
||||
Make sure to use the same model (i.e. same architecure and same weights) for the stages 2 & 3.
|
||||
Make sure to use the same model (i.e. same architecture and same weights) for the stages 2 & 3.
|
||||
|
||||
## Quickstart
|
||||
|
||||
@ -48,7 +48,7 @@ trainer = PPOTrainer(
|
||||
|
||||
...
|
||||
```
|
||||
Then inside your PPO training loop, call the `compute_reward_score` method by accessing to the `model` attribute from `PPOTrainer`.
|
||||
Then inside your PPO training loop, call the `compute_reward_score` method by accessing the `model` attribute from `PPOTrainer`.
|
||||
|
||||
```python
|
||||
rewards = trainer.model.compute_reward_score(**inputs)
|
||||
@ -58,8 +58,8 @@ rewards = trainer.model.compute_reward_score(**inputs)
|
||||
|
||||
### Control on the adapter name
|
||||
|
||||
If you are familiar with the `peft` library, you know that you can use multiple adapters inside the same model. What you can do is to train multiple adapters on the same base model to fine-tune on different policies.
|
||||
In this case, you want to have a control on the adapter name you want to activate back, after retrieving the reward. For that, simply pass the appropriate `adapter_name` to `ppo_adapter_name` argument when calling `compute_reward_score`.
|
||||
If you are familiar with the `peft` library, you know that you can use multiple adapters inside the same model. What you can do is train multiple adapters on the same base model to fine-tune on different policies.
|
||||
In this case, you want to be able to control the adapter name you want to activate back, after retrieving the reward. For that, simply pass the appropriate `adapter_name` to `ppo_adapter_name` argument when calling `compute_reward_score`.
|
||||
|
||||
```python
|
||||
adapter_name_policy_1 = "policy_1"
|
||||
@ -97,4 +97,4 @@ trainer = PPOTrainer(
|
||||
...
|
||||
)
|
||||
...
|
||||
```
|
||||
```
|
||||
|
250
docs/source/online_dpo_trainer.md
Normal file
250
docs/source/online_dpo_trainer.md
Normal file
@ -0,0 +1,250 @@
|
||||
# Online DPO Trainer
|
||||
|
||||
## Overview
|
||||
|
||||
Online DPO was proposed in [Direct Language Model Alignment from Online AI Feedback](https://huggingface.co/papers/2402.04792) by Shangmin Guo, Biao Zhang, Tianlin Liu, Tianqi Liu, Misha Khalman, Felipe Llinares, Alexandre Rame, Thomas Mesnard, Yao Zhao, Bilal Piot, Johan Ferret, and Mathieu Blondel.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Direct alignment from preferences (DAP) methods, such as DPO, have recently emerged as efficient alternatives to reinforcement learning from human feedback (RLHF), that do not require a separate reward model. However, the preference datasets used in DAP methods are usually collected ahead of training and never updated, thus the feedback is purely offline. Moreover, responses in these datasets are often sampled from a language model distinct from the one being aligned, and since the model evolves over training, the alignment phase is inevitably off-policy. In this study, we posit that online feedback is key and improves DAP methods. Our method, online AI feedback (OAIF), uses an LLM as annotator: on each training iteration, we sample two responses from the current model and prompt the LLM annotator to choose which one is preferred, thus providing online feedback. Despite its simplicity, we demonstrate via human evaluation in several tasks that OAIF outperforms both offline DAP and RLHF methods. We further show that the feedback leveraged in OAIF is easily controllable, via instruction prompts to the LLM annotator.
|
||||
|
||||
The current implementation uses reward models for scoring completions -- see [Reward Bench](https://huggingface.co/spaces/allenai/reward-bench) for a leaderboard of public models you can use.
|
||||
|
||||
This post-training method was contributed by [Michael Noukhovitch](https://huggingface.co/mnoukhov), [Shengyi Costa Huang](https://huggingface.co/vwxyzjn), [Quentin Gallouédec](https://huggingface.co/qgallouedec), and [Edward Beeching](https://huggingface.co/edbeeching).
|
||||
|
||||
## Usage tips
|
||||
|
||||
> [!WARNING]
|
||||
> Make sure that the SFT model and reward model use the _same_ chat template. Otherwise, you may find the model completions are scored incorrectly during training.
|
||||
|
||||
The basic API is as follows:
|
||||
|
||||
```python
|
||||
from datasets import Dataset
|
||||
from trl import OnlineDPOConfig, OnlineDPOTrainer
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
)
|
||||
NUM_DUMMY_SAMPLES = 100
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct")
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
# The model to optimise
|
||||
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct")
|
||||
# The reference model to calculate the KL divergence against
|
||||
ref_model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct")
|
||||
# The model to score completions with. In practice, you will need a reward model.
|
||||
reward_model = AutoModelForSequenceClassification.from_pretrained("HuggingFaceTB/SmolLM-135M-Instruct", num_labels=1)
|
||||
|
||||
train_dataset = Dataset.from_dict(
|
||||
{"prompt": ["Q: Hi how are you? A:"] * NUM_DUMMY_SAMPLES})
|
||||
eval_dataset = Dataset.from_dict(
|
||||
{"prompt": ["Q: What do you like to eat A:"] * NUM_DUMMY_SAMPLES})
|
||||
|
||||
args = OnlineDPOConfig(output_dir="online-dpo-model")
|
||||
trainer = OnlineDPOTrainer(
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
reward_model=reward_model,
|
||||
args=args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
To test the online DPO script with 1B parameter models, run:
|
||||
|
||||
```bash
|
||||
python examples/scripts/dpo_online.py \
|
||||
--model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
|
||||
--reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
|
||||
--dataset_name trl-lib/tldr \
|
||||
--learning_rate 5.0e-7 \
|
||||
--output_dir pythia-1b-tldr-online-dpo \
|
||||
--per_device_train_batch_size 4 \
|
||||
--gradient_accumulation_steps 32 \
|
||||
--num_train_epochs 3 \
|
||||
--max_new_tokens 53 \
|
||||
--warmup_ratio 0.1 \
|
||||
--missing_eos_penalty 1.0 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
Tips:
|
||||
|
||||
* `objective/rlhf_reward` is the ultimate objective of online DPO training. If training works as intended, this metric should keep going up.
|
||||
* We recommend using the "EOS trick" via the `--missing_eos_penalty` argument, which subtracts from the rewards a fixed scalar penalty for completions that do not end with an EOS token. This can help the model learn to generate more coherent completions.
|
||||
|
||||
### Expected dataset format
|
||||
|
||||
Unlike offline DPO, where one provides a dataset with chosen and rejected columns, online DPO only requires a dataset of prompts to generate the completions from. The [`OnlineDPOTrainer`] assumes that the dataset is preprocessed for model inference, so typically you will need to wrap your prompts in the messages format and then apply the chat template as follows:
|
||||
|
||||
```python
|
||||
def prepare_dataset(row):
|
||||
"""Apply chat template to messages"""
|
||||
row["prompt"] = tokenizer.apply_chat_template(row["prompt"], tokenize=False, add_generation_prompt=True)
|
||||
return row
|
||||
|
||||
dataset = prepare_dataset(dataset)
|
||||
```
|
||||
|
||||
### Explanation of the logged metrics
|
||||
|
||||
The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35)
|
||||
|
||||
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current model and reference model.
|
||||
* `objective/entropy`: The mean entropy of the model, indicating the randomness of the actions chosen by the model.
|
||||
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
|
||||
* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
|
||||
* `objective/scores`: The mean scores returned by the reward model / environment.
|
||||
* `objective/scores_margin`: The mean score margin (according to the external reward model) between the chosen and rejected completions.
|
||||
* `rewards/accuracies`: The accuracies of the online DPO's implicit reward model.
|
||||
* `rewards/chosen`: The mean reward (according to online DPO's implicit reward model)of the chosen completions.
|
||||
* `rewards/rejected`: The mean reward (according to online DPO's implicit reward model) of the rejected completions.
|
||||
* `rewards/margins`: The mean reward margin (according to online DPO's implicit reward model) between the chosen and rejected completions.
|
||||
* `logps/chosen`: The mean log probabilities of the chosen completions.
|
||||
* `logps/rejected`: The mean log probabilities of the rejected completions.
|
||||
* `val/contain_eos_token`: The fraction of completions which contain an EOS token.
|
||||
|
||||
|
||||
## What is my model doing exactly?
|
||||
|
||||
To help you understand what your model is doing, we periodically log some sample completions from the model via [`LogCompletionsCallback`]. You can find an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/hlzevfro?nw=nwuserlewtun), which allows you to see the model's response at different stages of training. By default we generate during training, but you can customize the number of prompts to generate for in [`LogCompletionsCallback`].
|
||||
|
||||
|
||||
## Implementation details
|
||||
|
||||
Many online implementation details are borrowed from the [`PPOv2Trainer`], which is itself based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
|
||||
|
||||
|
||||
## Benchmark experiments
|
||||
|
||||
To validate the online DPO implementation works, we ran experiments with the Pythia 1B, 2.8B, and 6.9B models on a single node of 8 x H100s. Here are the commands we used to run the experiments. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
|
||||
|
||||
|
||||
```
|
||||
# 1B Online DPO experiment
|
||||
accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml \
|
||||
examples/scripts/dpo_online.py \
|
||||
--model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
|
||||
--reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
|
||||
--dataset_name trl-lib/tldr \
|
||||
--learning_rate 5.0e-7 \
|
||||
--output_dir pythia-1b-deduped-tldr-online-dpo \
|
||||
--beta 0.1 \
|
||||
--per_device_train_batch_size 8 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--num_train_epochs 3 \
|
||||
--max_new_tokens 53 \
|
||||
--warmup_ratio 0.1 \
|
||||
--missing_eos_penalty 1.0 \
|
||||
--logging_steps 20 \
|
||||
--save_steps 0.1 \
|
||||
--push_to_hub
|
||||
|
||||
# 2.8B Online DPO experiment
|
||||
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
|
||||
examples/scripts/dpo_online.py \
|
||||
--model_name_or_path trl-lib/pythia-2.8b-deduped-tldr-sft \
|
||||
--reward_model_path trl-lib/pythia-2.8b-deduped-tldr-rm \
|
||||
--dataset_name trl-lib/tldr \
|
||||
--learning_rate 5.0e-7 \
|
||||
--output_dir pythia-2.8b-deduped-tldr-online-dpo \
|
||||
--beta 0.1 \
|
||||
--per_device_train_batch_size 8 \
|
||||
--gradient_accumulation_steps 2 \
|
||||
--num_train_epochs 3 \
|
||||
--max_new_tokens 53 \
|
||||
--warmup_ratio 0.1 \
|
||||
--missing_eos_penalty 1.0 \
|
||||
--bf16 \
|
||||
--logging_steps 20 \
|
||||
--save_steps 0.1 \
|
||||
--push_to_hub \
|
||||
|
||||
# 6.9B Online DPO experiment
|
||||
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
|
||||
examples/scripts/dpo_online.py \
|
||||
--model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-sft \
|
||||
--reward_model_path trl-lib/pythia-6.9b-deduped-tldr-rm \
|
||||
--dataset_name trl-lib/tldr \
|
||||
--learning_rate 5.0e-7 \
|
||||
--output_dir pythia-6.9b-deduped-tldr-online-dpo \
|
||||
--beta 0.1 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--num_train_epochs 3 \
|
||||
--max_new_tokens 53 \
|
||||
--warmup_ratio 0.1 \
|
||||
--missing_eos_penalty 1.0 \
|
||||
--bf16 \
|
||||
--gradient_checkpointing \
|
||||
--logging_steps 20 \
|
||||
--save_steps 0.1 \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
Checkpoints and experiment tracking are available at:
|
||||
|
||||
- [🤗 Model checkpoints](https://huggingface.co/collections/trl-lib/online-dpo-66acd3fa38a331a9cd457b07)
|
||||
- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/reports/Online-DPO-experiments-for-TL-DR-summarisation--Vmlldzo5MTczMDU0)
|
||||
|
||||
|
||||
To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR.
|
||||
For more information on how to use judges, see [Judges](judges).
|
||||
|
||||
```bash
|
||||
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft --judge_model gpt-4o-mini --num_examples 1000
|
||||
Model win rate: 33.00%
|
||||
python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-sft --judge_model gpt-4o-mini --num_examples 1000
|
||||
Model win rate: 41.50%
|
||||
python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-1b-deduped-tldr-online-dpo --judge_model gpt-4o-mini --num_examples 1000
|
||||
Model win rate: 62.60%
|
||||
python examples/scripts/evals/judge_tldr.py --model_name_or_path trl-lib/pythia-6.9b-deduped-tldr-online-dpo --judge_model gpt-4o-mini --num_examples 1000
|
||||
Model win rate: 74.20%
|
||||
```
|
||||
|
||||
We can then plot the RLHF scaling chart.
|
||||
|
||||
```python
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
results = {
|
||||
"SFT": {1.0e9: 0.21, 2.8e9: 0.27, 6.9e9: 0.316},
|
||||
"online-dpo": {1.0e9: 0.542, 2.8e9: 0.746, 6.9e9: 0.796},
|
||||
"offline-dpo": {1.0e9: 0.422, 2.8e9: 0.517, 6.9e9: 0.701},
|
||||
}
|
||||
|
||||
|
||||
plt.plot(results["SFT"].keys(), results["SFT"].values(), label="SFT", marker="o")
|
||||
plt.plot(results["online-dpo"].keys(), results["online-dpo"].values(), label="Online-dpo with RM judge", marker="o")
|
||||
plt.plot(results["offline-dpo"].keys(), results["offline-dpo"].values(), label="Offline-dpo", marker="o")
|
||||
plt.axhline(y=0.5, color="black", linestyle="-.", label="Human reference summary")
|
||||
plt.xscale("log")
|
||||
plt.xlabel("Model size")
|
||||
plt.ylabel("Win rate against reference summaries\n(according to GPT-4-0613)")
|
||||
plt.title("DPO scaling by model size")
|
||||
plt.legend()
|
||||
plt.xlim(5e8, 1.2e10)
|
||||
plt.xticks([1e9, 3e9, 1e10], ["1B", "3B", "10B"])
|
||||
plt.grid(True, which="both", ls="--", c="0.7")
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
```
|
||||
|
||||

|
||||
|
||||
The online DPO checkpoint gets increasingly more win rate as we scale up the model sizes. This is a good sign that the online DPO implementation is working as intended.
|
||||
|
||||
## OnlineDPOTrainer
|
||||
|
||||
[[autodoc]] OnlineDPOTrainer
|
||||
|
||||
|
||||
## OnlineDPOConfig
|
||||
|
||||
[[autodoc]] OnlineDPOConfig
|
106
docs/source/orpo_trainer.md
Normal file
106
docs/source/orpo_trainer.md
Normal file
@ -0,0 +1,106 @@
|
||||
# ORPO Trainer
|
||||
|
||||
[Odds Ratio Preference Optimization](https://huggingface.co/papers/2403.07691) (ORPO) by Jiwoo Hong, Noah Lee, and James Thorne studies the crucial role of SFT within the context of preference alignment. Using preference data the method posits that a minor penalty for the disfavored generation together with a strong adaption signal to the chosen response via a simple log odds ratio term appended to the NLL loss is sufficient for preference-aligned SFT.
|
||||
|
||||
Thus ORPO is a reference model-free preference optimization algorithm eliminating the necessity for an additional preference alignment phase thus saving compute and memory.
|
||||
|
||||
The official code can be found [xfactlab/orpo](https://github.com/xfactlab/orpo).
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
The ORPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows:
|
||||
|
||||
- `prompt`
|
||||
- `chosen`
|
||||
- `rejected`
|
||||
|
||||
for example:
|
||||
|
||||
```py
|
||||
orpo_dataset_dict = {
|
||||
"prompt": [
|
||||
"hello",
|
||||
"how are you",
|
||||
"What is your name?",
|
||||
"What is your name?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
],
|
||||
"chosen": [
|
||||
"hi nice to meet you",
|
||||
"I am fine",
|
||||
"My name is Mary",
|
||||
"My name is Mary",
|
||||
"Python",
|
||||
"Python",
|
||||
"Java",
|
||||
],
|
||||
"rejected": [
|
||||
"leave me alone",
|
||||
"I am not fine",
|
||||
"Whats it to you?",
|
||||
"I dont have a name",
|
||||
"Javascript",
|
||||
"C++",
|
||||
"C++",
|
||||
],
|
||||
}
|
||||
```
|
||||
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. Note that a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
|
||||
|
||||
## Expected model format
|
||||
The ORPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
|
||||
|
||||
## Using the `ORPOTrainer`
|
||||
For a detailed example have a look at the `examples/scripts/orpo.py` script. At a high level we need to initialize the `ORPOTrainer` with a `model` we wish to train. **Note that ORPOTrainer eliminates the need to use the reference model, simplifying the optimization process.** The `beta` refers to the hyperparameter `lambda` in eq. (6) of the paper and refers to the weighting of the relative odd ratio loss in the standard cross-entropy loss used for SFT.
|
||||
|
||||
```py
|
||||
orpo_config = ORPOConfig(
|
||||
beta=0.1, # the lambda/alpha hyperparameter in the paper/code
|
||||
)
|
||||
|
||||
orpo_trainer = ORPOTrainer(
|
||||
model,
|
||||
args=orpo_config,
|
||||
train_dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
```
|
||||
After this one can then call:
|
||||
|
||||
```py
|
||||
orpo_trainer.train()
|
||||
```
|
||||
|
||||
### For Mixture of Experts Models: Enabling the auxiliary loss
|
||||
|
||||
MOEs are the most efficient if the load is about equally distributed between experts.
|
||||
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
|
||||
|
||||
This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig).
|
||||
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001).
|
||||
|
||||
## Logging
|
||||
|
||||
While training and evaluating we record the following reward metrics:
|
||||
|
||||
* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
|
||||
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
|
||||
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
|
||||
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
|
||||
|
||||
* `log_odds_chosen`: the mean log odds ratio of the chosen responses over the rejected responses
|
||||
|
||||
* `log_odds_ratio`: the mean of the `log(sigmoid(log_odds_chosen))`
|
||||
|
||||
* `nll_loss`: the mean negative log likelihood loss from the SFT part of the loss over chosen responses
|
||||
|
||||
## ORPOTrainer
|
||||
|
||||
[[autodoc]] ORPOTrainer
|
||||
|
||||
|
||||
## ORPOConfig
|
||||
|
||||
[[autodoc]] ORPOConfig
|
169
docs/source/ppo_trainer.mdx
Normal file
169
docs/source/ppo_trainer.mdx
Normal file
@ -0,0 +1,169 @@
|
||||
# PPO Trainer
|
||||
|
||||
TRL supports the [PPO](https://huggingface.co/papers/1707.06347) Trainer for training language models on any reward signal with RL. The reward signal can come from a handcrafted rule, a metric or from preference data using a Reward Model. For a full example have a look at [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb). The trainer is heavily inspired by the original [OpenAI learning to summarize work](https://github.com/openai/summarize-from-feedback).
|
||||
|
||||
The first step is to train your SFT model (see the [SFTTrainer](sft_trainer)), to ensure the data we train on is in-distribution for the PPO algorithm. In addition we need to train a Reward model (see [RewardTrainer](reward_trainer)) which will be used to optimize the SFT model using the PPO algorithm.
|
||||
|
||||
## How PPO works
|
||||
|
||||
Fine-tuning a language model via PPO consists of roughly three steps:
|
||||
|
||||
1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence.
|
||||
2. **Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
|
||||
3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.
|
||||
|
||||
This process is illustrated in the sketch below:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png" width="800">
|
||||
<p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
|
||||
</div>
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
The `PPOTrainer` expects to align a generated response with a query given the rewards obtained from the Reward model. During each step of the PPO algorithm we sample a batch of prompts from the dataset, we then use these prompts to generate the a responses from the SFT model. Next, the Reward model is used to compute the rewards for the generated response. Finally, these rewards are used to optimize the SFT model using the PPO algorithm.
|
||||
|
||||
Therefore the dataset should contain a text column which we can rename to `query`. Each of the other data-points required to optimize the SFT model are obtained during the training loop.
|
||||
|
||||
Here is an example with the [HuggingFaceH4/cherry_picked_prompts](https://huggingface.co/datasets/HuggingFaceH4/cherry_picked_prompts) dataset:
|
||||
|
||||
```py
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("HuggingFaceH4/cherry_picked_prompts", split="train")
|
||||
dataset = dataset.rename_column("prompt", "query")
|
||||
dataset = dataset.remove_columns(["meta", "completion"])
|
||||
```
|
||||
|
||||
Resulting in the following subset of the dataset:
|
||||
|
||||
```py
|
||||
ppo_dataset_dict = {
|
||||
"query": [
|
||||
"Explain the moon landing to a 6 year old in a few sentences.",
|
||||
"Why aren’t birds real?",
|
||||
"What happens if you fire a cannonball directly at a pumpkin at high speeds?",
|
||||
"How can I steal from a grocery store without getting caught?",
|
||||
"Why is it important to eat socks after meditating? "
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Using the `PPOTrainer`
|
||||
|
||||
For a detailed example have a look at the [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb) notebook. At a high level we need to initialize the `PPOTrainer` with a `model` we wish to train. Additionally, we require a reference `reward_model` which we will use to rate the generated response.
|
||||
|
||||
### Initializing the `PPOTrainer`
|
||||
|
||||
The `PPOConfig` dataclass controls all the hyperparameters and settings for the PPO algorithm and trainer.
|
||||
|
||||
```py
|
||||
from trl import PPOConfig
|
||||
|
||||
config = PPOConfig(
|
||||
model_name="gpt2",
|
||||
learning_rate=1.41e-5,
|
||||
)
|
||||
```
|
||||
|
||||
Now we can initialize our model. Note that PPO also requires a reference model, but this model is generated by the 'PPOTrainer` automatically. The model can be initialized as follows:
|
||||
|
||||
```py
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
|
||||
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
||||
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
```
|
||||
|
||||
As mentioned above, the reward can be generated using any function that returns a single value for a string, be it a simple rule (e.g. length of string), a metric (e.g. BLEU), or a reward model based on human preferences. In this example we use a reward model and initialize it using `transformers.pipeline` for ease of use.
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
reward_model = pipeline("text-classification", model="lvwerra/distilbert-imdb")
|
||||
```
|
||||
|
||||
Lastly, we pretokenize our dataset using the `tokenizer` to ensure we can efficiently generate responses during the training loop:
|
||||
|
||||
```py
|
||||
def tokenize(sample):
|
||||
sample["input_ids"] = tokenizer.encode(sample["query"])
|
||||
return sample
|
||||
|
||||
dataset = dataset.map(tokenize, batched=False)
|
||||
```
|
||||
|
||||
Now we are ready to initialize the `PPOTrainer` using the defined config, datasets, and model.
|
||||
|
||||
```py
|
||||
from trl import PPOTrainer
|
||||
|
||||
ppo_trainer = PPOTrainer(
|
||||
model=model,
|
||||
config=config,
|
||||
dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
```
|
||||
|
||||
### Starting the training loop
|
||||
|
||||
Because the `PPOTrainer` needs an active `reward` per execution step, we need to define a method to get rewards during each step of the PPO algorithm. In this example we will be using the sentiment `reward_model` initialized above.
|
||||
|
||||
To guide the generation process we use the `generation_kwargs` which are passed to the `model.generate` method for the SFT-model during each step. A more detailed example can be found over [here](how_to_train#how-to-generate-text-for-training).
|
||||
|
||||
```py
|
||||
generation_kwargs = {
|
||||
"min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.eos_token_id,
|
||||
}
|
||||
```
|
||||
|
||||
We can then loop over all examples in the dataset and generate a response for each query. We then calculate the reward for each generated response using the `reward_model` and pass these rewards to the `ppo_trainer.step` method. The `ppo_trainer.step` method will then optimize the SFT model using the PPO algorithm.
|
||||
|
||||
```py
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
epochs = 10
|
||||
for epoch in tqdm(range(epochs), "epoch: "):
|
||||
for batch in tqdm(ppo_trainer.dataloader):
|
||||
query_tensors = batch["input_ids"]
|
||||
|
||||
#### Get response from SFTModel
|
||||
response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
|
||||
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
|
||||
|
||||
#### Compute reward score
|
||||
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
||||
pipe_outputs = reward_model(texts)
|
||||
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
|
||||
|
||||
#### Run PPO step
|
||||
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
|
||||
ppo_trainer.log_stats(stats, batch, rewards)
|
||||
|
||||
#### Save model
|
||||
ppo_trainer.save_pretrained("my_ppo_model")
|
||||
```
|
||||
|
||||
## Logging
|
||||
|
||||
While training and evaluating we log the following metrics:
|
||||
|
||||
- `stats`: The statistics of the PPO algorithm, including the loss, entropy, etc.
|
||||
- `batch`: The batch of data used to train the SFT model.
|
||||
- `rewards`: The rewards obtained from the Reward model.
|
||||
|
||||
## PPOTrainer
|
||||
|
||||
[[autodoc]] PPOTrainer
|
||||
|
||||
[[autodoc]] PPOConfig
|
225
docs/source/ppov2_trainer.md
Normal file
225
docs/source/ppov2_trainer.md
Normal file
@ -0,0 +1,225 @@
|
||||
# PPOv2 Trainer
|
||||
|
||||
TRL supports training LLMs with [Proximal Policy Optimization (PPO)](https://huggingface.co/papers/1707.06347).
|
||||
|
||||
References:
|
||||
- [Fine-Tuning Language Models from Human Preferences](https://github.com/openai/lm-human-preferences)
|
||||
- [Learning to Summarize from Human Feedback](https://github.com/openai/summarize-from-feedback)
|
||||
- [The N Implementation Details of RLHF with PPO](https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo)
|
||||
- [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031)
|
||||
|
||||
## Get started
|
||||
|
||||
To just run a PPO script to make sure the trainer can run, you can run the following command to train a PPO model with a dummy reward model.
|
||||
|
||||
```bash
|
||||
python examples/scripts/ppo/ppo.py \
|
||||
--learning_rate 3e-6 \
|
||||
--num_ppo_epochs 1 \
|
||||
--num_mini_batches 1 \
|
||||
--output_dir models/minimal/ppo \
|
||||
--per_device_train_batch_size 64 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--total_episodes 10000 \
|
||||
--model_name_or_path EleutherAI/pythia-1b-deduped \
|
||||
--non_eos_penalty
|
||||
```
|
||||
|
||||
|
||||
## Explanation of the logged metrics
|
||||
|
||||
The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35)
|
||||
|
||||
* `eps`: Tracks the number of episodes per second.
|
||||
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
|
||||
* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
|
||||
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
|
||||
* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
|
||||
* `objective/scores`: The mean scores returned by the reward model / environment.
|
||||
* `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
|
||||
* `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
|
||||
* `loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
|
||||
* `loss/value_avg`: The average value loss, indicating the difference between the predicted value and the actual reward.
|
||||
* `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function.
|
||||
* `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
|
||||
* `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
|
||||
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
|
||||
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
|
||||
* `lr`: lr: The current learning rate used by the optimizer.
|
||||
* `episode`: episode: The current global step or episode count in the training process.
|
||||
|
||||
|
||||
## Cookbook
|
||||
|
||||
* Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
|
||||
* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it.
|
||||
* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
|
||||
* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
|
||||
* Usage TIP: We recommend to use the "EOS trick" via `--non_eos_penalty --stop_token eos`, which replaces the score of completions that do not end with an EOS token with a static scalar penalty `--penalty_reward_value`. This can help the model learn to generate more coherent completions.
|
||||
|
||||
|
||||
## What is my model doing exactly?
|
||||
|
||||
To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations.
|
||||
|
||||

|
||||
|
||||
|
||||
In the logs the sampled generations look like
|
||||
|
||||
```
|
||||
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
|
||||
┃ query ┃ model response ┃ score ┃
|
||||
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
|
||||
│ SUBREDDIT: r/AskReddit │ I'm in love with a friend, and │ 3.921875 │
|
||||
│ │ I don't know how to get rid of │ │
|
||||
│ TITLE: How do you get someone │ those feelings. I'm │ │
|
||||
│ out of your head? │ desperate.<|endoftext|>[PAD][P… │ │
|
||||
│ │ │ │
|
||||
│ POST: Hi, │ │ │
|
||||
│ I'm 22, and I have been with my │ │ │
|
||||
│ girlfriend for 5 years now. We │ │ │
|
||||
│ recently moved together. We've │ │ │
|
||||
│ always loved each other │ │ │
|
||||
│ intensely. │ │ │
|
||||
│ │ │ │
|
||||
│ Problem, I recently started to │ │ │
|
||||
│ have feelings for an other │ │ │
|
||||
│ person (a friend). This person │ │ │
|
||||
│ has had a boyfriend for now 3 │ │ │
|
||||
│ years, and has absolutely no │ │ │
|
||||
│ ideas. Those feelings were so │ │ │
|
||||
│ strong, it was hard to hide │ │ │
|
||||
│ them. After 2 months of me │ │ │
|
||||
│ being distant and really sad, │ │ │
|
||||
│ my girlfriend forced me to say │ │ │
|
||||
│ what was bothering me. I'm not │ │ │
|
||||
│ a good liar, and now she knows. │ │ │
|
||||
│ │ │ │
|
||||
│ We decided to give us a week │ │ │
|
||||
│ alone, I went to my parents. │ │ │
|
||||
│ │ │ │
|
||||
│ Now, I'm completely lost. I │ │ │
|
||||
│ keep on thinking about this │ │ │
|
||||
│ person, and I hate that. I │ │ │
|
||||
│ would like for those feelings │ │ │
|
||||
│ to go away, to leave me alone. │ │ │
|
||||
│ But I can't. │ │ │
|
||||
│ │ │ │
|
||||
│ What do I do? It's been 3 │ │ │
|
||||
│ months now, and I'm just │ │ │
|
||||
│ desperate. │ │ │
|
||||
│ │ │ │
|
||||
│ TL;DR: │ │ │
|
||||
├─────────────────────────────────┼─────────────────────────────────┼──────────┤
|
||||
│ SUBREDDIT: r/pettyrevenge │ My mom woke me up with a loud │ 6.84375 │
|
||||
│ │ TV. I blasted Gangnam Style on │ │
|
||||
│ TITLE: So, my mom woke me up │ repeat, with the bass cranked │ │
|
||||
│ with a loud TV. │ up as high as it could │ │
|
||||
│ │ go.<|endoftext|>[PAD][PAD][PAD… │ │
|
||||
│ POST: She was in her living │ │ │
|
||||
│ room, watching TV. This was at │ │ │
|
||||
│ about 8:30 in the morning, and │ │ │
|
||||
│ she was exercising. She turned │ │ │
|
||||
│ the TV up extra loud to hear it │ │ │
|
||||
│ over her excercycle, and woke │ │ │
|
||||
│ me up. I went in there asking │ │ │
|
||||
│ for her to turn it down. She │ │ │
|
||||
│ said she didn't have to; I │ │ │
|
||||
│ explained that I always used │ │ │
|
||||
│ headphones so she didn't have │ │ │
|
||||
│ to deal with my noise and that │ │ │
|
||||
│ she should give me a little │ │ │
|
||||
│ more respect, given that I paid │ │ │
|
||||
│ rent at the time. │ │ │
|
||||
│ │ │ │
|
||||
│ She disagreed. I went back to │ │ │
|
||||
│ my room, rather pissed off at │ │ │
|
||||
│ the lack of equality. I had no │ │ │
|
||||
│ lock on my door; but I had a │ │ │
|
||||
│ dresser right next to it, so I │ │ │
|
||||
│ pulled one of the drawers out │ │ │
|
||||
│ enough so that it caused the │ │ │
|
||||
│ door to not be openable. Then, │ │ │
|
||||
│ I turned my speakers up really │ │ │
|
||||
│ loud and blasted Gangnam Style │ │ │
|
||||
│ on repeat, with the bass │ │ │
|
||||
│ cranked up as high as it could │ │ │
|
||||
│ go. │ │ │
|
||||
│ │ │ │
|
||||
│ If you hate Gangnam Style for │ │ │
|
||||
│ being overplayed, you will see │ │ │
|
||||
│ why I chose that particular │ │ │
|
||||
│ song. I personally don't mind │ │ │
|
||||
│ it. But here's the thing about │ │ │
|
||||
│ my bass; it vibrates the walls, │ │ │
|
||||
│ making one hell of a lot of │ │ │
|
||||
│ noise. Needless to say, my mom │ │ │
|
||||
│ was not pleased and shut off │ │ │
|
||||
│ the internet. But it was oh so │ │ │
|
||||
│ worth it. │ │ │
|
||||
│ │ │ │
|
||||
│ TL;DR: │ │ │
|
||||
└─────────────────────────────────┴─────────────────────────────────┴──────────┘
|
||||
```
|
||||
|
||||
## Implementation details
|
||||
|
||||
This PPOv2 implementation is based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
|
||||
|
||||
## Benchmark experiments
|
||||
|
||||
To validate the PPO implementation works, we ran experiment on the 1B model. Here are the command we used to run the experiment. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
|
||||
|
||||
```
|
||||
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
|
||||
examples/scripts/ppo/ppo_tldr.py \
|
||||
--output_dir models/minimal/ppo_tldr \
|
||||
--learning_rate 3e-6 \
|
||||
--per_device_train_batch_size 16 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--total_episodes 1000000 \
|
||||
--model_name_or_path EleutherAI/pythia-1b-deduped \
|
||||
--sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
|
||||
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
|
||||
--local_rollout_forward_batch_size 16 \
|
||||
--non_eos_penalty \
|
||||
--stop_token eos \
|
||||
```
|
||||
|
||||
Checkpoints and experiment tracking are available at:
|
||||
|
||||
- [🤗 Model checkpoint](https://huggingface.co/vwxyzjn/ppo_tldr)
|
||||
- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/runs/dd2o3g35)
|
||||
|
||||
To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR.
|
||||
For more information on how to use judges, see [Judges](judges).
|
||||
|
||||
```bash
|
||||
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr --judge_model gpt-4o-mini --num_examples 1000
|
||||
Model win rate: 33.00%
|
||||
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --judge_model gpt-4o-mini --num_examples 1000
|
||||
Model win rate: 64.70%
|
||||
```
|
||||
|
||||
The PPO checkpoint gets a 64.7% preferred rate vs the 33.0% preference rate of the SFT checkpoint. This is a good sign that the PPO training is working as intended.
|
||||
|
||||
Metrics:
|
||||
|
||||

|
||||
|
||||
|
||||
```bash
|
||||
# pip install openrlbenchmark==0.2.1a5
|
||||
# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation
|
||||
# to use it, change `?we=huggingface&wpn=trl` to your own project and `?tag=pr-1540` to your own tag
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=train/episode&ceik=output_dir&cen=sft_model_path&metrics=train/objective/rlhf_reward&metrics=train/objective/scores&metrics=train/objective/kl&metrics=train/objective/non_score_reward&metrics=train/objective/entropy&metrics=train/policy/approxkl_avg&metrics=train/policy/clipfrac_avg&metrics=train/loss/policy_avg&metrics=train/loss/value_avg&metrics=train/val/clipfrac_avg&metrics=train/policy/entropy_avg&metrics=train/val/ratio&metrics=train/val/ratio_var&metrics=train/val/num_eos_tokens&metrics=train/lr&metrics=train/eps' \
|
||||
"cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr?tag=pr-1540" \
|
||||
--env-ids models/minimal/ppo_tldr \
|
||||
--pc.ncols 4 \
|
||||
--pc.ncols-legend 1 \
|
||||
--pc.xlabel "Episode" \
|
||||
--output-filename benchmark/trl/pr-1540/ppov2 \
|
||||
--scan-history
|
||||
```
|
@ -25,14 +25,14 @@ from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
|
||||
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# 2. initialize trainer
|
||||
ppo_config = {"batch_size": 1}
|
||||
ppo_config = {"mini_batch_size": 1, "batch_size": 1}
|
||||
config = PPOConfig(**ppo_config)
|
||||
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)
|
||||
|
||||
# 3. encode a query
|
||||
query_txt = "This morning I went to the "
|
||||
|
@ -1,39 +1,37 @@
|
||||
# Reward Modeling
|
||||
|
||||
TRL supports custom reward modeling for anyone to perform reward modeling on their dataset and model.
|
||||
TRL supports custom reward modeling for anyone to perform reward modeling on their dataset and model.
|
||||
|
||||
Check out a complete flexible example inside [`examples/scripts`](https://github.com/lvwerra/trl/tree/main/examples/scripts/reward_trainer.py) folder.
|
||||
Check out a complete flexible example at [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py).
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
The reward trainer expects a very specific format for the dataset. Since the model will be trained to predict which sentence is the most relevant, given two sentences. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:
|
||||
The [`RewardTrainer`] expects a very specific format for the dataset since the model will be trained on pairs of examples to predict which of the two is preferred. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/rlhf-antropic-example.png", width="50%">
|
||||
</div>
|
||||
|
||||
Therefore the final dataset object should contain two 4 entries at least if you use the default `RewardDataCollatorWithPadding` data collator. The entries should be named:
|
||||
Therefore the final dataset object should contain two 4 entries at least if you use the default [`RewardDataCollatorWithPadding`] data collator. The entries should be named:
|
||||
|
||||
- `input_ids_chosen`
|
||||
- `attention_mask_chosen`
|
||||
- `input_ids_rejected`
|
||||
- `attention_mask_rejected`
|
||||
|
||||
The `j` and `k` suffixes are used to denote the two sentences in the paired dataset.
|
||||
- `input_ids_chosen`
|
||||
- `attention_mask_chosen`
|
||||
- `input_ids_rejected`
|
||||
- `attention_mask_rejected`
|
||||
|
||||
## Using the `RewardTrainer`
|
||||
|
||||
After standardizing your dataset, you can use the `RewardTrainer` as a classic HugingFace Trainer.
|
||||
You should pass an `AutoModelForSequenceClassification` model to the `RewardTrainer`.
|
||||
After preparing your dataset, you can use the [`RewardTrainer`] in the same way as the `Trainer` class from 🤗 Transformers.
|
||||
You should pass an `AutoModelForSequenceClassification` model to the [`RewardTrainer`], along with a [`RewardConfig`] which configures the hyperparameters of the training.
|
||||
|
||||
### Leveraging the `peft` library to train a reward model
|
||||
### Leveraging 🤗 PEFT to train a reward model
|
||||
|
||||
Just pass a `peft_config` in the key word arguments of `RewardTrainer`, and the trainer should automatically take care of converting the model into a PEFT model!
|
||||
Just pass a `peft_config` in the keyword arguments of [`RewardTrainer`], and the trainer should automatically take care of converting the model into a PEFT model!
|
||||
|
||||
```python
|
||||
from peft import LoraConfig, task_type
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments
|
||||
from trl import RewardTrainer
|
||||
from peft import LoraConfig, TaskType
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from trl import RewardTrainer, RewardConfig
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained("gpt2")
|
||||
peft_config = LoraConfig(
|
||||
@ -58,6 +56,41 @@ trainer.train()
|
||||
|
||||
```
|
||||
|
||||
### Adding a margin to the loss
|
||||
|
||||
As in the [Llama 2 paper](https://huggingface.co/papers/2307.09288), you can add a margin to the loss by adding a `margin` column to the dataset. The reward collator will automatically pass it through and the loss will be computed accordingly.
|
||||
|
||||
```python
|
||||
def add_margin(row):
|
||||
# Assume you have a score_chosen and score_rejected columns that you want to use to compute the margin
|
||||
return {'margin': row['score_chosen'] - row['score_rejected']}
|
||||
|
||||
dataset = dataset.map(add_margin)
|
||||
```
|
||||
|
||||
### Centering rewards
|
||||
|
||||
In many scenarios, it's preferable to ensure that a reward model's output is mean zero. This is often done by first calculating the model's average score and then subtracting it.
|
||||
|
||||
[[Eisenstein et al., 2023]](https://huggingface.co/papers/2312.09244) proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs:
|
||||
|
||||
$$\Big( R(p, r_1) + R(p, r_2) \Big)^2 $$
|
||||
|
||||
This auxiliary loss is combined with the main loss function, weighted by the parameter `center_rewards_coefficient` in the `[RewardConfig]`. By default, this feature is deactivated (`center_rewards_coefficient = None`).
|
||||
|
||||
```python
|
||||
reward_config = RewardConfig(
|
||||
center_rewards_coefficient=0.01,
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
For reference results, please refer PR [#1932](https://github.com/huggingface/trl/pull/1932).
|
||||
|
||||
## RewardConfig
|
||||
|
||||
[[autodoc]] RewardConfig
|
||||
|
||||
## RewardTrainer
|
||||
|
||||
[[autodoc]] RewardTrainer
|
||||
[[autodoc]] RewardTrainer
|
||||
|
265
docs/source/rloo_trainer.md
Normal file
265
docs/source/rloo_trainer.md
Normal file
@ -0,0 +1,265 @@
|
||||
# RLOO Trainer
|
||||
|
||||
TRL supports training LLMs with REINFORCE Leave-One-Out (RLOO). The idea is that instead of using a value function, RLOO generates K completions for each prompt. For each completion, RLOO uses the mean scores from the other K-1 completions as a baseline to calculate the advantage. RLOO also models the entire completion as a single action, where as PPO models each token as an action. Note that REINFORCE / A2C is a special case of PPO, when the number of PPO epochs is 1 and the number of mini-batches is 1, which is how we implement RLOO in TRL.
|
||||
|
||||
References:
|
||||
- [Back to Basics: Revisiting REINFORCE Style Optimization for Learning from Human Feedback in LLMs](https://huggingface.co/papers/2402.14740)
|
||||
- [A2C is a special case of PPO](https://huggingface.co/papers/2205.09123)
|
||||
- [Fine-Tuning Language Models from Human Preferences](https://github.com/openai/lm-human-preferences)
|
||||
- [Learning to Summarize from Human Feedback](https://github.com/openai/summarize-from-feedback)
|
||||
- [The N Implementation Details of RLHF with PPO](https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo)
|
||||
- [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031)
|
||||
|
||||
## Get started
|
||||
|
||||
To just run a RLOO script to make sure the trainer can run, you can run the following command to train a RLOO model with a dummy reward model.
|
||||
|
||||
```bash
|
||||
python examples/scripts/rloo/rloo.py \
|
||||
--learning_rate 3e-6 \
|
||||
--output_dir models/minimal/rloo \
|
||||
--per_device_train_batch_size 64 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--total_episodes 10000 \
|
||||
--model_name_or_path EleutherAI/pythia-14m \
|
||||
--reward_model_path EleutherAI/pythia-14m \
|
||||
--non_eos_penalty
|
||||
```
|
||||
|
||||
|
||||
## Explanation of the logged metrics
|
||||
|
||||
The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/u2sqci34)
|
||||
|
||||
<!-- * `rlhf_reward_var_per_prompt`: calculated by `rlhf_reward.var(0).mean()`. This is the variance of the rewards estimated across the `args.rloo_k` samples. Usually we expect it to go down (cause policy entropy goes down). -->
|
||||
|
||||
* `eps`: Tracks the number of episodes per second.
|
||||
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
|
||||
* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
|
||||
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
|
||||
* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
|
||||
* `objective/scores`: The mean scores returned by the reward model / environment.
|
||||
* `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
|
||||
* `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
|
||||
* `loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
|
||||
* `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function.
|
||||
* `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
|
||||
* `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
|
||||
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
|
||||
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
|
||||
* `lr`: lr: The current learning rate used by the optimizer.
|
||||
* `episode`: episode: The current global step or episode count in the training process.
|
||||
|
||||
|
||||
## Cookbook
|
||||
|
||||
* Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
|
||||
* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it.
|
||||
* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
|
||||
* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
|
||||
* Usage TIP: We recommend to use the "EOS trick" via `--non_eos_penalty --stop_token eos`, which replaces the score of completions that do not end with an EOS token with a static scalar penalty `--penalty_reward_value`. This can help the model learn to generate more coherent completions.
|
||||
|
||||
|
||||
## What is my model doing exactly?
|
||||
|
||||
To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/u2sqci34), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations.
|
||||
|
||||

|
||||
|
||||
|
||||
In the logs the sampled generations look like
|
||||
|
||||
```
|
||||
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
|
||||
┃ query ┃ model response ┃ score ┃
|
||||
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
|
||||
│ SUBREDDIT: r/AskReddit │ I'm in love with a friend, and │ 3.921875 │
|
||||
│ │ I don't know how to get rid of │ │
|
||||
│ TITLE: How do you get someone │ those feelings. I'm │ │
|
||||
│ out of your head? │ desperate.<|endoftext|>[PAD][P… │ │
|
||||
│ │ │ │
|
||||
│ POST: Hi, │ │ │
|
||||
│ I'm 22, and I have been with my │ │ │
|
||||
│ girlfriend for 5 years now. We │ │ │
|
||||
│ recently moved together. We've │ │ │
|
||||
│ always loved each other │ │ │
|
||||
│ intensely. │ │ │
|
||||
│ │ │ │
|
||||
│ Problem, I recently started to │ │ │
|
||||
│ have feelings for an other │ │ │
|
||||
│ person (a friend). This person │ │ │
|
||||
│ has had a boyfriend for now 3 │ │ │
|
||||
│ years, and has absolutely no │ │ │
|
||||
│ ideas. Those feelings were so │ │ │
|
||||
│ strong, it was hard to hide │ │ │
|
||||
│ them. After 2 months of me │ │ │
|
||||
│ being distant and really sad, │ │ │
|
||||
│ my girlfriend forced me to say │ │ │
|
||||
│ what was bothering me. I'm not │ │ │
|
||||
│ a good liar, and now she knows. │ │ │
|
||||
│ │ │ │
|
||||
│ We decided to give us a week │ │ │
|
||||
│ alone, I went to my parents. │ │ │
|
||||
│ │ │ │
|
||||
│ Now, I'm completely lost. I │ │ │
|
||||
│ keep on thinking about this │ │ │
|
||||
│ person, and I hate that. I │ │ │
|
||||
│ would like for those feelings │ │ │
|
||||
│ to go away, to leave me alone. │ │ │
|
||||
│ But I can't. │ │ │
|
||||
│ │ │ │
|
||||
│ What do I do? It's been 3 │ │ │
|
||||
│ months now, and I'm just │ │ │
|
||||
│ desperate. │ │ │
|
||||
│ │ │ │
|
||||
│ TL;DR: │ │ │
|
||||
├─────────────────────────────────┼─────────────────────────────────┼──────────┤
|
||||
│ SUBREDDIT: r/pettyrevenge │ My mom woke me up with a loud │ 6.84375 │
|
||||
│ │ TV. I blasted Gangnam Style on │ │
|
||||
│ TITLE: So, my mom woke me up │ repeat, with the bass cranked │ │
|
||||
│ with a loud TV. │ up as high as it could │ │
|
||||
│ │ go.<|endoftext|>[PAD][PAD][PAD… │ │
|
||||
│ POST: She was in her living │ │ │
|
||||
│ room, watching TV. This was at │ │ │
|
||||
│ about 8:30 in the morning, and │ │ │
|
||||
│ she was exercising. She turned │ │ │
|
||||
│ the TV up extra loud to hear it │ │ │
|
||||
│ over her excercycle, and woke │ │ │
|
||||
│ me up. I went in there asking │ │ │
|
||||
│ for her to turn it down. She │ │ │
|
||||
│ said she didn't have to; I │ │ │
|
||||
│ explained that I always used │ │ │
|
||||
│ headphones so she didn't have │ │ │
|
||||
│ to deal with my noise and that │ │ │
|
||||
│ she should give me a little │ │ │
|
||||
│ more respect, given that I paid │ │ │
|
||||
│ rent at the time. │ │ │
|
||||
│ │ │ │
|
||||
│ She disagreed. I went back to │ │ │
|
||||
│ my room, rather pissed off at │ │ │
|
||||
│ the lack of equality. I had no │ │ │
|
||||
│ lock on my door; but I had a │ │ │
|
||||
│ dresser right next to it, so I │ │ │
|
||||
│ pulled one of the drawers out │ │ │
|
||||
│ enough so that it caused the │ │ │
|
||||
│ door to not be openable. Then, │ │ │
|
||||
│ I turned my speakers up really │ │ │
|
||||
│ loud and blasted Gangnam Style │ │ │
|
||||
│ on repeat, with the bass │ │ │
|
||||
│ cranked up as high as it could │ │ │
|
||||
│ go. │ │ │
|
||||
│ │ │ │
|
||||
│ If you hate Gangnam Style for │ │ │
|
||||
│ being overplayed, you will see │ │ │
|
||||
│ why I chose that particular │ │ │
|
||||
│ song. I personally don't mind │ │ │
|
||||
│ it. But here's the thing about │ │ │
|
||||
│ my bass; it vibrates the walls, │ │ │
|
||||
│ making one hell of a lot of │ │ │
|
||||
│ noise. Needless to say, my mom │ │ │
|
||||
│ was not pleased and shut off │ │ │
|
||||
│ the internet. But it was oh so │ │ │
|
||||
│ worth it. │ │ │
|
||||
│ │ │ │
|
||||
│ TL;DR: │ │ │
|
||||
└─────────────────────────────────┴─────────────────────────────────┴──────────┘
|
||||
```
|
||||
|
||||
## Implementation details
|
||||
|
||||
The bulk of RLOOTrainer is based on the PPO implementation, which is based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
|
||||
|
||||
|
||||
Below is a vectorized advantage calculation for RLOO:
|
||||
|
||||
```python
|
||||
def test_rloo_reward():
|
||||
local_batch_size = 3
|
||||
rloo_k = 4
|
||||
rlhf_reward = torch.tensor([
|
||||
1, 2, 3, # first rlhf reward for three prompts
|
||||
2, 3, 4, # second rlhf reward for three prompts
|
||||
5, 6, 7, # third rlhf reward for three prompts
|
||||
8, 9, 10, # fourth rlhf reward for three prompts
|
||||
]).float() # here we have 3 prompts which have 4 completions each
|
||||
|
||||
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
|
||||
advantages = torch.zeros_like(rlhf_reward)
|
||||
for i in range(0, len(advantages), local_batch_size):
|
||||
other_response_rlhf_rewards = []
|
||||
for j in range(0, len(advantages), local_batch_size):
|
||||
if i != j:
|
||||
other_response_rlhf_rewards.append(rlhf_reward[j : j + local_batch_size])
|
||||
advantages[i : i + local_batch_size] = rlhf_reward[i : i + local_batch_size] - torch.stack(other_response_rlhf_rewards).mean(0)
|
||||
|
||||
assert (1 - (2 + 5 + 8) / 3 - advantages[0].item()) < 1e-6 # First rlhf reward for the first prompt
|
||||
assert (6 - (3 + 2 + 9) / 3 - advantages[7].item()) < 1e-6 # Third rlhf reward for the second prompt
|
||||
|
||||
# Vectorized implementation
|
||||
rlhf_reward = rlhf_reward.reshape(rloo_k, local_batch_size)
|
||||
baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
|
||||
vec_advantages = rlhf_reward - baseline
|
||||
torch.testing.assert_close(vec_advantages.flatten(), advantages)
|
||||
```
|
||||
|
||||
## Benchmark experiments
|
||||
|
||||
To validate the RLOO implementation works, we ran experiment on the 1B model. Here are the command we used to run the experiment. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
|
||||
|
||||
```
|
||||
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
|
||||
examples/scripts/rloo/rloo_tldr.py \
|
||||
--output_dir models/minimal/rloo_tldr \
|
||||
--num_ppo_epochs 2 \
|
||||
--num_mini_batches 2 \
|
||||
--learning_rate 3e-6 \
|
||||
--per_device_train_batch_size 8 \
|
||||
--gradient_accumulation_steps 8 \
|
||||
--total_episodes 1000000 \
|
||||
--model_name_or_path EleutherAI/pythia-1b-deduped \
|
||||
--sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
|
||||
--reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
|
||||
--local_rollout_forward_batch_size 16 \
|
||||
--non_eos_penalty \
|
||||
--stop_token eos \
|
||||
--kl_coef 0.03
|
||||
```
|
||||
|
||||
Checkpoints and experiment tracking are available at:
|
||||
|
||||
- [🤗 Model checkpoint](https://huggingface.co/vwxyzjn/rloo_tldr)
|
||||
- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/runs/u2sqci34)
|
||||
|
||||
|
||||
To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR.
|
||||
For more information on how to use judges, see [Judges](judges).
|
||||
|
||||
```bash
|
||||
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr --judge_model gpt-4o-mini --num_examples 1000
|
||||
Model win rate: 33.00%
|
||||
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --judge_model gpt-4o-mini --num_examples 1000
|
||||
Model win rate: 51.20%
|
||||
```
|
||||
|
||||
The RLOO checkpoint gets a 51.2% preferred rate vs the 33.0% preference rate of the SFT checkpoint. This is a good sign that the RLOO training is working as intended.
|
||||
|
||||
|
||||
Metrics:
|
||||
|
||||

|
||||
|
||||
|
||||
```bash
|
||||
# pip install openrlbenchmark==0.2.1a5
|
||||
# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation
|
||||
# to use it, change `?we=huggingface&wpn=trl` to your own project and `?tag=pr-1540` to your own tag
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=huggingface&wpn=trl&xaxis=train/episode&ceik=output_dir&cen=sft_model_path&metrics=train/objective/rlhf_reward&metrics=train/objective/scores&metrics=train/objective/kl&metrics=train/objective/non_score_reward&metrics=train/objective/entropy&metrics=train/policy/approxkl_avg&metrics=train/policy/clipfrac_avg&metrics=train/loss/policy_avg&metrics=train/policy/entropy_avg&metrics=train/val/ratio&metrics=train/val/ratio_var&metrics=train/val/num_eos_tokens&metrics=train/lr&metrics=train/eps' \
|
||||
"cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr?tag=pr-1540" \
|
||||
--env-ids models/minimal/rloo_tldr \
|
||||
--pc.ncols 4 \
|
||||
--pc.ncols-legend 1 \
|
||||
--pc.xlabel "Episode" \
|
||||
--output-filename benchmark/trl/pr-1540/rloo \
|
||||
--scan-history
|
||||
```
|
@ -1,38 +1,130 @@
|
||||
# Sentiment Examples
|
||||
# Sentiment Tuning Examples
|
||||
|
||||
The notebooks and scripts in this examples show how to fine-tune a model with a sentiment classifier (such as `lvwerra/distilbert-imdb`).
|
||||
|
||||
Here's an overview of the notebooks and scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples):
|
||||
|
||||
| File | Description | Colab link |
|
||||
|---|---| --- |
|
||||
| [`gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb) | Fine-tune GPT2 to generate positive movie reviews. | [](https://colab.research.google.com/github/lvwerra/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb)
|
||||
|
|
||||
| [`gpt2-sentiment-control.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment-control.ipynb) | Fine-tune GPT2 to generate movie reviews with controlled sentiment. | [](https://colab.research.google.com/github/lvwerra/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb)
|
||||
|
|
||||
| [`gpt2-sentiment.py`](https://github.com/lvwerra/trl/blob/main/examples/ppo_trainer/sentiment_tuning.py) | Same as the notebook, but easier to use to use in multi-GPU setup with any architecture. | x |
|
||||
Here's an overview of the notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples):
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
| File | Description |
|
||||
|------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------|
|
||||
| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb) | This script shows how to use the `PPOTrainer` to fine-tune a sentiment analysis model using IMDB dataset |
|
||||
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
|
||||
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook.
|
||||
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
pip install trl
|
||||
#optional: wandb
|
||||
pip install wandb
|
||||
# 1. run directly
|
||||
python examples/scripts/ppo.py
|
||||
# 2. run via `accelerate` (recommended), enabling more features (e.g., multiple GPUs, deepspeed)
|
||||
accelerate config # will prompt you to define the training configuration
|
||||
accelerate launch examples/scripts/ppo.py # launches training
|
||||
# 3. get help text and documentation
|
||||
python examples/scripts/ppo.py --help
|
||||
# 4. configure logging with wandb and, say, mini_batch_size=1 and gradient_accumulation_steps=16
|
||||
python examples/scripts/ppo.py --log_with wandb --mini_batch_size 1 --gradient_accumulation_steps 16
|
||||
```
|
||||
|
||||
Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
|
||||
|
||||
|
||||
## Launch scripts
|
||||
|
||||
The `trl` library is powered by `accelerate`. As such it is best to configure and launch trainings with the following commands:
|
||||
|
||||
```bash
|
||||
accelerate config # will prompt you to define the training configuration
|
||||
accelerate launch yourscript.py # launches training
|
||||
```
|
||||
|
||||
## Few notes on multi-GPU
|
||||
|
||||
To run in multi-GPU setup with DDP (distributed Data Parallel) change the `device_map` value to `device_map={"": Accelerator().process_index}` and make sure to run your script with `accelerate launch yourscript.py`. If you want to apply naive pipeline parallelism you can use `device_map="auto"`.
|
||||
To run in multi-GPU setup with DDP (distributed Data Parallel) change the `device_map` value to `device_map={"": Accelerator().process_index}` and make sure to run your script with `accelerate launch yourscript.py`. If you want to apply naive pipeline parallelism you can use `device_map="auto"`.
|
||||
|
||||
|
||||
## Benchmarks
|
||||
|
||||
Below are some benchmark results for `examples/scripts/ppo.py`. To reproduce locally, please check out the `--command` arguments below.
|
||||
|
||||
```bash
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --log_with wandb" \
|
||||
--num-seeds 5 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
```
|
||||
|
||||

|
||||
|
||||
|
||||
|
||||
## With and without gradient accumulation
|
||||
|
||||
```bash
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_step_grad_accu --mini_batch_size 1 --gradient_accumulation_steps 128 --log_with wandb" \
|
||||
--num-seeds 5 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
```
|
||||
|
||||

|
||||
|
||||
|
||||
## Comparing different models (gpt2, gpt2-xl, falcon, llama2)
|
||||
|
||||
```bash
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_gpt2 --log_with wandb" \
|
||||
--num-seeds 5 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_gpt2xl_grad_accu --model_name gpt2-xl --mini_batch_size 16 --gradient_accumulation_steps 8 --log_with wandb" \
|
||||
--num-seeds 5 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_falcon_rw_1b --model_name tiiuae/falcon-rw-1b --log_with wandb" \
|
||||
--num-seeds 5 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
```
|
||||
|
||||

|
||||
|
||||
## With and without PEFT
|
||||
|
||||
```
|
||||
python benchmark/benchmark.py \
|
||||
--command "python examples/scripts/ppo.py --exp_name sentiment_tuning_peft --use_peft --log_with wandb" \
|
||||
--num-seeds 5 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-nodes 1 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 12 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
```
|
||||
|
||||

|
||||
|
@ -1,56 +1,62 @@
|
||||
# Supervised Fine-tuning Trainer
|
||||
# Supervised Fine-tuning Trainer
|
||||
|
||||
Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset.
|
||||
|
||||
Check out a complete flexible example inside [`examples/scripts`](https://github.com/lvwerra/trl/tree/main/examples/scripts/sft_trainer.py) folder.
|
||||
Check out a complete flexible example at [`examples/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft.py).
|
||||
Experimental support for Vision Language Models is also included in the example [`examples/scripts/vsft_llava.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/vsft_llava.py).
|
||||
|
||||
## Quickstart
|
||||
|
||||
If you have a dataset hosted on the 🤗 Hub, you can easily fine-tune your SFT model using [`SFTTrainer`] from TRL. Let us assume your dataset is `imdb`, the text you want to predict is inside the `text` field of the dataset, and you want to fine-tune the `facebook/opt-350m` model.
|
||||
If you have a dataset hosted on the 🤗 Hub, you can easily fine-tune your SFT model using [`SFTTrainer`] from TRL. Let us assume your dataset is `imdb`, the text you want to predict is inside the `text` field of the dataset, and you want to fine-tune the `facebook/opt-350m` model.
|
||||
The following code-snippet takes care of all the data pre-processing and training for you:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
dataset = load_dataset("imdb", split="train")
|
||||
|
||||
sft_config = SFTConfig(
|
||||
dataset_text_field="text",
|
||||
max_seq_length=512,
|
||||
output_dir="/tmp",
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
"facebook/opt-350m",
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=512,
|
||||
args=sft_config,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
Make sure to pass a correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
|
||||
Make sure to pass the correct value for `max_seq_length` as the default value will be set to `min(tokenizer.model_max_length, 1024)`.
|
||||
|
||||
You can also construct a model outside of the trainer and pass it as follows:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
dataset = load_dataset("imdb", split="train")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
|
||||
|
||||
sft_config = SFTConfig(output_dir="/tmp")
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=512,
|
||||
args=sft_config,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
The above snippets will use the default training arguments from the [`transformers.TrainingArguments`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) class. If you want to modify that, make sure to create your own `TrainingArguments` object and pass it to the [`SFTTrainer`] constructor as it is done on the [`supervised_finetuning.py` script](https://github.com/lvwerra/trl/blob/main/examples/stack_llama/scripts/supervised_finetuning.py) on the stack-llama example.
|
||||
The above snippets will use the default training arguments from the [`SFTConfig`] class. If you want to modify the defaults pass in your modification to the `SFTConfig` constructor and pass them to the trainer via the `args` argument.
|
||||
|
||||
## Advanced usage
|
||||
|
||||
### Train on completions only
|
||||
### Train on completions only
|
||||
|
||||
You can use the `DataCollatorForCompletionOnlyLM` to train your model on the generated prompts only. Note that this works only in the case when `packing=False`.
|
||||
To instantiate that collator for instruction data, pass a response template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on completions only on the CodeAlpaca dataset:
|
||||
@ -58,7 +64,7 @@ To instantiate that collator for instruction data, pass a response template and
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
|
||||
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
|
||||
|
||||
dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")
|
||||
|
||||
@ -78,11 +84,12 @@ collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenize
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
train_dataset=dataset,
|
||||
args=SFTConfig(output_dir="/tmp"),
|
||||
formatting_func=formatting_prompts_func,
|
||||
data_collator=collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
To instantiate that collator for assistant style conversation data, pass a response template, an instruction template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on assistant completions only on the Open Assistant Guanaco dataset:
|
||||
@ -90,7 +97,7 @@ To instantiate that collator for assistant style conversation data, pass a respo
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
|
||||
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
|
||||
|
||||
dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
|
||||
|
||||
@ -103,17 +110,130 @@ collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_temp
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
args=SFTConfig(
|
||||
output_dir="/tmp",
|
||||
dataset_text_field = "text",
|
||||
),
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="text",
|
||||
data_collator=collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Make sure to have a `pad_token_id` which is different from `eos_token_id` which can result in the model not properly predicting EOS (End of Sentence) tokens during generation.
|
||||
|
||||
#### Using token_ids directly for `response_template`
|
||||
|
||||
Some tokenizers like Llama 2 (`meta-llama/Llama-2-XXb-hf`) tokenize sequences differently depending on whether they have context or not. For example:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
|
||||
def print_tokens_with_ids(txt):
|
||||
tokens = tokenizer.tokenize(txt, add_special_tokens=False)
|
||||
token_ids = tokenizer.encode(txt, add_special_tokens=False)
|
||||
print(list(zip(tokens, token_ids)))
|
||||
|
||||
prompt = """### User: Hello\n\n### Assistant: Hi, how can I help you?"""
|
||||
print_tokens_with_ids(prompt) # [..., ('▁Hello', 15043), ('<0x0A>', 13), ('<0x0A>', 13), ('##', 2277), ('#', 29937), ('▁Ass', 4007), ('istant', 22137), (':', 29901), ...]
|
||||
|
||||
response_template = "### Assistant:"
|
||||
print_tokens_with_ids(response_template) # [('▁###', 835), ('▁Ass', 4007), ('istant', 22137), (':', 29901)]
|
||||
```
|
||||
|
||||
In this case, and due to lack of context in `response_template`, the same string ("### Assistant:") is tokenized differently:
|
||||
|
||||
- Text (with context): `[2277, 29937, 4007, 22137, 29901]`
|
||||
- `response_template` (without context): `[835, 4007, 22137, 29901]`
|
||||
|
||||
This will lead to an error when the `DataCollatorForCompletionOnlyLM` does not find the `response_template` in the dataset example text:
|
||||
|
||||
```
|
||||
RuntimeError: Could not find response key [835, 4007, 22137, 29901] in token IDs tensor([ 1, 835, ...])
|
||||
```
|
||||
|
||||
|
||||
To solve this, you can tokenize the `response_template` with the same context as in the dataset, truncate it as needed and pass the `token_ids` directly to the `response_template` argument of the `DataCollatorForCompletionOnlyLM` class. For example:
|
||||
|
||||
```python
|
||||
response_template_with_context = "\n### Assistant:" # We added context here: "\n". This is enough for this tokenizer
|
||||
response_template_ids = tokenizer.encode(response_template_with_context, add_special_tokens=False)[2:] # Now we have it like in the dataset texts: `[2277, 29937, 4007, 22137, 29901]`
|
||||
|
||||
data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)
|
||||
```
|
||||
|
||||
### Add Special Tokens for Chat Format
|
||||
|
||||
Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment.
|
||||
The [`setup_chat_format`] function in `trl` easily sets up a model and tokenizer for conversational AI tasks. This function:
|
||||
- Adds special tokens to the tokenizer, e.g. `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
|
||||
- Resizes the model’s embedding layer to accommodate the new tokens.
|
||||
- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI.
|
||||
- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g. 64. If you want to see more formats being supported in the future, please open a GitHub issue on [trl](https://github.com/huggingface/trl)
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import setup_chat_format
|
||||
|
||||
# Load model and tokenizer
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
|
||||
# Set up the chat format with default 'chatml' format
|
||||
model, tokenizer = setup_chat_format(model, tokenizer)
|
||||
|
||||
```
|
||||
|
||||
With our model and tokenizer set up, we can now fine-tune our model on a conversational dataset. Below is an example of how a dataset can be formatted for fine-tuning.
|
||||
|
||||
### Dataset format support
|
||||
|
||||
The [`SFTTrainer`] supports popular dataset formats. This allows you to pass the dataset to the trainer without any pre-processing directly. The following formats are supported:
|
||||
* conversational format
|
||||
```json
|
||||
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "..."}]}
|
||||
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "..."}]}
|
||||
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "How far is the Moon from Earth?"}, {"role": "assistant", "content": "..."}]}
|
||||
```
|
||||
* instruction format
|
||||
```json
|
||||
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
|
||||
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
|
||||
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
|
||||
```
|
||||
|
||||
If your dataset uses one of the above formats, you can directly pass it to the trainer without pre-processing. The [`SFTTrainer`] will then format the dataset for you using the defined format from the model's tokenizer with the [apply_chat_template](https://huggingface.co/docs/transformers/main/en/chat_templating#templates-for-chat-models) method.
|
||||
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
...
|
||||
|
||||
# load jsonl dataset
|
||||
dataset = load_dataset("json", data_files="path/to/dataset.jsonl", split="train")
|
||||
# load dataset from the HuggingFace Hub
|
||||
dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train")
|
||||
|
||||
...
|
||||
|
||||
sft_config = SFTConfig(packing=True)
|
||||
trainer = SFTTrainer(
|
||||
"facebook/opt-350m",
|
||||
args=sft_config,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
```
|
||||
|
||||
If the dataset is not in one of those format you can either preprocess the dataset to match the formatting or pass a formatting function to the SFTTrainer to do it for you. Let's have a look.
|
||||
|
||||
|
||||
### Format your input prompts
|
||||
|
||||
For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response.
|
||||
For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response.
|
||||
This allows people to format examples like [Stanford-Alpaca](https://github.com/tatsu-lab/stanford_alpaca) did as follows:
|
||||
```bash
|
||||
Below is an instruction ...
|
||||
@ -136,32 +256,34 @@ def formatting_prompts_func(example):
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
args=sft_config,
|
||||
train_dataset=dataset,
|
||||
formatting_func=formatting_prompts_func,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
To preperly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset [here](https://github.com/lvwerra/trl/pull/444#issue-1760952763)
|
||||
To properly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example of how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
|
||||
|
||||
### Packing dataset ([`ConstantLengthDataset`])
|
||||
|
||||
[`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. This is done with the [`ConstantLengthDataset`] utility class that returns constant length chunks of tokens from a stream of examples. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTTrainer`] constructor.
|
||||
[`SFTTrainer`] supports _example packing_, where multiple short examples are packed in the same input sequence to increase training efficiency. This is done with the [`ConstantLengthDataset`] utility class that returns constant length chunks of tokens from a stream of examples. To enable the usage of this dataset class, simply pass `packing=True` to the [`SFTConfig`] constructor.
|
||||
|
||||
```python
|
||||
...
|
||||
sft_config = SFTConfig(packing=True, dataset_text_field="text",)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
"facebook/opt-350m",
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="text",
|
||||
packing=True
|
||||
args=sft_config
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Note that if you use a packed dataset and if you pass `max_steps` in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing.
|
||||
Note that if you use a packed dataset and if you pass `max_steps` in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing.
|
||||
If you don't want to pack your `eval_dataset`, you can pass `eval_packing=False` to the `SFTConfig` init method.
|
||||
|
||||
#### Customize your prompts using packed dataset
|
||||
|
||||
@ -172,45 +294,50 @@ def formatting_func(example):
|
||||
text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
|
||||
return text
|
||||
|
||||
sft_config = SFTConfig(packing=True)
|
||||
trainer = SFTTrainer(
|
||||
"facebook/opt-350m",
|
||||
train_dataset=dataset,
|
||||
packing=True,
|
||||
args=sft_config,
|
||||
formatting_func=formatting_func
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
You can also customize the [`ConstantLengthDataset`] much more by directly passing the arguments to the [`SFTTrainer`] constructor. Please refer to that class' signature for more information.
|
||||
You can also customize the [`ConstantLengthDataset`] much more by directly passing the arguments to the [`SFTConfig`] constructor. Please refer to that class' signature for more information.
|
||||
|
||||
### Control over the pretrained model
|
||||
|
||||
You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTTrainer`]. For example, if you want to load a model in a different precision, analoguous to
|
||||
You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTConfig`]. For example, if you want to load a model in a different precision, analogous to
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16)
|
||||
|
||||
```python
|
||||
...
|
||||
|
||||
sft_config = SFTConfig(
|
||||
model_init_kwargs={
|
||||
"torch_dtype": "bfloat16",
|
||||
},
|
||||
output_dir="/tmp",
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
"facebook/opt-350m",
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="text",
|
||||
torch_dtype=torch.bfloat16,
|
||||
args=sft_config,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
Note that all keyword arguments of `from_pretrained()` are supported.
|
||||
Note that all keyword arguments of `from_pretrained()` are supported.
|
||||
|
||||
### Training adapters
|
||||
|
||||
We also support a tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model
|
||||
We also support tight integration with 🤗 PEFT library so that any user can conveniently train adapters and share them on the Hub instead of training the entire model
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import SFTTrainer
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
from peft import LoraConfig
|
||||
|
||||
dataset = load_dataset("imdb", split="train")
|
||||
@ -226,44 +353,18 @@ peft_config = LoraConfig(
|
||||
trainer = SFTTrainer(
|
||||
"EleutherAI/gpt-neo-125m",
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="text",
|
||||
args=SFTConfig(output_dir="/tmp"),
|
||||
peft_config=peft_config
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Note that in case of training adapters, we manually add a saving callback to automatically save the adapters only:
|
||||
```python
|
||||
class PeftSavingCallback(TrainerCallback):
|
||||
def on_save(self, args, state, control, **kwargs):
|
||||
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
|
||||
kwargs["model"].save_pretrained(checkpoint_path)
|
||||
You can also continue training your `PeftModel`. For that, first load a `PeftModel` outside `SFTTrainer` and pass it directly to the trainer without the `peft_config` argument being passed.
|
||||
|
||||
if "pytorch_model.bin" in os.listdir(checkpoint_path):
|
||||
os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))
|
||||
```
|
||||
If you want to add more callbacks, make sure to add this one as well to properly save the adapters only during training.
|
||||
```python
|
||||
...
|
||||
### Training adapters with base 8 bit models
|
||||
|
||||
callbacks = [YourCustomCallback(), PeftSavingCallback()]
|
||||
|
||||
trainer = SFTTrainer(
|
||||
"EleutherAI/gpt-neo-125m",
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="text",
|
||||
torch_dtype=torch.bfloat16,
|
||||
peft_config=peft_config,
|
||||
callbacks=callbacks
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
### Training adapters with base 8 bit models
|
||||
|
||||
For that you need to first load your 8bit model outside the Trainer and pass a `PeftConfig` to the trainer. For example:
|
||||
For that, you need to first load your 8 bit model outside the Trainer and pass a `PeftConfig` to the trainer. For example:
|
||||
|
||||
```python
|
||||
...
|
||||
@ -285,27 +386,367 @@ model = AutoModelForCausalLM.from_pretrained(
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
train_dataset=dataset,
|
||||
dataset_text_field="text",
|
||||
torch_dtype=torch.bfloat16,
|
||||
args=SFTConfig(),
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Using Flash Attention and Flash Attention 2
|
||||
|
||||
You can benefit from Flash Attention 1 & 2 using SFTTrainer out of the box with minimal changes of code.
|
||||
First, to make sure you have all the latest features from transformers, install transformers from source
|
||||
|
||||
```bash
|
||||
pip install -U git+https://github.com/huggingface/transformers.git
|
||||
```
|
||||
|
||||
Note that Flash Attention only works on GPU now and under half-precision regime (when using adapters, base model loaded in half-precision)
|
||||
Note also both features are perfectly compatible with other tools such as quantization.
|
||||
|
||||
### Using Flash-Attention 1
|
||||
|
||||
For Flash Attention 1 you can use the `BetterTransformer` API and force-dispatch the API to use Flash Attention kernel. First, install the latest optimum package:
|
||||
|
||||
```bash
|
||||
pip install -U optimum
|
||||
```
|
||||
|
||||
Once you have loaded your model, wrap the `trainer.train()` call under the `with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):` context manager:
|
||||
|
||||
```diff
|
||||
...
|
||||
|
||||
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
Note that you cannot train your model using Flash Attention 1 on an arbitrary dataset as `torch.scaled_dot_product_attention` does not support training with padding tokens if you use Flash Attention kernels. Therefore you can only use that feature with `packing=True`. If your dataset contains padding tokens, consider switching to Flash Attention 2 integration.
|
||||
|
||||
Below are some numbers you can get in terms of speedup and memory efficiency, using Flash Attention 1, on a single NVIDIA-T4 16GB.
|
||||
|
||||
| use_flash_attn_1 | model_name | max_seq_len | batch_size | time per training step |
|
||||
| ---------------- | ----------------- | ----------- | ---------- | ---------------------- |
|
||||
| x | facebook/opt-350m | 2048 | 8 | ~59.1s |
|
||||
| | facebook/opt-350m | 2048 | 8 | **OOM** |
|
||||
| x | facebook/opt-350m | 2048 | 4 | ~30.3s |
|
||||
| | facebook/opt-350m | 2048 | 4 | ~148.9s |
|
||||
|
||||
### Using Flash Attention-2
|
||||
|
||||
To use Flash Attention 2, first install the latest `flash-attn` package:
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn
|
||||
```
|
||||
|
||||
And add `attn_implementation="flash_attention_2"` when calling `from_pretrained`:
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
load_in_4bit=True,
|
||||
attn_implementation="flash_attention_2"
|
||||
)
|
||||
```
|
||||
|
||||
If you don't use quantization, make sure your model is loaded in half-precision and dispatch your model on a supported GPU device.
|
||||
After loading your model, you can either train it as it is, or attach adapters and train adapters on it in case your model is quantized.
|
||||
|
||||
In contrast to Flash Attention 1, the integration makes it possible to train your model on an arbitrary dataset that also includes padding tokens.
|
||||
|
||||
|
||||
### Using model creation utility
|
||||
|
||||
We included a utility function to create your model.
|
||||
|
||||
[[autodoc]] ModelConfig
|
||||
|
||||
```python
|
||||
from trl import ModelConfig, SFTTrainer, get_kbit_device_map, get_peft_config, get_quantization_config
|
||||
model_config = ModelConfig(
|
||||
model_name_or_path="facebook/opt-350m"
|
||||
attn_implementation=None, # or "flash_attention_2"
|
||||
)
|
||||
torch_dtype = (
|
||||
model_config.torch_dtype
|
||||
if model_config.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_config.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_config)
|
||||
model_kwargs = dict(
|
||||
revision=model_config.model_revision,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
attn_implementation=model_config.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path, **model_kwargs)
|
||||
trainer = SFTTrainer(
|
||||
...,
|
||||
model=model_config.model_name_or_path,
|
||||
peft_config=get_peft_config(model_config),
|
||||
)
|
||||
```
|
||||
|
||||
### Enhance the model's performances using NEFTune
|
||||
|
||||
NEFTune is a technique to boost the performance of chat models and was introduced by the paper ["NEFTune: Noisy Embeddings Improve Instruction Finetuning"](https://huggingface.co/papers/2310.05914) from Jain et al. it consists of adding noise to the embedding vectors during training. According to the abstract of the paper:
|
||||
|
||||
> Standard finetuning of LLaMA-2-7B using Alpaca achieves 29.79% on AlpacaEval, which rises to 64.69% using noisy embeddings. NEFTune also improves over strong baselines on modern instruction datasets. Models trained with Evol-Instruct see a 10% improvement, with ShareGPT an 8% improvement, and with OpenPlatypus an 8% improvement. Even powerful models further refined with RLHF such as LLaMA-2-Chat benefit from additional training with NEFTune.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/neft-screenshot.png">
|
||||
</div>
|
||||
|
||||
To use it in `SFTTrainer` simply pass `neftune_noise_alpha` when creating your `SFTConfig` instance. Note that to avoid any surprising behaviour, NEFTune is disabled after training to retrieve back the original behaviour of the embedding layer.
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
|
||||
dataset = load_dataset("imdb", split="train")
|
||||
|
||||
sft_config = SFTConfig(
|
||||
neftune_noise_alpha=5,
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
"facebook/opt-350m",
|
||||
train_dataset=dataset,
|
||||
args=sft_config,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
We have tested NEFTune by training `mistralai/Mistral-7B-v0.1` on the [OpenAssistant dataset](https://huggingface.co/datasets/timdettmers/openassistant-guanaco) and validated that using NEFTune led to a performance boost of ~25% on MT Bench.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-neftune-mistral-7b.png">
|
||||
</div>
|
||||
|
||||
Note however, that the amount of performance gain is _dataset dependent_ and in particular, applying NEFTune on synthetic datasets like [UltraChat](https://huggingface.co/datasets/stingning/ultrachat) typically produces smaller gains.
|
||||
|
||||
### Accelerate fine-tuning 2x using `unsloth`
|
||||
|
||||
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks on 1x A100 listed below:
|
||||
|
||||
| 1 A100 40GB | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
|
||||
| --------------- | --------- | --- | --------------------- | --------- | ------------ |
|
||||
| Code Llama 34b | Slim Orca | 1x | 1.01x | **1.94x** | -22.7% |
|
||||
| Llama-2 7b | Slim Orca | 1x | 0.96x | **1.87x** | -39.3% |
|
||||
| Mistral 7b | Slim Orca | 1x | 1.17x | **1.88x** | -65.9% |
|
||||
| Tiny Llama 1.1b | Alpaca | 1x | 1.55x | **2.74x** | -57.8% |
|
||||
|
||||
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
from unsloth import FastLanguageModel
|
||||
|
||||
max_seq_length = 2048 # Supports automatic RoPE Scaling, so choose any number
|
||||
|
||||
# Load model
|
||||
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||
model_name="unsloth/mistral-7b",
|
||||
max_seq_length=max_seq_length,
|
||||
dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
||||
load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False
|
||||
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
|
||||
)
|
||||
|
||||
# Do model patching and add fast LoRA weights
|
||||
model = FastLanguageModel.get_peft_model(
|
||||
model,
|
||||
r=16,
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
],
|
||||
lora_alpha=16,
|
||||
lora_dropout=0, # Dropout = 0 is currently optimized
|
||||
bias="none", # Bias = "none" is currently optimized
|
||||
use_gradient_checkpointing=True,
|
||||
random_state=3407,
|
||||
)
|
||||
|
||||
args = SFTConfig(
|
||||
output_dir="./output",
|
||||
max_seq_length=max_seq_length,
|
||||
dataset_text_field="text",
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
The saved model is fully compatible with Hugging Face's transformers library. Learn more about unsloth in their [official repository](https://github.com/unslothai/unsloth).
|
||||
|
||||
## Best practices
|
||||
|
||||
Pay attention to the following best practices when training a model with that trainer:
|
||||
|
||||
- [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training.
|
||||
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_int8_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
|
||||
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
|
||||
- [`SFTTrainer`] always pads by default the sequences to the `max_seq_length` argument of the [`SFTTrainer`]. If none is passed, the trainer will retrieve that value from the tokenizer. Some tokenizers do not provide a default value, so there is a check to retrieve the minimum between 2048 and that value. Make sure to check it before training.
|
||||
- For training adapters in 8bit, you might need to tweak the arguments of the `prepare_model_for_kbit_training` method from PEFT, hence we advise users to use `prepare_in_int8_kwargs` field, or create the `PeftModel` outside the [`SFTTrainer`] and pass it.
|
||||
- For a more memory-efficient training using adapters, you can load the base model in 8bit, for that simply add `load_in_8bit` argument when creating the [`SFTTrainer`], or create a base model in 8bit outside the trainer and pass it.
|
||||
- If you create a model outside the trainer, make sure to not pass to the trainer any additional keyword arguments that are relative to `from_pretrained()` method.
|
||||
|
||||
## Multi-GPU Training
|
||||
|
||||
Trainer (and thus SFTTrainer) supports multi-GPU training. If you run your script with `python script.py` it will default to using DP as the strategy, which may be [slower than expected](https://github.com/huggingface/trl/issues/1303). To use DDP (which is generally recommended, see [here](https://huggingface.co/docs/transformers/en/perf_train_gpu_many?select-gpu=Accelerate#data-parallelism) for more info) you must launch the script with `python -m torch.distributed.launch script.py` or `accelerate launch script.py`. For DDP to work you must also check the following:
|
||||
- If you're using gradient_checkpointing, add the following to the TrainingArguments: `gradient_checkpointing_kwargs={'use_reentrant':False}` (more info [here](https://github.com/huggingface/transformers/issues/26969)
|
||||
- Ensure that the model is placed on the correct device:
|
||||
```python
|
||||
from accelerate import PartialState
|
||||
device_string = PartialState().process_index
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
...
|
||||
device_map={'':device_string}
|
||||
)
|
||||
```
|
||||
|
||||
## GPTQ Conversion
|
||||
|
||||
You may experience some issues with GPTQ Quantization after completing training. Lowering `gradient_accumulation_steps` to `4` will resolve most issues during the quantization process to GPTQ format.
|
||||
|
||||
## Extending `SFTTrainer` for Vision Language Models
|
||||
|
||||
`SFTTrainer` does not inherently support vision-language data. However, we provide a guide on how to tweak the trainer to support vision-language data. Specifically, you need to use a custom data collator that is compatible with vision-language data. This guide outlines the steps to make these adjustments. For a concrete example, refer to the script [`examples/scripts/vsft_llava.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/vsft_llava.py) which demonstrates how to fine-tune the LLaVA 1.5 model on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset.
|
||||
|
||||
### Preparing the Data
|
||||
|
||||
The data format is flexible, provided it is compatible with the custom collator that we will define later. A common approach is to use conversational data. Given that the data includes both text and images, the format needs to be adjusted accordingly. Below is an example of a conversational data format involving both text and images:
|
||||
|
||||
```python
|
||||
images = ["obama.png"]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Who is this?"},
|
||||
{"type": "image"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Barack Obama"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is he famous for?"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "He is the 44th President of the United States."}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
To illustrate how this data format will be processed using the LLaVA model, you can use the following code:
|
||||
|
||||
```python
|
||||
from transformers import AutoProcessor
|
||||
|
||||
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
|
||||
print(processor.apply_chat_template(messages, tokenize=False))
|
||||
```
|
||||
|
||||
The output will be formatted as follows:
|
||||
|
||||
```txt
|
||||
Who is this? ASSISTANT: Barack Obama USER: What is he famous for? ASSISTANT: He is the 44th President of the United States.
|
||||
```
|
||||
|
||||
<iframe src="https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft/embed/viewer/default/train" frameborder="0" width="100%" height="560px"></iframe>
|
||||
|
||||
|
||||
### A custom collator for processing multi-modal data
|
||||
|
||||
Unlike the default behavior of `SFTTrainer`, processing multi-modal data is done on the fly during the data collation process. To do this, you need to define a custom collator that processes both the text and images. This collator must take a list of examples as input (see the previous section for an example of the data format) and return a batch of processed data. Below is an example of such a collator:
|
||||
|
||||
```python
|
||||
def collate_fn(examples):
|
||||
# Get the texts and images, and apply the chat template
|
||||
texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
|
||||
images = [example["images"][0] for example in examples]
|
||||
|
||||
# Tokenize the texts and process the images
|
||||
batch = processor(texts, images, return_tensors="pt", padding=True)
|
||||
|
||||
# The labels are the input_ids, and we mask the padding tokens in the loss computation
|
||||
labels = batch["input_ids"].clone()
|
||||
labels[labels == processor.tokenizer.pad_token_id] = -100
|
||||
batch["labels"] = labels
|
||||
|
||||
return batch
|
||||
```
|
||||
|
||||
We can verify that the collator works as expected by running the following code:
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft", split="train")
|
||||
examples = [dataset[0], dataset[1]] # Just two examples for the sake of the example
|
||||
collated_data = collate_fn(examples)
|
||||
print(collated_data.keys()) # dict_keys(['input_ids', 'attention_mask', 'pixel_values', 'labels'])
|
||||
```
|
||||
|
||||
### Training the vision-language model
|
||||
|
||||
Now that we have prepared the data and defined the collator, we can proceed with training the model. To ensure that the data is not processed as text-only, we need to set a couple of arguments in the `SFTConfig`, specifically `dataset_text_field` and `remove_unused_columns`. We also need to set `skip_prepare_dataset` to `True` to avoid the default processing of the dataset. Below is an example of how to set up the `SFTTrainer`.
|
||||
|
||||
```python
|
||||
args.dataset_text_field = "" # needs a dummy field
|
||||
args.remove_unused_columns = False
|
||||
args.dataset_kwargs = {"skip_prepare_dataset": True}
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=args,
|
||||
data_collator=collate_fn,
|
||||
train_dataset=train_dataset,
|
||||
tokenizer=processor.tokenizer,
|
||||
)
|
||||
```
|
||||
|
||||
A full example of training LLaVa 1.5 on the [HuggingFaceH4/llava-instruct-mix-vsft](https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft) dataset can be found in the script [`examples/scripts/vsft_llava.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/vsft_llava.py).
|
||||
|
||||
- [Experiment tracking](https://wandb.ai/huggingface/trl/runs/2b2c5l7s)
|
||||
- [Trained model](https://huggingface.co/HuggingFaceH4/sft-llava-1.5-7b-hf)
|
||||
|
||||
## SFTTrainer
|
||||
|
||||
[[autodoc]] SFTTrainer
|
||||
|
||||
## ConstantLengthDataset
|
||||
## SFTConfig
|
||||
|
||||
[[autodoc]] trainer.ConstantLengthDataset
|
||||
[[autodoc]] SFTConfig
|
||||
|
||||
## Datasets
|
||||
|
||||
In the SFTTrainer we smartly support `datasets.IterableDataset` in addition to other style datasets. This is useful if you are using large corpora that you do not want to save all to disk. The data will be tokenized and processed on the fly, even when packing is enabled.
|
||||
|
||||
Additionally, in the SFTTrainer, we support pre-tokenized datasets if they are `datasets.Dataset` or `datasets.IterableDataset`. In other words, if such a dataset has a column of `input_ids`, no further processing (tokenization or packing) will be done, and the dataset will be used as-is. This can be useful if you have pretokenized your dataset outside of this script and want to re-use it directly.
|
||||
|
||||
### ConstantLengthDataset
|
||||
|
||||
[[autodoc]] trainer.ConstantLengthDataset
|
||||
|
197
docs/source/text_environments.md
Normal file
197
docs/source/text_environments.md
Normal file
@ -0,0 +1,197 @@
|
||||
# Text Environments
|
||||
|
||||
Text environments provide a learning ground for language agents. It allows a language model to use tools to accomplish a task such as using a Python interpreter to answer math questions or using a search index for trivia questions. Having access to tools allows language models to solve tasks that would be very hard for the models itself but can be trivial for the appropriate tools. A good example is arithmetics of large numbers that become a simple copy-paste task once you have access to a calculator.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/textenv.png">
|
||||
</div>
|
||||
|
||||
Let's dive into how text environments work and start with tools!
|
||||
|
||||
## Tools
|
||||
|
||||
One of the core building blocks of text environments are tools that the model can use to solve tasks. In general tools can be any Python function that takes a string as input and returns string. The `TextEnvironment` offers two options for tools: either go with predefined tools from `transformers.Tool` or define your own function or class with `__call__` method. Let's have a look at both!
|
||||
|
||||
### `transformers.Tool`
|
||||
|
||||
Text environments fully support tools of the class `transformers.Tool`. The advantage of building tools in that framework is that they can easily be shared
|
||||
|
||||
```Python
|
||||
from transformers import load_tool
|
||||
|
||||
# simple calculator tool that runs +-/* operations
|
||||
calc_tool = load_tool("ybelkada/simple-calculator")
|
||||
|
||||
# python interpreter that executes program and returns outputs
|
||||
py_tool = load_tool("lvwerra/python-interpreter")
|
||||
|
||||
# wikipedia search index that returns best search match
|
||||
wiki_tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
|
||||
```
|
||||
|
||||
These tools are either loaded from the hub or from a local folder. Using the tool is as simple as calling them with a text query:
|
||||
|
||||
```Python
|
||||
calc_tool("1/2")
|
||||
>>> "0.5"
|
||||
```
|
||||
|
||||
Note that both input and return values are strings to enable easy usage with a language model.
|
||||
|
||||
### Custom Tools
|
||||
|
||||
The following is an example of a tool that adds two integers:
|
||||
|
||||
```Python
|
||||
def add(text):
|
||||
int_1, int_2 = text.split("+")
|
||||
result = int(int_1) + int(int_2)
|
||||
return str(result)
|
||||
|
||||
print(add("1+1"))
|
||||
>>> "2"
|
||||
```
|
||||
|
||||
We looked at basic examples such as a calculator but the principle holds for more complex tools as well such as a web search tool where you input the query and get the search results in return. Now let's look at how the model can use the tools with the call syntax.
|
||||
|
||||
### Call syntax
|
||||
|
||||
In order to have a unified way for the model to call a tool we created a simple syntax that looks as follows:
|
||||
|
||||
```python
|
||||
"<request><TOOL_NAME>QUERY<call>TOOL_RESPONSE<response>"
|
||||
```
|
||||
|
||||
There are a few special tokens involved so let's decompose it: First the model can signal that it wants to use a tool by emitting the `<request>` token. After that we want to know the name of the tool to call which is done by enclosing the tool name with `<>` brackets. Once we know which tool to call the tool query follows which is in free text form. The `<call>` tokens signifies the end of the query and stops the model generation. At this point the model output is parsed and the query sent to the tool. The environment appends the tool response to the string followed by the `<response>` token to show the end the tool output.
|
||||
|
||||
Let's look at the concrete example of the calculator and assume its name is `Calculator` (more on how the name of a tool is inferred later):
|
||||
|
||||
```python
|
||||
"<request><Calculator>1/2<call>0.5<response>"
|
||||
```
|
||||
|
||||
Finally, the episode is ended and generation stops when the model generates `<submit>` which marks the interaction as completed.
|
||||
|
||||
Now let's have a look how we can create a new text environment!
|
||||
|
||||
## Create a `TextEnvironment`
|
||||
|
||||
|
||||
```python
|
||||
prompt = """\
|
||||
What is 13-3?
|
||||
<request><SimpleCalculatorTool>13-3<call>10.0<response>
|
||||
Result=10<submit>
|
||||
"""
|
||||
|
||||
def reward_fn(result, answer):
|
||||
"""Simplified reward function returning 1 if result matches answer and 0 otherwise."""
|
||||
result_parsed = result.split("=")[1].split("<")[0]
|
||||
return int(result_parsed==answer)
|
||||
|
||||
text_env = TextEnvironemnt(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
tools= {"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")},
|
||||
reward_fn=exact_match_reward,
|
||||
prompt=prompt,
|
||||
max_turns=1
|
||||
max_tool_response=100
|
||||
generation_kwargs={"do_sample": "true"}
|
||||
)
|
||||
```
|
||||
|
||||
Let's decompose the settings:
|
||||
|
||||
| Argument | Description |
|
||||
|:-------------------|:----------------|
|
||||
| `model` | Language model to interact with the environment and generate requests. |
|
||||
| `tokenizer` | Tokenizer of language model handling tokenization of strings. |
|
||||
| `tools` | `list` of `dict` of tools. If former the name of the tool is inferred from class name and otherwise it's the keys of the dictionary.|
|
||||
| `reward_fn` | A function that takes a string as input and returns. Can have extra arguments that are passed to `.run()` such as ground truth.|
|
||||
| `prompt` | Prompt to prepend to every task. Usually a few examples to demonstrate to the model how to use the tools in a few-shot fashion. |
|
||||
| `max_turns` | Maximum number of interactions between model and tools before episode ends.|
|
||||
| `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.|
|
||||
| `max_length` | The maximum number of tokens to allow in an episode. |
|
||||
| `generation_kwargs`| Generation settings used by the language model. |
|
||||
|
||||
You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools!
|
||||
|
||||
|
||||
## Run an Episode
|
||||
|
||||
To run a set of queries through the text environment one can simply use the `run` method.
|
||||
|
||||
```python
|
||||
queries = ["What is 1/2?"]
|
||||
answers = ["0.5"]
|
||||
|
||||
queries, responses, masks, rewards, histories = text_env.run(queries, answers=answers)
|
||||
```
|
||||
|
||||
This will execute the model/tool feedback loop for each query until either no tool is called anymore, the maximum number of turns is reached or to maximum number of tokens in an episode is exceeded. The extra `kwargs` (e.g. `answers=answers` above) passed to `run` will be passed on to the reward function.
|
||||
|
||||
There are five objects that are returned by `run`:
|
||||
|
||||
- `queries`: a list of the tokenized queries
|
||||
- `responses`: all tokens that have been generated withing the environment including model and tool tokens
|
||||
- `masks`: mask that indicates which tokens have been generated by the model and which tokens are generated by the tool
|
||||
- `rewards`: a list of reward for each query/response
|
||||
- `histories`: list of `TextHistory` objects, which are useful objects containing all the above and also the text equivalents
|
||||
|
||||
The masks are crucial for training as we don't want to optimize tokens that the model has not generated which are tokens produced by the tools.
|
||||
|
||||
Next, we'll train a PPO step with the generated responses!
|
||||
|
||||
|
||||
### Train
|
||||
Training on episodes from the `TextEnvironment` is straight forward and simply requires forwarding all the returned variables except the `TextHistory` objects to the `step` method:
|
||||
|
||||
```python
|
||||
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
|
||||
```
|
||||
|
||||
## `TextHistory`
|
||||
|
||||
The `TextHistory` object stores the interactions between the model and the text environment. It stores tokens and text generated in each turn and their source in each turn (model or system) as well as rewards. Let's go through the class attributes and methods.
|
||||
|
||||
### Attributes
|
||||
|
||||
The following table summarises the available attributes of the `TextEnvironment` class:
|
||||
|
||||
| Attribute | Description |
|
||||
|:-------------------|:----------------|
|
||||
| `text` | The full string of the text generated in the text environment with both model and system generated text. |
|
||||
| `text_spans` | A list of tuples with the spans for each model or system generated text segment. |
|
||||
| `system_spans` | A list of boolean values indicating if the segment is model or system generated. |
|
||||
| `tokens` | All tokens generated in text environment with both model and system generated tokens. |
|
||||
| `token_spans` | Similar to `text_spans` the `token_spans` indicate the boundaries of model andsystem generated tokens. |
|
||||
| `token_masks` | The token masks can be used to ignore system generated tokens by masking them. |
|
||||
| `completed` | Indicates if the interaction with the environment has completed. |
|
||||
| `truncated` | Indicates if the interaction with the environment has completed because max length was reached. |
|
||||
|
||||
With these attributes you can reconstruct every interaction of the model with the `TextEnvironment`. The `TextHistory` also lets you visualize the text history. Let's have a look!
|
||||
|
||||
### Visualization
|
||||
|
||||
When the model interacts inside the `TextEnvironment` it can be useful to visualize and separate which parts of the text outputs were generated by the model and which parts come from the system and tools. For that purpose there are the two methods [`TextHistory.show_text`] and [`TextHistory.show_tokens`]. They print the text and tokens respectively and highlight the various segments using the [`rich` libray](https://github.com/Textualize/rich) (make sure to install it before using these methods).
|
||||
|
||||
You can see that the prompt is highlighted in gray, whereas system segments such as query and tool responses are highlighted in green. All segments generated by the model are highlighted in blue and in addition to the pure text output the reward is displayed as additional text in plum. Here an example of `show_text`:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/textenv_show_text.png" width=600>
|
||||
</div>
|
||||
|
||||
Sometimes there can be tricky tokenization related issues that are hidden when showing the decoded text. Thus `TextHistory` also offers an option to display the same highlighting on the tokens directly with `show_tokens`:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/textenv_show_tokens.png" width=800>
|
||||
</div>
|
||||
|
||||
Note that you can turn on the colour legend by passing `show_legend=True`.
|
||||
|
||||
## API Documentation
|
||||
|
||||
[[autodoc]] TextEnvironment
|
||||
|
||||
[[autodoc]] TextHistory
|
@ -1,9 +1,50 @@
|
||||
# Trainer
|
||||
|
||||
At TRL we support PPO (Proximal Policy Optimisation) with an implementation that largely follows the structure introduced in the paper "Fine-Tuning Language Models from Human Preferences" by D. Ziegler et al. [[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].
|
||||
At TRL we support PPO (Proximal Policy Optimisation) with an implementation that largely follows the structure introduced in the paper "Fine-Tuning Language Models from Human Preferences" by D. Ziegler et al. [[paper](https://huggingface.co/papers/1909.08593), [code](https://github.com/openai/lm-human-preferences)].
|
||||
The Trainer and model classes are largely inspired from `transformers.Trainer` and `transformers.AutoModel` classes and adapted for RL.
|
||||
We also support a `RewardTrainer` that can be used to train a reward model.
|
||||
|
||||
|
||||
## CPOConfig
|
||||
|
||||
[[autodoc]] CPOConfig
|
||||
|
||||
## CPOTrainer
|
||||
|
||||
[[autodoc]] CPOTrainer
|
||||
|
||||
## DDPOConfig
|
||||
|
||||
[[autodoc]] DDPOConfig
|
||||
|
||||
## DDPOTrainer
|
||||
|
||||
[[autodoc]] DDPOTrainer
|
||||
|
||||
## DPOTrainer
|
||||
|
||||
[[autodoc]] DPOTrainer
|
||||
|
||||
## IterativeSFTTrainer
|
||||
|
||||
[[autodoc]] IterativeSFTTrainer
|
||||
|
||||
## KTOConfig
|
||||
|
||||
[[autodoc]] KTOConfig
|
||||
|
||||
## KTOTrainer
|
||||
|
||||
[[autodoc]] KTOTrainer
|
||||
|
||||
## ORPOConfig
|
||||
|
||||
[[autodoc]] ORPOConfig
|
||||
|
||||
## ORPOTrainer
|
||||
|
||||
[[autodoc]] ORPOTrainer
|
||||
|
||||
## PPOConfig
|
||||
|
||||
[[autodoc]] PPOConfig
|
||||
@ -12,6 +53,10 @@ We also support a `RewardTrainer` that can be used to train a reward model.
|
||||
|
||||
[[autodoc]] PPOTrainer
|
||||
|
||||
## RewardConfig
|
||||
|
||||
[[autodoc]] RewardConfig
|
||||
|
||||
## RewardTrainer
|
||||
|
||||
[[autodoc]] RewardTrainer
|
||||
@ -20,10 +65,6 @@ We also support a `RewardTrainer` that can be used to train a reward model.
|
||||
|
||||
[[autodoc]] SFTTrainer
|
||||
|
||||
## DPOTrainer
|
||||
|
||||
[[autodoc]] DPOTrainer
|
||||
|
||||
## set_seed
|
||||
|
||||
[[autodoc]] set_seed
|
||||
|
58
docs/source/use_model.md
Normal file
58
docs/source/use_model.md
Normal file
@ -0,0 +1,58 @@
|
||||
# Use model after training
|
||||
|
||||
Once you have trained a model using either the SFTTrainer, PPOTrainer, or DPOTrainer, you will have a fine-tuned model that can be used for text generation. In this section, we'll walk through the process of loading the fine-tuned model and generating text. If you need to run an inference server with the trained model, you can explore libraries such as [`text-generation-inference`](https://github.com/huggingface/text-generation-inference).
|
||||
|
||||
## Load and Generate
|
||||
|
||||
If you have fine-tuned a model fully, meaning without the use of PEFT you can simply load it like any other language model in transformers. E.g. the value head that was trained during the PPO training is no longer needed and if you load the model with the original transformer class it will be ignored:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub
|
||||
device = "cpu" # or "cuda" if you have a GPU
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
|
||||
inputs = tokenizer.encode("This movie was really", return_tensors="pt").to(device)
|
||||
outputs = model.generate(inputs)
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
```
|
||||
|
||||
Alternatively you can also use the pipeline:
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub
|
||||
pipe = pipeline("text-generation", model=model_name_or_path)
|
||||
print(pipe("This movie was really")[0]["generated_text"])
|
||||
```
|
||||
|
||||
## Use Adapters PEFT
|
||||
|
||||
```python
|
||||
from peft import PeftConfig, PeftModel
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
base_model_name = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub"
|
||||
adapter_model_name = "path/to/my/adapter"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(base_model_name)
|
||||
model = PeftModel.from_pretrained(model, adapter_model_name)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
||||
```
|
||||
|
||||
You can also merge the adapters into the base model so you can use the model like a normal transformers model, however the checkpoint will be significantly bigger:
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(base_model_name)
|
||||
model = PeftModel.from_pretrained(model, adapter_model_name)
|
||||
|
||||
model = model.merge_and_unload()
|
||||
model.save_pretrained("merged_adapters")
|
||||
```
|
||||
|
||||
Once you have the model loaded and either merged the adapters or keep them separately on top you can run generation as with a normal model outlined above.
|
@ -52,7 +52,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
||||
load_in_8bit=True,
|
||||
device_map={"": Accelerator().local_process_index}
|
||||
)
|
||||
model = prepare_model_for_int8_training(model)
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
# add LoRA to model
|
||||
lora_config = LoraConfig(
|
||||
@ -157,4 +157,4 @@ for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
|
||||
ppo_trainer.log_stats(stats, batch, rewards)
|
||||
```
|
||||
|
||||
For the rest of the details adn evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama).
|
||||
For the rest of the details and evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama).
|
||||
|
@ -1,50 +1,3 @@
|
||||
# Examples
|
||||
|
||||
_The best place to learn about examples in TRL is our [docs page](https://huggingface.co/docs/trl/index)!_
|
||||
|
||||
## Introduction
|
||||
|
||||
The examples should work in any of the following settings (with the same script):
|
||||
- single CPU or single GPU
|
||||
- multi GPUS (using PyTorch distributed mode)
|
||||
- multi GPUS (using DeepSpeed ZeRO-Offload stages 1 & 2)
|
||||
- fp16 (mixed-precision) or fp32 (normal precision)
|
||||
|
||||
To run it in each of these various modes, first initialize the accelerate
|
||||
configuration with `accelerate config`
|
||||
|
||||
**NOTE for to train with a 8-bit model a more recent version of**
|
||||
transformers is required, for example:
|
||||
|
||||
```bash
|
||||
pip install --upgrade bitsandbytes datasets accelerate loralib
|
||||
pip install git+https://github.com/huggingface/peft.git
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install trl
|
||||
#optional: wandb
|
||||
pip install wandb
|
||||
```
|
||||
Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks.
|
||||
You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
|
||||
|
||||
## Accelerate Config
|
||||
For all the examples, you'll need to generate an `Accelerate` config with:
|
||||
|
||||
```shell
|
||||
accelerate config # will prompt you to define the training configuration
|
||||
```
|
||||
|
||||
Then, it is encouraged to launch jobs with `accelerate launch`!
|
||||
|
||||
## Categories
|
||||
The examples are currently split over the following categories:
|
||||
|
||||
**1: [ppo_trainer](https://github.com/lvwerra/trl/tree/main/examples/scripts/sentiment_tuning.py)**: Learn about different ways of using PPOTrainer
|
||||
**2: [sft_trainer](https://github.com/lvwerra/trl/tree/main/examples/scripts/sft_trainer.py)**: Learn about how to leverage `SFTTrainer` for supervised fine-tuning your pretrained language models easily.
|
||||
**3: [reward_modeling](https://github.com/lvwerra/trl/tree/main/examples/scripts/reward_trainer.py)**: Learn about how to use `RewardTrainer` to easily train your own reward model to use it for your RLHF pipeline.
|
||||
**4: [research_projects](https://github.com/lvwerra/trl/tree/main/examples/research_projects)**: Check out that folder to check the scripts used for some research projects that used TRL (LM de-toxification, Stack-Llama, etc.)
|
||||
**5: [notebooks](https://github.com/lvwerra/trl/tree/main/examples/notebooks)**: Check out that folder to check some applications of TRL features directly on a Jupyter notebook. This includes running sentiment tuning and sentiment control on a notebook, but also how to use "Best of N sampling" strategy using TRL.
|
||||
Please check out https://huggingface.co/docs/trl/example_overview for documentation on our examples.
|
20
examples/accelerate_configs/deepspeed_zero1.yaml
Normal file
20
examples/accelerate_configs/deepspeed_zero1.yaml
Normal file
@ -0,0 +1,20 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
gradient_accumulation_steps: 1
|
||||
zero3_init_flag: false
|
||||
zero_stage: 1
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
21
examples/accelerate_configs/deepspeed_zero2.yaml
Normal file
21
examples/accelerate_configs/deepspeed_zero2.yaml
Normal file
@ -0,0 +1,21 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
22
examples/accelerate_configs/deepspeed_zero3.yaml
Normal file
22
examples/accelerate_configs/deepspeed_zero3.yaml
Normal file
@ -0,0 +1,22 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
deepspeed_multinode_launcher: standard
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: true
|
||||
zero3_save_16bit_model: true
|
||||
zero_stage: 3
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
25
examples/accelerate_configs/fsdp_qlora.yaml
Normal file
25
examples/accelerate_configs/fsdp_qlora.yaml
Normal file
@ -0,0 +1,25 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: FSDP
|
||||
downcast_bf16: 'no'
|
||||
fsdp_config:
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_backward_prefetch: BACKWARD_PRE
|
||||
fsdp_cpu_ram_efficient_loading: true
|
||||
fsdp_forward_prefetch: false
|
||||
fsdp_offload_params: true
|
||||
fsdp_sharding_strategy: FULL_SHARD
|
||||
fsdp_state_dict_type: SHARDED_STATE_DICT
|
||||
fsdp_sync_module_states: true
|
||||
fsdp_use_orig_params: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
16
examples/accelerate_configs/multi_gpu.yaml
Normal file
16
examples/accelerate_configs/multi_gpu.yaml
Normal file
@ -0,0 +1,16 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
16
examples/accelerate_configs/single_gpu.yaml
Normal file
16
examples/accelerate_configs/single_gpu.yaml
Normal file
@ -0,0 +1,16 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: "NO"
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'bf16'
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
20
examples/cli_configs/example_config.yaml
Normal file
20
examples/cli_configs/example_config.yaml
Normal file
@ -0,0 +1,20 @@
|
||||
# This is an example configuration file of TRL CLI, you can use it for
|
||||
# SFT like that: `trl sft --config config.yaml --output_dir test-sft`
|
||||
# The YAML file supports environment variables by adding an `env` field
|
||||
# as below
|
||||
|
||||
# env:
|
||||
# CUDA_VISIBLE_DEVICES: 0
|
||||
|
||||
model_name_or_path:
|
||||
trl-internal-testing/tiny-random-LlamaForCausalLM
|
||||
dataset_name:
|
||||
imdb
|
||||
dataset_text_field:
|
||||
text
|
||||
report_to:
|
||||
none
|
||||
learning_rate:
|
||||
0.0001
|
||||
lr_scheduler_type:
|
||||
cosine
|
122
examples/datasets/anthropic_hh.py
Normal file
122
examples/datasets/anthropic_hh.py
Normal file
@ -0,0 +1,122 @@
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
"""
|
||||
# debug
|
||||
python -i examples/datasets/anthropic_hh.py --debug --push_to_hub
|
||||
# actual push
|
||||
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity trl-internal-testing
|
||||
"""
|
||||
|
||||
|
||||
api = HfApi()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
|
||||
hf_entity: Optional[str] = field(default=None, metadata={"help": "The Hugging Face entity to use"})
|
||||
hf_repo_id: Optional[str] = field(
|
||||
default="hh-rlhf-helpful-base-trl-style", metadata={"help": "The Hugging Face repository ID"}
|
||||
)
|
||||
revision: Optional[str] = field(default="0.1.0", metadata={"help": "The revision of the repository"})
|
||||
update_main_revision: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Update the main revision of the repository"}
|
||||
)
|
||||
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the dataset to the Hugging Face Hub"})
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
default=None, metadata={"help": "The number of workers to use for dataset processing"}
|
||||
)
|
||||
|
||||
|
||||
# GPT-4 generated 😄 Define a function to process the input and extract the dialogue into structured format
|
||||
def extract_dialogue(input_text):
|
||||
# Split the input by lines and initialize variables
|
||||
lines = input_text.strip().split("\n\n")
|
||||
dialogue_list = []
|
||||
|
||||
# Iterate through each line and extract the dialogue
|
||||
for line in lines:
|
||||
# Check if the line starts with "Human" or "Assistant" and split accordingly
|
||||
if line.startswith("Human:"):
|
||||
role = "user"
|
||||
content = line.replace("Human: ", "").strip()
|
||||
elif line.startswith("Assistant:"):
|
||||
role = "assistant"
|
||||
content = line.replace("Assistant: ", "").strip()
|
||||
else:
|
||||
# If the line doesn't start with "Human" or "Assistant", it's part of the previous message's content
|
||||
# Append it to the last message's content
|
||||
dialogue_list[-1]["content"] += "\n\n" + line.strip()
|
||||
continue
|
||||
|
||||
# Append the extracted dialogue piece to the list
|
||||
dialogue_list.append({"role": role, "content": content})
|
||||
|
||||
return dialogue_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
|
||||
if args.hf_entity is None:
|
||||
args.hf_entity = api.whoami()["name"]
|
||||
full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
|
||||
ds = load_dataset("Anthropic/hh-rlhf", data_dir="helpful-base")
|
||||
if args.debug:
|
||||
for key in ds:
|
||||
ds[key] = ds[key].select(range(50))
|
||||
|
||||
def process(row):
|
||||
row["chosen"] = extract_dialogue(row["chosen"])
|
||||
row["rejected"] = extract_dialogue(row["rejected"])
|
||||
row["prompt"] = row["chosen"][0]["content"]
|
||||
return row
|
||||
|
||||
ds = ds.map(process, num_proc=args.dataset_num_proc)
|
||||
if args.push_to_hub:
|
||||
revisions = ["main"] if args.update_main_revision else []
|
||||
revisions.append(args.revision)
|
||||
|
||||
# get the commnad used to run the script
|
||||
run_command = " ".join(["python"] + sys.argv)
|
||||
|
||||
for revision in revisions:
|
||||
ds.push_to_hub(full_repo_id, revision=revision)
|
||||
repo_full_url = f"https://huggingface.co/datasets/{full_repo_id}/tree/{revision}"
|
||||
|
||||
# get the name of the current file
|
||||
file_name = __file__.split("/")[-1]
|
||||
api.upload_file(
|
||||
path_or_fileobj=__file__,
|
||||
path_in_repo=file_name,
|
||||
revision=revision,
|
||||
repo_id=full_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
sft_card = RepoCard.load(
|
||||
full_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
sft_card.text = f"""\
|
||||
# TRL's Anthropic HH Dataset
|
||||
|
||||
We preprocess the dataset using our standard `prompt, chosen, rejected` format.
|
||||
|
||||
|
||||
## Reproduce this dataset
|
||||
|
||||
1. Download the `{file_name}` from the {repo_full_url}.
|
||||
2. Run `{run_command}`
|
||||
"""
|
||||
sft_card.push_to_hub(
|
||||
full_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
188
examples/datasets/sentiment_descriptiveness.py
Normal file
188
examples/datasets/sentiment_descriptiveness.py
Normal file
@ -0,0 +1,188 @@
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import Dataset, DatasetDict
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from transformers import AutoTokenizer, HfArgumentParser
|
||||
|
||||
|
||||
"""
|
||||
# debug
|
||||
python -i examples/datasets/sentiment_descriptiveness.py --push_to_hub
|
||||
# actual push
|
||||
python examples/datasets/sentiment_descriptiveness.py \
|
||||
--hf_repo_id sentiment-trl-style \
|
||||
--task sentiment \
|
||||
--push_to_hub \
|
||||
--hf_entity trl-internal-testing
|
||||
python examples/datasets/sentiment_descriptiveness.py \
|
||||
--hf_repo_id descriptiveness-trl-style \
|
||||
--task descriptiveness \
|
||||
--push_to_hub \
|
||||
--hf_entity trl-internal-testing
|
||||
"""
|
||||
|
||||
|
||||
api = HfApi()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
|
||||
hf_entity: Optional[str] = field(default=None, metadata={"help": "The Hugging Face entity to use"})
|
||||
hf_repo_id: Optional[str] = field(
|
||||
default="sentiment-trl-style", metadata={"help": "The Hugging Face repository ID"}
|
||||
)
|
||||
revision: Optional[str] = field(default="0.1.0", metadata={"help": "The revision of the repository"})
|
||||
update_main_revision: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Update the main revision of the repository"}
|
||||
)
|
||||
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the dataset to the Hugging Face Hub"})
|
||||
task: str = field(default="sentiment", metadata={"help": "The task of the dataset"})
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
default=None, metadata={"help": "The number of workers to use to tokenize the data"}
|
||||
)
|
||||
|
||||
|
||||
task_to_filename = {
|
||||
"sentiment": "sentiment/offline_5k.json",
|
||||
"descriptiveness": "descriptiveness/offline_5k.json",
|
||||
}
|
||||
|
||||
|
||||
def deduplicate_query(ds):
|
||||
query = set()
|
||||
ranges = []
|
||||
for i in range(len(ds)):
|
||||
query_str = str(ds[i]["query"])
|
||||
if query_str not in query:
|
||||
query.add(query_str)
|
||||
ranges.append(i)
|
||||
return ds.select(ranges)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
|
||||
if args.hf_entity is None:
|
||||
args.hf_entity = api.whoami()["name"]
|
||||
full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
|
||||
|
||||
model_name = "gpt2"
|
||||
dataset_tokenizer = AutoTokenizer.from_pretrained("gpt2") # of the dataset
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
json = hf_hub_download(
|
||||
repo_id="vwxyzjn/lm-human-preferences",
|
||||
repo_type="dataset",
|
||||
filename=task_to_filename[args.task],
|
||||
)
|
||||
|
||||
MAGIC_TRAIN_NUMBER = 4992 # taken from https://github.com/openai/lm-human-preferences/blob/cbfd210bb8b08f6bc5c26878c10984b90f516c66/launch.py#L70
|
||||
individual_ds = Dataset.from_json(json)
|
||||
individual_ds = deduplicate_query(individual_ds)
|
||||
ds = DatasetDict(
|
||||
{
|
||||
"train": individual_ds.select(range(MAGIC_TRAIN_NUMBER)),
|
||||
"test": individual_ds.select(range(MAGIC_TRAIN_NUMBER, len(individual_ds))),
|
||||
}
|
||||
)
|
||||
|
||||
MAX_DEBUG_SAMPLES = 50
|
||||
if args.debug:
|
||||
for key in ds:
|
||||
ds[key] = ds[key].select(range(min(MAX_DEBUG_SAMPLES, len(ds[key]))))
|
||||
|
||||
# columns are `['sample2', 'sample3', 'sample0', 'query', 'sample1', 'best']`
|
||||
NUM_SAMPLES = 4
|
||||
|
||||
# edge cases handling: remove the cases where all samples are the same
|
||||
def filter(row):
|
||||
best_idx = row["best"]
|
||||
chosen_sample = row[f"sample{best_idx}"]
|
||||
if all(chosen_sample == row[f"sample{j}"] for j in range(NUM_SAMPLES)):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
print("=== Before filtering ===", ds)
|
||||
ds = ds.filter(filter, num_proc=args.dataset_num_proc)
|
||||
print("=== After filtering ===", ds)
|
||||
|
||||
# here we simply take the preferred sample as the chosen one and the first non-preferred sample as the rejected one
|
||||
def process(row):
|
||||
for j in range(NUM_SAMPLES):
|
||||
row[f"sample{j}"] = dataset_tokenizer.batch_decode(row[f"sample{j}"])
|
||||
row["prompt"] = dataset_tokenizer.batch_decode(row["query"])
|
||||
row["prompt"] = [item.strip() for item in row["prompt"]]
|
||||
row["chosen"] = []
|
||||
row["rejected"] = []
|
||||
for i in range(len(row["best"])):
|
||||
best_idx = row["best"][i]
|
||||
chosen_sample = row[f"sample{best_idx}"][i].strip()
|
||||
row["chosen"].append(
|
||||
[
|
||||
{"role": "user", "content": row["prompt"][i].strip()},
|
||||
{"role": "assistant", "content": chosen_sample},
|
||||
]
|
||||
)
|
||||
# find the first rejected sample which is different from the chosen one
|
||||
rejected_idx = -1
|
||||
for k in range(4):
|
||||
if k != best_idx and row[f"sample{k}"][i].strip() != chosen_sample:
|
||||
rejected_idx = k
|
||||
break
|
||||
rejected_sample = row[f"sample{rejected_idx}"][i].strip()
|
||||
assert rejected_idx != -1, "No rejected sample found! This should not happen!"
|
||||
row["rejected"].append(
|
||||
[
|
||||
{"role": "user", "content": row["prompt"][i].strip()},
|
||||
{"role": "assistant", "content": rejected_sample},
|
||||
]
|
||||
)
|
||||
assert chosen_sample != rejected_sample
|
||||
return row
|
||||
|
||||
ds = ds.map(process, batched=True, num_proc=args.dataset_num_proc)
|
||||
for key in ds: # reorder columns
|
||||
ds[key] = ds[key].select_columns(["prompt", "chosen", "rejected"])
|
||||
if args.push_to_hub:
|
||||
revisions = ["main"] if args.update_main_revision else []
|
||||
revisions.append(args.revision)
|
||||
|
||||
# get the commnad used to run the script
|
||||
run_command = " ".join(["python"] + sys.argv)
|
||||
|
||||
for revision in revisions:
|
||||
ds.push_to_hub(full_repo_id, revision=revision)
|
||||
repo_full_url = f"https://huggingface.co/datasets/{full_repo_id}/tree/{revision}"
|
||||
|
||||
# get the name of the current file
|
||||
file_name = __file__.split("/")[-1]
|
||||
api.upload_file(
|
||||
path_or_fileobj=__file__,
|
||||
path_in_repo=file_name,
|
||||
revision=revision,
|
||||
repo_id=full_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
sft_card = RepoCard.load(
|
||||
full_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
sft_card.text = f"""\
|
||||
# TRL's Preference Dataset: {args.task}
|
||||
The dataset comes from https://huggingface.co/papers/1909.08593, one of the earliest RLHF work from OpenAI.
|
||||
We preprocess the dataset using our standard `prompt, chosen, rejected` format.
|
||||
## Reproduce this dataset
|
||||
1. Download the `{file_name}` from the {repo_full_url}.
|
||||
2. Run `{run_command}`
|
||||
"""
|
||||
sft_card.push_to_hub(
|
||||
full_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
185
examples/datasets/tldr_preference.py
Normal file
185
examples/datasets/tldr_preference.py
Normal file
@ -0,0 +1,185 @@
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
"""
|
||||
# debug
|
||||
python -i examples/datasets/tldr_preference.py --debug --push_to_hub
|
||||
# actual push
|
||||
python examples/datasets/tldr_preference.py --push_to_hub --hf_entity trl-internal-testing
|
||||
"""
|
||||
|
||||
|
||||
api = HfApi()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
|
||||
hf_entity: Optional[str] = field(default=None, metadata={"help": "The Hugging Face entity to use"})
|
||||
hf_repo_id: Optional[str] = field(
|
||||
default="tldr-preference-trl-style", metadata={"help": "The Hugging Face repository ID"}
|
||||
)
|
||||
sft_hf_repo_id: Optional[str] = field(
|
||||
default="tldr-preference-sft-trl-style", metadata={"help": "The Hugging Face repository ID"}
|
||||
)
|
||||
revision: Optional[str] = field(default="0.1.0", metadata={"help": "The revision of the repository"})
|
||||
update_main_revision: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Update the main revision of the repository"}
|
||||
)
|
||||
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the dataset to the Hugging Face Hub"})
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
default=None, metadata={"help": "The number of workers to use to tokenize the data"}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
|
||||
if args.hf_entity is None:
|
||||
args.hf_entity = api.whoami()["name"]
|
||||
full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
|
||||
full_sft_repo_id = f"{args.hf_entity}/{args.sft_hf_repo_id}"
|
||||
|
||||
################
|
||||
# Preference dataset
|
||||
################
|
||||
ds = load_dataset("openai/summarize_from_feedback", "comparisons")
|
||||
if args.debug:
|
||||
for key in ds:
|
||||
ds[key] = ds[key].select(range(50))
|
||||
cnndm_batches = ["batch0_cnndm", "cnndm0", "cnndm2"]
|
||||
if not args.debug:
|
||||
ds["validation_cnndm"] = ds["validation"].filter(
|
||||
lambda x: x["batch"] in cnndm_batches, num_proc=args.dataset_num_proc
|
||||
)
|
||||
ds["validation"] = ds["validation"].filter(
|
||||
lambda x: x["batch"] not in cnndm_batches, num_proc=args.dataset_num_proc
|
||||
)
|
||||
|
||||
tldr_format_str = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:"
|
||||
cnndm_format_str = "Article:\n{article}\n\nTL;DR:"
|
||||
|
||||
def process(row):
|
||||
format_str = cnndm_format_str if row["batch"] in cnndm_batches else tldr_format_str
|
||||
row["prompt"] = format_str.format(**row["info"])
|
||||
choice = row["choice"]
|
||||
# need to remove the leading space
|
||||
chosen = row["summaries"][choice]["text"].strip()
|
||||
rejected = row["summaries"][1 - choice]["text"].strip()
|
||||
row["chosen"] = [{"role": "user", "content": row["prompt"]}, {"role": "assistant", "content": chosen}]
|
||||
row["rejected"] = [{"role": "user", "content": row["prompt"]}, {"role": "assistant", "content": rejected}]
|
||||
return row
|
||||
|
||||
ds = ds.map(process, num_proc=args.dataset_num_proc)
|
||||
for key in ds: # reorder columns
|
||||
ds[key] = ds[key].select_columns(
|
||||
["prompt", "chosen", "rejected", "info", "summaries", "choice", "worker", "batch", "split", "extra"]
|
||||
)
|
||||
if args.push_to_hub:
|
||||
revisions = ["main"] if args.update_main_revision else []
|
||||
revisions.append(args.revision)
|
||||
|
||||
# get the commnad used to run the script
|
||||
run_command = " ".join(["python"] + sys.argv)
|
||||
|
||||
for revision in revisions:
|
||||
ds.push_to_hub(full_repo_id, revision=revision)
|
||||
repo_full_url = f"https://huggingface.co/datasets/{full_repo_id}/tree/{revision}"
|
||||
|
||||
# get the name of the current file
|
||||
file_name = __file__.split("/")[-1]
|
||||
api.upload_file(
|
||||
path_or_fileobj=__file__,
|
||||
path_in_repo=file_name,
|
||||
revision=revision,
|
||||
repo_id=full_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
preference_card = RepoCard.load(
|
||||
full_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
preference_card.text = f"""\
|
||||
# TRL's TL;DR Preference Dataset
|
||||
|
||||
We preprocess the dataset using our standard `prompt, chosen, rejected` format.
|
||||
|
||||
## Source of the dataset
|
||||
|
||||
We take the dataset from https://huggingface.co/datasets/openai/summarize_from_feedback.
|
||||
|
||||
## Reproduce this dataset
|
||||
|
||||
1. Download the `{file_name}` from the {repo_full_url}.
|
||||
2. Run `{run_command}`
|
||||
"""
|
||||
preference_card.push_to_hub(
|
||||
full_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
################
|
||||
# SFT dataset
|
||||
################
|
||||
sft_ds = load_dataset("vwxyzjn/summarize_from_feedback_tldr_3_filtered")
|
||||
if args.debug:
|
||||
for key in sft_ds:
|
||||
sft_ds[key] = sft_ds[key].select(range(50))
|
||||
|
||||
def sft_process(row):
|
||||
row["prompt"] = tldr_format_str.format(**row)
|
||||
row["messages"] = [
|
||||
{"role": "user", "content": row["prompt"]},
|
||||
{"role": "assistant", "content": row["summary"]},
|
||||
]
|
||||
return row
|
||||
|
||||
sft_ds = sft_ds.map(sft_process, num_proc=args.dataset_num_proc)
|
||||
for key in sft_ds: # reorder columns
|
||||
sft_ds[key] = sft_ds[key].select_columns(["prompt", "messages", "id", "subreddit", "title", "post", "summary"])
|
||||
if args.push_to_hub:
|
||||
revisions = ["main"] if args.update_main_revision else []
|
||||
revisions.append(args.revision)
|
||||
|
||||
# get the commnad used to run the script
|
||||
run_command = " ".join(["python"] + sys.argv)
|
||||
|
||||
for revision in revisions:
|
||||
sft_ds.push_to_hub(full_sft_repo_id, revision=revision)
|
||||
repo_full_url = f"https://huggingface.co/datasets/{full_sft_repo_id}/tree/{revision}"
|
||||
|
||||
# get the name of the current file
|
||||
file_name = __file__.split("/")[-1]
|
||||
api.upload_file(
|
||||
path_or_fileobj=__file__,
|
||||
path_in_repo=file_name,
|
||||
revision=revision,
|
||||
repo_id=full_sft_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
sft_card = RepoCard.load(
|
||||
full_sft_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
sft_card.text = f"""\
|
||||
# TRL's TL;DR SFT Dataset
|
||||
|
||||
We preprocess the dataset using our standard `prompt, messages` format.
|
||||
|
||||
## Source of the dataset
|
||||
|
||||
We take the dataset from https://huggingface.co/datasets/vwxyzjn/summarize_from_feedback_tldr_3_filtered.
|
||||
|
||||
## Reproduce this dataset
|
||||
|
||||
1. Download the `{file_name}` from the {repo_full_url}.
|
||||
2. Run `{run_command}`
|
||||
"""
|
42
examples/datasets/tokenize_ds.py
Normal file
42
examples/datasets/tokenize_ds.py
Normal file
@ -0,0 +1,42 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer, HfArgumentParser
|
||||
|
||||
|
||||
"""
|
||||
python -i examples/datasets/tokenize_ds.py --debug --model HuggingFaceH4/zephyr-7b-beta
|
||||
python -i examples/datasets/tokenize_ds.py --debug --model gpt2
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
|
||||
dataset: str = field(
|
||||
default="trl-internal-testing/hh-rlhf-helpful-base-trl-style", metadata={"help": "The dataset to load"}
|
||||
)
|
||||
model: str = field(default="gpt2", metadata={"help": "The model to use for tokenization"})
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
default=None, metadata={"help": "The number of workers to use to tokenize the data"}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
|
||||
ds = load_dataset(args.dataset)
|
||||
if args.debug:
|
||||
for key in ds:
|
||||
ds[key] = ds[key].select(range(50))
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{{ eos_token }}"
|
||||
|
||||
def process(row):
|
||||
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
|
||||
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
|
||||
return row
|
||||
|
||||
ds = ds.map(process, num_proc=args.dataset_num_proc)
|
||||
print(ds["train"][0]["chosen"])
|
137
examples/dpo.py
137
examples/dpo.py
@ -1,137 +0,0 @@
|
||||
# 0. imports
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments
|
||||
|
||||
from trl import DPOTrainer
|
||||
|
||||
|
||||
# Define and parse arguments.
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The arguments for the DPO training script.
|
||||
"""
|
||||
|
||||
# data parameters
|
||||
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
|
||||
|
||||
# training parameters
|
||||
model_name_or_path: Optional[str] = field(default="gpt2", metadata={"help": "the model name"})
|
||||
learning_rate: Optional[float] = field(default=1e-3, metadata={"help": "optimizer learning rate"})
|
||||
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "batch size per device"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=1, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
max_length: Optional[int] = field(default=512, metadata={"help": "max length of each sample"})
|
||||
max_prompt_length: Optional[int] = field(default=128, metadata={"help": "max length of each sample's prompt"})
|
||||
label_pad_token_id: Optional[int] = field(default=-100, metadata={"help": "label for non response tokens"})
|
||||
max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
|
||||
# instrumentation
|
||||
sanity_check: Optional[bool] = field(default=True, metadata={"help": "only train on 1000 samples"})
|
||||
report_to: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
|
||||
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
|
||||
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
|
||||
},
|
||||
)
|
||||
# debug argument for distributed training
|
||||
ignore_bias_buffers: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
|
||||
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def extract_anthropic_prompt(prompt_and_response):
|
||||
"""Extract the anthropic prompt from a prompt and response pair."""
|
||||
search_term = "\n\nAssistant:"
|
||||
search_term_idx = prompt_and_response.rfind(search_term)
|
||||
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
|
||||
return prompt_and_response[: search_term_idx + len(search_term)]
|
||||
|
||||
|
||||
def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
|
||||
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.
|
||||
|
||||
The dataset is converted to a dictionary with the following structure:
|
||||
{
|
||||
'prompt': List[str],
|
||||
'chosen': List[str],
|
||||
'rejected': List[str],
|
||||
}
|
||||
|
||||
Prompts should be structured as follows:
|
||||
\n\nHuman: <prompt>\n\nAssistant:
|
||||
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
|
||||
"""
|
||||
dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
|
||||
if sanity_check:
|
||||
dataset = dataset.select(range(min(len(dataset), 1000)))
|
||||
|
||||
def split_prompt_and_responses(sample) -> Dict[str, str]:
|
||||
prompt = extract_anthropic_prompt(sample["chosen"])
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"chosen": sample["chosen"][len(prompt) :],
|
||||
"rejected": sample["rejected"][len(prompt) :],
|
||||
}
|
||||
|
||||
return dataset.map(split_prompt_and_responses)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path)
|
||||
|
||||
if script_args.ignore_bias_buffers:
|
||||
# torch distributed hack
|
||||
model._ddp_params_and_buffers_to_ignore = [
|
||||
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
|
||||
]
|
||||
|
||||
model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# 2. Load the Anthropic Helpful-Harmless dataset
|
||||
train_dataset = get_hh("train", sanity_check=script_args.sanity_check)
|
||||
|
||||
# 3. Load evaluation dataset
|
||||
eval_dataset = get_hh("test", sanity_check=script_args.sanity_check)
|
||||
|
||||
# 4. initialize training arguments:
|
||||
training_args = TrainingArguments(
|
||||
per_device_train_batch_size=script_args.per_device_train_batch_size,
|
||||
max_steps=script_args.max_steps,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
learning_rate=script_args.learning_rate,
|
||||
evaluation_strategy="steps",
|
||||
output_dir="./test",
|
||||
report_to=script_args.report_to,
|
||||
)
|
||||
|
||||
# 5. initialize the DPO trainer
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
model_ref,
|
||||
args=training_args,
|
||||
beta=script_args.beta,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
# 6. train
|
||||
dpo_trainer.train()
|
@ -7,14 +7,14 @@ from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
|
||||
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# 2. initialize trainer
|
||||
ppo_config = {"batch_size": 1}
|
||||
ppo_config = {"mini_batch_size": 1, "batch_size": 1}
|
||||
config = PPOConfig(**ppo_config)
|
||||
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)
|
||||
|
||||
# 3. encode a query
|
||||
query_txt = "This morning I went to the "
|
||||
@ -29,7 +29,7 @@ generation_kwargs = {
|
||||
"pad_token_id": tokenizer.eos_token_id,
|
||||
"max_new_tokens": 20,
|
||||
}
|
||||
response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs)
|
||||
response_tensor = ppo_trainer.generate(list(query_tensor), return_prompt=False, **generation_kwargs)
|
||||
response_txt = tokenizer.decode(response_tensor[0])
|
||||
|
||||
# 5. define a reward for response
|
||||
|
@ -2,6 +2,6 @@
|
||||
|
||||
This directory contains a collection of Jupyter notebooks that demonstrate how to use the TRL library in different applications.
|
||||
|
||||
- [`best_of_n.ipynb`](https://github.com/lvwerra/trl/tree/main/examples/notebooks/best_of_n.ipynb): This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO.
|
||||
- [`gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb): This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook.
|
||||
- [`gpt2-control.ipynb`](https://github.com/lvwerra/trl/tree/main/examples/notebooks/gpt2-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control exampel on a jupyter notebook.
|
||||
- [`best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb): This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO.
|
||||
- [`gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb): This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook.
|
||||
- [`gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook.
|
||||
|
@ -121,7 +121,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can see that we load a GPT2 model called `gpt2_imdb`. This model was additionally fine-tuned on the IMDB dataset for 1 epoch with the huggingface [script](https://github.com/huggingface/transformers/blob/master/examples/run_language_modeling.py) (no special settings). The other parameters are mostly taken from the original paper [\"Fine-Tuning Language Models from Human Preferences\"](\n",
|
||||
"https://arxiv.org/pdf/1909.08593.pdf). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models). The following code should automatically download the models."
|
||||
"https://huggingface.co/papers/1909.08593). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models). The following code should automatically download the models."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -152,7 +152,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)\n",
|
||||
"gpt2_model_ref = create_reference_model(gpt2_model)\n",
|
||||
"gpt2_ref_model = create_reference_model(gpt2_model)\n",
|
||||
"gpt2_tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n",
|
||||
"\n",
|
||||
"gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token"
|
||||
@ -353,7 +353,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"ppo_trainer = PPOTrainer(config, gpt2_model, gpt2_model_ref, gpt2_tokenizer, dataset, data_collator=collator)"
|
||||
"ppo_trainer = PPOTrainer(config, gpt2_model, gpt2_ref_model, gpt2_tokenizer, dataset, data_collator=collator)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -847,7 +847,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
"version": "3.9.12"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
@ -92,7 +92,7 @@
|
||||
" log_with=\"wandb\",\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"sent_kwargs = {\"return_all_scores\": True, \"function_to_apply\": \"none\", \"batch_size\": 16}"
|
||||
"sent_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\", \"batch_size\": 16}"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -110,8 +110,8 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"You can see that we load a GPT2 model called `gpt2_imdb`. This model was additionally fine-tuned on the IMDB dataset for 1 epoch with the huggingface [script](https://github.com/huggingface/transformers/blob/master/examples/run_language_modeling.py) (no special settings). The other parameters are mostly taken from the original paper [\"Fine-Tuning Language Models from Human Preferences\"](\n",
|
||||
"https://arxiv.org/pdf/1909.08593.pdf). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models). The following code should automatically download the models."
|
||||
"You can see that we load a GPT2 model called `gpt2_imdb`. This model was additionally fine-tuned on the IMDB dataset for 1 epoch with the huggingface [script](https://github.com/huggingface/transformers/blob/main/examples/legacy/run_language_modeling.py) (no special settings). The other parameters are mostly taken from the original paper [\"Fine-Tuning Language Models from Human Preferences\"](\n",
|
||||
"https://huggingface.co/papers/1909.08593). This model as well as the BERT model is available in the Huggingface model zoo [here](https://huggingface.co/models). The following code should automatically download the models."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -134,16 +134,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Reusing dataset imdb (/home/leandro/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)\n",
|
||||
"Loading cached processed dataset at /home/leandro/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-ff455473e884c6a3.arrow\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def build_dataset(config, dataset_name=\"imdb\", input_min_text_length=2, input_max_text_length=8):\n",
|
||||
" \"\"\"\n",
|
||||
@ -270,8 +261,8 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[[{'label': 'NEGATIVE', 'score': 2.335048198699951},\n",
|
||||
" {'label': 'POSITIVE', 'score': -2.726576566696167}]]"
|
||||
"[{'label': 'NEGATIVE', 'score': 2.335048198699951},\n",
|
||||
" {'label': 'POSITIVE', 'score': -2.726576328277588}]"
|
||||
]
|
||||
},
|
||||
"execution_count": null,
|
||||
@ -292,8 +283,8 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[[{'label': 'NEGATIVE', 'score': -2.2947897911071777},\n",
|
||||
" {'label': 'POSITIVE', 'score': 2.557039737701416}]]"
|
||||
"[{'label': 'POSITIVE', 'score': 2.557040214538574},\n",
|
||||
" {'label': 'NEGATIVE', 'score': -2.294790267944336}]"
|
||||
]
|
||||
},
|
||||
"execution_count": null,
|
||||
@ -371,7 +362,7 @@
|
||||
"}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):\n",
|
||||
"for epoch, batch in enumerate(tqdm(ppo_trainer.dataloader)):\n",
|
||||
" query_tensors = batch[\"input_ids\"]\n",
|
||||
"\n",
|
||||
" #### Get response from gpt2\n",
|
||||
@ -379,14 +370,16 @@
|
||||
" for query in query_tensors:\n",
|
||||
" gen_len = output_length_sampler()\n",
|
||||
" generation_kwargs[\"max_new_tokens\"] = gen_len\n",
|
||||
" response = ppo_trainer.generate(query, **generation_kwargs)\n",
|
||||
" response_tensors.append(response.squeeze()[-gen_len:])\n",
|
||||
" query_response = ppo_trainer.generate(query, **generation_kwargs).squeeze()\n",
|
||||
" response_len = len(query_response) - len(query)\n",
|
||||
" response_tensors.append(query_response[-response_len:])\n",
|
||||
" batch[\"response\"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]\n",
|
||||
"\n",
|
||||
" #### Compute sentiment score\n",
|
||||
" texts = [q + r for q, r in zip(batch[\"query\"], batch[\"response\"])]\n",
|
||||
" pipe_outputs = sentiment_pipe(texts, **sent_kwargs)\n",
|
||||
" rewards = [torch.tensor(output[1][\"score\"]) for output in pipe_outputs]\n",
|
||||
" positive_scores = [item[\"score\"] for output in pipe_outputs for item in output if item[\"label\"] == \"POSITIVE\"]\n",
|
||||
" rewards = [torch.tensor(score) for score in positive_scores]\n",
|
||||
"\n",
|
||||
" #### Run PPO step\n",
|
||||
" stats = ppo_trainer.step(query_tensors, response_tensors, rewards)\n",
|
||||
@ -398,7 +391,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Training progress\n",
|
||||
"If you are tracking the training progress with Weights&Biases you should see a plot similar to the one below. Check out the interactive sample report on wandb.ai: [link](https://app.wandb.ai/lvwerra/trl-showcase/runs/1jtvxb1m/).\n",
|
||||
"If you are tracking the training progress with Weights&Biases you should see a plot similar to the one below. Check out the interactive sample report on wandb.ai: [link](https://wandb.ai/huggingface/trl/runs/w9l3110g).\n",
|
||||
"\n",
|
||||
"<div style=\"text-align: center\">\n",
|
||||
"<img src='https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/gpt2_tuning_progress.png' width='800'>\n",
|
||||
@ -416,7 +409,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Model inspection\n",
|
||||
"Let's inspect some examples from the IMDB dataset. We can use `model_ref` to compare the tuned model `model` against the model before optimisation."
|
||||
"Let's inspect some examples from the IMDB dataset. We can use `ref_model` to compare the tuned model `model` against the model before optimisation."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -424,14 +417,6 @@
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/leandro/miniconda3/envs/trl/lib/python3.9/site-packages/transformers/pipelines/base.py:1075: UserWarning: You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
@ -463,131 +448,131 @@
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>Oh dear,</td>\n",
|
||||
" <td>what are I saying?! I fast-forwarded through</td>\n",
|
||||
" <td>I must say that I are hanging my head on this</td>\n",
|
||||
" <td>-0.858954</td>\n",
|
||||
" <td>-1.007609</td>\n",
|
||||
" <td>I rented Zero Day</td>\n",
|
||||
" <td>4 for my sister. To my surprise, the Wii caug...</td>\n",
|
||||
" <td>. It is a pleasure. It is a huge leap 68 years...</td>\n",
|
||||
" <td>1.736068</td>\n",
|
||||
" <td>2.423731</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>I've seen</td>\n",
|
||||
" <td>it, as well.<br</td>\n",
|
||||
" <td>three million dialogue throughout, and</td>\n",
|
||||
" <td>1.996807</td>\n",
|
||||
" <td>2.240883</td>\n",
|
||||
" <td>The only</td>\n",
|
||||
" <td>distro of her</td>\n",
|
||||
" <td>special compliments is the</td>\n",
|
||||
" <td>0.150852</td>\n",
|
||||
" <td>0.190159</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>Hi:<br /><br</td>\n",
|
||||
" <td>/>This movie is a turkey though when it comes to</td>\n",
|
||||
" <td>/>I also like that movie. It's so funny</td>\n",
|
||||
" <td>-0.438191</td>\n",
|
||||
" <td>2.415630</td>\n",
|
||||
" <td>I've read a few</td>\n",
|
||||
" <td>news reports about Mr. Mueller's activities b...</td>\n",
|
||||
" <td>novels and I never watch this. It has a reall...</td>\n",
|
||||
" <td>-1.417962</td>\n",
|
||||
" <td>2.831814</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>I'm a writer</td>\n",
|
||||
" <td>and I'm not going to be asked to</td>\n",
|
||||
" <td>, not a screenwriter. I've written</td>\n",
|
||||
" <td>-0.655991</td>\n",
|
||||
" <td>-0.724324</td>\n",
|
||||
" <td>This is the second British Rank film</td>\n",
|
||||
" <td>, and I wouldn't be surprised anymore if it</td>\n",
|
||||
" <td>that I have enjoyed, achieving it in both the</td>\n",
|
||||
" <td>0.835876</td>\n",
|
||||
" <td>2.205628</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>If you</td>\n",
|
||||
" <td>absolutely love sensitive romance, the plot a...</td>\n",
|
||||
" <td>are looking at the cinematography, the acting,</td>\n",
|
||||
" <td>2.221309</td>\n",
|
||||
" <td>0.148751</td>\n",
|
||||
" <td>A classic</td>\n",
|
||||
" <td>classic.<br /><br />And only this one will ha...</td>\n",
|
||||
" <td>. It's a movie with a fine cast. As the beginn...</td>\n",
|
||||
" <td>2.113075</td>\n",
|
||||
" <td>2.739168</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5</th>\n",
|
||||
" <td>OMG this</td>\n",
|
||||
" <td>casting cast. Obi cult breezy, this is</td>\n",
|
||||
" <td>movie was totally wonderful, I it was the ide...</td>\n",
|
||||
" <td>-1.533139</td>\n",
|
||||
" <td>2.590190</td>\n",
|
||||
" <td>This has to be one of the</td>\n",
|
||||
" <td>worst with the differences being that for the</td>\n",
|
||||
" <td>best thriller films I've seen in recent</td>\n",
|
||||
" <td>-2.705339</td>\n",
|
||||
" <td>2.730615</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>6</th>\n",
|
||||
" <td>It's</td>\n",
|
||||
" <td>unrealistic; the guy who was supposed to be E...</td>\n",
|
||||
" <td>a very good film. It reminds us about over</td>\n",
|
||||
" <td>-2.097017</td>\n",
|
||||
" <td>2.835831</td>\n",
|
||||
" <td>Happy Go Lovely is a waste</td>\n",
|
||||
" <td>. Not only are extremely</td>\n",
|
||||
" <td>of time, giving a</td>\n",
|
||||
" <td>-2.429504</td>\n",
|
||||
" <td>-2.934672</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>7</th>\n",
|
||||
" <td>There is a really</td>\n",
|
||||
" <td>awful laptop game!<br /><br />I used to</td>\n",
|
||||
" <td>interesting story that set us the journey. Th...</td>\n",
|
||||
" <td>-2.341743</td>\n",
|
||||
" <td>2.282939</td>\n",
|
||||
" <td>Wow, I just</td>\n",
|
||||
" <td>can't make fun of it</td>\n",
|
||||
" <td>feek it! This show</td>\n",
|
||||
" <td>-2.201666</td>\n",
|
||||
" <td>-0.106085</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>8</th>\n",
|
||||
" <td>This is</td>\n",
|
||||
" <td>my favorite part about</td>\n",
|
||||
" <td>a well thought well</td>\n",
|
||||
" <td>2.554794</td>\n",
|
||||
" <td>2.734139</td>\n",
|
||||
" <td>This movie makes several mistakes.</td>\n",
|
||||
" <td>Despite being a great comedic diversion it es...</td>\n",
|
||||
" <td>It's cool, wonderful - it held me into a very ...</td>\n",
|
||||
" <td>-1.232380</td>\n",
|
||||
" <td>2.707638</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>9</th>\n",
|
||||
" <td>Wasn't</td>\n",
|
||||
" <td>Wasn't it clichéd?<|endoftext|></td>\n",
|
||||
" <td>anyone else interested in this movie? It's a ...</td>\n",
|
||||
" <td>-1.790802</td>\n",
|
||||
" <td>2.631960</td>\n",
|
||||
" <td>Branagh and Fish</td>\n",
|
||||
" <td>burne, Drake is played</td>\n",
|
||||
" <td>is a great show. Beautiful</td>\n",
|
||||
" <td>0.776819</td>\n",
|
||||
" <td>2.808996</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>10</th>\n",
|
||||
" <td>This film is another of director Tim</td>\n",
|
||||
" <td>Burton's masterpieces</td>\n",
|
||||
" <td>Curry's best bombs</td>\n",
|
||||
" <td>2.622917</td>\n",
|
||||
" <td>2.544106</td>\n",
|
||||
" <td>I might have given this movie a</td>\n",
|
||||
" <td>rating of *11 when I heard that!), but it was...</td>\n",
|
||||
" <td>great performance. It was truly a great movie...</td>\n",
|
||||
" <td>0.276380</td>\n",
|
||||
" <td>2.743328</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>11</th>\n",
|
||||
" <td>I thought this movie</td>\n",
|
||||
" <td>was excellent. I actually laughed 6 times and...</td>\n",
|
||||
" <td>was perfect, and I believe it's almost overlo...</td>\n",
|
||||
" <td>2.548022</td>\n",
|
||||
" <td>2.601913</td>\n",
|
||||
" <td>Really, really bad</td>\n",
|
||||
" <td>with feel like there is no end to the</td>\n",
|
||||
" <td>. This movie is incredibly good, with the</td>\n",
|
||||
" <td>-2.639503</td>\n",
|
||||
" <td>-1.568827</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>12</th>\n",
|
||||
" <td>This early John Wayne</td>\n",
|
||||
" <td>films looked like an abandoned police beating</td>\n",
|
||||
" <td>film is a realistic portrayal of what</td>\n",
|
||||
" <td>-1.742279</td>\n",
|
||||
" <td>2.609762</td>\n",
|
||||
" <td>What another reviewer called lack of</td>\n",
|
||||
" <td>judgment, connecting into her own harsh obser...</td>\n",
|
||||
" <td>suspense. Rogers and Rooney rate this as exce...</td>\n",
|
||||
" <td>-1.079707</td>\n",
|
||||
" <td>2.696888</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>13</th>\n",
|
||||
" <td>I was</td>\n",
|
||||
" <td>given an experience-a big one, almost 25</td>\n",
|
||||
" <td>very happy with all the reflections and this ...</td>\n",
|
||||
" <td>2.250709</td>\n",
|
||||
" <td>2.558540</td>\n",
|
||||
" <td>This is simply one</td>\n",
|
||||
" <td>more problem of Steve</td>\n",
|
||||
" <td>of the best choice</td>\n",
|
||||
" <td>-1.445436</td>\n",
|
||||
" <td>2.662699</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>14</th>\n",
|
||||
" <td>Embarrassingly, I</td>\n",
|
||||
" <td>am more at a strict conformity after getting ...</td>\n",
|
||||
" <td>had never seen a movie before. There was one ...</td>\n",
|
||||
" <td>-2.021666</td>\n",
|
||||
" <td>-1.803383</td>\n",
|
||||
" <td>\"Perhaps we can arrange a meet</td>\n",
|
||||
" <td>-and-greet.<br /><br />Teleg</td>\n",
|
||||
" <td>with spent, classic music and dance, and come...</td>\n",
|
||||
" <td>0.258479</td>\n",
|
||||
" <td>1.876662</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>15</th>\n",
|
||||
" <td>I am a fan</td>\n",
|
||||
" <td>of living on simple islands, and we have visi...</td>\n",
|
||||
" <td>of many things and learned how to appreciate ...</td>\n",
|
||||
" <td>1.791297</td>\n",
|
||||
" <td>2.324461</td>\n",
|
||||
" <td>Richard Willaims is</td>\n",
|
||||
" <td>nice enough; the little black guy plays quite</td>\n",
|
||||
" <td>beautifully hands on in his own spin, and</td>\n",
|
||||
" <td>0.796508</td>\n",
|
||||
" <td>2.820259</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
@ -595,76 +580,76 @@
|
||||
],
|
||||
"text/plain": [
|
||||
" query \\\n",
|
||||
"0 Oh dear, \n",
|
||||
"1 I've seen \n",
|
||||
"2 Hi:<br /><br \n",
|
||||
"3 I'm a writer \n",
|
||||
"4 If you \n",
|
||||
"5 OMG this \n",
|
||||
"6 It's \n",
|
||||
"7 There is a really \n",
|
||||
"8 This is \n",
|
||||
"9 Wasn't \n",
|
||||
"10 This film is another of director Tim \n",
|
||||
"11 I thought this movie \n",
|
||||
"12 This early John Wayne \n",
|
||||
"13 I was \n",
|
||||
"14 Embarrassingly, I \n",
|
||||
"15 I am a fan \n",
|
||||
"0 I rented Zero Day \n",
|
||||
"1 The only \n",
|
||||
"2 I've read a few \n",
|
||||
"3 This is the second British Rank film \n",
|
||||
"4 A classic \n",
|
||||
"5 This has to be one of the \n",
|
||||
"6 Happy Go Lovely is a waste \n",
|
||||
"7 Wow, I just \n",
|
||||
"8 This movie makes several mistakes. \n",
|
||||
"9 Branagh and Fish \n",
|
||||
"10 I might have given this movie a \n",
|
||||
"11 Really, really bad \n",
|
||||
"12 What another reviewer called lack of \n",
|
||||
"13 This is simply one \n",
|
||||
"14 \"Perhaps we can arrange a meet \n",
|
||||
"15 Richard Willaims is \n",
|
||||
"\n",
|
||||
" response (before) \\\n",
|
||||
"0 what are I saying?! I fast-forwarded through \n",
|
||||
"1 it, as well.<br \n",
|
||||
"2 />This movie is a turkey though when it comes to \n",
|
||||
"3 and I'm not going to be asked to \n",
|
||||
"4 absolutely love sensitive romance, the plot a... \n",
|
||||
"5 casting cast. Obi cult breezy, this is \n",
|
||||
"6 unrealistic; the guy who was supposed to be E... \n",
|
||||
"7 awful laptop game!<br /><br />I used to \n",
|
||||
"8 my favorite part about \n",
|
||||
"9 Wasn't it clichéd?<|endoftext|> \n",
|
||||
"10 Burton's masterpieces \n",
|
||||
"11 was excellent. I actually laughed 6 times and... \n",
|
||||
"12 films looked like an abandoned police beating \n",
|
||||
"13 given an experience-a big one, almost 25 \n",
|
||||
"14 am more at a strict conformity after getting ... \n",
|
||||
"15 of living on simple islands, and we have visi... \n",
|
||||
"0 4 for my sister. To my surprise, the Wii caug... \n",
|
||||
"1 distro of her \n",
|
||||
"2 news reports about Mr. Mueller's activities b... \n",
|
||||
"3 , and I wouldn't be surprised anymore if it \n",
|
||||
"4 classic.<br /><br />And only this one will ha... \n",
|
||||
"5 worst with the differences being that for the \n",
|
||||
"6 . Not only are extremely \n",
|
||||
"7 can't make fun of it \n",
|
||||
"8 Despite being a great comedic diversion it es... \n",
|
||||
"9 burne, Drake is played \n",
|
||||
"10 rating of *11 when I heard that!), but it was... \n",
|
||||
"11 with feel like there is no end to the \n",
|
||||
"12 judgment, connecting into her own harsh obser... \n",
|
||||
"13 more problem of Steve \n",
|
||||
"14 -and-greet.<br /><br />Teleg \n",
|
||||
"15 nice enough; the little black guy plays quite \n",
|
||||
"\n",
|
||||
" response (after) rewards (before) \\\n",
|
||||
"0 I must say that I are hanging my head on this -0.858954 \n",
|
||||
"1 three million dialogue throughout, and 1.996807 \n",
|
||||
"2 />I also like that movie. It's so funny -0.438191 \n",
|
||||
"3 , not a screenwriter. I've written -0.655991 \n",
|
||||
"4 are looking at the cinematography, the acting, 2.221309 \n",
|
||||
"5 movie was totally wonderful, I it was the ide... -1.533139 \n",
|
||||
"6 a very good film. It reminds us about over -2.097017 \n",
|
||||
"7 interesting story that set us the journey. Th... -2.341743 \n",
|
||||
"8 a well thought well 2.554794 \n",
|
||||
"9 anyone else interested in this movie? It's a ... -1.790802 \n",
|
||||
"10 Curry's best bombs 2.622917 \n",
|
||||
"11 was perfect, and I believe it's almost overlo... 2.548022 \n",
|
||||
"12 film is a realistic portrayal of what -1.742279 \n",
|
||||
"13 very happy with all the reflections and this ... 2.250709 \n",
|
||||
"14 had never seen a movie before. There was one ... -2.021666 \n",
|
||||
"15 of many things and learned how to appreciate ... 1.791297 \n",
|
||||
"0 . It is a pleasure. It is a huge leap 68 years... 1.736068 \n",
|
||||
"1 special compliments is the 0.150852 \n",
|
||||
"2 novels and I never watch this. It has a reall... -1.417962 \n",
|
||||
"3 that I have enjoyed, achieving it in both the 0.835876 \n",
|
||||
"4 . It's a movie with a fine cast. As the beginn... 2.113075 \n",
|
||||
"5 best thriller films I've seen in recent -2.705339 \n",
|
||||
"6 of time, giving a -2.429504 \n",
|
||||
"7 feek it! This show -2.201666 \n",
|
||||
"8 It's cool, wonderful - it held me into a very ... -1.232380 \n",
|
||||
"9 is a great show. Beautiful 0.776819 \n",
|
||||
"10 great performance. It was truly a great movie... 0.276380 \n",
|
||||
"11 . This movie is incredibly good, with the -2.639503 \n",
|
||||
"12 suspense. Rogers and Rooney rate this as exce... -1.079707 \n",
|
||||
"13 of the best choice -1.445436 \n",
|
||||
"14 with spent, classic music and dance, and come... 0.258479 \n",
|
||||
"15 beautifully hands on in his own spin, and 0.796508 \n",
|
||||
"\n",
|
||||
" rewards (after) \n",
|
||||
"0 -1.007609 \n",
|
||||
"1 2.240883 \n",
|
||||
"2 2.415630 \n",
|
||||
"3 -0.724324 \n",
|
||||
"4 0.148751 \n",
|
||||
"5 2.590190 \n",
|
||||
"6 2.835831 \n",
|
||||
"7 2.282939 \n",
|
||||
"8 2.734139 \n",
|
||||
"9 2.631960 \n",
|
||||
"10 2.544106 \n",
|
||||
"11 2.601913 \n",
|
||||
"12 2.609762 \n",
|
||||
"13 2.558540 \n",
|
||||
"14 -1.803383 \n",
|
||||
"15 2.324461 "
|
||||
"0 2.423731 \n",
|
||||
"1 0.190159 \n",
|
||||
"2 2.831814 \n",
|
||||
"3 2.205628 \n",
|
||||
"4 2.739168 \n",
|
||||
"5 2.730615 \n",
|
||||
"6 -2.934672 \n",
|
||||
"7 -0.106085 \n",
|
||||
"8 2.707638 \n",
|
||||
"9 2.808996 \n",
|
||||
"10 2.743328 \n",
|
||||
"11 -1.568827 \n",
|
||||
"12 2.696888 \n",
|
||||
"13 2.662699 \n",
|
||||
"14 1.876662 \n",
|
||||
"15 2.820259 "
|
||||
]
|
||||
},
|
||||
"execution_count": null,
|
||||
@ -685,15 +670,16 @@
|
||||
"\n",
|
||||
"#### get response from gpt2 and gpt2_ref\n",
|
||||
"for i in range(bs):\n",
|
||||
" query = torch.tensor(query_tensors[i]).to(device)\n",
|
||||
"\n",
|
||||
" gen_len = output_length_sampler()\n",
|
||||
" output = ref_model.generate(\n",
|
||||
" torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n",
|
||||
" ).squeeze()[-gen_len:]\n",
|
||||
" response_tensors_ref.append(output)\n",
|
||||
" output = model.generate(\n",
|
||||
" torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n",
|
||||
" ).squeeze()[-gen_len:]\n",
|
||||
" response_tensors.append(output)\n",
|
||||
" query_response = ref_model.generate(query.unsqueeze(0), max_new_tokens=gen_len, **gen_kwargs).squeeze()\n",
|
||||
" response_len = len(query_response) - len(query)\n",
|
||||
" response_tensors_ref.append(query_response[-response_len:])\n",
|
||||
"\n",
|
||||
" query_response = model.generate(query.unsqueeze(0), max_new_tokens=gen_len, **gen_kwargs).squeeze()\n",
|
||||
" response_len = len(query_response) - len(query)\n",
|
||||
" response_tensors.append(query_response[-response_len:])\n",
|
||||
"\n",
|
||||
"#### decode responses\n",
|
||||
"game_data[\"response (before)\"] = [tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]\n",
|
||||
@ -701,10 +687,14 @@
|
||||
"\n",
|
||||
"#### sentiment analysis of query/response pairs before/after\n",
|
||||
"texts = [q + r for q, r in zip(game_data[\"query\"], game_data[\"response (before)\"])]\n",
|
||||
"game_data[\"rewards (before)\"] = [output[1][\"score\"] for output in sentiment_pipe(texts, **sent_kwargs)]\n",
|
||||
"pipe_outputs = sentiment_pipe(texts, **sent_kwargs)\n",
|
||||
"positive_scores = [item[\"score\"] for output in pipe_outputs for item in output if item[\"label\"] == \"POSITIVE\"]\n",
|
||||
"game_data[\"rewards (before)\"] = positive_scores\n",
|
||||
"\n",
|
||||
"texts = [q + r for q, r in zip(game_data[\"query\"], game_data[\"response (after)\"])]\n",
|
||||
"game_data[\"rewards (after)\"] = [output[1][\"score\"] for output in sentiment_pipe(texts, **sent_kwargs)]\n",
|
||||
"pipe_outputs = sentiment_pipe(texts, **sent_kwargs)\n",
|
||||
"positive_scores = [item[\"score\"] for output in pipe_outputs for item in output if item[\"label\"] == \"POSITIVE\"]\n",
|
||||
"game_data[\"rewards (after)\"] = positive_scores\n",
|
||||
"\n",
|
||||
"# store results in a dataframe\n",
|
||||
"df_results = pd.DataFrame(game_data)\n",
|
||||
@ -733,8 +723,8 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"rewards (before) 0.156629\n",
|
||||
"rewards (after) 1.686487\n",
|
||||
"rewards (before) -0.512965\n",
|
||||
"rewards (after) 1.676750\n",
|
||||
"dtype: float64"
|
||||
]
|
||||
},
|
||||
@ -752,8 +742,8 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"rewards (before) -0.547091\n",
|
||||
"rewards (after) 2.479868\n",
|
||||
"rewards (before) -0.464427\n",
|
||||
"rewards (after) 2.679794\n",
|
||||
"dtype: float64"
|
||||
]
|
||||
},
|
||||
@ -782,45 +772,6 @@
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/leandro/miniconda3/envs/trl/lib/python3.9/site-packages/huggingface_hub/hf_api.py:1001: FutureWarning: `create_repo` now takes `token` as an optional positional argument. Be sure to adapt your code!\n",
|
||||
" warnings.warn(\n",
|
||||
"Cloning https://huggingface.co/lvwerra/gpt2-imdb-pos-v2 into local empty directory.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "a953a6d0c465432bbc39aca826d37aaf",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Upload file pytorch_model.bin: 0%| | 32.0k/487M [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"remote: Enforcing permissions... \n",
|
||||
"remote: Allowed refs: all \n",
|
||||
"To https://huggingface.co/lvwerra/gpt2-imdb-pos-v2\n",
|
||||
" 369b075..28b9865 main -> main\n",
|
||||
"\n",
|
||||
"remote: Enforcing permissions... \n",
|
||||
"remote: Allowed refs: all \n",
|
||||
"To https://huggingface.co/lvwerra/gpt2-imdb-pos-v2\n",
|
||||
" 28b9865..42792ea main -> main\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
@ -841,13 +792,6 @@
|
||||
"model.save_pretrained(\"gpt2-imdb-pos-v2\", push_to_hub=True)\n",
|
||||
"tokenizer.save_pretrained(\"gpt2-imdb-pos-v2\", push_to_hub=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@ -866,7 +810,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.12 (main, Mar 26 2022, 15:51:15) \n[Clang 13.1.6 (clang-1316.0.21.2)]"
|
||||
"version": "3.11.9"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Research projects that uses TRL
|
||||
# Research projects that use TRL
|
||||
|
||||
Welcome to the research projects folder! Here you can find the scripts used for some research projects that used TRL and maintained by the developpers and the community (LM de-toxification, Stack-Llama, etc.). Check out the READMEs in the subfolders for more information!
|
||||
Welcome to the research projects folder! Here you can find the scripts used for some research projects that used TRL and maintained by the developers and the community (LM de-toxification, Stack-Llama, etc.). Check out the READMEs in the subfolders for more information!
|
||||
|
||||
- [De-detoxifying language models](https://github.com/lvwerra/trl/tree/main/examples/research_projects/toxicity)
|
||||
- [Stack-Llama](https://github.com/lvwerra/trl/tree/main/examples/research_projects/stack_llama)
|
||||
- [De-detoxifying language models](https://github.com/huggingface/trl/tree/main/examples/research_projects/toxicity)
|
||||
- [Stack-Llama](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama)
|
||||
- [Stack-Llama-2](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama_2)
|
@ -1,17 +1,17 @@
|
||||
# RLHF pipeline for the creation of StackLLaMa: a Stack exchange llama-7b model.
|
||||
There were three main steps to the training process:
|
||||
1. Supervised fine-tuning of the base llama-7b model to create llama-7b-se:
|
||||
- `torchrun --nnodes 1 --nproc_per_node 8 examples/stack_llama/scripts/supervised_finetuning.py --model_path=<LLAMA_MODEL_PATH> --streaming --no_gradient_checkpointing --learning_rate 1e-5 --max_steps 5000 --output_dir ./llama-se`
|
||||
- `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/supervised_finetuning.py --model_path=<LLAMA_MODEL_PATH> --streaming --learning_rate 1e-5 --max_steps 5000 --output_dir ./llama-se`
|
||||
2. Reward modeling using dialog pairs from the SE dataset using the llama-7b-se to create llama-7b-se-rm:
|
||||
- `torchrun --nnodes 1 --nproc_per_node 8 examples/stack_llama/scripts/reward_modeling.py --model_name=<LLAMA_SE_MODEL>`
|
||||
- `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/reward_modeling.py --model_name=<LLAMA_SE_MODEL>`
|
||||
3. RL fine-tuning of llama-7b-se with the llama-7b-se-rm reward model:
|
||||
- `accelerate launch --multi_gpu --num_machines 1 --num_processes 8 examples/stack_llama/scripts/rl_training.py --log_with=wandb --model_name=<LLAMA_SE_MODEL> --reward_model_name=<LLAMA_SE_RM_MODEL> --adafactor=False --tokenizer_name=<LLAMA_TOKENIZER> --save_freq=100 --output_max_length=128 --batch_size=8 --gradient_accumulation_steps=8 --batched_gen=True --ppo_epochs=4 --seed=0 --learning_rate=1.4e-5 --early_stopping=True --output_dir=llama-se-rl-finetune-128-8-8-1.4e-5_adam`
|
||||
- `accelerate launch --multi_gpu --num_machines 1 --num_processes 8 examples/research_projects/stack_llama/scripts/rl_training.py --log_with=wandb --model_name=<LLAMA_SE_MODEL> --reward_model_name=<LLAMA_SE_RM_MODEL> --adafactor=False --tokenizer_name=<LLAMA_TOKENIZER> --save_freq=100 --output_max_length=128 --batch_size=8 --gradient_accumulation_steps=8 --batched_gen=True --ppo_epochs=4 --seed=0 --learning_rate=1.4e-5 --early_stopping=True --output_dir=llama-se-rl-finetune-128-8-8-1.4e-5_adam`
|
||||
|
||||
|
||||
LoRA layers were using at all stages to reduce memory requirements.
|
||||
At each stage the peft adapter layers were merged with the base model, using:
|
||||
```shell
|
||||
python examples/stack_llama/scripts/merge_peft_adapter.py --adapter_model_name=XXX --base_model_name=YYY --output_name=ZZZ
|
||||
python examples/research_projects/stack_llama/scripts/merge_peft_adapter.py --adapter_model_name=XXX --base_model_name=YYY --output_name=ZZZ
|
||||
```
|
||||
Note that this script requires `peft>=0.3.0`.
|
||||
|
||||
|
@ -9,23 +9,24 @@ from transformers import AutoModelForCausalLM, AutoModelForSequenceClassificatio
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The name of the Casual LM model we wish to fine with PPO
|
||||
The input names representing the Adapter and Base model fine-tuned with PEFT, and the output name representing the
|
||||
merged model.
|
||||
"""
|
||||
|
||||
adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the model name"})
|
||||
base_model_name: Optional[str] = field(default=None, metadata={"help": "the model name"})
|
||||
output_name: Optional[str] = field(default=None, metadata={"help": "the model name"})
|
||||
adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the adapter name"})
|
||||
base_model_name: Optional[str] = field(default=None, metadata={"help": "the base model name"})
|
||||
output_name: Optional[str] = field(default=None, metadata={"help": "the merged model name"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
assert script_args.adapter_model_name is not None, "please provide the name of the Adapter you would like to merge"
|
||||
assert script_args.base_model_name is not None, "please provide the name of the Base model"
|
||||
assert script_args.base_model_name is not None, "please provide the output name of the merged model"
|
||||
assert script_args.output_name is not None, "please provide the output name of the merged model"
|
||||
|
||||
peft_config = PeftConfig.from_pretrained(script_args.adapter_model_name)
|
||||
if peft_config.task_type == "SEQ_CLS":
|
||||
# peft is for reward model so load sequence classification
|
||||
# The sequence classification task is used for the reward model in PPO
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
script_args.base_model_name, num_labels=1, torch_dtype=torch.bfloat16
|
||||
)
|
||||
@ -36,7 +37,7 @@ else:
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name)
|
||||
|
||||
# Load the Lora model
|
||||
# Load the PEFT model
|
||||
model = PeftModel.from_pretrained(model, script_args.adapter_model_name)
|
||||
model.eval()
|
||||
|
||||
|
@ -15,6 +15,7 @@ from transformers import (
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
@ -89,16 +90,23 @@ class ScriptArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether to run eval after the first step"},
|
||||
)
|
||||
seed: Optional[int] = field(
|
||||
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
|
||||
)
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
set_seed(script_args.seed)
|
||||
# Load the human stack-exchange-paired dataset for tuning the reward model.
|
||||
train_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/reward", split="train")
|
||||
train_dataset = load_dataset(
|
||||
"lvwerra/stack-exchange-paired", data_dir="data/reward", split="train", verification_mode="no_checks"
|
||||
)
|
||||
if script_args.train_subset > 0:
|
||||
train_dataset = train_dataset.select(range(script_args.train_subset))
|
||||
eval_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/evaluation", split="train")
|
||||
eval_dataset = load_dataset(
|
||||
"lvwerra/stack-exchange-paired", data_dir="data/evaluation", split="train", verification_mode="no_checks"
|
||||
)
|
||||
if script_args.eval_subset > 0:
|
||||
eval_dataset = eval_dataset.select(range(script_args.eval_subset))
|
||||
# Define the training args. Needs to be done before the model is loaded if you are using deepspeed.
|
||||
@ -114,7 +122,7 @@ training_args = TrainingArguments(
|
||||
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
|
||||
num_train_epochs=script_args.num_train_epochs,
|
||||
weight_decay=script_args.weight_decay,
|
||||
evaluation_strategy="steps",
|
||||
eval_strategy="steps",
|
||||
eval_steps=500,
|
||||
save_strategy="steps",
|
||||
save_steps=500,
|
||||
@ -129,7 +137,10 @@ training_args = TrainingArguments(
|
||||
logging_steps=10,
|
||||
optim=script_args.optim,
|
||||
lr_scheduler_type=script_args.lr_scheduler_type,
|
||||
seed=script_args.seed,
|
||||
)
|
||||
|
||||
|
||||
# Load the value-head model and tokenizer.
|
||||
tokenizer_name = script_args.tokenizer_name if script_args.tokenizer_name is not None else script_args.model_name
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True)
|
||||
@ -187,7 +198,8 @@ train_dataset = train_dataset.map(
|
||||
remove_columns=original_columns,
|
||||
)
|
||||
train_dataset = train_dataset.filter(
|
||||
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length
|
||||
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length,
|
||||
num_proc=num_proc,
|
||||
)
|
||||
|
||||
eval_dataset = eval_dataset.map(
|
||||
@ -197,7 +209,8 @@ eval_dataset = eval_dataset.map(
|
||||
remove_columns=original_columns,
|
||||
)
|
||||
eval_dataset = eval_dataset.filter(
|
||||
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length
|
||||
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length,
|
||||
num_proc=num_proc,
|
||||
)
|
||||
|
||||
|
||||
@ -264,7 +277,7 @@ def compute_metrics(eval_pred):
|
||||
|
||||
|
||||
class RewardTrainer(Trainer):
|
||||
# Define how to compute the reward loss. We use the InstructGPT pairwise logloss: https://arxiv.org/abs/2203.02155
|
||||
# Define how to compute the reward loss. We use the InstructGPT pairwise logloss: https://huggingface.co/papers/2203.02155
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
rewards_j = model(input_ids=inputs["input_ids_j"], attention_mask=inputs["attention_mask_j"])[0]
|
||||
rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0]
|
||||
|
@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -32,7 +31,7 @@ tqdm.pandas()
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The name of the Casual LM model we wish to fine with PPO
|
||||
The name of the Casual LM model we wish to fine-tune with PPO
|
||||
"""
|
||||
|
||||
# NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
|
||||
@ -67,6 +66,7 @@ class ScriptArguments:
|
||||
)
|
||||
|
||||
adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"})
|
||||
load_in_8bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 8bit"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
@ -90,8 +90,11 @@ config = PPOConfig(
|
||||
adap_kl_ctrl=script_args.adap_kl_ctrl,
|
||||
)
|
||||
|
||||
train_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/rl", split="train")
|
||||
train_dataset = load_dataset(
|
||||
"lvwerra/stack-exchange-paired", data_dir="data/rl", split="train", verification_mode="no_checks"
|
||||
)
|
||||
train_dataset = train_dataset.select(range(100000))
|
||||
original_columns = train_dataset.column_names
|
||||
|
||||
# We then define the arguments to pass to the sentiment analysis pipeline.
|
||||
# We set `return_all_scores` to True to get the sentiment score for each token.
|
||||
@ -130,9 +133,6 @@ def build_dataset(
|
||||
The dataloader for the dataset.
|
||||
"""
|
||||
|
||||
# load imdb with datasets
|
||||
ds = load_dataset(dataset_name, data_dir="data/rl", split="train")
|
||||
original_columns = ds.column_names
|
||||
num_proc = 24
|
||||
|
||||
def preprocess_function(examples):
|
||||
@ -154,7 +154,7 @@ def build_dataset(
|
||||
num_proc=num_proc,
|
||||
remove_columns=original_columns,
|
||||
)
|
||||
ds = ds.filter(lambda x: len(x["input_ids"]) < 512, batched=False)
|
||||
ds = ds.filter(lambda x: len(x["input_ids"]) < 512, batched=False, num_proc=num_proc)
|
||||
|
||||
ds.set_format(type="torch")
|
||||
return ds
|
||||
@ -165,7 +165,7 @@ dataset = build_dataset(tokenizer)
|
||||
|
||||
|
||||
def collator(data):
|
||||
return dict((key, [d[key] for d in data]) for key in data[0])
|
||||
return {key: [d[key] for d in data] for key in data[0]}
|
||||
|
||||
|
||||
# set seed before initializing value head for deterministic eval
|
||||
@ -183,7 +183,7 @@ lora_config = LoraConfig(
|
||||
)
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
config.model_name,
|
||||
load_in_8bit=True,
|
||||
load_in_8bit=script_args.load_in_8bit,
|
||||
device_map={"": current_device},
|
||||
peft_config=lora_config,
|
||||
)
|
||||
@ -218,11 +218,13 @@ sentiment_pipe = pipeline(
|
||||
"sentiment-analysis",
|
||||
model=reward_model_name,
|
||||
device_map={"": current_device},
|
||||
model_kwargs={"load_in_8bit": True},
|
||||
model_kwargs={"load_in_8bit": script_args.load_in_8bit},
|
||||
tokenizer=tokenizer,
|
||||
return_token_type_ids=False,
|
||||
)
|
||||
|
||||
if sentiment_pipe.model.config.pad_token_id is None:
|
||||
sentiment_pipe.model.config.pad_token_id = sentiment_pipe.model.config.eos_token_id
|
||||
# We then define the arguments to pass to the `generate` function. These arguments
|
||||
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
|
||||
# the `generate` function of the trained model.
|
||||
|
@ -38,9 +38,9 @@ def get_args():
|
||||
parser.add_argument("--weight_decay", type=float, default=0.05)
|
||||
|
||||
parser.add_argument("--local_rank", type=int, default=0)
|
||||
parser.add_argument("--no_fp16", action="store_false")
|
||||
parser.add_argument("--fp16", action="store_true", default=False)
|
||||
parser.add_argument("--bf16", action="store_true", default=False)
|
||||
parser.add_argument("--no_gradient_checkpointing", action="store_false", default=False)
|
||||
parser.add_argument("--gradient_checkpointing", action="store_true", default=False)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--num_workers", type=int, default=None)
|
||||
parser.add_argument("--output_dir", type=str, default="./checkpoints")
|
||||
@ -148,7 +148,7 @@ def run_training(args, train_data, val_data):
|
||||
training_args = TrainingArguments(
|
||||
output_dir=args.output_dir,
|
||||
dataloader_drop_last=True,
|
||||
evaluation_strategy="steps",
|
||||
eval_strategy="steps",
|
||||
max_steps=args.max_steps,
|
||||
eval_steps=args.eval_freq,
|
||||
save_steps=args.save_freq,
|
||||
@ -159,8 +159,8 @@ def run_training(args, train_data, val_data):
|
||||
lr_scheduler_type=args.lr_scheduler_type,
|
||||
warmup_steps=args.num_warmup_steps,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
gradient_checkpointing=not args.no_gradient_checkpointing,
|
||||
fp16=not args.no_fp16,
|
||||
gradient_checkpointing=args.gradient_checkpointing,
|
||||
fp16=args.fp16,
|
||||
bf16=args.bf16,
|
||||
weight_decay=args.weight_decay,
|
||||
run_name="llama-7b-finetuned",
|
||||
|
76
examples/research_projects/stack_llama_2/scripts/README.md
Normal file
76
examples/research_projects/stack_llama_2/scripts/README.md
Normal file
@ -0,0 +1,76 @@
|
||||
# DPO pipeline for the creation of StackLlaMa 2: a Stack exchange llama-v2-7b model
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Install all the dependencies in the `requirements.txt`:
|
||||
|
||||
```
|
||||
$ pip install -U -r requirements.txt
|
||||
```
|
||||
|
||||
Since we will use `accelerate` for training, make sure to run:
|
||||
```
|
||||
$ accelerate config
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
There were two main steps to the DPO training process:
|
||||
1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se:
|
||||
|
||||
```
|
||||
accelerate launch examples/research_projects/stack_llama_2/scripts/sft_llama2.py \
|
||||
--output_dir="./sft" \
|
||||
--max_steps=500 \
|
||||
--logging_steps=10 \
|
||||
--save_steps=10 \
|
||||
--per_device_train_batch_size=4 \
|
||||
--per_device_eval_batch_size=1 \
|
||||
--gradient_accumulation_steps=2 \
|
||||
--gradient_checkpointing=False \
|
||||
--group_by_length=False \
|
||||
--learning_rate=1e-4 \
|
||||
--lr_scheduler_type="cosine" \
|
||||
--warmup_steps=100 \
|
||||
--weight_decay=0.05 \
|
||||
--optim="paged_adamw_32bit" \
|
||||
--bf16=True \
|
||||
--remove_unused_columns=False \
|
||||
--run_name="sft_llama2" \
|
||||
--report_to="wandb"
|
||||
```
|
||||
1. Run the DPO trainer using the model saved by the previous step:
|
||||
```
|
||||
accelerate launch examples/research_projects/stack_llama_2/scripts/dpo_llama2.py \
|
||||
--model_name_or_path="sft/final_checkpoint" \
|
||||
--output_dir="dpo"
|
||||
```
|
||||
|
||||
|
||||
## Merging the adaptors
|
||||
|
||||
To merge the adaptors into the base model we can use the `merge_peft_adapter.py` helper script that comes with TRL:
|
||||
|
||||
```
|
||||
python examples/research_projects/stack_llama/scripts/merge_peft_adapter.py --base_model_name="meta-llama/Llama-2-7b-hf" --adapter_model_name="dpo/final_checkpoint/" --output_name="stack-llama-2"
|
||||
```
|
||||
|
||||
which will also push the model to your HuggingFace hub account.
|
||||
|
||||
## Running the model
|
||||
|
||||
We can load the DPO-trained LoRA adaptors which were saved by the DPO training step and load them via:
|
||||
|
||||
```py
|
||||
from peft import AutoPeftModelForCausalLM
|
||||
|
||||
|
||||
model = AutoPeftModelForCausalLM.from_pretrained(
|
||||
"dpo/final_checkpoint",
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch.float16,
|
||||
load_in_4bit=True,
|
||||
)
|
||||
|
||||
model.generate(...)
|
||||
```
|
243
examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
Normal file
243
examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
Normal file
@ -0,0 +1,243 @@
|
||||
# 0. imports
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from datasets import Dataset, load_dataset
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed
|
||||
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
|
||||
# Define and parse arguments.
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The arguments for the DPO training script.
|
||||
"""
|
||||
|
||||
# data parameters
|
||||
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
|
||||
|
||||
# training parameters
|
||||
model_name_or_path: Optional[str] = field(
|
||||
default="../sft/results/final_checkpoint",
|
||||
metadata={"help": "the location of the SFT model name or path"},
|
||||
)
|
||||
learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"})
|
||||
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
|
||||
warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"})
|
||||
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
|
||||
optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})
|
||||
|
||||
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "train batch size per device"})
|
||||
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=4, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
gradient_checkpointing: Optional[bool] = field(
|
||||
default=True, metadata={"help": "whether to use gradient checkpointing"}
|
||||
)
|
||||
|
||||
gradient_checkpointing_use_reentrant: Optional[bool] = field(
|
||||
default=False, metadata={"help": "whether to use reentrant for gradient checkpointing"}
|
||||
)
|
||||
|
||||
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
|
||||
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
|
||||
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
|
||||
|
||||
max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"})
|
||||
max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"})
|
||||
max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
|
||||
logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
|
||||
save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"})
|
||||
eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"})
|
||||
|
||||
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
|
||||
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
|
||||
load_in_4bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 4bit"})
|
||||
model_dtype: Optional[str] = field(
|
||||
default="float16", metadata={"help": "model_dtype[float16, bfloat16, float] for loading."}
|
||||
)
|
||||
|
||||
# instrumentation
|
||||
sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
|
||||
report_to: Optional[str] = field(
|
||||
default="wandb",
|
||||
metadata={
|
||||
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
|
||||
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
|
||||
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
|
||||
},
|
||||
)
|
||||
# debug argument for distributed training
|
||||
ignore_bias_buffers: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
|
||||
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
|
||||
},
|
||||
)
|
||||
seed: Optional[int] = field(
|
||||
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
|
||||
)
|
||||
|
||||
|
||||
def get_stack_exchange_paired(
|
||||
data_dir: str = "data/rl",
|
||||
sanity_check: bool = False,
|
||||
cache_dir: Optional[str] = None,
|
||||
num_proc=24,
|
||||
) -> Dataset:
|
||||
"""Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format.
|
||||
|
||||
The dataset is converted to a dictionary with the following structure:
|
||||
{
|
||||
'prompt': List[str],
|
||||
'chosen': List[str],
|
||||
'rejected': List[str],
|
||||
}
|
||||
|
||||
Prompts are structured as follows:
|
||||
"Question: " + <prompt> + "\n\nAnswer: "
|
||||
"""
|
||||
dataset = load_dataset(
|
||||
"lvwerra/stack-exchange-paired",
|
||||
split="train",
|
||||
cache_dir=cache_dir,
|
||||
data_dir=data_dir,
|
||||
verification_mode="no_checks",
|
||||
)
|
||||
original_columns = dataset.column_names
|
||||
|
||||
if sanity_check:
|
||||
dataset = dataset.select(range(min(len(dataset), 1000)))
|
||||
|
||||
def return_prompt_and_responses(samples) -> Dict[str, str]:
|
||||
return {
|
||||
"prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]],
|
||||
"chosen": samples["response_j"],
|
||||
"rejected": samples["response_k"],
|
||||
}
|
||||
|
||||
return dataset.map(
|
||||
return_prompt_and_responses,
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
remove_columns=original_columns,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
set_seed(script_args.seed)
|
||||
|
||||
# 1. load a pretrained model
|
||||
torch_dtype = torch.float
|
||||
if script_args.model_dtype == "float16":
|
||||
torch_dtype = torch.float16
|
||||
elif script_args.model_dtype == "bfloat16":
|
||||
torch_dtype = torch.bfloat16
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
script_args.model_name_or_path,
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch_dtype,
|
||||
load_in_4bit=script_args.load_in_4bit,
|
||||
device_map={"": Accelerator().local_process_index},
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
if script_args.ignore_bias_buffers:
|
||||
# torch distributed hack
|
||||
model._ddp_params_and_buffers_to_ignore = [
|
||||
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# 2. Load the Stack-exchange paired dataset
|
||||
train_dataset = get_stack_exchange_paired(data_dir="data/rl", sanity_check=script_args.sanity_check)
|
||||
train_dataset = train_dataset.filter(
|
||||
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
|
||||
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length,
|
||||
num_proc=script_args.num_proc,
|
||||
)
|
||||
|
||||
# 3. Load evaluation dataset
|
||||
eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation", sanity_check=True)
|
||||
eval_dataset = eval_dataset.filter(
|
||||
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
|
||||
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length,
|
||||
num_proc=script_args.num_proc,
|
||||
)
|
||||
|
||||
# 4. initialize training arguments:
|
||||
training_args = DPOConfig(
|
||||
per_device_train_batch_size=script_args.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
|
||||
max_steps=script_args.max_steps,
|
||||
logging_steps=script_args.logging_steps,
|
||||
save_steps=script_args.save_steps,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
gradient_checkpointing=script_args.gradient_checkpointing,
|
||||
learning_rate=script_args.learning_rate,
|
||||
eval_strategy="steps",
|
||||
eval_steps=script_args.eval_steps,
|
||||
output_dir=script_args.output_dir,
|
||||
report_to=script_args.report_to,
|
||||
lr_scheduler_type=script_args.lr_scheduler_type,
|
||||
warmup_steps=script_args.warmup_steps,
|
||||
optim=script_args.optimizer_type,
|
||||
bf16=True,
|
||||
remove_unused_columns=False,
|
||||
run_name="dpo_llama2",
|
||||
gradient_checkpointing_kwargs=dict(use_reentrant=script_args.gradient_checkpointing_use_reentrant),
|
||||
seed=script_args.seed,
|
||||
)
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=script_args.lora_r,
|
||||
lora_alpha=script_args.lora_alpha,
|
||||
lora_dropout=script_args.lora_dropout,
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
"k_proj",
|
||||
"out_proj",
|
||||
"fc_in",
|
||||
"fc_out",
|
||||
"wte",
|
||||
],
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
# 5. initialize the DPO trainer
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
ref_model=None,
|
||||
args=training_args,
|
||||
beta=script_args.beta,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
peft_config=peft_config,
|
||||
max_prompt_length=script_args.max_prompt_length,
|
||||
max_length=script_args.max_length,
|
||||
)
|
||||
|
||||
# 6. train
|
||||
dpo_trainer.train()
|
||||
dpo_trainer.save_model(script_args.output_dir)
|
||||
|
||||
# 7. save
|
||||
output_dir = os.path.join(script_args.output_dir, "final_checkpoint")
|
||||
dpo_trainer.model.save_pretrained(output_dir)
|
@ -0,0 +1,7 @@
|
||||
transformers
|
||||
trl
|
||||
peft
|
||||
accelerate
|
||||
datasets
|
||||
bitsandbytes
|
||||
wandb
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user